117 lines
5.1 KiB
Python
117 lines
5.1 KiB
Python
import json
|
||
import os
|
||
|
||
import numpy as np
|
||
import torch
|
||
from PIL import Image
|
||
from pycocotools.coco import COCO
|
||
from pycocotools.cocoeval import COCOeval
|
||
from tqdm import tqdm
|
||
|
||
from utils.utils import cvtColor, preprocess_input, resize_image
|
||
from yolo import YOLO
|
||
|
||
# ---------------------------------------------------------------------------#
|
||
# map_mode用于指定该文件运行时计算的内容
|
||
# map_mode为0代表整个map计算流程,包括获得预测结果、计算map。
|
||
# map_mode为1代表仅仅获得预测结果。
|
||
# map_mode为2代表仅仅获得计算map。
|
||
# ---------------------------------------------------------------------------#
|
||
map_mode = 0
|
||
# -------------------------------------------------------#
|
||
# 指向了验证集标签与图片路径
|
||
# -------------------------------------------------------#
|
||
cocoGt_path = 'coco_dataset/annotations/instances_val2017.json'
|
||
dataset_img_path = 'coco_dataset/val2017'
|
||
# -------------------------------------------------------#
|
||
# 结果输出的文件夹,默认为map_out
|
||
# -------------------------------------------------------#
|
||
temp_save_path = 'map_out/coco_eval'
|
||
|
||
|
||
class mAP_YOLO(YOLO):
|
||
# ---------------------------------------------------#
|
||
# 检测图片
|
||
# ---------------------------------------------------#
|
||
def detect_image(self, image_id, image, results):
|
||
# ---------------------------------------------------#
|
||
# 计算输入图片的高和宽
|
||
# ---------------------------------------------------#
|
||
image_shape = np.array(np.shape(image)[0:2])
|
||
# ---------------------------------------------------------#
|
||
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
|
||
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
|
||
# ---------------------------------------------------------#
|
||
image = cvtColor(image)
|
||
# ---------------------------------------------------------#
|
||
# 给图像增加灰条,实现不失真的resize
|
||
# 也可以直接resize进行识别
|
||
# ---------------------------------------------------------#
|
||
image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
|
||
# ---------------------------------------------------------#
|
||
# 添加上batch_size维度
|
||
# ---------------------------------------------------------#
|
||
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
|
||
|
||
with torch.no_grad():
|
||
images = torch.from_numpy(image_data)
|
||
if self.cuda:
|
||
images = images.cuda()
|
||
# ---------------------------------------------------------#
|
||
# 将图像输入网络当中进行预测!
|
||
# ---------------------------------------------------------#
|
||
outputs = self.net(images)
|
||
outputs = self.bbox_util.decode_box(outputs)
|
||
# ---------------------------------------------------------#
|
||
# 将预测框进行堆叠,然后进行非极大抑制
|
||
# ---------------------------------------------------------#
|
||
outputs = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
|
||
image_shape, self.letterbox_image, conf_thres=self.confidence,
|
||
nms_thres=self.nms_iou)
|
||
|
||
if outputs[0] is None:
|
||
return results
|
||
|
||
top_label = np.array(outputs[0][:, 6], dtype='int32')
|
||
top_conf = outputs[0][:, 4] * outputs[0][:, 5]
|
||
top_boxes = outputs[0][:, :4]
|
||
|
||
for i, c in enumerate(top_label):
|
||
result = {}
|
||
top, left, bottom, right = top_boxes[i]
|
||
|
||
result["image_id"] = int(image_id)
|
||
result["category_id"] = clsid2catid[c]
|
||
result["bbox"] = [float(left), float(top), float(right - left), float(bottom - top)]
|
||
result["score"] = float(top_conf[i])
|
||
results.append(result)
|
||
return results
|
||
|
||
|
||
if __name__ == "__main__":
|
||
if not os.path.exists(temp_save_path):
|
||
os.makedirs(temp_save_path)
|
||
|
||
cocoGt = COCO(cocoGt_path)
|
||
ids = list(cocoGt.imgToAnns.keys())
|
||
clsid2catid = cocoGt.getCatIds()
|
||
|
||
if map_mode == 0 or map_mode == 1:
|
||
yolo = mAP_YOLO(confidence=0.001, nms_iou=0.65)
|
||
|
||
with open(os.path.join(temp_save_path, 'eval_results.json'), "w") as f:
|
||
results = []
|
||
for image_id in tqdm(ids):
|
||
image_path = os.path.join(dataset_img_path, cocoGt.loadImgs(image_id)[0]['file_name'])
|
||
image = Image.open(image_path)
|
||
results = yolo.detect_image(image_id, image, results)
|
||
json.dump(results, f)
|
||
|
||
if map_mode == 0 or map_mode == 2:
|
||
cocoDt = cocoGt.loadRes(os.path.join(temp_save_path, 'eval_results.json'))
|
||
cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
|
||
cocoEval.evaluate()
|
||
cocoEval.accumulate()
|
||
cocoEval.summarize()
|
||
print("Get map done.")
|