118 lines
3.3 KiB
Python
118 lines
3.3 KiB
Python
|
import argparse, json
|
|||
|
import datetime
|
|||
|
import os
|
|||
|
import logging
|
|||
|
import os
|
|||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
|||
|
import torch, random
|
|||
|
|
|||
|
|
|||
|
from server import *
|
|||
|
from client import *
|
|||
|
import models, datasets
|
|||
|
from torchvision.datasets import ImageFolder
|
|||
|
|
|||
|
import torch
|
|||
|
from torchvision import transforms, datasets
|
|||
|
import torch.nn as nn
|
|||
|
from torch.utils.data import DataLoader
|
|||
|
import torchvision
|
|||
|
import matplotlib.pyplot as plt
|
|||
|
import numpy as np
|
|||
|
from log import get_log
|
|||
|
|
|||
|
from torch import randperm
|
|||
|
|
|||
|
|
|||
|
import os
|
|||
|
|
|||
|
|
|||
|
|
|||
|
logger = get_log('/home/ykn/cds/chapter03_Python_image_classification/log/log.txt')
|
|||
|
#logger.info("MSE: %.6f" % (mse))
|
|||
|
#logger.info("RMSE: %.6f" % (rmse))
|
|||
|
#logger.info("MAE: %.6f" % (mae))
|
|||
|
#logger.info("MAPE: %.6f" % (mape))
|
|||
|
|
|||
|
|
|||
|
transforms = transforms.Compose([
|
|||
|
transforms.Resize(256), # 将图片短边缩放至256,长宽比保持不变:
|
|||
|
transforms.CenterCrop(224), #将图片从中心切剪成3*224*224大小的图片
|
|||
|
transforms.ToTensor() #把图片进行归一化,并把数据转换成Tensor类型
|
|||
|
])
|
|||
|
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
|
|||
|
parser = argparse.ArgumentParser(description='Federated Learning')
|
|||
|
parser.add_argument('--conf', default = '/home/ykn/cds/chapter03_Python_image_classification/utils/conf.json', dest='conf')
|
|||
|
args = parser.parse_args()
|
|||
|
|
|||
|
|
|||
|
with open(args.conf, 'r') as f:
|
|||
|
conf = json.load(f)
|
|||
|
|
|||
|
path1 = '/home/ykn/cds/chapter03_Python_image_classification/data/Brain Tumor MRI Dataset/archive/Training'
|
|||
|
path2 = '/home/ykn/cds/chapter03_Python_image_classification/data/Brain Tumor MRI Dataset/archive/Testing'
|
|||
|
data_train = datasets.ImageFolder(path1, transform=transforms)
|
|||
|
data_test = datasets.ImageFolder(path2, transform=transforms)
|
|||
|
print(data_train)
|
|||
|
# data_loader = DataLoader(data_train, batch_size=64, shuffle=True)
|
|||
|
|
|||
|
# for i, data in enumerate(data_loader):
|
|||
|
# images, labels = data
|
|||
|
# img = torchvision.utils.make_grid(images).numpy()
|
|||
|
# plt.imshow(np.transpose(img, (1, 2, 0)))
|
|||
|
# #plt.show()
|
|||
|
|
|||
|
# train_datasets, eval_datasets = datasets.get_dataset("./data/", conf["type"])
|
|||
|
# data_train_shuffle = DataLoader(data_train, batch_size=64, shuffle=True)
|
|||
|
# data_test_shuffle = DataLoader(data_test, batch_size=64, shuffle=True)
|
|||
|
# print(data_train_shuffle)
|
|||
|
|
|||
|
lenth_train = randperm(len(data_train)).tolist() # 生成乱序的索引
|
|||
|
data_train_shuffle = torch.utils.data.Subset(data_train, lenth_train)
|
|||
|
lenth_test = randperm(len(data_test)).tolist() # 生成乱序的索引
|
|||
|
data_test_shuffle = torch.utils.data.Subset(data_test, lenth_test)
|
|||
|
|
|||
|
|
|||
|
train_datasets, eval_datasets = data_train_shuffle, data_test_shuffle
|
|||
|
server = Server(conf, eval_datasets)
|
|||
|
clients = []
|
|||
|
|
|||
|
for c in range(conf["no_models"]):
|
|||
|
clients.append(Client(conf, server.global_model, train_datasets, c))
|
|||
|
|
|||
|
print("\n\n")
|
|||
|
for e in range(conf["global_epochs"]):
|
|||
|
random.shuffle(clients)
|
|||
|
for client in clients[:conf['k']]:
|
|||
|
print(client.client_id)
|
|||
|
weight_accumulator = {}
|
|||
|
|
|||
|
for name, params in server.global_model.state_dict().items():
|
|||
|
weight_accumulator[name] = torch.zeros_like(params)
|
|||
|
|
|||
|
|
|||
|
diff = client.local_train(server.global_model)
|
|||
|
|
|||
|
for name, params in server.global_model.state_dict().items():
|
|||
|
weight_accumulator[name].add_(diff[name])
|
|||
|
|
|||
|
|
|||
|
server.model_aggregate(weight_accumulator)
|
|||
|
|
|||
|
acc, loss = server.model_eval()
|
|||
|
|
|||
|
print("Epoch %d, acc: %f, loss: %f\n" % (e, acc, loss))
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|