Guide-Vest/include/Yolo.hpp

92 lines
2.7 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#ifndef __YOLO_HPP__
#define __YOLO_HPP__
#include <future>
#include <memory>
#include <string>
#include <vector>
#include <opencv2/opencv.hpp>
// yolov8 instance segmentation表示不同的模型类型
enum class Type : int {
V5 = 0,
X = 1,
V3 = 2,
V7 = 3,
V8 = 5,
V8Seg = 6
};
//用于存储实例分割的结果
struct InstanceSegmentMap {
int width = 0, height = 0; // width % 8 == 0
unsigned char *data = nullptr; // is width * height memory
InstanceSegmentMap(int width, int height);
virtual ~InstanceSegmentMap();
};
//表示检测到的物体边界框
struct Box {
float left, top, right, bottom, confidence;
int classLabel;
std::shared_ptr<InstanceSegmentMap> seg; // valid only in segment task
Box() = default;
Box(float left, float top, float right, float bottom, float confidence, int label)
: left(left),
top(top),
right(right),
bottom(bottom),
confidence(confidence),
classLabel(label) {}
};
//表示输入图像。它包含图像数据指针、宽度和高度。
struct Image {
const void *bgrptr = nullptr;
int width = 0, height = 0;
Image() = default;
Image(const void *bgrptr, int width, int height) : bgrptr(bgrptr), width(width), height(height) {}
};
//存储检测结果的 Box 结构体的容器类型
typedef std::vector<Box> BoxArray;
// [Preprocess]: 0.50736 ms
// [Forward]: 3.96410 ms
// [BoxDecode]: 0.12016 ms
// [SegmentDecode]: 0.15610 ms
//用于执行模型推理并返回检测结果
class Infer {
public:
virtual BoxArray forward(const Image &image, void *stream = nullptr) = 0;
virtual std::vector<BoxArray> forwards(const std::vector<Image> &images,
void *stream = nullptr) = 0;
};
//用于加载指定类型的模型,
std::shared_ptr<Infer> load(const std::string &engine_file, Type type,
float confidence_threshold = 0.25f, float nms_threshold = 0.5f);
//返回指定模型类型的字符串表示。
const char *type_name(Type type);
std::tuple<uint8_t, uint8_t, uint8_t> hsv2bgr(float h, float s, float v);
std::tuple<uint8_t, uint8_t, uint8_t> random_color(int id);
//定义了一个 Yolo 类,用于加载 YOLOv8 模型并执行物体检测和实例分割。
class Yolo{
public:
Yolo(const std::string enginePath) {detector = load(enginePath, Type::V8Seg);};
static const char *cocoLabels[];
BoxArray detect(cv::Mat colorMat);
std::shared_ptr<Infer> detector;
//yolo-video_demo函数直接返回BoxArray类型vector<Box>的objs,可以直接使用obj.left等形式返回物体识别框信息
//可返回left, top, right, bottom, confidencebox_widthbox_height
};
#endif // __YOLO_HPP__