39 lines
1.2 KiB
Python
39 lines
1.2 KiB
Python
import torch
|
|
from ultralytics import YOLO
|
|
from ultralytics.data import download
|
|
|
|
# 下载COCO128数据集
|
|
download('coco128')
|
|
|
|
# 定义训练参数
|
|
epochs = 10 # 训练轮数
|
|
batch_size = 16 # 批次大小
|
|
img_size = 640 # 输入图像尺寸
|
|
|
|
# 加载YOLOv8模型
|
|
model = YOLO('yolov8s.yaml') # 创建新的模型实例
|
|
|
|
# 开始训练
|
|
model.train(data='coco128.yaml', epochs=epochs, batch=batch_size, imgsz=img_size)
|
|
|
|
# 加载经过训练的模型,假设模型保存在 'best.pt'
|
|
model = YOLO('best.pt')
|
|
|
|
# 设置要检测的对象类别,这里的例子是只检测行人
|
|
class_names = model.names
|
|
person_class_id = class_names.index('person')
|
|
|
|
# 加载图片或视频
|
|
img_path = 'path_to_your_image.jpg'
|
|
|
|
# 进行目标检测
|
|
results = model(img_path)
|
|
|
|
# 处理结果
|
|
for result in results:
|
|
boxes = result.boxes
|
|
for box in boxes:
|
|
if box.cls == person_class_id: # 只处理行人检测结果
|
|
x1, y1, x2, y2 = box.xyxy[0] # 获取边界框坐标
|
|
confidence = box.conf.item() # 获取置信度
|
|
print(f"Pedestrian detected at ({x1:.2f}, {y1:.2f}) to ({x2:.2f}, {y2:.2f}), Confidence: {confidence:.2f}") |