266 lines
8.1 KiB
C++
266 lines
8.1 KiB
C++
#include "yolo.h"
|
||
#include "kalmanboxtracker.h"
|
||
#include <iostream>
|
||
#include <opencv2/opencv.hpp>
|
||
#include <vector>
|
||
#include <time.h>
|
||
#include "sort.h"
|
||
using namespace std;
|
||
using namespace cv;
|
||
using namespace cv::dnn;
|
||
|
||
int main()
|
||
{
|
||
string img_path = "D:/VS/yoloSORT/images/test.mp4";
|
||
string model_path = "D:/VS/yoloSORT/models/yolov5s.onnx";
|
||
|
||
// 创建Yolov5模型和跟踪器
|
||
Yolov5 yolo;
|
||
Net net;
|
||
Sort mot_tracker = Sort(1, 3, 0.3); // 创建Sort跟踪器
|
||
|
||
// 读取Yolo模型
|
||
if (!yolo.readModel(net, model_path, true)) {
|
||
cout << "无法加载Yolo模型" << endl;
|
||
return -1;
|
||
}
|
||
|
||
// 生成颜色表,每个类别对应一个颜色
|
||
vector<Scalar> colors;
|
||
srand(42); // 固定随机种子,确保每次运行颜色一致
|
||
for (int i = 0; i < 80; i++) { // 假设最多80个类别
|
||
int b = rand() % 256;
|
||
int g = rand() % 256;
|
||
int r = rand() % 256;
|
||
colors.push_back(Scalar(b, g, r));
|
||
}
|
||
|
||
// 读取视频
|
||
VideoCapture cap(img_path);
|
||
if (!cap.isOpened()) {
|
||
cout << "无法打开视频文件" << endl;
|
||
return -1;
|
||
}
|
||
|
||
VideoWriter writer;
|
||
writer = VideoWriter("D:/VS/yoloSORT/images/detect.mp4", CAP_OPENCV_MJPEG, 20, Size(2560, 1440), true);
|
||
|
||
clock_t start, end; //计时器
|
||
start = clock();
|
||
int num = 0;
|
||
|
||
Mat frame;
|
||
int frame_counter = 0;
|
||
int detect_interval = 3; // 每隔多少帧重新调用一次YOLO进行检测
|
||
vector<Rect> detection_rects;
|
||
|
||
while (cap.read(frame)) {
|
||
cap >> frame;
|
||
|
||
if (frame.empty()) {
|
||
cout << "视频已结束" << endl;
|
||
break;
|
||
}
|
||
|
||
vector<Output> result;
|
||
|
||
// 每隔 detect_interval 帧进行一次 YOLO 检测
|
||
if (frame_counter % detect_interval == 0) {
|
||
detection_rects.clear(); // 清空之前的检测框
|
||
|
||
// 进行目标检测
|
||
if (yolo.Detect(frame, net, result)) {
|
||
for (int i = 0; i < result.size(); i++) {
|
||
int x = result[i].box.x;
|
||
int y = result[i].box.y;
|
||
int w = result[i].box.width;
|
||
int h = result[i].box.height;
|
||
Rect rect(x, y, w, h);
|
||
detection_rects.push_back(rect);
|
||
}
|
||
}
|
||
|
||
// 将检测结果传递给 SORT 进行初始化或更新
|
||
vector<vector<float>> trackers = mot_tracker.update(detection_rects);
|
||
// 绘制跟踪结果
|
||
for (int i = 0; i < trackers.size(); i++) {
|
||
Rect rect(trackers[i][0], trackers[i][1], trackers[i][2] - trackers[i][0], trackers[i][3] - trackers[i][1]);
|
||
|
||
// 获取检测的类别和置信度
|
||
// 在这里,我们假设检测的类别和置信度没有变化
|
||
int class_id = 0; // 假设分类 ID 为 0,仅用于示例
|
||
float confidence = 0.8f; // 假设置信度为 0.8,仅用于示例
|
||
|
||
// 使用类别索引从颜色表中获取颜色
|
||
Scalar color = colors[class_id % colors.size()]; // 使用颜色表中的颜色
|
||
|
||
// 构建标签内容:类别名称 + 置信度
|
||
string label = "box";
|
||
|
||
// 绘制矩形框和类别标签
|
||
rectangle(frame, rect, color, 2); // 绘制矩形框
|
||
putText(frame, label, Point(rect.x, rect.y), FONT_HERSHEY_SIMPLEX, 1, color, 2); // 显示类别名称和置信度
|
||
}
|
||
}
|
||
else {
|
||
// 如果不是检测帧,使用 SORT 进行目标预测和跟踪
|
||
vector<vector<float>> trackers = mot_tracker.update(detection_rects);
|
||
|
||
// 绘制跟踪结果
|
||
for (int i = 0; i < trackers.size(); i++) {
|
||
Rect rect(trackers[i][0], trackers[i][1], trackers[i][2] - trackers[i][0], trackers[i][3] - trackers[i][1]);
|
||
|
||
// 获取检测的类别和置信度
|
||
// 在这里,我们假设检测的类别和置信度没有变化
|
||
int class_id = 0; // 假设分类 ID 为 0,仅用于示例
|
||
float confidence = 0.8f; // 假设置信度为 0.8,仅用于示例
|
||
|
||
// 使用类别索引从颜色表中获取颜色
|
||
Scalar color = colors[class_id % colors.size()]; // 使用颜色表中的颜色
|
||
|
||
// 构建标签内容:类别名称 + 置信度
|
||
string label = "box";
|
||
|
||
// 绘制矩形框和类别标签
|
||
rectangle(frame, rect, color, 2); // 绘制矩形框
|
||
putText(frame, label, Point(rect.x, rect.y), FONT_HERSHEY_SIMPLEX, 1, color, 2); // 显示类别名称和置信度
|
||
}
|
||
}
|
||
|
||
// 显示并保存帧
|
||
imshow("img", frame);
|
||
writer << frame;
|
||
|
||
if (waitKey(10) == 27) {
|
||
break; // 退出循环
|
||
}
|
||
|
||
frame_counter++;
|
||
num++;
|
||
}
|
||
|
||
end = clock();
|
||
cout << "time = " << double(end - start) / CLOCKS_PER_SEC << "s" << endl; // 输出时间(单位:s)
|
||
cout << "frame num: " << num << endl;
|
||
cout << "Speed = " << double(end - start) * 1000 / CLOCKS_PER_SEC / num << "ms" << endl; // 输出时间(单位:ms)
|
||
|
||
cap.release();
|
||
destroyAllWindows();
|
||
|
||
return 0;
|
||
}
|
||
|
||
|
||
/*
|
||
#include "yolo.h"
|
||
#include "kalmanboxtracker.h"
|
||
#include <iostream>
|
||
#include <opencv2/opencv.hpp>
|
||
#include <vector>
|
||
#include <time.h>
|
||
#include "sort.h"
|
||
using namespace std;
|
||
using namespace cv;
|
||
using namespace cv::dnn;
|
||
|
||
int main()
|
||
{
|
||
string img_path = "D:/VS/yoloTest/images/1.mp4";
|
||
string model_path = "D:/VS/yoloTest/models/yolov5s.onnx";
|
||
|
||
// 创建Yolov5模型和跟踪器
|
||
Yolov5 yolo;
|
||
Net net;
|
||
Sort mot_tracker = Sort(1, 3, 0.3); // 创建Sort跟踪器
|
||
|
||
// 读取Yolo模型
|
||
if (!yolo.readModel(net, model_path, true)) {
|
||
cout << "无法加载Yolo模型" << endl;
|
||
return -1;
|
||
}
|
||
|
||
// 生成随机颜色
|
||
srand(time(0));
|
||
vector<Scalar> colors;
|
||
for (int i = 0; i < 80; i++) {
|
||
int b = rand() % 256;
|
||
int g = rand() % 256;
|
||
int r = rand() % 256;
|
||
colors.push_back(Scalar(b, g, r));
|
||
}
|
||
|
||
// 读取视频
|
||
VideoCapture cap(img_path);
|
||
if (!cap.isOpened()) {
|
||
cout << "无法打开视频文件" << endl;
|
||
return -1;
|
||
}
|
||
|
||
VideoWriter writer;
|
||
writer = VideoWriter("D:/VS/yoloTest/images/test.mp4", CAP_OPENCV_MJPEG, 20, Size(2560, 1440), true);
|
||
|
||
clock_t start, end; //计时器
|
||
start = clock();
|
||
int num = 0;
|
||
|
||
Mat frame;
|
||
while (cap.read(frame)) {
|
||
cap >> frame;
|
||
|
||
if (frame.empty()) {
|
||
cout << "视频已结束" << endl;
|
||
break;
|
||
}
|
||
// 进行目标检测
|
||
vector<Output> result;
|
||
if (yolo.Detect(frame, net, result)) {
|
||
// 跟踪已检测到的目标
|
||
vector<Rect> detection_rects; // 存储检测结果的矩形框
|
||
for (int i = 0; i < result.size(); i++) {
|
||
int x = result[i].box.x;
|
||
int y = result[i].box.y;
|
||
int w = result[i].box.width;
|
||
int h = result[i].box.height;
|
||
Rect rect(x, y, w, h);
|
||
detection_rects.push_back(rect);
|
||
}
|
||
// 更新目标跟踪器
|
||
vector<vector<float>> trackers = mot_tracker.update(detection_rects);//x,y,w,h,id
|
||
// 绘制跟踪结果
|
||
for (int i = 0; i < trackers.size(); i++) {
|
||
Rect rect(trackers[i][0], trackers[i][1], trackers[i][2] - trackers[i][0], trackers[i][3] - trackers[i][1]);
|
||
//cout << "0:" << trackers[i][0] << endl; //跟踪目标的 x 坐标(左上角的 x 坐标)。
|
||
//cout << "1:" << trackers[i][1] << endl; //跟踪目标的 y 坐标(左上角的 y 坐标)。
|
||
//cout << "2:" << trackers[i][2] << endl; //跟踪目标的宽度。
|
||
//cout << "3:" << trackers[i][3] << endl; //跟踪目标的高度。
|
||
//cout << "4:" << trackers[i][4] << endl; //跟踪目标的 ID。
|
||
int id = static_cast<int>(trackers[i][4]);
|
||
//cout << "id:" << id << endl;
|
||
//string label = yolo.className[result[id].id] + ":" + to_string(result[id].confidence);
|
||
rectangle(frame, rect, Scalar(0, 255, 0), 2);
|
||
cout << trackers[i][5] << endl;
|
||
putText(frame, "id:" + to_string(int(trackers[i][4])), Point(rect.x, rect.y), FONT_HERSHEY_SIMPLEX, 1, Scalar(0, 255, 0), 2);
|
||
}
|
||
|
||
}
|
||
|
||
imshow("img", frame);
|
||
writer << frame;
|
||
if (waitKey(10) == 27) {
|
||
break; // 退出循环
|
||
}
|
||
|
||
num++;
|
||
}
|
||
|
||
end = clock();
|
||
cout << "time = " << double(end - start) / CLOCKS_PER_SEC << "s" << endl; //输出时间(单位:s)
|
||
cout << "frame num: " << num << endl;
|
||
cout << "Speed = " << double(end - start) * 1000 / CLOCKS_PER_SEC / num << "ms" << endl; //输出时间(单位:ms)
|
||
|
||
cap.release();
|
||
destroyAllWindows();
|
||
|
||
return 0;
|
||
}
|
||
*/ |