GeoYolo-SLAM/ultralytics/predict.py

46 lines
1.9 KiB
Python
Raw Permalink Normal View History

2025-04-09 16:18:27 +08:00
import os
from ultralytics import YOLO
def predict_and_save(model, input_folder, output_folder):
# 创建输出文件夹(如果不存在)
os.makedirs(output_folder, exist_ok=True)
# 获取输入文件夹中的所有图像文件
image_files = sorted([f for f in os.listdir(input_folder) if f.endswith(('.jpg', '.jpeg', '.png'))])
# 预测每张图像并保存结果
for image_file in image_files:
# 使用图片文件的编号
file_name = os.path.splitext(image_file)[0]
output_file = os.path.join(output_folder, f'{file_name}.txt')
# 加载图像进行预测
image_path = os.path.join(input_folder, image_file)
results = model(image_path)
# 打开文件并以写入模式保存预测结果到标签文件
with open(output_file, 'w') as f:
for result in results:
for bbox in result.boxes:
print(f"bbox.xywh: {bbox.xywh}")
if bbox.xywh.shape[1] == 4: # 确认 bbox.xywh 形状正确
cls = int(bbox.cls[0])
x_center = float(bbox.xywh[0, 0])
y_center = float(bbox.xywh[0, 1])
width = float(bbox.xywh[0, 2])
height = float(bbox.xywh[0, 3])
f.write(f'{cls} {x_center} {y_center} {width} {height}\n')
# 确保模型路径正确并且文件存在
model_path = os.path.expanduser('~/catkin_ws/src/ultralytics/yolov8n.pt')
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found at {model_path}")
# 加载模型
model = YOLO(model_path) # 替换为你的YOLOv8模型路径
input_folder = '/root/catkin_ws/src/ultralytics/ours_15000/renders' # 确保这是图像文件夹
output_folder = '/root/catkin_ws/src/ultralytics/ours_15000/labels_renders2'
predict_and_save(model, input_folder, output_folder)