68 lines
2.5 KiB
Python
Executable File
68 lines
2.5 KiB
Python
Executable File
import os
|
||
import numpy as np
|
||
|
||
def parse_yolo_label(file_path):
|
||
"""解析YOLO格式的标签文件。"""
|
||
labels = []
|
||
try:
|
||
with open(file_path, 'r') as file:
|
||
lines = file.readlines()
|
||
labels = [list(map(float, line.strip().split())) for line in lines]
|
||
except FileNotFoundError:
|
||
pass # 如果文件不存在,返回空列表
|
||
return labels
|
||
|
||
def calculate_iou(box1, box2):
|
||
"""计算两个边界框的交并比(IoU)。"""
|
||
x1, y1, w1, h1 = box1
|
||
x2, y2, w2, h2 = box2
|
||
|
||
inter_w = max(0, min(x1 + w1 / 2, x2 + w2 / 2) - max(x1 - w1 / 2, x2 - w2 / 2))
|
||
inter_h = max(0, min(y1 + h1 / 2, y2 + h2 / 2) - max(y1 - h1 / 2, y2 - h2 / 2))
|
||
inter_area = inter_w * inter_h
|
||
|
||
union_area = w1 * h1 + w2 * h2 - inter_area
|
||
iou = inter_area / union_area if union_area > 0 else 0
|
||
return iou
|
||
|
||
def evaluate_predictions(gt_folder, pred_folder, iou_threshold=0.9):
|
||
"""评估预测的精度、召回率和F1分数。"""
|
||
gt_files = os.listdir(gt_folder)
|
||
tp = 0 # 真正例
|
||
fp = 0 # 假正例
|
||
fn = 0 # 假负例
|
||
|
||
for gt_file in gt_files:
|
||
gt_path = os.path.join(gt_folder, gt_file)
|
||
pred_path = os.path.join(pred_folder, gt_file)
|
||
|
||
gt_labels = parse_yolo_label(gt_path)
|
||
pred_labels = parse_yolo_label(pred_path)
|
||
|
||
matched = [False] * len(pred_labels)
|
||
|
||
for gt in gt_labels:
|
||
gt_matched = False
|
||
for i, pred in enumerate(pred_labels):
|
||
if calculate_iou(gt[1:], pred[1:]) >= iou_threshold:
|
||
if not matched[i]: # 防止多个真实标签匹配到同一个预测标签
|
||
matched[i] = True
|
||
gt_matched = True
|
||
tp += 1
|
||
break
|
||
if not gt_matched:
|
||
fn += 1
|
||
|
||
fp += matched.count(False) # 所有未匹配的预测都视为假正例
|
||
|
||
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
||
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
||
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
||
|
||
return precision, recall, f1
|
||
|
||
# 示例用法
|
||
gt_folder = '/root/catkin_ws/src/ultralytics/ours_15000/labels_renders2'
|
||
pred_folder = '/root/catkin_ws/src/ultralytics/ours_15000/labels_renders'
|
||
precision, recall, f1 = evaluate_predictions(gt_folder, pred_folder)
|
||
print(f'Precision: {precision}, Recall: {recall}, F1 Score: {f1}') |