164 lines
5.2 KiB
Python
164 lines
5.2 KiB
Python
|
import argparse, json
|
|||
|
import datetime
|
|||
|
import os
|
|||
|
import logging
|
|||
|
import os
|
|||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
|||
|
import torch, random
|
|||
|
|
|||
|
|
|||
|
from server import *
|
|||
|
from client import *
|
|||
|
import models, datasets
|
|||
|
from torchvision.datasets import ImageFolder
|
|||
|
|
|||
|
import torch
|
|||
|
from torchvision import transforms, datasets
|
|||
|
import torch.nn as nn
|
|||
|
from torch.utils.data import DataLoader
|
|||
|
import torchvision
|
|||
|
import matplotlib.pyplot as plt
|
|||
|
import numpy as np
|
|||
|
from log import get_log
|
|||
|
|
|||
|
from torch import randperm
|
|||
|
|
|||
|
|
|||
|
import os
|
|||
|
|
|||
|
|
|||
|
|
|||
|
logger = get_log('/home/cds/brain-tumor_image_classification/log/logkashi.txt')
|
|||
|
#logger.info("MSE: %.6f" % (mse))
|
|||
|
#logger.info("RMSE: %.6f" % (rmse))
|
|||
|
#logger.info("MAE: %.6f" % (mae))
|
|||
|
#logger.info("MAPE: %.6f" % (mape))
|
|||
|
|
|||
|
|
|||
|
transforms = transforms.Compose([
|
|||
|
transforms.Resize(256), # 将图片短边缩放至256,长宽比保持不变:
|
|||
|
transforms.CenterCrop(224), #将图片从中心切剪成3*224*224大小的图片
|
|||
|
transforms.ToTensor() #把图片进行归一化,并把数据转换成Tensor类型
|
|||
|
])
|
|||
|
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
|
|||
|
parser = argparse.ArgumentParser(description='Federated Learning')
|
|||
|
parser.add_argument('--conf', default = '/home/cds/brain-tumor_image_classification/utils/conf.json', dest='conf')
|
|||
|
args = parser.parse_args()
|
|||
|
|
|||
|
|
|||
|
with open(args.conf, 'r') as f:
|
|||
|
conf = json.load(f)
|
|||
|
|
|||
|
path1 = '/home/cds/brain-tumor_image_classification/data/Brain Tumor MRI Dataset/archive/Training'
|
|||
|
path2 = '/home/cds/brain-tumor_image_classification/data/Brain Tumor MRI Dataset/archive/Testing'
|
|||
|
data_train = datasets.ImageFolder(path1, transform=transforms)
|
|||
|
data_test = datasets.ImageFolder(path2, transform=transforms)
|
|||
|
print(data_train)
|
|||
|
# data_loader = DataLoader(data_train, batch_size=64, shuffle=True)
|
|||
|
|
|||
|
# for i, data in enumerate(data_loader):
|
|||
|
# images, labels = data
|
|||
|
# img = torchvision.utils.make_grid(images).numpy()
|
|||
|
# plt.imshow(np.transpose(img, (1, 2, 0)))
|
|||
|
# #plt.show()
|
|||
|
|
|||
|
# train_datasets, eval_datasets = datasets.get_dataset("./data/", conf["type"])
|
|||
|
# data_train_shuffle = DataLoader(data_train, batch_size=64, shuffle=True)
|
|||
|
# data_test_shuffle = DataLoader(data_test, batch_size=64, shuffle=True)
|
|||
|
# print(data_train_shuffle)
|
|||
|
|
|||
|
lenth_train = randperm(len(data_train)).tolist() # 生成乱序的索引
|
|||
|
data_train_shuffle = torch.utils.data.Subset(data_train, lenth_train)
|
|||
|
lenth_test = randperm(len(data_test)).tolist() # 生成乱序的索引
|
|||
|
data_test_shuffle = torch.utils.data.Subset(data_test, lenth_test)
|
|||
|
|
|||
|
|
|||
|
train_datasets, eval_datasets = data_train_shuffle, data_test_shuffle
|
|||
|
server = Server(conf, eval_datasets)
|
|||
|
clients = []
|
|||
|
#
|
|||
|
# class Client:
|
|||
|
# def __init__(self, conf, global_model, train_datasets, client_id):
|
|||
|
# self.conf = conf
|
|||
|
# self.global_model = global_model
|
|||
|
# self.train_datasets = train_datasets
|
|||
|
# self.client_id = client_id
|
|||
|
|
|||
|
# # 计算并打印数据分布
|
|||
|
# self.print_data_distribution()
|
|||
|
|
|||
|
# def print_data_distribution(self):
|
|||
|
# # 创建一个字典来存储每个类别的数量
|
|||
|
# class_counts = {}
|
|||
|
|
|||
|
# # 遍历数据集
|
|||
|
# for _, label in self.train_datasets:
|
|||
|
# # 获取类别的名称
|
|||
|
# class_name = self.train_datasets.dataset.classes[label]
|
|||
|
|
|||
|
# # 如果这个类别还没有被计数过,就添加到字典中
|
|||
|
# if class_name not in class_counts:
|
|||
|
# class_counts[class_name] = 0
|
|||
|
|
|||
|
# # 增加这个类别的计数
|
|||
|
# class_counts[class_name] += 1
|
|||
|
|
|||
|
# # 打印结果
|
|||
|
# print(f"Client {self.client_id} data distribution:")
|
|||
|
# for class_name, count in class_counts.items():
|
|||
|
# print(f"Class {class_name}: {count} samples")
|
|||
|
|
|||
|
|
|||
|
|
|||
|
for c in range(conf["no_models"]):
|
|||
|
clients.append(Client(conf, server.global_model, train_datasets, c))
|
|||
|
test_client = Client(conf, server.global_model, train_datasets, c)
|
|||
|
class_num_total = np.array([0,0,0,0])
|
|||
|
for batch_id, batch in enumerate(test_client.train_loader):
|
|||
|
data, target = batch
|
|||
|
unique_elements, counts = np.unique(target, return_counts=True)
|
|||
|
# 打印每个元素及其计数
|
|||
|
for element, count in zip(unique_elements, counts):
|
|||
|
class_num_total[element] += count
|
|||
|
print(class_num_total)
|
|||
|
|
|||
|
# for element, count in zip(unique_elements, counts):
|
|||
|
# print("Client:", c, "Class:", test_client.train_loader.dataset.classes[element], "Count:", count)
|
|||
|
|
|||
|
#client = Client(conf, server.global_model, train_datasets, c)
|
|||
|
|
|||
|
#clients.append(client)
|
|||
|
print("\n\n")
|
|||
|
|
|||
|
|
|||
|
for e in range(conf["global_epochs"]):
|
|||
|
|
|||
|
candidates = random.sample(clients, conf["k"])
|
|||
|
|
|||
|
weight_accumulator = {}
|
|||
|
|
|||
|
for name, params in server.global_model.state_dict().items():
|
|||
|
weight_accumulator[name] = torch.zeros_like(params)
|
|||
|
|
|||
|
for c in candidates:
|
|||
|
diff = c.local_train(server.global_model)
|
|||
|
|
|||
|
for name, params in server.global_model.state_dict().items():
|
|||
|
weight_accumulator[name].add_(diff[name])
|
|||
|
|
|||
|
|
|||
|
server.model_aggregate(weight_accumulator)
|
|||
|
|
|||
|
# acc, loss = server.model_eval()
|
|||
|
|
|||
|
# print("Epoch %d, acc: %f, loss: %f\n" % (e, acc, loss))
|
|||
|
acc, loss, average_precision, average_recall, average_f1_score = server.model_eval()
|
|||
|
print("Epoch {}, Accuracy: {:.2f}%, Loss: {:.2f}, Average Precision: {:.2f}%, Average Recall: {:.2f}%, Average F1-Score: {:.2f}%".format(
|
|||
|
e, acc, loss, average_precision * 100, average_recall * 100, average_f1_score * 100))
|
|||
|
|
|||
|
|