fujie_code/predict.py

182 lines
9.0 KiB
Python
Raw Normal View History

2024-07-04 17:03:29 +08:00
# -----------------------------------------------------------------------#
# predict.py将单张图片预测、摄像头检测、FPS测试和目录遍历检测等功能
# 整合到了一个py文件中通过指定mode进行模式的修改。
# -----------------------------------------------------------------------#
import time
import cv2
import numpy as np
from PIL import Image
from yolo import YOLO
if __name__ == "__main__":
yolo = YOLO()
# ----------------------------------------------------------------------------------------------------------#
# mode用于指定测试的模式
# 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释
# 'video' 表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。
# 'fps' 表示测试fps使用的图片是img里面的street.jpg详情查看下方注释。
# 'dir_predict' 表示遍历文件夹进行检测并保存。默认遍历img文件夹保存img_out文件夹详情查看下方注释。
# 'heatmap' 表示进行预测结果的热力图可视化,详情查看下方注释。
# 'export_onnx' 表示将模型导出为onnx需要pytorch1.7.1以上。
# ----------------------------------------------------------------------------------------------------------#
mode = "predict"
# -------------------------------------------------------------------------#
# crop 指定了是否在单张图片预测后对目标进行截取
# count 指定了是否进行目标的计数
# crop、count仅在mode='predict'时有效
# -------------------------------------------------------------------------#
crop = False
count = False
# ----------------------------------------------------------------------------------------------------------#
# video_path 用于指定视频的路径当video_path=0时表示检测摄像头
# 想要检测视频则设置如video_path = "xxx.mp4"即可代表读取出根目录下的xxx.mp4文件。
# video_save_path 表示视频保存的路径当video_save_path=""时表示不保存
# 想要保存视频则设置如video_save_path = "yyy.mp4"即可代表保存为根目录下的yyy.mp4文件。
# video_fps 用于保存的视频的fps
#
# video_path、video_save_path和video_fps仅在mode='video'时有效
# 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。
# ----------------------------------------------------------------------------------------------------------#
video_path = 0
video_save_path = ""
video_fps = 25.0
# ----------------------------------------------------------------------------------------------------------#
# test_interval 用于指定测量fps的时候图片检测的次数。理论上test_interval越大fps越准确。
# fps_image_path 用于指定测试的fps图片
#
# test_interval和fps_image_path仅在mode='fps'有效
# ----------------------------------------------------------------------------------------------------------#
test_interval = 100
fps_image_path = "img/street.jpg"
# -------------------------------------------------------------------------#
# dir_origin_path 指定了用于检测的图片的文件夹路径
# dir_save_path 指定了检测完图片的保存路径
#
# dir_origin_path和dir_save_path仅在mode='dir_predict'时有效
# -------------------------------------------------------------------------#
dir_origin_path = "img/"
dir_save_path = "img_out/"
# -------------------------------------------------------------------------#
# heatmap_save_path 热力图的保存路径默认保存在model_data下
#
# heatmap_save_path仅在mode='heatmap'有效
# -------------------------------------------------------------------------#
heatmap_save_path = "model_data/heatmap_vision.png"
# -------------------------------------------------------------------------#
# simplify 使用Simplify onnx
# onnx_save_path 指定了onnx的保存路径
# -------------------------------------------------------------------------#
simplify = True
onnx_save_path = "model_data/models.onnx"
if mode == "predict":
'''
1如果想要进行检测完的图片的保存利用r_image.save("img.jpg")即可保存直接在predict.py里进行修改即可
2如果想要获得预测框的坐标可以进入yolo.detect_image函数在绘图部分读取topleftbottomright这四个值
3如果想要利用预测框截取下目标可以进入yolo.detect_image函数在绘图部分利用获取到的topleftbottomright这四个值
在原图上利用矩阵的方式进行截取
4如果想要在预测图上写额外的字比如检测到的特定目标的数量可以进入yolo.detect_image函数在绘图部分对predicted_class进行判断
比如判断if predicted_class == 'car': 即可判断当前目标是否为车然后记录数量即可利用draw.text即可写字
'''
while True:
img = input('Input image filename:')
# img/street.jpg
# img/street_a3.jpg
try:
image = Image.open(img)
except:
print('Open Error! Try again!')
continue
else:
r_image = yolo.detect_image(image, crop=crop, count=count)
# r_image.show()
r_image.save("duffision.png")
elif mode == "video":
capture = cv2.VideoCapture(video_path)
if video_save_path != "":
fourcc = cv2.VideoWriter_fourcc(*'XVID')
size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
ref, frame = capture.read()
if not ref:
raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。")
fps = 0.0
while (True):
t1 = time.time()
# 读取某一帧
ref, frame = capture.read()
if not ref:
break
# 格式转变BGRtoRGB
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# 转变成Image
frame = Image.fromarray(np.uint8(frame))
# 进行检测
frame = np.array(yolo.detect_image(frame))
# RGBtoBGR满足opencv显示格式
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
fps = (fps + (1. / (time.time() - t1))) / 2
print("fps= %.2f" % (fps))
frame = cv2.putText(frame, "fps= %.2f" % (fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
cv2.imshow("video", frame)
c = cv2.waitKey(1) & 0xff
if video_save_path != "":
out.write(frame)
if c == 27:
capture.release()
break
print("Video Detection Done!")
capture.release()
if video_save_path != "":
print("Save processed video to the path :" + video_save_path)
out.release()
cv2.destroyAllWindows()
elif mode == "fps":
img = Image.open(fps_image_path)
tact_time = yolo.get_FPS(img, test_interval)
print(str(tact_time) + ' seconds, ' + str(1 / tact_time) + 'FPS, @batch_size 1')
elif mode == "dir_predict":
import os
from tqdm import tqdm
img_names = os.listdir(dir_origin_path)
for img_name in tqdm(img_names):
if img_name.lower().endswith(
('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
image_path = os.path.join(dir_origin_path, img_name)
image = Image.open(image_path)
r_image = yolo.detect_image(image)
if not os.path.exists(dir_save_path):
os.makedirs(dir_save_path)
r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0)
elif mode == "heatmap":
while True:
img = input('Input image filename:')
try:
image = Image.open(img)
except:
print('Open Error! Try again!')
continue
else:
yolo.detect_heatmap(image, heatmap_save_path)
elif mode == "export_onnx":
yolo.convert_to_onnx(simplify, onnx_save_path)
else:
raise AssertionError(
"Please specify the correct mode: 'predict', 'video', 'fps', 'heatmap', 'export_onnx', 'dir_predict'.")