Yolo-Detection/yolo+ByteTrack/yoloTest/main.cpp

241 lines
6.8 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include "yolo.h"
#include "kalmanFilter.h"
#include <iostream>
#include <opencv2/opencv.hpp>
#include <vector>
#include <time.h>
#include "BYTETracker.h"
using namespace std;
using namespace cv;
using namespace cv::dnn;
int main()
{
string img_path = "D:/VS/yoloTest/images/detect.mp4";
string model_path = "D:/VS/yoloTest/models/yolov5s.onnx";
// 创建Yolov5模型和跟踪器
Yolov5 yolo;
Net net;
BYTETracker tracker(30, 30); // 创建Sort跟踪器
// 读取Yolo模型
if (!yolo.readModel(net, model_path, true)) {
cout << "无法加载Yolo模型" << endl;
return -1;
}
// 生成颜色表,每个类别对应一个颜色
vector<Scalar> colors;
srand(42); // 固定随机种子,确保每次运行颜色一致
for (int i = 0; i < 80; i++) { // 假设最多80个类别
int b = rand() % 256;
int g = rand() % 256;
int r = rand() % 256;
colors.push_back(Scalar(b, g, r));
}
// 读取视频
VideoCapture cap(img_path);
if (!cap.isOpened()) {
cout << "无法打开视频文件" << endl;
return -1;
}
VideoWriter writer;
writer = VideoWriter("D:/VS/yoloTest/images/test.mp4", CAP_OPENCV_MJPEG, 30, Size(1920, 1080), true);
clock_t start, end; //计时器
start = clock();
int num = 0;
Mat frame;
int frame_counter = 0;
int detect_interval = 3; // 每隔多少帧重新调用一次YOLO进行检测
vector<Rect> detection_rects;
vector<Object> objects;
while (cap.read(frame)) {
if (frame.empty()) {
cout << "视频已结束" << endl;
break;
}
vector<Output> result;
// 每隔 detect_interval 帧进行一次 YOLO 检测
if (frame_counter % detect_interval == 0) {
detection_rects.clear(); // 清空之前的检测框
objects.clear();
// 进行目标检测
if (yolo.Detect(frame, net, result)) {
for (int i = 0; i < result.size(); i++) {
int x = result[i].box.x;
int y = result[i].box.y;
int w = result[i].box.width;
int h = result[i].box.height;
Rect rect(x, y, w, h);
Object object;
object.label = result[i].id;
object.prob = result[i].confidence;
object.rect = rect;
objects.push_back(object);
}
}
// 将检测结果传递给 SORT 进行初始化或更新
vector<STrack> object_trackers = tracker.update(objects);
// 绘制跟踪结果
vector<vector<float>> trackers;
for (int i = 0; i < object_trackers.size(); i++) {
trackers.push_back(object_trackers[i].tlwh);
}
for (int i = 0; i < trackers.size(); i++) {
Rect rect(trackers[i][0], trackers[i][1], trackers[i][2], trackers[i][3]);
// 获取检测的类别和置信度
// 在这里,我们假设检测的类别和置信度没有变化
int class_id = 0; // 假设分类 ID 为 0仅用于示例
float confidence = 0.8f; // 假设置信度为 0.8,仅用于示例
// 使用类别索引从颜色表中获取颜色
Scalar color = colors[class_id % colors.size()]; // 使用颜色表中的颜色
// 构建标签内容:类别名称 + 置信度
string label = "box";
// 绘制矩形框和类别标签
rectangle(frame, rect, color, 2); // 绘制矩形框
putText(frame, label, Point(rect.x, rect.y), FONT_HERSHEY_SIMPLEX, 1, color, 2); // 显示类别名称和置信度
}
}
else {
// 如果不是检测帧,使用 SORT 进行目标预测和跟踪
vector<STrack> object_trackers = tracker.update(objects);
// 绘制跟踪结果
vector<vector<float>> trackers;
for (int i = 0; i < object_trackers.size(); i++) {
trackers.push_back(object_trackers[i].tlwh);
}
// 绘制跟踪结果
for (int i = 0; i < trackers.size(); i++) {
Rect rect(trackers[i][0], trackers[i][1], trackers[i][2], trackers[i][3]);
// 获取检测的类别和置信度
// 在这里,我们假设检测的类别和置信度没有变化
int class_id = 0; // 假设分类 ID 为 0仅用于示例
float confidence = 0.8f; // 假设置信度为 0.8,仅用于示例
// 使用类别索引从颜色表中获取颜色
Scalar color = colors[class_id % colors.size()]; // 使用颜色表中的颜色
// 构建标签内容:类别名称 + 置信度
string label = "box";
// 绘制矩形框和类别标签
rectangle(frame, rect, color, 2); // 绘制矩形框
putText(frame, label, Point(rect.x, rect.y), FONT_HERSHEY_SIMPLEX, 1, color, 2); // 显示类别名称和置信度
}
}
// 显示并保存帧
imshow("img", frame);
writer << frame;
if (waitKey(10) == 27) {
break; // 退出循环
}
frame_counter++;
num++;
}
end = clock();
cout << "time = " << double(end - start) / CLOCKS_PER_SEC << "s" << endl; // 输出时间单位s
cout << "frame num: " << num << endl;
cout << "Speed = " << double(end - start) * 1000 / CLOCKS_PER_SEC / num << "ms" << endl; // 输出时间单位ms
cap.release();
destroyAllWindows();
writer.release();
return 0;
}
/*
#include "yolo.h"
#include <iostream>
#include<opencv2//opencv.hpp>
#include<math.h>
#include <time.h>
using namespace std;
using namespace cv;
using namespace dnn;
int main()
{
string img_path = "D:/VS/yoloTest/images/2.mp4";
string model = "D:/VS/yoloTest/models/yolov5s.onnx";
Yolov5 test;
Net net;
if (test.readModel(net, model, false)) {
cout << "read net ok!" << endl;
}
else {
return -1;
}
//生成随机颜色
vector<Scalar> color;
srand(time(0));
for (int i = 0; i < 80; i++) {
int b = rand() % 256;
int g = rand() % 256;
int r = rand() % 256;
color.push_back(Scalar(b, g, r));
}
vector<Output> result;
VideoCapture cap(img_path); //读取图像/视频
VideoWriter writer;
writer = VideoWriter("D:/VS/yoloTest/images/test.mp4", CAP_OPENCV_MJPEG, 20, Size(2560, 1440), true); //将结果保存为视频
clock_t start, end; //计时器
start = clock();
int num = 0;
Mat frame;
while (cap.read(frame)) {
vector<Output> result;
if (test.Detect(frame, net, result)) {
test.drawPred(frame, result, color);
}
imshow("img", frame);
writer << frame;
if (waitKey(10) == 27) {
break; // 退出循环
}
num++;
}
end = clock();
cout << "time = " << double(end - start) / CLOCKS_PER_SEC << "s" << endl; //输出时间(单位:s)
cout << "frame num: " << num << endl;
cout << "Speed = " << double(end - start)*1000 / CLOCKS_PER_SEC/num << "ms" << endl; //输出时间单位m
cap.release();
destroyAllWindows();
system("pause");
return 0;
}
*/