brain-tumor_image_classific.../server.py

102 lines
3.2 KiB
Python
Raw Permalink Normal View History

2024-06-21 10:34:43 +08:00
import models, torch
class Server(object):
def __init__(self, conf, eval_dataset):
self.conf = conf
self.global_model = models.get_model(self.conf["model_name"])
self.eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=self.conf["batch_size"], shuffle=True)
def model_aggregate(self, weight_accumulator):
for name, data in self.global_model.state_dict().items():
update_per_layer = weight_accumulator[name] * self.conf["lambda"]
if data.type() != update_per_layer.type():
data.add_(update_per_layer.to(torch.int64))
else:
data.add_(update_per_layer)
# def model_eval(self):
# self.global_model.eval()
# total_loss = 0.0
# correct = 0
# dataset_size = 0
# for batch_id, batch in enumerate(self.eval_loader):
# data, target = batch
# dataset_size += data.size()[0]
# if torch.cuda.is_available():
# data = data.cuda()
# target = target.cuda()
# output = self.global_model(data)
# # print(output)
# print("Targets: ",target)
# total_loss += torch.nn.functional.cross_entropy(output, target,
# reduction='sum').item() # sum up batch loss
# pred = output.data.max(1)[1] # get the index of the max log-probability
# print("pred: ",pred)
# correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
# acc = 100.0 * (float(correct) / float(dataset_size))
# total_l = total_loss / dataset_size
# torch.save(self.global_model.state_dict(), "./data/model_parameter.h5")
# return acc, total_l
def model_eval(self):
self.global_model.eval()
total_loss = 0.0
correct = 0
dataset_size = 0
confusion_matrix = torch.zeros(4, 4) # 初始化混淆矩阵
for batch_id, batch in enumerate(self.eval_loader):
data, target = batch
dataset_size += data.size()[0]
if torch.cuda.is_available():
data = data.cuda()
target = target.cuda()
output = self.global_model(data)
total_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()
pred = output.data.max(1)[1] # get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
# 更新混淆矩阵
for t, p in zip(target.view(-1), pred.view(-1)):
confusion_matrix[t.long(), p.long()] += 1
acc = 100.0 * (float(correct) / float(dataset_size))
total_l = total_loss / dataset_size
# 计算精确度、召回率和F1分数
precision = torch.diag(confusion_matrix) / (confusion_matrix.sum(0) + 1e-9)
average_precision = torch.mean(precision)
recall = torch.diag(confusion_matrix) / (confusion_matrix.sum(1) + 1e-9)
average_recall = torch.mean(recall)
f1_score = 2 * precision * recall / (precision + recall + 1e-9)
average_f1_score = torch.mean(f1_score)
# 打印平均精确度、召回率和F1分数
#print("Average Precision: {:.2f}%, Average Recall: {:.2f}%, Average F1-Score: {:.2f}%".format(average_precision.item() * 100, average_recall.item() * 100, average_f1_score.item() * 100))
torch.save(self.global_model.state_dict(), "./data/model_parameter.h5")
return acc, total_l, average_precision, average_recall, average_f1_score # 返回评估指标