102 lines
3.2 KiB
Python
102 lines
3.2 KiB
Python
|
|
||
|
import models, torch
|
||
|
|
||
|
|
||
|
class Server(object):
|
||
|
|
||
|
def __init__(self, conf, eval_dataset):
|
||
|
|
||
|
self.conf = conf
|
||
|
|
||
|
self.global_model = models.get_model(self.conf["model_name"])
|
||
|
|
||
|
self.eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=self.conf["batch_size"], shuffle=True)
|
||
|
|
||
|
|
||
|
def model_aggregate(self, weight_accumulator):
|
||
|
for name, data in self.global_model.state_dict().items():
|
||
|
|
||
|
update_per_layer = weight_accumulator[name] * self.conf["lambda"]
|
||
|
|
||
|
if data.type() != update_per_layer.type():
|
||
|
data.add_(update_per_layer.to(torch.int64))
|
||
|
else:
|
||
|
data.add_(update_per_layer)
|
||
|
|
||
|
# def model_eval(self):
|
||
|
# self.global_model.eval()
|
||
|
|
||
|
# total_loss = 0.0
|
||
|
# correct = 0
|
||
|
# dataset_size = 0
|
||
|
# for batch_id, batch in enumerate(self.eval_loader):
|
||
|
# data, target = batch
|
||
|
# dataset_size += data.size()[0]
|
||
|
|
||
|
# if torch.cuda.is_available():
|
||
|
# data = data.cuda()
|
||
|
# target = target.cuda()
|
||
|
|
||
|
|
||
|
# output = self.global_model(data)
|
||
|
# # print(output)
|
||
|
# print("Targets: ",target)
|
||
|
|
||
|
# total_loss += torch.nn.functional.cross_entropy(output, target,
|
||
|
# reduction='sum').item() # sum up batch loss
|
||
|
# pred = output.data.max(1)[1] # get the index of the max log-probability
|
||
|
# print("pred: ",pred)
|
||
|
|
||
|
# correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
|
||
|
|
||
|
# acc = 100.0 * (float(correct) / float(dataset_size))
|
||
|
# total_l = total_loss / dataset_size
|
||
|
|
||
|
# torch.save(self.global_model.state_dict(), "./data/model_parameter.h5")
|
||
|
|
||
|
# return acc, total_l
|
||
|
|
||
|
def model_eval(self):
|
||
|
self.global_model.eval()
|
||
|
|
||
|
total_loss = 0.0
|
||
|
correct = 0
|
||
|
dataset_size = 0
|
||
|
confusion_matrix = torch.zeros(4, 4) # 初始化混淆矩阵
|
||
|
|
||
|
for batch_id, batch in enumerate(self.eval_loader):
|
||
|
data, target = batch
|
||
|
dataset_size += data.size()[0]
|
||
|
|
||
|
if torch.cuda.is_available():
|
||
|
data = data.cuda()
|
||
|
target = target.cuda()
|
||
|
|
||
|
output = self.global_model(data)
|
||
|
|
||
|
total_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()
|
||
|
pred = output.data.max(1)[1] # get the index of the max log-probability
|
||
|
|
||
|
correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
|
||
|
|
||
|
# 更新混淆矩阵
|
||
|
for t, p in zip(target.view(-1), pred.view(-1)):
|
||
|
confusion_matrix[t.long(), p.long()] += 1
|
||
|
|
||
|
acc = 100.0 * (float(correct) / float(dataset_size))
|
||
|
total_l = total_loss / dataset_size
|
||
|
|
||
|
# 计算精确度、召回率和F1分数
|
||
|
precision = torch.diag(confusion_matrix) / (confusion_matrix.sum(0) + 1e-9)
|
||
|
average_precision = torch.mean(precision)
|
||
|
recall = torch.diag(confusion_matrix) / (confusion_matrix.sum(1) + 1e-9)
|
||
|
average_recall = torch.mean(recall)
|
||
|
f1_score = 2 * precision * recall / (precision + recall + 1e-9)
|
||
|
average_f1_score = torch.mean(f1_score)
|
||
|
|
||
|
# 打印平均精确度、召回率和F1分数
|
||
|
#print("Average Precision: {:.2f}%, Average Recall: {:.2f}%, Average F1-Score: {:.2f}%".format(average_precision.item() * 100, average_recall.item() * 100, average_f1_score.item() * 100))
|
||
|
|
||
|
torch.save(self.global_model.state_dict(), "./data/model_parameter.h5")
|
||
|
|
||
|
return acc, total_l, average_precision, average_recall, average_f1_score # 返回评估指标
|