180 lines
6.8 KiB
Python
180 lines
6.8 KiB
Python
|
import os
|
||
|
import sys
|
||
|
import json
|
||
|
import pickle
|
||
|
import random
|
||
|
|
||
|
import torch
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
import matplotlib.pyplot as plt
|
||
|
|
||
|
|
||
|
def read_split_data(root: str, val_rate: float = 0.2):
|
||
|
random.seed(0) # 保证随机结果可复现
|
||
|
assert os.path.exists(root), "dataset root: {} does not exist.".format(root)
|
||
|
|
||
|
# 遍历文件夹,一个文件夹对应一个类别
|
||
|
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
|
||
|
# 排序,保证各平台顺序一致
|
||
|
flower_class.sort()
|
||
|
# 生成类别名称以及对应的数字索引
|
||
|
class_indices = dict((k, v) for v, k in enumerate(flower_class))
|
||
|
json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
|
||
|
with open('class_indices.json', 'w') as json_file:
|
||
|
json_file.write(json_str)
|
||
|
|
||
|
train_images_path = [] # 存储训练集的所有图片路径
|
||
|
train_images_label = [] # 存储训练集图片对应索引信息
|
||
|
val_images_path = [] # 存储验证集的所有图片路径
|
||
|
val_images_label = [] # 存储验证集图片对应索引信息
|
||
|
every_class_num = [] # 存储每个类别的样本总数
|
||
|
supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型
|
||
|
# 遍历每个文件夹下的文件
|
||
|
for cla in flower_class:
|
||
|
cla_path = os.path.join(root, cla)
|
||
|
# 遍历获取supported支持的所有文件路径
|
||
|
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
|
||
|
if os.path.splitext(i)[-1] in supported]
|
||
|
# 排序,保证各平台顺序一致
|
||
|
images.sort()
|
||
|
# 获取该类别对应的索引
|
||
|
image_class = class_indices[cla]
|
||
|
# 记录该类别的样本数量
|
||
|
every_class_num.append(len(images))
|
||
|
# 按比例随机采样验证样本
|
||
|
val_path = random.sample(images, k=int(len(images) * val_rate))
|
||
|
|
||
|
for img_path in images:
|
||
|
if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集
|
||
|
val_images_path.append(img_path)
|
||
|
val_images_label.append(image_class)
|
||
|
else: # 否则存入训练集
|
||
|
train_images_path.append(img_path)
|
||
|
train_images_label.append(image_class)
|
||
|
|
||
|
print("{} images were found in the dataset.".format(sum(every_class_num)))
|
||
|
print("{} images for training.".format(len(train_images_path)))
|
||
|
print("{} images for validation.".format(len(val_images_path)))
|
||
|
assert len(train_images_path) > 0, "number of training images must greater than 0."
|
||
|
assert len(val_images_path) > 0, "number of validation images must greater than 0."
|
||
|
|
||
|
plot_image = False
|
||
|
if plot_image:
|
||
|
# 绘制每种类别个数柱状图
|
||
|
plt.bar(range(len(flower_class)), every_class_num, align='center')
|
||
|
# 将横坐标0,1,2,3,4替换为相应的类别名称
|
||
|
plt.xticks(range(len(flower_class)), flower_class)
|
||
|
# 在柱状图上添加数值标签
|
||
|
for i, v in enumerate(every_class_num):
|
||
|
plt.text(x=i, y=v + 5, s=str(v), ha='center')
|
||
|
# 设置x坐标
|
||
|
plt.xlabel('image class')
|
||
|
# 设置y坐标
|
||
|
plt.ylabel('number of images')
|
||
|
# 设置柱状图的标题
|
||
|
plt.title('flower class distribution')
|
||
|
plt.show()
|
||
|
|
||
|
return train_images_path, train_images_label, val_images_path, val_images_label
|
||
|
|
||
|
|
||
|
def plot_data_loader_image(data_loader):
|
||
|
batch_size = data_loader.batch_size
|
||
|
plot_num = min(batch_size, 4)
|
||
|
|
||
|
json_path = './class_indices.json'
|
||
|
assert os.path.exists(json_path), json_path + " does not exist."
|
||
|
json_file = open(json_path, 'r')
|
||
|
class_indices = json.load(json_file)
|
||
|
|
||
|
for data in data_loader:
|
||
|
images, labels = data
|
||
|
for i in range(plot_num):
|
||
|
# [C, H, W] -> [H, W, C]
|
||
|
img = images[i].numpy().transpose(1, 2, 0)
|
||
|
# 反Normalize操作
|
||
|
img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
|
||
|
label = labels[i].item()
|
||
|
plt.subplot(1, plot_num, i+1)
|
||
|
plt.xlabel(class_indices[str(label)])
|
||
|
plt.xticks([]) # 去掉x轴的刻度
|
||
|
plt.yticks([]) # 去掉y轴的刻度
|
||
|
plt.imshow(img.astype('uint8'))
|
||
|
plt.show()
|
||
|
|
||
|
|
||
|
def write_pickle(list_info: list, file_name: str):
|
||
|
with open(file_name, 'wb') as f:
|
||
|
pickle.dump(list_info, f)
|
||
|
|
||
|
|
||
|
def read_pickle(file_name: str) -> list:
|
||
|
with open(file_name, 'rb') as f:
|
||
|
info_list = pickle.load(f)
|
||
|
return info_list
|
||
|
|
||
|
|
||
|
def train_one_epoch(model, optimizer, data_loader, device, epoch):
|
||
|
model.train()
|
||
|
loss_function = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
|
||
|
accu_loss = torch.zeros(1).to(device) # 累计损失
|
||
|
accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
|
||
|
optimizer.zero_grad()
|
||
|
|
||
|
sample_num = 0
|
||
|
data_loader = tqdm(data_loader, file=sys.stdout)
|
||
|
for step, data in enumerate(data_loader):
|
||
|
images, labels = data
|
||
|
sample_num += images.shape[0]
|
||
|
|
||
|
pred = model(images.to(device))
|
||
|
pred_classes = torch.max(pred, dim=1)[1]
|
||
|
accu_num += torch.eq(pred_classes, labels.to(device)).sum()
|
||
|
|
||
|
loss = loss_function(pred, labels.to(device))
|
||
|
loss.backward()
|
||
|
accu_loss += loss.detach()
|
||
|
|
||
|
data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
|
||
|
accu_loss.item() / (step + 1),
|
||
|
accu_num.item() / sample_num)
|
||
|
|
||
|
if not torch.isfinite(loss):
|
||
|
print('WARNING: non-finite loss, ending training ', loss)
|
||
|
sys.exit(1)
|
||
|
|
||
|
optimizer.step()
|
||
|
optimizer.zero_grad()
|
||
|
|
||
|
return accu_loss.item() / (step + 1), accu_num.item() / sample_num
|
||
|
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def evaluate(model, data_loader, device, epoch):
|
||
|
loss_function = torch.nn.CrossEntropyLoss()
|
||
|
|
||
|
model.eval()
|
||
|
|
||
|
accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
|
||
|
accu_loss = torch.zeros(1).to(device) # 累计损失
|
||
|
|
||
|
sample_num = 0
|
||
|
data_loader = tqdm(data_loader, file=sys.stdout)
|
||
|
for step, data in enumerate(data_loader):
|
||
|
images, labels = data
|
||
|
sample_num += images.shape[0]
|
||
|
|
||
|
pred = model(images.to(device))
|
||
|
pred_classes = torch.max(pred, dim=1)[1]
|
||
|
accu_num += torch.eq(pred_classes, labels.to(device)).sum()
|
||
|
|
||
|
loss = loss_function(pred, labels.to(device))
|
||
|
accu_loss += loss
|
||
|
|
||
|
data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
|
||
|
accu_loss.item() / (step + 1),
|
||
|
accu_num.item() / sample_num)
|
||
|
|
||
|
return accu_loss.item() / (step + 1), accu_num.item() / sample_num
|