import os import sys import json import pickle import random import torch from tqdm import tqdm import matplotlib.pyplot as plt def read_split_data(root: str, val_rate: float = 0.2): random.seed(0) # 保证随机结果可复现 assert os.path.exists(root), "dataset root: {} does not exist.".format(root) # 遍历文件夹,一个文件夹对应一个类别 flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))] # 排序,保证各平台顺序一致 flower_class.sort() # 生成类别名称以及对应的数字索引 class_indices = dict((k, v) for v, k in enumerate(flower_class)) json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4) with open('class_indices.json', 'w') as json_file: json_file.write(json_str) train_images_path = [] # 存储训练集的所有图片路径 train_images_label = [] # 存储训练集图片对应索引信息 val_images_path = [] # 存储验证集的所有图片路径 val_images_label = [] # 存储验证集图片对应索引信息 every_class_num = [] # 存储每个类别的样本总数 supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型 # 遍历每个文件夹下的文件 for cla in flower_class: cla_path = os.path.join(root, cla) # 遍历获取supported支持的所有文件路径 images = [os.path.join(root, cla, i) for i in os.listdir(cla_path) if os.path.splitext(i)[-1] in supported] # 排序,保证各平台顺序一致 images.sort() # 获取该类别对应的索引 image_class = class_indices[cla] # 记录该类别的样本数量 every_class_num.append(len(images)) # 按比例随机采样验证样本 val_path = random.sample(images, k=int(len(images) * val_rate)) for img_path in images: if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集 val_images_path.append(img_path) val_images_label.append(image_class) else: # 否则存入训练集 train_images_path.append(img_path) train_images_label.append(image_class) print("{} images were found in the dataset.".format(sum(every_class_num))) print("{} images for training.".format(len(train_images_path))) print("{} images for validation.".format(len(val_images_path))) assert len(train_images_path) > 0, "number of training images must greater than 0." assert len(val_images_path) > 0, "number of validation images must greater than 0." plot_image = False if plot_image: # 绘制每种类别个数柱状图 plt.bar(range(len(flower_class)), every_class_num, align='center') # 将横坐标0,1,2,3,4替换为相应的类别名称 plt.xticks(range(len(flower_class)), flower_class) # 在柱状图上添加数值标签 for i, v in enumerate(every_class_num): plt.text(x=i, y=v + 5, s=str(v), ha='center') # 设置x坐标 plt.xlabel('image class') # 设置y坐标 plt.ylabel('number of images') # 设置柱状图的标题 plt.title('flower class distribution') plt.show() return train_images_path, train_images_label, val_images_path, val_images_label def plot_data_loader_image(data_loader): batch_size = data_loader.batch_size plot_num = min(batch_size, 4) json_path = './class_indices.json' assert os.path.exists(json_path), json_path + " does not exist." json_file = open(json_path, 'r') class_indices = json.load(json_file) for data in data_loader: images, labels = data for i in range(plot_num): # [C, H, W] -> [H, W, C] img = images[i].numpy().transpose(1, 2, 0) # 反Normalize操作 img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255 label = labels[i].item() plt.subplot(1, plot_num, i+1) plt.xlabel(class_indices[str(label)]) plt.xticks([]) # 去掉x轴的刻度 plt.yticks([]) # 去掉y轴的刻度 plt.imshow(img.astype('uint8')) plt.show() def write_pickle(list_info: list, file_name: str): with open(file_name, 'wb') as f: pickle.dump(list_info, f) def read_pickle(file_name: str) -> list: with open(file_name, 'rb') as f: info_list = pickle.load(f) return info_list def train_one_epoch(model, optimizer, data_loader, device, epoch): model.train() loss_function = torch.nn.CrossEntropyLoss(label_smoothing=0.1) accu_loss = torch.zeros(1).to(device) # 累计损失 accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数 optimizer.zero_grad() sample_num = 0 data_loader = tqdm(data_loader, file=sys.stdout) for step, data in enumerate(data_loader): images, labels = data sample_num += images.shape[0] pred = model(images.to(device)) pred_classes = torch.max(pred, dim=1)[1] accu_num += torch.eq(pred_classes, labels.to(device)).sum() loss = loss_function(pred, labels.to(device)) loss.backward() accu_loss += loss.detach() data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch, accu_loss.item() / (step + 1), accu_num.item() / sample_num) if not torch.isfinite(loss): print('WARNING: non-finite loss, ending training ', loss) sys.exit(1) optimizer.step() optimizer.zero_grad() return accu_loss.item() / (step + 1), accu_num.item() / sample_num @torch.no_grad() def evaluate(model, data_loader, device, epoch): loss_function = torch.nn.CrossEntropyLoss() model.eval() accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数 accu_loss = torch.zeros(1).to(device) # 累计损失 sample_num = 0 data_loader = tqdm(data_loader, file=sys.stdout) for step, data in enumerate(data_loader): images, labels = data sample_num += images.shape[0] pred = model(images.to(device)) pred_classes = torch.max(pred, dim=1)[1] accu_num += torch.eq(pred_classes, labels.to(device)).sum() loss = loss_function(pred, labels.to(device)) accu_loss += loss data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch, accu_loss.item() / (step + 1), accu_num.item() / sample_num) return accu_loss.item() / (step + 1), accu_num.item() / sample_num