68 lines
2.5 KiB
Python
68 lines
2.5 KiB
Python
|
import os
|
|||
|
import numpy as np
|
|||
|
|
|||
|
def parse_yolo_label(file_path):
|
|||
|
"""解析YOLO格式的标签文件。"""
|
|||
|
labels = []
|
|||
|
try:
|
|||
|
with open(file_path, 'r') as file:
|
|||
|
lines = file.readlines()
|
|||
|
labels = [list(map(float, line.strip().split())) for line in lines]
|
|||
|
except FileNotFoundError:
|
|||
|
pass # 如果文件不存在,返回空列表
|
|||
|
return labels
|
|||
|
|
|||
|
def calculate_iou(box1, box2):
|
|||
|
"""计算两个边界框的交并比(IoU)。"""
|
|||
|
x1, y1, w1, h1 = box1
|
|||
|
x2, y2, w2, h2 = box2
|
|||
|
|
|||
|
inter_w = max(0, min(x1 + w1 / 2, x2 + w2 / 2) - max(x1 - w1 / 2, x2 - w2 / 2))
|
|||
|
inter_h = max(0, min(y1 + h1 / 2, y2 + h2 / 2) - max(y1 - h1 / 2, y2 - h2 / 2))
|
|||
|
inter_area = inter_w * inter_h
|
|||
|
|
|||
|
union_area = w1 * h1 + w2 * h2 - inter_area
|
|||
|
iou = inter_area / union_area if union_area > 0 else 0
|
|||
|
return iou
|
|||
|
|
|||
|
def evaluate_predictions(gt_folder, pred_folder, iou_threshold=0.9):
|
|||
|
"""评估预测的精度、召回率和F1分数。"""
|
|||
|
gt_files = os.listdir(gt_folder)
|
|||
|
tp = 0 # 真正例
|
|||
|
fp = 0 # 假正例
|
|||
|
fn = 0 # 假负例
|
|||
|
|
|||
|
for gt_file in gt_files:
|
|||
|
gt_path = os.path.join(gt_folder, gt_file)
|
|||
|
pred_path = os.path.join(pred_folder, gt_file)
|
|||
|
|
|||
|
gt_labels = parse_yolo_label(gt_path)
|
|||
|
pred_labels = parse_yolo_label(pred_path)
|
|||
|
|
|||
|
matched = [False] * len(pred_labels)
|
|||
|
|
|||
|
for gt in gt_labels:
|
|||
|
gt_matched = False
|
|||
|
for i, pred in enumerate(pred_labels):
|
|||
|
if calculate_iou(gt[1:], pred[1:]) >= iou_threshold:
|
|||
|
if not matched[i]: # 防止多个真实标签匹配到同一个预测标签
|
|||
|
matched[i] = True
|
|||
|
gt_matched = True
|
|||
|
tp += 1
|
|||
|
break
|
|||
|
if not gt_matched:
|
|||
|
fn += 1
|
|||
|
|
|||
|
fp += matched.count(False) # 所有未匹配的预测都视为假正例
|
|||
|
|
|||
|
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
|||
|
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
|||
|
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
|||
|
|
|||
|
return precision, recall, f1
|
|||
|
|
|||
|
# 示例用法
|
|||
|
gt_folder = '/root/catkin_ws/src/ultralytics/ours_15000/labels_renders2'
|
|||
|
pred_folder = '/root/catkin_ws/src/ultralytics/ours_15000/labels_renders'
|
|||
|
precision, recall, f1 = evaluate_predictions(gt_folder, pred_folder)
|
|||
|
print(f'Precision: {precision}, Recall: {recall}, F1 Score: {f1}')
|