242 lines
9.5 KiB
Python
242 lines
9.5 KiB
Python
|
import datetime
|
|||
|
import os
|
|||
|
|
|||
|
import torch
|
|||
|
import matplotlib
|
|||
|
|
|||
|
import scipy.signal
|
|||
|
from matplotlib import pyplot as plt
|
|||
|
from torch.utils.tensorboard import SummaryWriter
|
|||
|
|
|||
|
import shutil
|
|||
|
import numpy as np
|
|||
|
|
|||
|
from PIL import Image
|
|||
|
from tqdm import tqdm
|
|||
|
from .utils import cvtColor, preprocess_input, resize_image
|
|||
|
from .utils_bbox import DecodeBox
|
|||
|
from .utils_map import get_coco_map, get_map
|
|||
|
|
|||
|
matplotlib.use('Agg')
|
|||
|
|
|||
|
|
|||
|
class LossHistory():
|
|||
|
def __init__(self, log_dir, model, input_shape):
|
|||
|
self.log_dir = log_dir
|
|||
|
self.losses = []
|
|||
|
self.val_loss = []
|
|||
|
|
|||
|
os.makedirs(self.log_dir)
|
|||
|
self.writer = SummaryWriter(self.log_dir)
|
|||
|
try:
|
|||
|
dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
|
|||
|
self.writer.add_graph(model, dummy_input)
|
|||
|
except:
|
|||
|
pass
|
|||
|
|
|||
|
def append_loss(self, epoch, loss, val_loss):
|
|||
|
if not os.path.exists(self.log_dir):
|
|||
|
os.makedirs(self.log_dir)
|
|||
|
|
|||
|
self.losses.append(loss)
|
|||
|
self.val_loss.append(val_loss)
|
|||
|
|
|||
|
with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
|
|||
|
f.write(str(loss))
|
|||
|
f.write("\n")
|
|||
|
with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
|
|||
|
f.write(str(val_loss))
|
|||
|
f.write("\n")
|
|||
|
|
|||
|
self.writer.add_scalar('loss', loss, epoch)
|
|||
|
self.writer.add_scalar('val_loss', val_loss, epoch)
|
|||
|
self.loss_plot()
|
|||
|
|
|||
|
def loss_plot(self):
|
|||
|
iters = range(len(self.losses))
|
|||
|
|
|||
|
plt.figure()
|
|||
|
plt.plot(iters, self.losses, 'red', linewidth=2, label='train loss')
|
|||
|
plt.plot(iters, self.val_loss, 'coral', linewidth=2, label='val loss')
|
|||
|
try:
|
|||
|
if len(self.losses) < 25:
|
|||
|
num = 5
|
|||
|
else:
|
|||
|
num = 15
|
|||
|
|
|||
|
plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle='--', linewidth=2,
|
|||
|
label='smooth train loss')
|
|||
|
plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle='--', linewidth=2,
|
|||
|
label='smooth val loss')
|
|||
|
except:
|
|||
|
pass
|
|||
|
|
|||
|
plt.grid(True)
|
|||
|
plt.xlabel('Epoch')
|
|||
|
plt.ylabel('Loss')
|
|||
|
plt.legend(loc="upper right")
|
|||
|
|
|||
|
plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
|
|||
|
|
|||
|
plt.cla()
|
|||
|
plt.close("all")
|
|||
|
|
|||
|
|
|||
|
class EvalCallback():
|
|||
|
def __init__(self, net, input_shape, anchors, anchors_mask, class_names, num_classes, val_lines, log_dir, cuda, \
|
|||
|
map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True,
|
|||
|
MINOVERLAP=0.5, eval_flag=True, period=1):
|
|||
|
super(EvalCallback, self).__init__()
|
|||
|
|
|||
|
self.net = net
|
|||
|
self.input_shape = input_shape
|
|||
|
self.anchors = anchors
|
|||
|
self.anchors_mask = anchors_mask
|
|||
|
self.class_names = class_names
|
|||
|
self.num_classes = num_classes
|
|||
|
self.val_lines = val_lines
|
|||
|
self.log_dir = log_dir
|
|||
|
self.cuda = cuda
|
|||
|
self.map_out_path = map_out_path
|
|||
|
self.max_boxes = max_boxes
|
|||
|
self.confidence = confidence
|
|||
|
self.nms_iou = nms_iou
|
|||
|
self.letterbox_image = letterbox_image
|
|||
|
self.MINOVERLAP = MINOVERLAP
|
|||
|
self.eval_flag = eval_flag
|
|||
|
self.period = period
|
|||
|
|
|||
|
self.bbox_util = DecodeBox(self.anchors, self.num_classes, (self.input_shape[0], self.input_shape[1]),
|
|||
|
self.anchors_mask)
|
|||
|
|
|||
|
self.maps = [0]
|
|||
|
self.epoches = [0]
|
|||
|
if self.eval_flag:
|
|||
|
with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
|
|||
|
f.write(str(0))
|
|||
|
f.write("\n")
|
|||
|
|
|||
|
def get_map_txt(self, image_id, image, class_names, map_out_path):
|
|||
|
f = open(os.path.join(map_out_path, "detection-results/" + image_id + ".txt"), "w", encoding='utf-8')
|
|||
|
image_shape = np.array(np.shape(image)[0:2])
|
|||
|
# ---------------------------------------------------------#
|
|||
|
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
|
|||
|
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
|
|||
|
# ---------------------------------------------------------#
|
|||
|
image = cvtColor(image)
|
|||
|
# ---------------------------------------------------------#
|
|||
|
# 给图像增加灰条,实现不失真的resize
|
|||
|
# 也可以直接resize进行识别
|
|||
|
# ---------------------------------------------------------#
|
|||
|
image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
|
|||
|
# ---------------------------------------------------------#
|
|||
|
# 添加上batch_size维度
|
|||
|
# ---------------------------------------------------------#
|
|||
|
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
|
|||
|
|
|||
|
with torch.no_grad():
|
|||
|
images = torch.from_numpy(image_data)
|
|||
|
if self.cuda:
|
|||
|
images = images.cuda()
|
|||
|
# ---------------------------------------------------------#
|
|||
|
# 将图像输入网络当中进行预测!
|
|||
|
# ---------------------------------------------------------#
|
|||
|
outputs = self.net(images)
|
|||
|
outputs = self.bbox_util.decode_box(outputs)
|
|||
|
# ---------------------------------------------------------#
|
|||
|
# 将预测框进行堆叠,然后进行非极大抑制
|
|||
|
# ---------------------------------------------------------#
|
|||
|
results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
|
|||
|
image_shape, self.letterbox_image, conf_thres=self.confidence,
|
|||
|
nms_thres=self.nms_iou)
|
|||
|
|
|||
|
if results[0] is None:
|
|||
|
return
|
|||
|
|
|||
|
top_label = np.array(results[0][:, 6], dtype='int32')
|
|||
|
top_conf = results[0][:, 4] * results[0][:, 5]
|
|||
|
top_boxes = results[0][:, :4]
|
|||
|
|
|||
|
top_100 = np.argsort(top_label)[::-1][:self.max_boxes]
|
|||
|
top_boxes = top_boxes[top_100]
|
|||
|
top_conf = top_conf[top_100]
|
|||
|
top_label = top_label[top_100]
|
|||
|
|
|||
|
for i, c in list(enumerate(top_label)):
|
|||
|
predicted_class = self.class_names[int(c)]
|
|||
|
box = top_boxes[i]
|
|||
|
score = str(top_conf[i])
|
|||
|
|
|||
|
top, left, bottom, right = box
|
|||
|
if predicted_class not in class_names:
|
|||
|
continue
|
|||
|
|
|||
|
f.write("%s %s %s %s %s %s\n" % (
|
|||
|
predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)), str(int(bottom))))
|
|||
|
|
|||
|
f.close()
|
|||
|
return
|
|||
|
|
|||
|
def on_epoch_end(self, epoch, model_eval):
|
|||
|
if epoch % self.period == 0 and self.eval_flag:
|
|||
|
self.net = model_eval
|
|||
|
if not os.path.exists(self.map_out_path):
|
|||
|
os.makedirs(self.map_out_path)
|
|||
|
if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")):
|
|||
|
os.makedirs(os.path.join(self.map_out_path, "ground-truth"))
|
|||
|
if not os.path.exists(os.path.join(self.map_out_path, "detection-results")):
|
|||
|
os.makedirs(os.path.join(self.map_out_path, "detection-results"))
|
|||
|
print("Get map.")
|
|||
|
for annotation_line in tqdm(self.val_lines):
|
|||
|
line = annotation_line.split()
|
|||
|
image_id = os.path.basename(line[0]).split('.')[0]
|
|||
|
# ------------------------------#
|
|||
|
# 读取图像并转换成RGB图像
|
|||
|
# ------------------------------#
|
|||
|
image = Image.open(line[0])
|
|||
|
# ------------------------------#
|
|||
|
# 获得预测框
|
|||
|
# ------------------------------#
|
|||
|
gt_boxes = np.array([np.array(list(map(int, box.split(',')))) for box in line[1:]])
|
|||
|
# ------------------------------#
|
|||
|
# 获得预测txt
|
|||
|
# ------------------------------#
|
|||
|
self.get_map_txt(image_id, image, self.class_names, self.map_out_path)
|
|||
|
|
|||
|
# ------------------------------#
|
|||
|
# 获得真实框txt
|
|||
|
# ------------------------------#
|
|||
|
with open(os.path.join(self.map_out_path, "ground-truth/" + image_id + ".txt"), "w") as new_f:
|
|||
|
for box in gt_boxes:
|
|||
|
left, top, right, bottom, obj = box
|
|||
|
obj_name = self.class_names[obj]
|
|||
|
new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
|
|||
|
|
|||
|
print("Calculate Map.")
|
|||
|
try:
|
|||
|
temp_map = get_coco_map(class_names=self.class_names, path=self.map_out_path)[1]
|
|||
|
except:
|
|||
|
temp_map = get_map(self.MINOVERLAP, False, path=self.map_out_path)
|
|||
|
self.maps.append(temp_map)
|
|||
|
self.epoches.append(epoch)
|
|||
|
|
|||
|
with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
|
|||
|
f.write(str(temp_map))
|
|||
|
f.write("\n")
|
|||
|
|
|||
|
plt.figure()
|
|||
|
plt.plot(self.epoches, self.maps, 'red', linewidth=2, label='train map')
|
|||
|
|
|||
|
plt.grid(True)
|
|||
|
plt.xlabel('Epoch')
|
|||
|
plt.ylabel('Map %s' % str(self.MINOVERLAP))
|
|||
|
plt.title('A Map Curve')
|
|||
|
plt.legend(loc="upper right")
|
|||
|
|
|||
|
plt.savefig(os.path.join(self.log_dir, "epoch_map.png"))
|
|||
|
plt.cla()
|
|||
|
plt.close("all")
|
|||
|
|
|||
|
print("Get map done.")
|
|||
|
shutil.rmtree(self.map_out_path)
|