GeoYolo-SLAM/ultralytics/eval.py

68 lines
2.5 KiB
Python
Raw Permalink Normal View History

2025-04-09 16:18:27 +08:00
import os
import numpy as np
def parse_yolo_label(file_path):
"""解析YOLO格式的标签文件。"""
labels = []
try:
with open(file_path, 'r') as file:
lines = file.readlines()
labels = [list(map(float, line.strip().split())) for line in lines]
except FileNotFoundError:
pass # 如果文件不存在,返回空列表
return labels
def calculate_iou(box1, box2):
"""计算两个边界框的交并比IoU"""
x1, y1, w1, h1 = box1
x2, y2, w2, h2 = box2
inter_w = max(0, min(x1 + w1 / 2, x2 + w2 / 2) - max(x1 - w1 / 2, x2 - w2 / 2))
inter_h = max(0, min(y1 + h1 / 2, y2 + h2 / 2) - max(y1 - h1 / 2, y2 - h2 / 2))
inter_area = inter_w * inter_h
union_area = w1 * h1 + w2 * h2 - inter_area
iou = inter_area / union_area if union_area > 0 else 0
return iou
def evaluate_predictions(gt_folder, pred_folder, iou_threshold=0.9):
"""评估预测的精度、召回率和F1分数。"""
gt_files = os.listdir(gt_folder)
tp = 0 # 真正例
fp = 0 # 假正例
fn = 0 # 假负例
for gt_file in gt_files:
gt_path = os.path.join(gt_folder, gt_file)
pred_path = os.path.join(pred_folder, gt_file)
gt_labels = parse_yolo_label(gt_path)
pred_labels = parse_yolo_label(pred_path)
matched = [False] * len(pred_labels)
for gt in gt_labels:
gt_matched = False
for i, pred in enumerate(pred_labels):
if calculate_iou(gt[1:], pred[1:]) >= iou_threshold:
if not matched[i]: # 防止多个真实标签匹配到同一个预测标签
matched[i] = True
gt_matched = True
tp += 1
break
if not gt_matched:
fn += 1
fp += matched.count(False) # 所有未匹配的预测都视为假正例
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
return precision, recall, f1
# 示例用法
gt_folder = '/root/catkin_ws/src/ultralytics/ours_15000/labels_renders2'
pred_folder = '/root/catkin_ws/src/ultralytics/ours_15000/labels_renders'
precision, recall, f1 = evaluate_predictions(gt_folder, pred_folder)
print(f'Precision: {precision}, Recall: {recall}, F1 Score: {f1}')