brain-tumor_image_classific.../main.py

164 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import argparse, json
import datetime
import os
import logging
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch, random
from server import *
from client import *
import models, datasets
from torchvision.datasets import ImageFolder
import torch
from torchvision import transforms, datasets
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
import matplotlib.pyplot as plt
import numpy as np
from log import get_log
from torch import randperm
import os
logger = get_log('/home/cds/brain-tumor_image_classification/log/logkashi.txt')
#logger.info("MSE: %.6f" % (mse))
#logger.info("RMSE: %.6f" % (rmse))
#logger.info("MAE: %.6f" % (mae))
#logger.info("MAPE: %.6f" % (mape))
transforms = transforms.Compose([
transforms.Resize(256), # 将图片短边缩放至256长宽比保持不变
transforms.CenterCrop(224), #将图片从中心切剪成3*224*224大小的图片
transforms.ToTensor() #把图片进行归一化并把数据转换成Tensor类型
])
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Federated Learning')
parser.add_argument('--conf', default = '/home/cds/brain-tumor_image_classification/utils/conf.json', dest='conf')
args = parser.parse_args()
with open(args.conf, 'r') as f:
conf = json.load(f)
path1 = '/home/cds/brain-tumor_image_classification/data/Brain Tumor MRI Dataset/archive/Training'
path2 = '/home/cds/brain-tumor_image_classification/data/Brain Tumor MRI Dataset/archive/Testing'
data_train = datasets.ImageFolder(path1, transform=transforms)
data_test = datasets.ImageFolder(path2, transform=transforms)
print(data_train)
# data_loader = DataLoader(data_train, batch_size=64, shuffle=True)
# for i, data in enumerate(data_loader):
# images, labels = data
# img = torchvision.utils.make_grid(images).numpy()
# plt.imshow(np.transpose(img, (1, 2, 0)))
# #plt.show()
# train_datasets, eval_datasets = datasets.get_dataset("./data/", conf["type"])
# data_train_shuffle = DataLoader(data_train, batch_size=64, shuffle=True)
# data_test_shuffle = DataLoader(data_test, batch_size=64, shuffle=True)
# print(data_train_shuffle)
lenth_train = randperm(len(data_train)).tolist() # 生成乱序的索引
data_train_shuffle = torch.utils.data.Subset(data_train, lenth_train)
lenth_test = randperm(len(data_test)).tolist() # 生成乱序的索引
data_test_shuffle = torch.utils.data.Subset(data_test, lenth_test)
train_datasets, eval_datasets = data_train_shuffle, data_test_shuffle
server = Server(conf, eval_datasets)
clients = []
#
# class Client:
# def __init__(self, conf, global_model, train_datasets, client_id):
# self.conf = conf
# self.global_model = global_model
# self.train_datasets = train_datasets
# self.client_id = client_id
# # 计算并打印数据分布
# self.print_data_distribution()
# def print_data_distribution(self):
# # 创建一个字典来存储每个类别的数量
# class_counts = {}
# # 遍历数据集
# for _, label in self.train_datasets:
# # 获取类别的名称
# class_name = self.train_datasets.dataset.classes[label]
# # 如果这个类别还没有被计数过,就添加到字典中
# if class_name not in class_counts:
# class_counts[class_name] = 0
# # 增加这个类别的计数
# class_counts[class_name] += 1
# # 打印结果
# print(f"Client {self.client_id} data distribution:")
# for class_name, count in class_counts.items():
# print(f"Class {class_name}: {count} samples")
for c in range(conf["no_models"]):
clients.append(Client(conf, server.global_model, train_datasets, c))
test_client = Client(conf, server.global_model, train_datasets, c)
class_num_total = np.array([0,0,0,0])
for batch_id, batch in enumerate(test_client.train_loader):
data, target = batch
unique_elements, counts = np.unique(target, return_counts=True)
# 打印每个元素及其计数
for element, count in zip(unique_elements, counts):
class_num_total[element] += count
print(class_num_total)
# for element, count in zip(unique_elements, counts):
# print("Client:", c, "Class:", test_client.train_loader.dataset.classes[element], "Count:", count)
#client = Client(conf, server.global_model, train_datasets, c)
#clients.append(client)
print("\n\n")
for e in range(conf["global_epochs"]):
candidates = random.sample(clients, conf["k"])
weight_accumulator = {}
for name, params in server.global_model.state_dict().items():
weight_accumulator[name] = torch.zeros_like(params)
for c in candidates:
diff = c.local_train(server.global_model)
for name, params in server.global_model.state_dict().items():
weight_accumulator[name].add_(diff[name])
server.model_aggregate(weight_accumulator)
# acc, loss = server.model_eval()
# print("Epoch %d, acc: %f, loss: %f\n" % (e, acc, loss))
acc, loss, average_precision, average_recall, average_f1_score = server.model_eval()
print("Epoch {}, Accuracy: {:.2f}%, Loss: {:.2f}, Average Precision: {:.2f}%, Average Recall: {:.2f}%, Average F1-Score: {:.2f}%".format(
e, acc, loss, average_precision * 100, average_recall * 100, average_f1_score * 100))