241 lines
6.8 KiB
C++
241 lines
6.8 KiB
C++
#include "yolo.h"
|
||
#include "kalmanFilter.h"
|
||
#include <iostream>
|
||
#include <opencv2/opencv.hpp>
|
||
#include <vector>
|
||
#include <time.h>
|
||
#include "BYTETracker.h"
|
||
using namespace std;
|
||
using namespace cv;
|
||
using namespace cv::dnn;
|
||
|
||
int main()
|
||
{
|
||
string img_path = "D:/VS/yoloTest/images/detect.mp4";
|
||
string model_path = "D:/VS/yoloTest/models/yolov5s.onnx";
|
||
|
||
// 创建Yolov5模型和跟踪器
|
||
Yolov5 yolo;
|
||
Net net;
|
||
BYTETracker tracker(30, 30); // 创建Sort跟踪器
|
||
|
||
// 读取Yolo模型
|
||
if (!yolo.readModel(net, model_path, true)) {
|
||
cout << "无法加载Yolo模型" << endl;
|
||
return -1;
|
||
}
|
||
|
||
// 生成颜色表,每个类别对应一个颜色
|
||
vector<Scalar> colors;
|
||
srand(42); // 固定随机种子,确保每次运行颜色一致
|
||
for (int i = 0; i < 80; i++) { // 假设最多80个类别
|
||
int b = rand() % 256;
|
||
int g = rand() % 256;
|
||
int r = rand() % 256;
|
||
colors.push_back(Scalar(b, g, r));
|
||
}
|
||
|
||
// 读取视频
|
||
VideoCapture cap(img_path);
|
||
if (!cap.isOpened()) {
|
||
cout << "无法打开视频文件" << endl;
|
||
return -1;
|
||
}
|
||
|
||
VideoWriter writer;
|
||
writer = VideoWriter("D:/VS/yoloTest/images/test.mp4", CAP_OPENCV_MJPEG, 30, Size(1920, 1080), true);
|
||
|
||
clock_t start, end; //计时器
|
||
start = clock();
|
||
int num = 0;
|
||
|
||
Mat frame;
|
||
int frame_counter = 0;
|
||
int detect_interval = 3; // 每隔多少帧重新调用一次YOLO进行检测
|
||
vector<Rect> detection_rects;
|
||
vector<Object> objects;
|
||
|
||
while (cap.read(frame)) {
|
||
if (frame.empty()) {
|
||
cout << "视频已结束" << endl;
|
||
break;
|
||
}
|
||
|
||
vector<Output> result;
|
||
|
||
// 每隔 detect_interval 帧进行一次 YOLO 检测
|
||
if (frame_counter % detect_interval == 0) {
|
||
detection_rects.clear(); // 清空之前的检测框
|
||
objects.clear();
|
||
|
||
// 进行目标检测
|
||
if (yolo.Detect(frame, net, result)) {
|
||
for (int i = 0; i < result.size(); i++) {
|
||
int x = result[i].box.x;
|
||
int y = result[i].box.y;
|
||
int w = result[i].box.width;
|
||
int h = result[i].box.height;
|
||
Rect rect(x, y, w, h);
|
||
Object object;
|
||
object.label = result[i].id;
|
||
object.prob = result[i].confidence;
|
||
object.rect = rect;
|
||
objects.push_back(object);
|
||
}
|
||
}
|
||
|
||
// 将检测结果传递给 SORT 进行初始化或更新
|
||
vector<STrack> object_trackers = tracker.update(objects);
|
||
// 绘制跟踪结果
|
||
vector<vector<float>> trackers;
|
||
for (int i = 0; i < object_trackers.size(); i++) {
|
||
trackers.push_back(object_trackers[i].tlwh);
|
||
}
|
||
|
||
|
||
for (int i = 0; i < trackers.size(); i++) {
|
||
Rect rect(trackers[i][0], trackers[i][1], trackers[i][2], trackers[i][3]);
|
||
|
||
// 获取检测的类别和置信度
|
||
// 在这里,我们假设检测的类别和置信度没有变化
|
||
int class_id = 0; // 假设分类 ID 为 0,仅用于示例
|
||
float confidence = 0.8f; // 假设置信度为 0.8,仅用于示例
|
||
|
||
// 使用类别索引从颜色表中获取颜色
|
||
Scalar color = colors[class_id % colors.size()]; // 使用颜色表中的颜色
|
||
|
||
// 构建标签内容:类别名称 + 置信度
|
||
string label = "box";
|
||
|
||
// 绘制矩形框和类别标签
|
||
rectangle(frame, rect, color, 2); // 绘制矩形框
|
||
putText(frame, label, Point(rect.x, rect.y), FONT_HERSHEY_SIMPLEX, 1, color, 2); // 显示类别名称和置信度
|
||
}
|
||
}
|
||
else {
|
||
// 如果不是检测帧,使用 SORT 进行目标预测和跟踪
|
||
vector<STrack> object_trackers = tracker.update(objects);
|
||
// 绘制跟踪结果
|
||
vector<vector<float>> trackers;
|
||
for (int i = 0; i < object_trackers.size(); i++) {
|
||
trackers.push_back(object_trackers[i].tlwh);
|
||
}
|
||
|
||
// 绘制跟踪结果
|
||
for (int i = 0; i < trackers.size(); i++) {
|
||
Rect rect(trackers[i][0], trackers[i][1], trackers[i][2], trackers[i][3]);
|
||
|
||
// 获取检测的类别和置信度
|
||
// 在这里,我们假设检测的类别和置信度没有变化
|
||
int class_id = 0; // 假设分类 ID 为 0,仅用于示例
|
||
float confidence = 0.8f; // 假设置信度为 0.8,仅用于示例
|
||
|
||
// 使用类别索引从颜色表中获取颜色
|
||
Scalar color = colors[class_id % colors.size()]; // 使用颜色表中的颜色
|
||
|
||
// 构建标签内容:类别名称 + 置信度
|
||
string label = "box";
|
||
|
||
// 绘制矩形框和类别标签
|
||
rectangle(frame, rect, color, 2); // 绘制矩形框
|
||
putText(frame, label, Point(rect.x, rect.y), FONT_HERSHEY_SIMPLEX, 1, color, 2); // 显示类别名称和置信度
|
||
}
|
||
}
|
||
|
||
// 显示并保存帧
|
||
imshow("img", frame);
|
||
writer << frame;
|
||
|
||
if (waitKey(10) == 27) {
|
||
break; // 退出循环
|
||
}
|
||
|
||
frame_counter++;
|
||
num++;
|
||
}
|
||
|
||
end = clock();
|
||
cout << "time = " << double(end - start) / CLOCKS_PER_SEC << "s" << endl; // 输出时间(单位:s)
|
||
cout << "frame num: " << num << endl;
|
||
cout << "Speed = " << double(end - start) * 1000 / CLOCKS_PER_SEC / num << "ms" << endl; // 输出时间(单位:ms)
|
||
|
||
cap.release();
|
||
destroyAllWindows();
|
||
writer.release();
|
||
|
||
return 0;
|
||
}
|
||
|
||
/*
|
||
#include "yolo.h"
|
||
#include <iostream>
|
||
#include<opencv2//opencv.hpp>
|
||
#include<math.h>
|
||
#include <time.h>
|
||
|
||
using namespace std;
|
||
using namespace cv;
|
||
using namespace dnn;
|
||
|
||
int main()
|
||
{
|
||
string img_path = "D:/VS/yoloTest/images/2.mp4";
|
||
string model = "D:/VS/yoloTest/models/yolov5s.onnx";
|
||
|
||
Yolov5 test;
|
||
Net net;
|
||
if (test.readModel(net, model, false)) {
|
||
cout << "read net ok!" << endl;
|
||
}
|
||
else {
|
||
return -1;
|
||
}
|
||
|
||
//生成随机颜色
|
||
vector<Scalar> color;
|
||
srand(time(0));
|
||
for (int i = 0; i < 80; i++) {
|
||
int b = rand() % 256;
|
||
int g = rand() % 256;
|
||
int r = rand() % 256;
|
||
color.push_back(Scalar(b, g, r));
|
||
}
|
||
vector<Output> result;
|
||
|
||
VideoCapture cap(img_path); //读取图像/视频
|
||
|
||
VideoWriter writer;
|
||
writer = VideoWriter("D:/VS/yoloTest/images/test.mp4", CAP_OPENCV_MJPEG, 20, Size(2560, 1440), true); //将结果保存为视频
|
||
|
||
clock_t start, end; //计时器
|
||
start = clock();
|
||
int num = 0;
|
||
|
||
Mat frame;
|
||
while (cap.read(frame)) {
|
||
|
||
vector<Output> result;
|
||
if (test.Detect(frame, net, result)) {
|
||
test.drawPred(frame, result, color);
|
||
}
|
||
|
||
imshow("img", frame);
|
||
writer << frame;
|
||
if (waitKey(10) == 27) {
|
||
break; // 退出循环
|
||
}
|
||
num++;
|
||
}
|
||
end = clock();
|
||
cout << "time = " << double(end - start) / CLOCKS_PER_SEC << "s" << endl; //输出时间(单位:s)
|
||
cout << "frame num: " << num << endl;
|
||
cout << "Speed = " << double(end - start)*1000 / CLOCKS_PER_SEC/num << "ms" << endl; //输出时间(单位:ms)
|
||
|
||
cap.release();
|
||
destroyAllWindows();
|
||
|
||
system("pause");
|
||
return 0;
|
||
}
|
||
|
||
*/ |