更新Demo
|
@ -0,0 +1,89 @@
|
|||
import cv2
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms, models
|
||||
import torch.nn as nn
|
||||
|
||||
class AgeGenderPredictor:
|
||||
def __init__(self, model_path):
|
||||
self.model = self.load_model(model_path)
|
||||
self.gender_labels=['Female','Male']
|
||||
|
||||
|
||||
def load_model(self, model_path):
|
||||
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
|
||||
num_ftrs = model.fc.in_features
|
||||
model.fc = nn.Linear(num_ftrs, 3) # 输出为性别和年龄
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def preprocess_image(self, image):
|
||||
preprocess = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||||
input_tensor = preprocess(image)
|
||||
input_batch = input_tensor.unsqueeze(0)
|
||||
return input_batch
|
||||
|
||||
def predict(self, face):
|
||||
input_batch = self.preprocess_image(face)
|
||||
if torch.cuda.is_available():
|
||||
input_batch = input_batch.to('cuda')
|
||||
self.model.to('cuda')
|
||||
|
||||
with torch.no_grad():
|
||||
output = self.model(input_batch)
|
||||
gender_preds = output[:, :2]
|
||||
age_preds = output[:, -1]
|
||||
gender = gender_preds.argmax(dim=1).item()
|
||||
age = age_preds.item()
|
||||
return self.gender_labels[gender], age, self.age_group(age)
|
||||
|
||||
def age_group(self, age):
|
||||
if age <= 18:
|
||||
return 'Teenager'
|
||||
elif age <= 59:
|
||||
return 'Adult'
|
||||
else:
|
||||
return 'Senior'
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 创建 AgeGenderPredictor 类的实例
|
||||
predictor = AgeGenderPredictor('../../weights/megaage_model_epoch99.pth')
|
||||
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
|
||||
# 打开摄像头
|
||||
cap = cv2.VideoCapture(0)
|
||||
|
||||
while True:
|
||||
# 读取一帧
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
# 进行人脸检测
|
||||
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
||||
faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
|
||||
|
||||
# 对于检测到的每一个人脸
|
||||
for (x, y, w, h) in faces:
|
||||
# 提取人脸 ROI
|
||||
face = frame[y:y + h, x:x + w]
|
||||
gender, age, age_group = predictor.predict(face)
|
||||
|
||||
cv2.putText(frame, f'Gender: {gender}, Age: {int(age)}, Age Group: {age_group}', (x, y - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.45, (0, 255, 0), 2)
|
||||
cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
|
||||
|
||||
# 显示帧
|
||||
cv2.imshow('Webcam', frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
|
||||
# 释放摄像头并关闭所有窗口
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
|
@ -0,0 +1,32 @@
|
|||
# 基于视觉的年龄性别预测系统
|
||||
|
||||
该项目是一个基于图像的年龄和性别预测系统。它使用ResNet50模型在MegaAge-Asian数据集上进行训练,然后可以从摄像头输入的视频中检测人脸,并为每个检测到的人脸预测年龄、性别和年龄组。
|
||||
|
||||
## 文件结构
|
||||
|
||||
- `AgeGenderPredictor.py`: 包含年龄性别预测模型的加载、预处理和推理逻辑。
|
||||
- `megaage_model_epoch99.pth`: 在MegaAge-Asian数据集上训练的模型权重文件。
|
||||
|
||||
## 使用方法
|
||||
|
||||
1. 确保已安装所需的Python库,包括`opencv-python`、`torch`、`torchvision`和`Pillow`。
|
||||
2. 运行`AgeGenderPredictor.py`脚本。
|
||||
3. 脚本将打开默认摄像头,开始人脸检测和年龄性别预测。
|
||||
4. 检测到的人脸周围会用矩形框标注,并显示预测的性别、年龄和年龄组信息。
|
||||
5. 按`q`键退出程序。
|
||||
|
||||
## 模型介绍
|
||||
|
||||
该项目使用ResNet50作为基础模型,对MegaAge-Asian数据集进行训练,以预测人脸图像的年龄和性别。最终模型输出包含3个值,分别对应男性概率、女性概率和估计年龄值。
|
||||
|
||||
### MegaAge-Asian数据集
|
||||
|
||||
MegaAge-Asian是一个大规模的人脸图像数据集,由商汤发布,总数有40000张图像。数据集中的图像包含了不同年龄和性别的亚洲人脸,年龄范围从1岁到70岁。
|
||||
|
||||
## 算法流程
|
||||
|
||||
1. **人脸检测**: 使用OpenCV内置的Haar级联人脸检测器在视频帧中检测人脸。
|
||||
2. **预处理**: 对检测到的人脸图像进行缩放、裁剪和标准化等预处理,以满足模型的输入要求。
|
||||
3. **推理**: 将预处理后的图像输入到预训练的ResNet50模型中,获得性别概率和年龄值的预测结果。
|
||||
4. **后处理**: 根据性别概率确定性别标签,将年龄值映射到具体的年龄组。
|
||||
5. **可视化**: 在视频帧上绘制人脸矩形框,并显示预测的性别、年龄和年龄组信息。
|
|
@ -0,0 +1,151 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, current_dir)
|
||||
|
||||
from collections import OrderedDict
|
||||
import torch
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
from scipy.signal import butter, filtfilt
|
||||
import pywt
|
||||
from .models.lstm import LSTMModel
|
||||
|
||||
|
||||
class BPModel:
|
||||
def __init__(self, model_path, fps=30):
|
||||
self.fps = fps
|
||||
|
||||
self.model = LSTMModel()
|
||||
|
||||
self.load_model(model_path)
|
||||
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
self.model = self.model.to(self.device)
|
||||
self.model.eval()
|
||||
self.warmup()
|
||||
|
||||
def predict(self, frames):
|
||||
|
||||
yg, g, t = self.process_frame_sequence(frames, self.fps)
|
||||
|
||||
yg = yg.reshape(1, -1, 1)
|
||||
inputs = torch.tensor(yg.copy(), dtype=torch.float32)
|
||||
inputs = inputs.to(self.device)
|
||||
with torch.no_grad():
|
||||
sbp_outputs, dbp_outputs = self.model(inputs)
|
||||
sbp_outputs = sbp_outputs.cpu().detach().numpy().item()
|
||||
dbp_outputs = dbp_outputs.cpu().detach().numpy().item()
|
||||
|
||||
return sbp_outputs, dbp_outputs
|
||||
|
||||
def load_model(self, model_path):
|
||||
|
||||
model_state_dict = torch.load(model_path)
|
||||
|
||||
# 判断model_state_dict的类型是否是OrderedDict
|
||||
if not isinstance(model_state_dict, OrderedDict):
|
||||
# model_state_dict=model_state_dict.state_dict()
|
||||
# 若不是OrderedDict类型,则为LSMTModel类型,直接加载
|
||||
self.model = model_state_dict
|
||||
return
|
||||
|
||||
# 判断是否是多GPU训练的模型
|
||||
if 'module' in model_state_dict.keys():
|
||||
self.model.load_state_dict(model_state_dict['module'])
|
||||
else:
|
||||
# 遍历模型参数,判断参数前是否有module.
|
||||
new_state_dict = {}
|
||||
for k, v in model_state_dict.items():
|
||||
if 'module.' in k:
|
||||
name = k[7:]
|
||||
else:
|
||||
name = k
|
||||
new_state_dict[name] = v
|
||||
self.model.load_state_dict(new_state_dict)
|
||||
|
||||
# 模型预热
|
||||
def warmup(self):
|
||||
inputs = torch.randn(10, 250, 1)
|
||||
inputs = inputs.to(self.device)
|
||||
with torch.no_grad():
|
||||
self.model(inputs)
|
||||
|
||||
def wavelet_detrend(self, signal, wavelet='sym6', level=6):
|
||||
"""
|
||||
小波分解和基线漂移去除
|
||||
|
||||
参数:
|
||||
signal (numpy.ndarray): 输入信号
|
||||
wavelet (str): 小波基函数名称,默认为'sym6'
|
||||
level (int): 小波分解层数,默认为6
|
||||
|
||||
返回:
|
||||
detrended_signal (numpy.ndarray): 去除基线漂移后的信号
|
||||
"""
|
||||
# 执行小波分解
|
||||
coeffs = pywt.wavedec(signal, wavelet, level=level)
|
||||
|
||||
# 获取第六层近似分量(基线漂移)
|
||||
cA6 = coeffs[0]
|
||||
|
||||
# 重构信号,去除基线漂移
|
||||
coeffs[0] = np.zeros_like(cA6) # 将基线漂移分量置为零
|
||||
detrended_signal = pywt.waverec(coeffs, wavelet)
|
||||
|
||||
return detrended_signal
|
||||
|
||||
def butter_bandpass(self, lowcut, highcut, fs, order=5):
|
||||
nyq = 0.5 * fs
|
||||
low = lowcut / nyq
|
||||
high = highcut / nyq
|
||||
b, a = butter(order, [low, high], btype='band')
|
||||
return b, a
|
||||
|
||||
def butter_bandpass_filter(self, data, lowcut, highcut, fs, order=5):
|
||||
b, a = self.butter_bandpass(lowcut, highcut, fs, order=order)
|
||||
y = filtfilt(b, a, data)
|
||||
return y
|
||||
|
||||
def process_frame_sequence(self, frames, fps):
|
||||
"""
|
||||
处理帧序列
|
||||
|
||||
参数:
|
||||
frames (list): 包含所有帧的列表,每一帧为numpy.ndarray
|
||||
|
||||
返回:
|
||||
t (list): 时间序列(秒),从0开始
|
||||
yg (numpy.ndarray): 处理后的绿色通道数据
|
||||
green (numpy.ndarray): 原始绿色通道数据
|
||||
"""
|
||||
all_frames = frames
|
||||
|
||||
green = []
|
||||
for frame in all_frames:
|
||||
r, g, b = (frame.mean(axis=0)).mean(axis=0)
|
||||
green.append(g)
|
||||
|
||||
t = [i / fps for i in range(len(all_frames))]
|
||||
|
||||
g_detrended = self.wavelet_detrend(green)
|
||||
lowcut = 0.6
|
||||
highcut = 8
|
||||
datag = g_detrended
|
||||
yg = self.butter_bandpass_filter(datag, lowcut, highcut, fps, order=4)
|
||||
|
||||
# self.plot(green, t, 'Original Green Channel',color='green')
|
||||
# self.plot(g_detrended, t, 'Detrended Green Channel', color='red')
|
||||
# self.plot(yg, t, 'Filtered Green Channel',color='blue')
|
||||
|
||||
return yg, green, t
|
||||
|
||||
def plot(self, yg, t, title, color='green', figsize=(30, 10)):
|
||||
plt.figure(figsize=figsize)
|
||||
plt.plot(t, yg, label=title, color=color)
|
||||
plt.xlabel('Time (s)')
|
||||
plt.ylabel('Amplitude')
|
||||
plt.legend()
|
||||
plt.show()
|
|
@ -0,0 +1,63 @@
|
|||
# 基于rPPG的血压估计系统
|
||||
|
||||
该项目是一个基于远程光电容积脉搏波描记法(rPPG)的血压估计系统。它使用LSTM神经网络在多个rPPG数据集上进行训练,然后可以从视频中提取光学脉冲信号,并预测个体的收缩压(SBP)和舒张压(DBP)值。
|
||||
|
||||
## 核心文件
|
||||
|
||||
- `BPApi.py`: 包含BP估计模型的核心逻辑,如信号预处理、模型推理等。
|
||||
- `lstm.py`: 定义了用于BP估计的LSTM神经网络架构。
|
||||
- `video.py`: 视频处理、人脸检测和BP估计的主要脚本。
|
||||
- `best_model.pth`: 在多个数据集上训练的最佳模型权重文件。
|
||||
|
||||
## 使用方法
|
||||
|
||||
1. 确保已安装所需的Python库,包括`opencv-python`、`torch`、`numpy`、`scipy`和`pywavelets`。
|
||||
2. 运行`video.py`脚本。
|
||||
3. 脚本将打开默认摄像头,开始人脸检测和BP估计。
|
||||
4. 检测到的人脸区域将被提取用于BP估计,预测结果将显示在视频流窗口中。
|
||||
5. 按`q`键退出程序。
|
||||
|
||||
## 模型介绍
|
||||
|
||||
该项目使用LSTM神经网络作为基础模型,使用大规模PPG信号数据集进行预训练,并进一步使用rPPG信号数据集进行微调,以预测个体的SBP和DBP值。模型输出包含两个值,分别对应SBP和DBP的预测值。
|
||||
|
||||
### 数据集介绍
|
||||
|
||||
该项目使用了以下三个公开的rPPG数据集进行训练:
|
||||
|
||||
1. **MIMIC-III数据集**: 包含9054000条PPG信号序列和对应的SBP/DBP标签。
|
||||
2. **UKL-rPPG数据集**: 包含7851条rPPG信号序列和对应的SBP/DBP标签。
|
||||
3. **iPPG-BP数据集**: 包含2120条rPPG信号序列和对应的SBP/DBP标签。
|
||||
|
||||
## 算法流程
|
||||
|
||||
1. **视频采集**:
|
||||
- 使用OpenCV库初始化视频捕捉对象,并获取视频的帧率。
|
||||
|
||||
2. **人脸检测**:
|
||||
- 在每一帧上使用Haar级联人脸检测器进行人脸检测。
|
||||
- 如果检测到人脸,获取人脸区域的边界框坐标。
|
||||
|
||||
3. **帧序列提取**:
|
||||
- 维护一个固定长度(如250帧)的循环队列,用于存储最近的人脸帧序列。
|
||||
- 对于新检测到的人脸,将其添加到队列中。
|
||||
|
||||
4. **信号预处理**:
|
||||
- 当队列满时,执行以下预处理步骤:
|
||||
- 从人脸帧序列中提取绿色通道信号。
|
||||
- 使用小波变换进行去趋势,消除基线漂移。
|
||||
- 使用带通滤波器去除高频和低频噪声,保留有效的脉搏频率范围。
|
||||
|
||||
5. **推理**:
|
||||
- 将预处理后的绿色通道信号输入到LSTM神经网络模型中。
|
||||
- 模型输出SBP和DBP的预测值。
|
||||
|
||||
6. **可视化**:
|
||||
- 在视频帧上绘制人脸边界框。
|
||||
- 在视频帧上显示预测的SBP和DBP值。
|
||||
|
||||
7. **持续循环**:
|
||||
- 对新的视频帧重复执行步骤2-6,持续进行人脸检测、BP估计和可视化。
|
||||
|
||||
8. **退出**:
|
||||
- 当用户按下特定按键(如'q')时,退出程序,关闭视频捕捉对象和所有窗口。
|
|
@ -0,0 +1,285 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from sklearn.model_selection import train_test_split
|
||||
import h5py
|
||||
|
||||
def custom_collate_fn(batch):
|
||||
X, y_SBP, y_DBP = zip(*batch)
|
||||
|
||||
X = torch.tensor(np.array(X), dtype=torch.float32)
|
||||
y_SBP = torch.tensor(y_SBP, dtype=torch.float32)
|
||||
y_DBP = torch.tensor(y_DBP, dtype=torch.float32)
|
||||
|
||||
return X, y_SBP, y_DBP
|
||||
|
||||
|
||||
class BPDataset(Dataset):
|
||||
def __init__(self, X_data, y_SBP, y_DBP):
|
||||
self.X_data = X_data
|
||||
self.y_SBP = y_SBP
|
||||
self.y_DBP = y_DBP
|
||||
|
||||
def __len__(self):
|
||||
return len(self.y_SBP)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# X_sample = self.X_data[idx * 250:(idx + 1) * 250]
|
||||
X_sample = self.X_data[idx]
|
||||
y_SBP_sample = self.y_SBP[idx]
|
||||
y_DBP_sample = self.y_DBP[idx]
|
||||
|
||||
return X_sample, y_SBP_sample, y_DBP_sample
|
||||
|
||||
|
||||
class BPDataLoader:
|
||||
def __init__(self, data_dir, val_split=0.2, batch_size=32, shuffle=True, data_type='npy'):
|
||||
self.data_dir = data_dir
|
||||
self.val_split = val_split
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
self.train_dataloader = None
|
||||
self.val_dataloader = None
|
||||
self.data_type = data_type
|
||||
|
||||
|
||||
def load_data(self):
|
||||
X_BP_path = os.path.join(self.data_dir, 'X_BP.npy')
|
||||
y_DBP_path = os.path.join(self.data_dir, 'Y_DBP.npy')
|
||||
y_SBP_path = os.path.join(self.data_dir, 'Y_SBP.npy')
|
||||
|
||||
X_BP = np.load(X_BP_path)
|
||||
# 将数据reshape成(batch_size, 250,1)的形状
|
||||
X_BP = X_BP.reshape(-1, 250, 1)
|
||||
|
||||
y_DBP = np.load(y_DBP_path)
|
||||
y_SBP = np.load(y_SBP_path)
|
||||
|
||||
return X_BP, y_DBP, y_SBP
|
||||
|
||||
def load_data_UKL_h5(self):
|
||||
|
||||
X_BP_path = os.path.join(self.data_dir, 'rPPG-BP-UKL_rppg_7s.h5')
|
||||
with h5py.File(X_BP_path, 'r') as f:
|
||||
rppg = f.get('rppg')
|
||||
BP = f.get('label')
|
||||
rppg = np.array(rppg)
|
||||
BP = np.array(BP)
|
||||
|
||||
# 将数据从(875, 7851)reshape成(7851, 875, 1)的形状
|
||||
rppg = rppg.transpose(1, 0)
|
||||
rppg = rppg.reshape(-1, 875, 1)
|
||||
|
||||
X_BP = rppg
|
||||
y_DBP = BP[1]
|
||||
y_SBP = BP[0]
|
||||
|
||||
return X_BP, y_DBP, y_SBP
|
||||
|
||||
def load_data_MIMIC_h5(self):
|
||||
|
||||
X_BP_path = os.path.join(self.data_dir, 'MIMIC-III_ppg_dataset.h5')
|
||||
|
||||
#
|
||||
# 获取data_dir下文件列表
|
||||
files = os.listdir(self.data_dir)
|
||||
|
||||
# 检查是否存在已经处理好的数据
|
||||
if 'X_MIMIC_BP.npy' in files and 'Y_MIMIC_DBP.npy' in files and 'Y_MIMIC_SBP.npy' in files:
|
||||
print('loading preprocessed data.....')
|
||||
|
||||
X_BP = np.load(os.path.join(self.data_dir, 'X_MIMIC_BP.npy'))
|
||||
y_DBP = np.load(os.path.join(self.data_dir, 'Y_MIMIC_DBP.npy'))
|
||||
y_SBP = np.load(os.path.join(self.data_dir, 'Y_MIMIC_SBP.npy'))
|
||||
|
||||
return X_BP, y_DBP, y_SBP
|
||||
|
||||
with h5py.File(X_BP_path, 'r') as f:
|
||||
ppg = f.get('ppg')
|
||||
BP = f.get('label')
|
||||
ppg = np.array(ppg)
|
||||
BP = np.array(BP)
|
||||
|
||||
# 统计BP中SBP的最大值和最小值
|
||||
max_sbp = np.max(BP[:, 0])
|
||||
min_sbp = np.min(BP[:, 0])
|
||||
|
||||
max_sbp = 10 - max_sbp % 10 + max_sbp
|
||||
min_sbp = min_sbp - min_sbp % 10
|
||||
|
||||
# 划分区间
|
||||
bins = np.arange(min_sbp, max_sbp, 10)
|
||||
|
||||
print(bins)
|
||||
|
||||
sampled_ppg_data = []
|
||||
sampled_bp_data = []
|
||||
|
||||
for i in range(len(bins) - 1):
|
||||
# 获取当前区间的数据
|
||||
bin_data_sbp_dbp = BP[(BP[:, 0] >= bins[i]) & (BP[:, 0] < bins[i + 1])]
|
||||
bin_data_ppg = ppg[(BP[:, 0] >= bins[i]) & (BP[:, 0] < bins[i + 1])]
|
||||
|
||||
# 如果当前区间有数据
|
||||
if len(bin_data_sbp_dbp) > 0:
|
||||
# 从当前区间中随机抽取20%的数据
|
||||
num_samples = int(len(bin_data_sbp_dbp) * 0.1)
|
||||
indices = np.random.choice(len(bin_data_sbp_dbp), num_samples, replace=False)
|
||||
sampled_bin_data_sbp_dbp = bin_data_sbp_dbp[indices]
|
||||
sampled_bin_data_ppg = bin_data_ppg[indices]
|
||||
|
||||
# 将抽取的数据添加到最终的列表中
|
||||
sampled_bp_data.append(sampled_bin_data_sbp_dbp)
|
||||
sampled_ppg_data.append(sampled_bin_data_ppg)
|
||||
|
||||
# 将列表中的数据合并成NumPy数组
|
||||
ppg = np.concatenate(sampled_ppg_data, axis=0)
|
||||
BP = np.concatenate(sampled_bp_data, axis=0)
|
||||
|
||||
print(ppg.shape, BP.shape)
|
||||
|
||||
# 将数据从(9054000, 875)reshape成(9054000, 875, 1)的形状
|
||||
ppg = ppg.reshape(-1, 875, 1)
|
||||
|
||||
X_BP = ppg
|
||||
|
||||
# 取出第一列赋值给y_DBP,第0列赋值给y_SBP
|
||||
y_DBP = BP[:, 1]
|
||||
y_SBP = BP[:, 0]
|
||||
|
||||
# 将数据保存到文件中
|
||||
np.save('data/X_MIMIC_BP.npy', X_BP)
|
||||
np.save('data/Y_MIMIC_DBP.npy', y_DBP)
|
||||
np.save('data/Y_MIMIC_SBP.npy', y_SBP)
|
||||
|
||||
return X_BP, y_DBP, y_SBP
|
||||
|
||||
def load_data_MIMIC_h5_full(self):
|
||||
|
||||
X_BP_path = os.path.join(self.data_dir, 'MIMIC-III_ppg_dataset.h5')
|
||||
|
||||
# 获取data_dir下文件列表
|
||||
files = os.listdir(self.data_dir)
|
||||
|
||||
# 检查是否存在已经处理好的数据
|
||||
if 'X_MIMIC_BP_full.npy' in files and 'Y_MIMIC_DBP_full.npy' in files and 'Y_MIMIC_SBP_full.npy' in files:
|
||||
print('loading preprocessed data.....')
|
||||
|
||||
X_BP = np.load(os.path.join(self.data_dir, 'X_MIMIC_BP_full.npy'))
|
||||
y_DBP = np.load(os.path.join(self.data_dir, 'Y_MIMIC_DBP_full.npy'))
|
||||
y_SBP = np.load(os.path.join(self.data_dir, 'Y_MIMIC_SBP_full.npy'))
|
||||
|
||||
return X_BP, y_DBP, y_SBP
|
||||
|
||||
with h5py.File(X_BP_path, 'r') as f:
|
||||
ppg = f.get('ppg')
|
||||
BP = f.get('label')
|
||||
ppg = np.array(ppg)
|
||||
BP = np.array(BP)
|
||||
|
||||
|
||||
# 将数据从(9054000, 875)reshape成(9054000, 875, 1)的形状
|
||||
ppg = ppg.reshape(-1, 875, 1)
|
||||
|
||||
X_BP = ppg
|
||||
|
||||
# 取出第一列赋值给y_DBP,第0列赋值给y_SBP
|
||||
y_DBP = BP[:, 1]
|
||||
y_SBP = BP[:, 0]
|
||||
|
||||
print("data shape:", X_BP.shape, y_DBP.shape, y_SBP.shape)
|
||||
|
||||
print("saving data.....")
|
||||
|
||||
# 将数据保存到文件中
|
||||
np.save('data/X_MIMIC_BP_full.npy', X_BP)
|
||||
np.save('data/Y_MIMIC_DBP_full.npy', y_DBP)
|
||||
np.save('data/Y_MIMIC_SBP_full.npy', y_SBP)
|
||||
|
||||
print("data saved.....")
|
||||
|
||||
return X_BP, y_DBP, y_SBP
|
||||
|
||||
def create_dataset(self, X_data, y_SBP, y_DBP):
|
||||
return BPDataset(X_data, y_SBP, y_DBP)
|
||||
|
||||
def split_data(self, X_data, y_SBP, y_DBP):
|
||||
X_train, X_val, y_train_SBP, y_val_SBP, y_train_DBP, y_val_DBP = train_test_split(
|
||||
X_data, y_SBP, y_DBP, test_size=self.val_split, random_state=42
|
||||
)
|
||||
|
||||
# print(X_train.shape, X_val.shape, y_train_SBP.shape, y_val_SBP.shape, y_train_DBP.shape, y_val_DBP.shape)
|
||||
|
||||
train_dataset = self.create_dataset(X_train, y_train_SBP, y_train_DBP)
|
||||
val_dataset = self.create_dataset(X_val, y_val_SBP, y_val_DBP)
|
||||
|
||||
return train_dataset, val_dataset
|
||||
|
||||
def create_dataloaders(self):
|
||||
if self.data_type == 'UKL':
|
||||
X_data, y_DBP, y_SBP = self.load_data_UKL_h5()
|
||||
elif self.data_type == 'MIMIC':
|
||||
X_data, y_DBP, y_SBP = self.load_data_MIMIC_h5()
|
||||
elif self.data_type == 'MIMIC_full':
|
||||
X_data, y_DBP, y_SBP = self.load_data_MIMIC_h5_full()
|
||||
else:
|
||||
X_data, y_DBP, y_SBP = self.load_data()
|
||||
train_dataset, val_dataset = self.split_data(X_data, y_SBP, y_DBP)
|
||||
|
||||
self.train_dataloader = DataLoader(
|
||||
train_dataset, batch_size=self.batch_size, shuffle=self.shuffle, collate_fn=custom_collate_fn
|
||||
)
|
||||
self.val_dataloader = DataLoader(
|
||||
val_dataset, batch_size=self.batch_size, shuffle=False, collate_fn=custom_collate_fn
|
||||
)
|
||||
|
||||
def get_dataloaders(self):
|
||||
if self.train_dataloader is None or self.val_dataloader is None:
|
||||
self.create_dataloaders()
|
||||
|
||||
return self.train_dataloader, self.val_dataloader
|
||||
|
||||
def get_distributed_dataloaders(self, world_size, rank):
|
||||
|
||||
if self.data_type == 'UKL':
|
||||
X_data, y_DBP, y_SBP = self.load_data_UKL_h5()
|
||||
elif self.data_type == 'MIMIC':
|
||||
X_data, y_DBP, y_SBP = self.load_data_MIMIC_h5()
|
||||
elif self.data_type == 'MIMIC_full':
|
||||
X_data, y_DBP, y_SBP = self.load_data_MIMIC_h5_full()
|
||||
else:
|
||||
X_data, y_DBP, y_SBP = self.load_data()
|
||||
train_dataset, val_dataset = self.split_data(X_data, y_SBP, y_DBP)
|
||||
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
train_dataset, num_replicas=world_size, rank=rank, shuffle=True
|
||||
)
|
||||
val_sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
val_dataset, num_replicas=world_size, rank=rank, shuffle=False
|
||||
)
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
sampler=train_sampler,
|
||||
collate_fn=custom_collate_fn,
|
||||
)
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
sampler=val_sampler,
|
||||
collate_fn=custom_collate_fn,
|
||||
)
|
||||
|
||||
return train_dataloader, val_dataloader, train_sampler, val_sampler
|
||||
|
||||
# 使用示例
|
||||
#
|
||||
# data_loader = BPDataLoader(data_dir='data', val_split=0.2, batch_size=32,data_type='MIMIC')
|
||||
# train_dataloader, val_dataloader = data_loader.get_dataloaders()
|
||||
#
|
||||
# for i, (X, y_SBP, y_DBP) in enumerate(train_dataloader):
|
||||
# print(f"Batch {i+1}: X.shape={X.shape }, y_SBP.shape={y_SBP.shape}, y_DBP.shape={y_DBP.shape}")
|
||||
# if i == 2:
|
||||
# break
|
|
@ -0,0 +1,195 @@
|
|||
import os
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
from dataloader import BPDataLoader
|
||||
from apis.lstm import LSTMModel
|
||||
|
||||
|
||||
|
||||
# 定义TensorBoard写入器
|
||||
writer = SummaryWriter()
|
||||
|
||||
# 定义训练参数
|
||||
max_epochs = 100
|
||||
batch_size = 1024
|
||||
warmup_epochs = 10
|
||||
lr = 0.0005
|
||||
|
||||
def train(gpu, args):
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = '12355'
|
||||
|
||||
rank = args.nr * args.gpus + gpu
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
# init_method="env://",
|
||||
world_size=args.world_size,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
# 设置当前 GPU 设备
|
||||
torch.cuda.set_device(gpu)
|
||||
|
||||
# 创建模型并移动到对应 GPU
|
||||
model = LSTMModel().to(gpu)
|
||||
model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
|
||||
|
||||
# 定义损失函数
|
||||
criterion = nn.MSELoss().to(gpu)
|
||||
|
||||
# 定义优化器
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||
|
||||
# 定义学习率调度器
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs - warmup_epochs, eta_min=1e-6)
|
||||
|
||||
# 准备数据加载器
|
||||
data_type = 'MIMIC_full'
|
||||
|
||||
# #检查模型存放路径是否存在
|
||||
# if not os.path.exists(f'weights'):
|
||||
# os.makedirs(f'weights')
|
||||
# if not os.path.exists(f'weights/{data_type}'):
|
||||
# os.makedirs(f'weights/{data_type}')
|
||||
|
||||
|
||||
data_loader = BPDataLoader(data_dir='data', val_split=0.2, batch_size=batch_size, data_type=data_type)
|
||||
train_loader, val_loader ,train_sampler, val_sampler = data_loader.get_distributed_dataloaders(rank=gpu, world_size=args.world_size)
|
||||
|
||||
|
||||
best_val_loss_sbp = float('inf')
|
||||
best_val_loss_dbp = float('inf')
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
if epoch < warmup_epochs:
|
||||
warmup_lr = 1e-6 + (epoch + 1) * (5e-4 - 1e-6) / warmup_epochs
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = warmup_lr
|
||||
|
||||
train_sampler.set_epoch(epoch)
|
||||
train_loss = run_train(model, train_loader, optimizer, criterion, epoch, gpu)
|
||||
|
||||
val_loss_sbp, val_loss_dbp = run_evaluate(model, val_loader, criterion, gpu)
|
||||
|
||||
if gpu == 0:
|
||||
writer.add_scalar("Loss/train", train_loss, epoch)
|
||||
writer.add_scalar("Loss/val_sbp", val_loss_sbp, epoch)
|
||||
writer.add_scalar("Loss/val_dbp", val_loss_dbp, epoch)
|
||||
|
||||
print(f"Epoch {epoch+1}/{max_epochs}, Train Loss: {train_loss:.4f}, Val Loss SBP: {val_loss_sbp:.4f}, Val Loss DBP: {val_loss_dbp:.4f}")
|
||||
|
||||
if val_loss_sbp < best_val_loss_sbp or val_loss_dbp < best_val_loss_dbp:
|
||||
best_val_loss_sbp = val_loss_sbp
|
||||
best_val_loss_dbp = val_loss_dbp
|
||||
torch.save(model.module, f'weights/{data_type}/best_{epoch}_lstm_model_sbp{val_loss_sbp:.4f}_dbp{val_loss_dbp:.4f}.pth')
|
||||
|
||||
torch.save(model.module, f'weights/{data_type}/last.pth')
|
||||
|
||||
scheduler.step()
|
||||
|
||||
writer.close()
|
||||
|
||||
def reduce_tensor(tensor):
|
||||
rt = tensor.clone()
|
||||
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
||||
rt /= dist.get_world_size()
|
||||
return rt
|
||||
|
||||
def run_train(model, dataloader, optimizer, criterion, epoch, gpu):
|
||||
model.train()
|
||||
running_loss = 0.0
|
||||
pbar = tqdm(dataloader, total=len(dataloader), disable=(gpu != 0),desc=f"GPU{gpu} Epoch {epoch+1}/{max_epochs}")
|
||||
for i, (inputs, sbp_labels, dbp_labels) in enumerate(pbar):
|
||||
inputs = inputs.cuda(gpu, non_blocking=True)
|
||||
sbp_labels = sbp_labels.cuda(gpu, non_blocking=True)
|
||||
dbp_labels = dbp_labels.cuda(gpu, non_blocking=True)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
sbp_outputs, dbp_outputs = model(inputs)
|
||||
|
||||
sbp_outputs = sbp_outputs.squeeze(1)
|
||||
dbp_outputs = dbp_outputs.squeeze(1)
|
||||
|
||||
loss_sbp = criterion(sbp_outputs, sbp_labels)
|
||||
loss_dbp = criterion(dbp_outputs, dbp_labels)
|
||||
|
||||
loss = loss_sbp + loss_dbp
|
||||
reduced_loss = reduce_tensor(loss)
|
||||
|
||||
reduced_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += reduced_loss.item()
|
||||
pbar.set_postfix(loss=running_loss / (i + 1))
|
||||
|
||||
return running_loss / len(dataloader)
|
||||
|
||||
def run_evaluate(model, dataloader, criterion, gpu):
|
||||
model.eval()
|
||||
running_loss_sbp = 0.0
|
||||
running_loss_dbp = 0.0
|
||||
with torch.no_grad():
|
||||
for inputs, sbp_labels, dbp_labels in dataloader:
|
||||
inputs = inputs.cuda(gpu, non_blocking=True)
|
||||
sbp_labels = sbp_labels.cuda(gpu, non_blocking=True)
|
||||
dbp_labels = dbp_labels.cuda(gpu, non_blocking=True)
|
||||
|
||||
sbp_outputs, dbp_outputs = model(inputs)
|
||||
|
||||
sbp_outputs = sbp_outputs.squeeze(1)
|
||||
dbp_outputs = dbp_outputs.squeeze(1)
|
||||
|
||||
loss_sbp = criterion(sbp_outputs, sbp_labels)
|
||||
loss_dbp = criterion(dbp_outputs, dbp_labels)
|
||||
|
||||
reduced_loss_sbp = reduce_tensor(loss_sbp)
|
||||
reduced_loss_dbp = reduce_tensor(loss_dbp)
|
||||
|
||||
running_loss_sbp += reduced_loss_sbp.item()
|
||||
running_loss_dbp += reduced_loss_dbp.item()
|
||||
|
||||
eval_loss_sbp = running_loss_sbp / len(dataloader)
|
||||
eval_loss_dbp = running_loss_dbp / len(dataloader)
|
||||
|
||||
return eval_loss_sbp, eval_loss_dbp
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--nr", type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
ngpus_per_node = torch.cuda.device_count()
|
||||
|
||||
if ngpus_per_node>4:
|
||||
ngpus_per_node = 4
|
||||
|
||||
args.world_size = ngpus_per_node
|
||||
args.gpus = max(ngpus_per_node, 1)
|
||||
mp.spawn(train, nprocs=args.gpus, args=(args,))
|
||||
|
||||
#检查模型存放路径是否存在
|
||||
def check_path(data_type):
|
||||
if not os.path.exists(f'weights'):
|
||||
os.makedirs(f'weights')
|
||||
if not os.path.exists(f'weights/{data_type}'):
|
||||
os.makedirs(f'weights/{data_type}')
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
check_path('MIMIC_full')
|
||||
main()
|
|
@ -0,0 +1,200 @@
|
|||
import os
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
from dataloader import BPDataLoader
|
||||
from apis.lstm import LSTMModel
|
||||
|
||||
|
||||
|
||||
# 定义TensorBoard写入器
|
||||
writer = SummaryWriter()
|
||||
|
||||
# 定义训练参数
|
||||
max_epochs = 100
|
||||
batch_size = 1024
|
||||
warmup_epochs = 10
|
||||
lr = 0.0005
|
||||
|
||||
def train(gpu, args):
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = '12355'
|
||||
|
||||
rank = args.nr * args.gpus + gpu
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
# init_method="env://",
|
||||
world_size=args.world_size,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
# 设置当前 GPU 设备
|
||||
torch.cuda.set_device(gpu)
|
||||
|
||||
# 创建模型并移动到对应 GPU
|
||||
model = LSTMModel().to(gpu)
|
||||
|
||||
w = torch.load(r'weights/MIMIC_full/best_90_lstm_model_sbp267.4183_dbp89.7367.pth',
|
||||
map_location=torch.device(f'cuda:{gpu}'))
|
||||
|
||||
# 加载权重
|
||||
model.load_state_dict(w.state_dict())
|
||||
|
||||
model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
|
||||
|
||||
|
||||
|
||||
# 定义损失函数
|
||||
criterion = nn.MSELoss().to(gpu)
|
||||
|
||||
# 定义优化器
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||
|
||||
# 定义学习率调度器
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs - warmup_epochs, eta_min=1e-6)
|
||||
|
||||
# 准备数据加载器
|
||||
data_type = 'UKL'
|
||||
|
||||
# #检查模型存放路径是否存在
|
||||
# if not os.path.exists(f'weights'):
|
||||
# os.makedirs(f'weights')
|
||||
# if not os.path.exists(f'weights/{data_type}'):
|
||||
# os.makedirs(f'weights/{data_type}')
|
||||
|
||||
|
||||
data_loader = BPDataLoader(data_dir='data', val_split=0.2, batch_size=batch_size, data_type=data_type)
|
||||
train_loader, val_loader ,train_sampler, val_sampler = data_loader.get_distributed_dataloaders(rank=gpu, world_size=args.world_size)
|
||||
|
||||
|
||||
best_val_loss_sbp = float('inf')
|
||||
best_val_loss_dbp = float('inf')
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
if epoch < warmup_epochs:
|
||||
warmup_lr = 1e-6 + (epoch + 1) * (5e-4 - 1e-6) / warmup_epochs
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = warmup_lr
|
||||
|
||||
train_sampler.set_epoch(epoch)
|
||||
train_loss = run_train(model, train_loader, optimizer, criterion, epoch, gpu)
|
||||
|
||||
val_loss_sbp, val_loss_dbp = run_evaluate(model, val_loader, criterion, gpu)
|
||||
|
||||
if gpu == 0:
|
||||
writer.add_scalar("Loss/train", train_loss, epoch)
|
||||
writer.add_scalar("Loss/val_sbp", val_loss_sbp, epoch)
|
||||
writer.add_scalar("Loss/val_dbp", val_loss_dbp, epoch)
|
||||
|
||||
print(f"Epoch {epoch+1}/{max_epochs}, Train Loss: {train_loss:.4f}, Val Loss SBP: {val_loss_sbp:.4f}, Val Loss DBP: {val_loss_dbp:.4f}")
|
||||
|
||||
if val_loss_sbp < best_val_loss_sbp or val_loss_dbp < best_val_loss_dbp:
|
||||
best_val_loss_sbp = val_loss_sbp
|
||||
best_val_loss_dbp = val_loss_dbp
|
||||
torch.save(model.module, f'weights/{data_type}/best_{epoch}_lstm_model_sbp{val_loss_sbp:.4f}_dbp{val_loss_dbp:.4f}.pth')
|
||||
|
||||
torch.save(model.module, f'weights/{data_type}/last.pth')
|
||||
|
||||
scheduler.step()
|
||||
|
||||
writer.close()
|
||||
|
||||
def reduce_tensor(tensor):
|
||||
rt = tensor.clone()
|
||||
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
||||
rt /= dist.get_world_size()
|
||||
return rt
|
||||
|
||||
def run_train(model, dataloader, optimizer, criterion, epoch, gpu):
|
||||
model.train()
|
||||
running_loss = 0.0
|
||||
pbar = tqdm(dataloader, total=len(dataloader), disable=(gpu != 0),desc=f"GPU{gpu} Epoch {epoch+1}/{max_epochs}")
|
||||
for i, (inputs, sbp_labels, dbp_labels) in enumerate(pbar):
|
||||
inputs = inputs.cuda(gpu, non_blocking=True)
|
||||
sbp_labels = sbp_labels.cuda(gpu, non_blocking=True)
|
||||
dbp_labels = dbp_labels.cuda(gpu, non_blocking=True)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
sbp_outputs, dbp_outputs = model(inputs)
|
||||
|
||||
sbp_outputs = sbp_outputs.squeeze(1)
|
||||
dbp_outputs = dbp_outputs.squeeze(1)
|
||||
|
||||
loss_sbp = criterion(sbp_outputs, sbp_labels)
|
||||
loss_dbp = criterion(dbp_outputs, dbp_labels)
|
||||
|
||||
loss = loss_sbp + loss_dbp
|
||||
reduced_loss = reduce_tensor(loss)
|
||||
|
||||
reduced_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += reduced_loss.item()
|
||||
pbar.set_postfix(loss=running_loss / (i + 1))
|
||||
|
||||
return running_loss / len(dataloader)
|
||||
|
||||
def run_evaluate(model, dataloader, criterion, gpu):
|
||||
model.eval()
|
||||
running_loss_sbp = 0.0
|
||||
running_loss_dbp = 0.0
|
||||
with torch.no_grad():
|
||||
for inputs, sbp_labels, dbp_labels in dataloader:
|
||||
inputs = inputs.cuda(gpu, non_blocking=True)
|
||||
sbp_labels = sbp_labels.cuda(gpu, non_blocking=True)
|
||||
dbp_labels = dbp_labels.cuda(gpu, non_blocking=True)
|
||||
|
||||
sbp_outputs, dbp_outputs = model(inputs)
|
||||
|
||||
sbp_outputs = sbp_outputs.squeeze(1)
|
||||
dbp_outputs = dbp_outputs.squeeze(1)
|
||||
|
||||
loss_sbp = criterion(sbp_outputs, sbp_labels)
|
||||
loss_dbp = criterion(dbp_outputs, dbp_labels)
|
||||
|
||||
reduced_loss_sbp = reduce_tensor(loss_sbp)
|
||||
reduced_loss_dbp = reduce_tensor(loss_dbp)
|
||||
|
||||
running_loss_sbp += reduced_loss_sbp.item()
|
||||
running_loss_dbp += reduced_loss_dbp.item()
|
||||
|
||||
eval_loss_sbp = running_loss_sbp / len(dataloader)
|
||||
eval_loss_dbp = running_loss_dbp / len(dataloader)
|
||||
|
||||
return eval_loss_sbp, eval_loss_dbp
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--nr", type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
ngpus_per_node = torch.cuda.device_count()
|
||||
|
||||
if ngpus_per_node>4:
|
||||
ngpus_per_node = 4
|
||||
|
||||
args.world_size = ngpus_per_node
|
||||
args.gpus = max(ngpus_per_node, 1)
|
||||
mp.spawn(train, nprocs=args.gpus, args=(args,))
|
||||
|
||||
def check_path(data_type):
|
||||
if not os.path.exists(f'weights'):
|
||||
os.makedirs(f'weights')
|
||||
if not os.path.exists(f'weights/{data_type}'):
|
||||
os.makedirs(f'weights/{data_type}')
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_path('UKL')
|
||||
main()
|
|
@ -0,0 +1,198 @@
|
|||
import os
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
from dataloader import BPDataLoader
|
||||
from apis.lstm import LSTMModel
|
||||
|
||||
|
||||
|
||||
# 定义TensorBoard写入器
|
||||
writer = SummaryWriter()
|
||||
|
||||
# 定义训练参数
|
||||
max_epochs = 100
|
||||
batch_size = 1024
|
||||
warmup_epochs = 10
|
||||
lr = 0.0005
|
||||
|
||||
def train(gpu, args):
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = '12355'
|
||||
|
||||
rank = args.nr * args.gpus + gpu
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
# init_method="env://",
|
||||
world_size=args.world_size,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
# 设置当前 GPU 设备
|
||||
torch.cuda.set_device(gpu)
|
||||
|
||||
# 创建模型并移动到对应 GPU
|
||||
model = LSTMModel().to(gpu)
|
||||
|
||||
|
||||
w = torch.load(r'weights/UKL/best_99_lstm_model_sbp90.9980_dbp51.0640.pth',
|
||||
map_location=torch.device(f'cuda:{gpu}'))
|
||||
|
||||
# 加载权重
|
||||
model.load_state_dict(w.state_dict())
|
||||
|
||||
model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
|
||||
|
||||
# 定义损失函数
|
||||
criterion = nn.MSELoss().to(gpu)
|
||||
|
||||
# 定义优化器
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||
|
||||
# 定义学习率调度器
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs - warmup_epochs, eta_min=1e-6)
|
||||
|
||||
# 准备数据加载器
|
||||
data_type = 'X'
|
||||
|
||||
# #检查模型存放路径是否存在
|
||||
# if not os.path.exists(f'weights'):
|
||||
# os.makedirs(f'weights')
|
||||
# if not os.path.exists(f'weights/{data_type}'):
|
||||
# os.makedirs(f'weights/{data_type}')
|
||||
|
||||
|
||||
data_loader = BPDataLoader(data_dir='data', val_split=0.2, batch_size=batch_size, data_type=data_type)
|
||||
train_loader, val_loader ,train_sampler, val_sampler = data_loader.get_distributed_dataloaders(rank=gpu, world_size=args.world_size)
|
||||
|
||||
|
||||
best_val_loss_sbp = float('inf')
|
||||
best_val_loss_dbp = float('inf')
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
if epoch < warmup_epochs:
|
||||
warmup_lr = 1e-6 + (epoch + 1) * (5e-4 - 1e-6) / warmup_epochs
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = warmup_lr
|
||||
|
||||
train_sampler.set_epoch(epoch)
|
||||
train_loss = run_train(model, train_loader, optimizer, criterion, epoch, gpu)
|
||||
|
||||
val_loss_sbp, val_loss_dbp = run_evaluate(model, val_loader, criterion, gpu)
|
||||
|
||||
if gpu == 0:
|
||||
writer.add_scalar("Loss/train", train_loss, epoch)
|
||||
writer.add_scalar("Loss/val_sbp", val_loss_sbp, epoch)
|
||||
writer.add_scalar("Loss/val_dbp", val_loss_dbp, epoch)
|
||||
|
||||
print(f"Epoch {epoch+1}/{max_epochs}, Train Loss: {train_loss:.4f}, Val Loss SBP: {val_loss_sbp:.4f}, Val Loss DBP: {val_loss_dbp:.4f}")
|
||||
|
||||
if val_loss_sbp < best_val_loss_sbp or val_loss_dbp < best_val_loss_dbp:
|
||||
best_val_loss_sbp = val_loss_sbp
|
||||
best_val_loss_dbp = val_loss_dbp
|
||||
torch.save(model.module, f'weights/{data_type}/best_{epoch}_lstm_model_sbp{val_loss_sbp:.4f}_dbp{val_loss_dbp:.4f}.pth')
|
||||
|
||||
torch.save(model.module, f'weights/{data_type}/last.pth')
|
||||
|
||||
scheduler.step()
|
||||
|
||||
writer.close()
|
||||
|
||||
def reduce_tensor(tensor):
|
||||
rt = tensor.clone()
|
||||
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
||||
rt /= dist.get_world_size()
|
||||
return rt
|
||||
|
||||
def run_train(model, dataloader, optimizer, criterion, epoch, gpu):
|
||||
model.train()
|
||||
running_loss = 0.0
|
||||
pbar = tqdm(dataloader, total=len(dataloader), disable=(gpu != 0),desc=f"GPU{gpu} Epoch {epoch+1}/{max_epochs}")
|
||||
for i, (inputs, sbp_labels, dbp_labels) in enumerate(pbar):
|
||||
inputs = inputs.cuda(gpu, non_blocking=True)
|
||||
sbp_labels = sbp_labels.cuda(gpu, non_blocking=True)
|
||||
dbp_labels = dbp_labels.cuda(gpu, non_blocking=True)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
sbp_outputs, dbp_outputs = model(inputs)
|
||||
|
||||
sbp_outputs = sbp_outputs.squeeze(1)
|
||||
dbp_outputs = dbp_outputs.squeeze(1)
|
||||
|
||||
loss_sbp = criterion(sbp_outputs, sbp_labels)
|
||||
loss_dbp = criterion(dbp_outputs, dbp_labels)
|
||||
|
||||
loss = loss_sbp + loss_dbp
|
||||
reduced_loss = reduce_tensor(loss)
|
||||
|
||||
reduced_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += reduced_loss.item()
|
||||
pbar.set_postfix(loss=running_loss / (i + 1))
|
||||
|
||||
return running_loss / len(dataloader)
|
||||
|
||||
def run_evaluate(model, dataloader, criterion, gpu):
|
||||
model.eval()
|
||||
running_loss_sbp = 0.0
|
||||
running_loss_dbp = 0.0
|
||||
with torch.no_grad():
|
||||
for inputs, sbp_labels, dbp_labels in dataloader:
|
||||
inputs = inputs.cuda(gpu, non_blocking=True)
|
||||
sbp_labels = sbp_labels.cuda(gpu, non_blocking=True)
|
||||
dbp_labels = dbp_labels.cuda(gpu, non_blocking=True)
|
||||
|
||||
sbp_outputs, dbp_outputs = model(inputs)
|
||||
|
||||
sbp_outputs = sbp_outputs.squeeze(1)
|
||||
dbp_outputs = dbp_outputs.squeeze(1)
|
||||
|
||||
loss_sbp = criterion(sbp_outputs, sbp_labels)
|
||||
loss_dbp = criterion(dbp_outputs, dbp_labels)
|
||||
|
||||
reduced_loss_sbp = reduce_tensor(loss_sbp)
|
||||
reduced_loss_dbp = reduce_tensor(loss_dbp)
|
||||
|
||||
running_loss_sbp += reduced_loss_sbp.item()
|
||||
running_loss_dbp += reduced_loss_dbp.item()
|
||||
|
||||
eval_loss_sbp = running_loss_sbp / len(dataloader)
|
||||
eval_loss_dbp = running_loss_dbp / len(dataloader)
|
||||
|
||||
return eval_loss_sbp, eval_loss_dbp
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--nr", type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
ngpus_per_node = torch.cuda.device_count()
|
||||
|
||||
if ngpus_per_node>4:
|
||||
ngpus_per_node = 4
|
||||
args.world_size = ngpus_per_node
|
||||
args.gpus = max(ngpus_per_node, 1)
|
||||
mp.spawn(train, nprocs=args.gpus, args=(args,))
|
||||
|
||||
def check_path(data_type):
|
||||
if not os.path.exists(f'weights'):
|
||||
os.makedirs(f'weights')
|
||||
if not os.path.exists(f'weights/{data_type}'):
|
||||
os.makedirs(f'weights/{data_type}')
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_path('X')
|
||||
main()
|
|
@ -0,0 +1,62 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class LSTMModel(nn.Module):
|
||||
def __init__(self, input_size=1, hidden_size=128, output_size=2):
|
||||
super(LSTMModel, self).__init__()
|
||||
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.output_size = output_size
|
||||
|
||||
self.conv1d = nn.Conv1d(input_size, 64, kernel_size=5, padding=2)
|
||||
self.relu = nn.ReLU()
|
||||
self.lstm1 = nn.LSTM(64, hidden_size, bidirectional=True, batch_first=True)
|
||||
self.lstm2 = nn.LSTM(hidden_size * 2, hidden_size, bidirectional=True, batch_first=True)
|
||||
self.lstm3 = nn.LSTM(hidden_size * 2, 64, bidirectional=False, batch_first=True)
|
||||
self.fc1 = nn.Linear(64, 512)
|
||||
self.fc2 = nn.Linear(512, 256)
|
||||
self.fc3 = nn.Linear(256, 128)
|
||||
self.fc_sbp = nn.Linear(128, 1)
|
||||
self.fc_dbp = nn.Linear(128, 1)
|
||||
|
||||
def forward(self, x):
|
||||
# 将输入传递给Conv1d层
|
||||
x = self.conv1d(x.permute(0, 2, 1).contiguous())
|
||||
x = self.relu(x)
|
||||
x = x.permute(0, 2, 1).contiguous()
|
||||
|
||||
# 将输入传递给LSTM层
|
||||
x, _ = self.lstm1(x)
|
||||
x, _ = self.lstm2(x)
|
||||
x, _ = self.lstm3(x)
|
||||
|
||||
# 只使用最后一个时间步的输出
|
||||
x = x[:, -1, :]
|
||||
|
||||
# 将LSTM输出传递给全连接层
|
||||
x = self.relu(self.fc1(x))
|
||||
x = self.relu(self.fc2(x))
|
||||
x = self.relu(self.fc3(x))
|
||||
|
||||
# 从两个Linear输出最终结果
|
||||
sbp = self.fc_sbp(x)
|
||||
dbp = self.fc_dbp(x)
|
||||
|
||||
return sbp, dbp
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 创建模型实例
|
||||
model = LSTMModel()
|
||||
|
||||
# 定义示例输入
|
||||
batch_size = 64
|
||||
seq_len = 1250
|
||||
input_size = 1
|
||||
input_data = torch.randn(batch_size, seq_len, input_size)
|
||||
|
||||
# 将输入数据传递给模型
|
||||
sbp, dbp = model(input_data)
|
||||
print(sbp.shape, dbp.shape) # 输出: torch.Size([64, 1]) torch.Size([64, 1])
|
|
@ -0,0 +1,62 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class LSTMModel(nn.Module):
|
||||
def __init__(self, input_size=1, hidden_size=128, output_size=2):
|
||||
super(LSTMModel, self).__init__()
|
||||
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.output_size = output_size
|
||||
|
||||
self.conv1d = nn.Conv1d(input_size, 64, kernel_size=5, padding=2)
|
||||
self.relu = nn.ReLU()
|
||||
self.lstm1 = nn.LSTM(64, hidden_size, bidirectional=True, batch_first=True)
|
||||
self.lstm2 = nn.LSTM(hidden_size * 2, hidden_size, bidirectional=True, batch_first=True)
|
||||
self.lstm3 = nn.LSTM(hidden_size * 2, 64, bidirectional=False, batch_first=True)
|
||||
self.fc1 = nn.Linear(64, 512)
|
||||
self.fc2 = nn.Linear(512, 256)
|
||||
self.fc3 = nn.Linear(256, 128)
|
||||
self.fc_sbp = nn.Linear(128, 1)
|
||||
self.fc_dbp = nn.Linear(128, 1)
|
||||
|
||||
def forward(self, x):
|
||||
# 将输入传递给Conv1d层
|
||||
x = self.conv1d(x.permute(0, 2, 1).contiguous())
|
||||
x = self.relu(x)
|
||||
x = x.permute(0, 2, 1).contiguous()
|
||||
|
||||
# 将输入传递给LSTM层
|
||||
x, _ = self.lstm1(x)
|
||||
x, _ = self.lstm2(x)
|
||||
x, _ = self.lstm3(x)
|
||||
|
||||
# 只使用最后一个时间步的输出
|
||||
x = x[:, -1, :]
|
||||
|
||||
# 将LSTM输出传递给全连接层
|
||||
x = self.relu(self.fc1(x))
|
||||
x = self.relu(self.fc2(x))
|
||||
x = self.relu(self.fc3(x))
|
||||
|
||||
# 从两个Linear输出最终结果
|
||||
sbp = self.fc_sbp(x)
|
||||
dbp = self.fc_dbp(x)
|
||||
|
||||
return sbp, dbp
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 创建模型实例
|
||||
model = LSTMModel()
|
||||
|
||||
# 定义示例输入
|
||||
batch_size = 64
|
||||
seq_len = 1250
|
||||
input_size = 1
|
||||
input_data = torch.randn(batch_size, seq_len, input_size)
|
||||
|
||||
# 将输入数据传递给模型
|
||||
sbp, dbp = model(input_data)
|
||||
print(sbp.shape, dbp.shape) # 输出: torch.Size([64, 1]) torch.Size([64, 1])
|
|
@ -0,0 +1,138 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
from dataloader import BPDataLoader
|
||||
from apis.lstm import LSTMModel
|
||||
|
||||
# 定义模型
|
||||
model = LSTMModel()
|
||||
|
||||
#定义训练参数
|
||||
max_epochs = 100
|
||||
batch_size= 1024
|
||||
warmup_epochs = 10
|
||||
lr = 0.0005
|
||||
|
||||
# 定义损失函数和优化器
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||
|
||||
# 定义学习率调度器
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs - warmup_epochs, eta_min=1e-6)
|
||||
|
||||
# 定义TensorBoard写入器
|
||||
writer = SummaryWriter()
|
||||
|
||||
# 训练函数
|
||||
def train(model, dataloader, epoch, device,batch_size):
|
||||
model.train()
|
||||
running_loss = 0.0
|
||||
pbar = tqdm(dataloader, total=len(dataloader), desc=f"Epoch {epoch+1}/{max_epochs}")
|
||||
for i, (inputs, sbp_labels, dbp_labels) in enumerate(pbar):
|
||||
inputs = inputs.to(device)
|
||||
sbp_labels = sbp_labels.to(device)
|
||||
dbp_labels = dbp_labels.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
sbp_outputs, dbp_outputs = model(inputs)
|
||||
|
||||
sbp_outputs = sbp_outputs.squeeze(1) # 将输出形状从(batch_size, 1)变为(batch_size,)
|
||||
dbp_outputs = dbp_outputs.squeeze(1)
|
||||
|
||||
loss_sbp = criterion(sbp_outputs, sbp_labels)
|
||||
loss_dbp = criterion(dbp_outputs, dbp_labels)
|
||||
|
||||
loss = loss_sbp + loss_dbp
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += loss.item()
|
||||
pbar.set_postfix(loss=running_loss / (i + 1))
|
||||
|
||||
scheduler.step()
|
||||
writer.add_scalar("Loss/train", running_loss / len(dataloader)/ batch_size, epoch)
|
||||
|
||||
return running_loss / len(dataloader) / batch_size
|
||||
|
||||
# 评估函数
|
||||
def evaluate(model, dataloader, device,batch_size):
|
||||
model.eval()
|
||||
running_loss_sbp = 0.0
|
||||
running_loss_dbp = 0.0
|
||||
with torch.no_grad():
|
||||
for inputs, sbp_labels, dbp_labels in dataloader:
|
||||
inputs = inputs.to(device)
|
||||
sbp_labels = sbp_labels.to(device)
|
||||
dbp_labels = dbp_labels.to(device)
|
||||
|
||||
sbp_outputs, dbp_outputs = model(inputs)
|
||||
|
||||
sbp_outputs = sbp_outputs.squeeze(1) # 将输出形状从(batch_size, 1)变为(batch_size,)
|
||||
dbp_outputs = dbp_outputs.squeeze(1)
|
||||
|
||||
loss_sbp = criterion(sbp_outputs, sbp_labels)
|
||||
loss_dbp = criterion(dbp_outputs, dbp_labels)
|
||||
|
||||
running_loss_sbp += loss_sbp.item()
|
||||
running_loss_dbp += loss_dbp.item()
|
||||
|
||||
eval_loss_sbp = running_loss_sbp / len(dataloader) / batch_size
|
||||
eval_loss_dbp = running_loss_dbp / len(dataloader) / batch_size
|
||||
|
||||
return eval_loss_sbp, eval_loss_dbp
|
||||
|
||||
# 训练循环
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = model.to(device)
|
||||
|
||||
data_type='MIMIC_full'
|
||||
|
||||
#判断权重保存目录是否存在,不存在则创建
|
||||
if not os.path.exists('weights'):
|
||||
os.makedirs('weights')
|
||||
#在其中创建data_type同名子文件夹
|
||||
os.makedirs(os.path.join('weights',data_type))
|
||||
else:
|
||||
#判断子文件夹是否存在
|
||||
if not os.path.exists(os.path.join('weights',data_type)):
|
||||
os.makedirs(os.path.join('weights',data_type))
|
||||
|
||||
|
||||
data_loader = BPDataLoader(data_dir='data', val_split=0.2, batch_size=batch_size,data_type=data_type)
|
||||
|
||||
train_dataloader, val_dataloader = data_loader.get_dataloaders()
|
||||
|
||||
best_val_loss_sbp = float('inf')
|
||||
best_val_loss_dbp = float('inf')
|
||||
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
if epoch < warmup_epochs:
|
||||
warmup_lr = 1e-6 + (epoch + 1) * (5e-4 - 1e-6) / warmup_epochs
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = warmup_lr
|
||||
|
||||
train_loss = train(model, train_dataloader, epoch, device,batch_size)
|
||||
val_loss_sbp, val_loss_dbp = evaluate(model, val_dataloader, device,batch_size)
|
||||
|
||||
writer.add_scalar("Loss/val_sbp", val_loss_sbp, epoch)
|
||||
writer.add_scalar("Loss/val_dbp", val_loss_dbp, epoch)
|
||||
|
||||
print(f"Epoch {epoch+1}/{max_epochs}, Train Loss: {train_loss:.4f}, Val Loss SBP: {val_loss_sbp:.4f}, Val Loss DBP: {val_loss_dbp:.4f}")
|
||||
|
||||
if val_loss_sbp < best_val_loss_sbp or val_loss_dbp < best_val_loss_dbp:
|
||||
best_val_loss_sbp = val_loss_sbp
|
||||
best_val_loss_dbp = val_loss_dbp
|
||||
torch.save(model.state_dict(), f'weights/{data_type}/best_{epoch}_lstm_model_sbp{val_loss_sbp:.4f}_dbp{val_loss_dbp:.4f}.pth')
|
||||
|
||||
torch.save(model.state_dict(),
|
||||
f'weights/{data_type}/last.pth')
|
||||
|
||||
writer.close()
|
|
@ -0,0 +1,144 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
from dataloader import BPDataLoader
|
||||
from apis.lstm import LSTMModel
|
||||
|
||||
# 定义模型
|
||||
model = LSTMModel()
|
||||
|
||||
# 加载权重
|
||||
model.load_state_dict(torch.load(r'weights/MIMIC/best_27_lstm_model_sbp1.4700_dbp0.4493.pth'))
|
||||
|
||||
# 定义训练参数
|
||||
max_epochs = 100
|
||||
batch_size= 1024
|
||||
warmup_epochs = 10
|
||||
lr = 0.0005
|
||||
|
||||
# 定义损失函数和优化器
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||
|
||||
# 定义学习率调度器
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs - warmup_epochs, eta_min=1e-6)
|
||||
|
||||
# 定义TensorBoard写入器
|
||||
writer = SummaryWriter()
|
||||
|
||||
|
||||
# 训练函数
|
||||
def train(model, dataloader, epoch, device, batch_size):
|
||||
model.train()
|
||||
running_loss = 0.0
|
||||
pbar = tqdm(dataloader, total=len(dataloader), desc=f"Epoch {epoch + 1}/{max_epochs}")
|
||||
for i, (inputs, sbp_labels, dbp_labels) in enumerate(pbar):
|
||||
inputs = inputs.to(device)
|
||||
sbp_labels = sbp_labels.to(device)
|
||||
dbp_labels = dbp_labels.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
sbp_outputs, dbp_outputs = model(inputs)
|
||||
|
||||
sbp_outputs = sbp_outputs.squeeze(1) # 将输出形状从(batch_size, 1)变为(batch_size,)
|
||||
dbp_outputs = dbp_outputs.squeeze(1)
|
||||
|
||||
loss_sbp = criterion(sbp_outputs, sbp_labels)
|
||||
loss_dbp = criterion(dbp_outputs, dbp_labels)
|
||||
|
||||
loss = loss_sbp + loss_dbp
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += loss.item()
|
||||
pbar.set_postfix(loss=running_loss / (i + 1))
|
||||
|
||||
scheduler.step()
|
||||
writer.add_scalar("Loss/train", running_loss / len(dataloader) / batch_size, epoch)
|
||||
|
||||
return running_loss / len(dataloader) / batch_size
|
||||
|
||||
|
||||
# 评估函数
|
||||
def evaluate(model, dataloader, device, batch_size):
|
||||
model.eval()
|
||||
running_loss_sbp = 0.0
|
||||
running_loss_dbp = 0.0
|
||||
with torch.no_grad():
|
||||
for inputs, sbp_labels, dbp_labels in dataloader:
|
||||
inputs = inputs.to(device)
|
||||
sbp_labels = sbp_labels.to(device)
|
||||
dbp_labels = dbp_labels.to(device)
|
||||
|
||||
sbp_outputs, dbp_outputs = model(inputs)
|
||||
|
||||
sbp_outputs = sbp_outputs.squeeze(1) # 将输出形状从(batch_size, 1)变为(batch_size,)
|
||||
dbp_outputs = dbp_outputs.squeeze(1)
|
||||
|
||||
loss_sbp = criterion(sbp_outputs, sbp_labels)
|
||||
loss_dbp = criterion(dbp_outputs, dbp_labels)
|
||||
|
||||
running_loss_sbp += loss_sbp.item()
|
||||
running_loss_dbp += loss_dbp.item()
|
||||
|
||||
eval_loss_sbp = running_loss_sbp / len(dataloader) / batch_size
|
||||
eval_loss_dbp = running_loss_dbp / len(dataloader) / batch_size
|
||||
|
||||
return eval_loss_sbp, eval_loss_dbp
|
||||
|
||||
|
||||
# 训练循环
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = model.to(device)
|
||||
|
||||
data_type = 'UKL'
|
||||
|
||||
# 判断权重保存目录是否存在,不存在则创建
|
||||
if not os.path.exists('weights'):
|
||||
os.makedirs('weights')
|
||||
# 在其中创建data_type同名子文件夹
|
||||
os.makedirs(os.path.join('weights', data_type))
|
||||
else:
|
||||
# 判断子文件夹是否存在
|
||||
if not os.path.exists(os.path.join('weights', data_type)):
|
||||
os.makedirs(os.path.join('weights', data_type))
|
||||
|
||||
data_loader = BPDataLoader(data_dir='data', val_split=0.2, batch_size=batch_size, data_type='UKL')
|
||||
|
||||
train_dataloader, val_dataloader = data_loader.get_dataloaders()
|
||||
|
||||
best_val_loss_sbp = float('inf')
|
||||
best_val_loss_dbp = float('inf')
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
if epoch < warmup_epochs:
|
||||
warmup_lr = 1e-6 + (epoch + 1) * (5e-4 - 1e-6) / warmup_epochs
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = warmup_lr
|
||||
|
||||
train_loss = train(model, train_dataloader, epoch, device, batch_size)
|
||||
val_loss_sbp, val_loss_dbp = evaluate(model, val_dataloader, device, batch_size)
|
||||
|
||||
writer.add_scalar("Loss/val_sbp", val_loss_sbp, epoch)
|
||||
writer.add_scalar("Loss/val_dbp", val_loss_dbp, epoch)
|
||||
|
||||
print(
|
||||
f"Epoch {epoch + 1}/{max_epochs}, Train Loss: {train_loss:.4f}, Val Loss SBP: {val_loss_sbp:.4f}, Val Loss DBP: {val_loss_dbp:.4f}")
|
||||
|
||||
if val_loss_sbp < best_val_loss_sbp or val_loss_dbp < best_val_loss_dbp:
|
||||
best_val_loss_sbp = val_loss_sbp
|
||||
best_val_loss_dbp = val_loss_dbp
|
||||
torch.save(model.state_dict(),
|
||||
f'weights/{data_type}/best_{epoch}_lstm_model_sbp{val_loss_sbp:.4f}_dbp{val_loss_dbp:.4f}.pth')
|
||||
|
||||
torch.save(model.state_dict(),
|
||||
f'weights/{data_type}/last.pth')
|
||||
|
||||
writer.close()
|
|
@ -0,0 +1,144 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
from dataloader import BPDataLoader
|
||||
from apis.lstm import LSTMModel
|
||||
|
||||
# 定义模型
|
||||
model = LSTMModel()
|
||||
|
||||
# 加载权重
|
||||
model.load_state_dict(torch.load(r'weights/UKL/best_28_lstm_model_sbp0.3520_dbp0.2052.pth'))
|
||||
|
||||
# 定义训练参数
|
||||
max_epochs = 100
|
||||
batch_size= 1024
|
||||
warmup_epochs = 10
|
||||
lr = 0.0005
|
||||
|
||||
# 定义损失函数和优化器
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||
|
||||
# 定义学习率调度器
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs - warmup_epochs, eta_min=1e-6)
|
||||
|
||||
# 定义TensorBoard写入器
|
||||
writer = SummaryWriter()
|
||||
|
||||
|
||||
# 训练函数
|
||||
def train(model, dataloader, epoch, device, batch_size):
|
||||
model.train()
|
||||
running_loss = 0.0
|
||||
pbar = tqdm(dataloader, total=len(dataloader), desc=f"Epoch {epoch + 1}/{max_epochs}")
|
||||
for i, (inputs, sbp_labels, dbp_labels) in enumerate(pbar):
|
||||
inputs = inputs.to(device)
|
||||
sbp_labels = sbp_labels.to(device)
|
||||
dbp_labels = dbp_labels.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
sbp_outputs, dbp_outputs = model(inputs)
|
||||
|
||||
sbp_outputs = sbp_outputs.squeeze(1) # 将输出形状从(batch_size, 1)变为(batch_size,)
|
||||
dbp_outputs = dbp_outputs.squeeze(1)
|
||||
|
||||
loss_sbp = criterion(sbp_outputs, sbp_labels)
|
||||
loss_dbp = criterion(dbp_outputs, dbp_labels)
|
||||
|
||||
loss = loss_sbp + loss_dbp
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += loss.item()
|
||||
pbar.set_postfix(loss=running_loss / (i + 1))
|
||||
|
||||
scheduler.step()
|
||||
writer.add_scalar("Loss/train", running_loss / len(dataloader) / batch_size, epoch)
|
||||
|
||||
return running_loss / len(dataloader) / batch_size
|
||||
|
||||
|
||||
# 评估函数
|
||||
def evaluate(model, dataloader, device, batch_size):
|
||||
model.eval()
|
||||
running_loss_sbp = 0.0
|
||||
running_loss_dbp = 0.0
|
||||
with torch.no_grad():
|
||||
for inputs, sbp_labels, dbp_labels in dataloader:
|
||||
inputs = inputs.to(device)
|
||||
sbp_labels = sbp_labels.to(device)
|
||||
dbp_labels = dbp_labels.to(device)
|
||||
|
||||
sbp_outputs, dbp_outputs = model(inputs)
|
||||
|
||||
sbp_outputs = sbp_outputs.squeeze(1) # 将输出形状从(batch_size, 1)变为(batch_size,)
|
||||
dbp_outputs = dbp_outputs.squeeze(1)
|
||||
|
||||
loss_sbp = criterion(sbp_outputs, sbp_labels)
|
||||
loss_dbp = criterion(dbp_outputs, dbp_labels)
|
||||
|
||||
running_loss_sbp += loss_sbp.item()
|
||||
running_loss_dbp += loss_dbp.item()
|
||||
|
||||
eval_loss_sbp = running_loss_sbp / len(dataloader) / batch_size
|
||||
eval_loss_dbp = running_loss_dbp / len(dataloader) / batch_size
|
||||
|
||||
return eval_loss_sbp, eval_loss_dbp
|
||||
|
||||
|
||||
# 训练循环
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = model.to(device)
|
||||
|
||||
data_type = 'X'
|
||||
|
||||
# 判断权重保存目录是否存在,不存在则创建
|
||||
if not os.path.exists('weights'):
|
||||
os.makedirs('weights')
|
||||
# 在其中创建data_type同名子文件夹
|
||||
os.makedirs(os.path.join('weights', data_type))
|
||||
else:
|
||||
# 判断子文件夹是否存在
|
||||
if not os.path.exists(os.path.join('weights', data_type)):
|
||||
os.makedirs(os.path.join('weights', data_type))
|
||||
|
||||
data_loader = BPDataLoader(data_dir='data', val_split=0.2, batch_size=batch_size, data_type='X')
|
||||
|
||||
train_dataloader, val_dataloader = data_loader.get_dataloaders()
|
||||
|
||||
best_val_loss_sbp = float('inf')
|
||||
best_val_loss_dbp = float('inf')
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
if epoch < warmup_epochs:
|
||||
warmup_lr = 1e-6 + (epoch + 1) * (5e-4 - 1e-6) / warmup_epochs
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = warmup_lr
|
||||
|
||||
train_loss = train(model, train_dataloader, epoch, device, batch_size)
|
||||
val_loss_sbp, val_loss_dbp = evaluate(model, val_dataloader, device, batch_size)
|
||||
|
||||
writer.add_scalar("Loss/val_sbp", val_loss_sbp, epoch)
|
||||
writer.add_scalar("Loss/val_dbp", val_loss_dbp, epoch)
|
||||
|
||||
print(
|
||||
f"Epoch {epoch + 1}/{max_epochs}, Train Loss: {train_loss:.4f}, Val Loss SBP: {val_loss_sbp:.4f}, Val Loss DBP: {val_loss_dbp:.4f}")
|
||||
|
||||
if val_loss_sbp < best_val_loss_sbp or val_loss_dbp < best_val_loss_dbp:
|
||||
best_val_loss_sbp = val_loss_sbp
|
||||
best_val_loss_dbp = val_loss_dbp
|
||||
torch.save(model.state_dict(),
|
||||
f'weights/{data_type}/best_{epoch}_lstm_model_sbp{val_loss_sbp:.4f}_dbp{val_loss_dbp:.4f}.pth')
|
||||
|
||||
torch.save(model.state_dict(),
|
||||
f'weights/{data_type}/last.pth')
|
||||
torch.cuda.is_available()
|
||||
writer.close()
|
|
@ -0,0 +1,68 @@
|
|||
import cv2
|
||||
|
||||
from BPApi import BPModel
|
||||
|
||||
|
||||
def main():
|
||||
cap = cv2.VideoCapture(0) # 使用摄像头
|
||||
|
||||
#设置视频宽高
|
||||
cap.set(3, 1920)
|
||||
cap.set(4, 1080)
|
||||
|
||||
video_fs = cap.get(5)
|
||||
# print(video_fs)
|
||||
|
||||
# 加载模型
|
||||
model = BPModel(model_path=r'final/best.pth', fps=video_fs)
|
||||
|
||||
frames = []
|
||||
|
||||
text = ["calculating..."]
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
|
||||
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
|
||||
# 检测人脸
|
||||
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
||||
faces = face_cascade.detectMultiScale(gray, 1.3, 5)
|
||||
|
||||
if faces is not None and len(faces) > 0:
|
||||
# 将第一个人脸区域的图像截取
|
||||
x, y, w, h = faces[0]
|
||||
face = frame[y:y + h, x:x + w]
|
||||
|
||||
frames.append(face)
|
||||
|
||||
cv2.rectangle(frame, (x, y), (x + w, y + h), (255, 255, 0), 2)
|
||||
print(len(frames))
|
||||
|
||||
if len(frames) == 250:
|
||||
|
||||
sbp_outputs, dbp_outputs = model.predict(frames)
|
||||
|
||||
print(sbp_outputs, dbp_outputs)
|
||||
|
||||
text.clear()
|
||||
text.append('SBP: {:.2f} mmHg'.format(sbp_outputs))
|
||||
text.append('DBP: {:.2f} mmHg'.format(dbp_outputs))
|
||||
|
||||
frames = []
|
||||
# 去除列表最前面的100个元素
|
||||
# frames=frames[50:]
|
||||
|
||||
for i, t in enumerate(text):
|
||||
cv2.putText(frame, t, (10, 60 + i * 20), font, 0.6, (0, 255, 0), 2)
|
||||
cv2.imshow('Blood Pressure Detection', frame)
|
||||
key = cv2.waitKey(1) & 0xFF
|
||||
if key == ord('q'):
|
||||
break
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,34 @@
|
|||
# 基于视觉的表情识别系统
|
||||
|
||||
该项目是一个基于图像的表情识别系统。它使用MobileViT在人脸表情数据集上进行训练,然后可以从摄像头输入的视频中检测人脸,并为每个检测到的人脸预测表情类型,共支持8类表情。
|
||||
|
||||
## 核心文件
|
||||
|
||||
- `class_indices.json`: 包含表情类型标签和对应数值编码的映射。
|
||||
- `predict_api.py`: 包含图像预测模型的加载、预处理和推理逻辑。
|
||||
- `video.py`: 视频处理和可视化的主要脚本。
|
||||
- `best.pth`: 训练的模型权重文件。
|
||||
|
||||
## 使用方法
|
||||
|
||||
1. 确保已安装所需的Python库,包括`opencv-python`、`torch`、`torchvision`、`Pillow`和`dlib`。
|
||||
2. 运行`video.py`脚本。
|
||||
3. 脚本将打开默认摄像头,开始人脸检测和表情预测。
|
||||
4. 检测到的人脸周围会用矩形框标注,并显示预测的表情类型和置信度分数。
|
||||
5. 按`q`键退出程序。
|
||||
|
||||
## 模型介绍
|
||||
|
||||
该项目使用MobileViT作为基础模型,对人脸表情图像数据集进行训练,以预测人脸图像的表情类型。模型输出包含8个值,分别对应各表情类型的概率。
|
||||
|
||||
### 数据集介绍
|
||||
|
||||
该项目使用的表情图像数据集来自网络开源数据,数据集包含35887张标注了皮肤病类型的人体皮肤图像。
|
||||
|
||||
## 算法流程
|
||||
|
||||
1. **人脸检测**: 使用Dlib库中的预训练人脸检测器在视频帧中检测人脸。
|
||||
2. **预处理**: 对检测到的人脸图像进行缩放、裁剪和标准化等预处理,以满足模型的输入要求。
|
||||
3. **推理**: 将预处理后的图像输入到预训练的Mobile-ViT模型中,获得不同表情类型的概率预测结果。
|
||||
4. **后处理**: 选取概率最高的类别作为最终预测结果。
|
||||
5. **可视化**: 在视频帧上绘制人脸矩形框,并显示预测的表情类型和置信度分数。
|
|
@ -0,0 +1,11 @@
|
|||
{
|
||||
"0": "ANGRY",
|
||||
"1": "CONFUSED",
|
||||
"2": "DISGUST",
|
||||
"3": "FEAR",
|
||||
"4": "HAPPY",
|
||||
"5": "NATURAL",
|
||||
"6": "SAD",
|
||||
"7": "SHY",
|
||||
"8": "SURPRISED"
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
{
|
||||
"0": "生气",
|
||||
"1": "困惑",
|
||||
"2": "厌恶",
|
||||
"3": "恐惧",
|
||||
"4": "快乐",
|
||||
"5": "平静",
|
||||
"6": "伤心",
|
||||
"7": "害羞",
|
||||
"8": "惊喜"
|
||||
}
|
|
@ -0,0 +1,562 @@
|
|||
"""
|
||||
original code from apple:
|
||||
https://github.com/apple/ml-cvnets/blob/main/cvnets/models/classification/mobilevit.py
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple, Union, Dict
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .transformer import TransformerEncoder
|
||||
from .model_config import get_config
|
||||
|
||||
|
||||
def make_divisible(
|
||||
v: Union[float, int],
|
||||
divisor: Optional[int] = 8,
|
||||
min_value: Optional[Union[float, int]] = None,
|
||||
) -> Union[float, int]:
|
||||
"""
|
||||
This function is taken from the original tf repo.
|
||||
It ensures that all layers have a channel number that is divisible by 8
|
||||
It can be seen here:
|
||||
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||
:param v:
|
||||
:param divisor:
|
||||
:param min_value:
|
||||
:return:
|
||||
"""
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class ConvLayer(nn.Module):
|
||||
"""
|
||||
Applies a 2D convolution over an input
|
||||
|
||||
Args:
|
||||
in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||
out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out})`
|
||||
kernel_size (Union[int, Tuple[int, int]]): Kernel size for convolution.
|
||||
stride (Union[int, Tuple[int, int]]): Stride for convolution. Default: 1
|
||||
groups (Optional[int]): Number of groups in convolution. Default: 1
|
||||
bias (Optional[bool]): Use bias. Default: ``False``
|
||||
use_norm (Optional[bool]): Use normalization layer after convolution. Default: ``True``
|
||||
use_act (Optional[bool]): Use activation layer after convolution (or convolution and normalization).
|
||||
Default: ``True``
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||
- Output: :math:`(N, C_{out}, H_{out}, W_{out})`
|
||||
|
||||
.. note::
|
||||
For depth-wise convolution, `groups=C_{in}=C_{out}`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Optional[Union[int, Tuple[int, int]]] = 1,
|
||||
groups: Optional[int] = 1,
|
||||
bias: Optional[bool] = False,
|
||||
use_norm: Optional[bool] = True,
|
||||
use_act: Optional[bool] = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride)
|
||||
|
||||
assert isinstance(kernel_size, Tuple)
|
||||
assert isinstance(stride, Tuple)
|
||||
|
||||
padding = (
|
||||
int((kernel_size[0] - 1) / 2),
|
||||
int((kernel_size[1] - 1) / 2),
|
||||
)
|
||||
|
||||
block = nn.Sequential()
|
||||
|
||||
conv_layer = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
groups=groups,
|
||||
padding=padding,
|
||||
bias=bias
|
||||
)
|
||||
|
||||
block.add_module(name="conv", module=conv_layer)
|
||||
|
||||
if use_norm:
|
||||
norm_layer = nn.BatchNorm2d(num_features=out_channels, momentum=0.1)
|
||||
block.add_module(name="norm", module=norm_layer)
|
||||
|
||||
if use_act:
|
||||
act_layer = nn.SiLU()
|
||||
block.add_module(name="act", module=act_layer)
|
||||
|
||||
self.block = block
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
"""
|
||||
This class implements the inverted residual block, as described in `MobileNetv2 <https://arxiv.org/abs/1801.04381>`_ paper
|
||||
|
||||
Args:
|
||||
in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||
out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out)`
|
||||
stride (int): Use convolutions with a stride. Default: 1
|
||||
expand_ratio (Union[int, float]): Expand the input channels by this factor in depth-wise conv
|
||||
skip_connection (Optional[bool]): Use skip-connection. Default: True
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||
- Output: :math:`(N, C_{out}, H_{out}, W_{out})`
|
||||
|
||||
.. note::
|
||||
If `in_channels =! out_channels` and `stride > 1`, we set `skip_connection=False`
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: int,
|
||||
expand_ratio: Union[int, float],
|
||||
skip_connection: Optional[bool] = True,
|
||||
) -> None:
|
||||
assert stride in [1, 2]
|
||||
hidden_dim = make_divisible(int(round(in_channels * expand_ratio)), 8)
|
||||
|
||||
super().__init__()
|
||||
|
||||
block = nn.Sequential()
|
||||
if expand_ratio != 1:
|
||||
block.add_module(
|
||||
name="exp_1x1",
|
||||
module=ConvLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=hidden_dim,
|
||||
kernel_size=1
|
||||
),
|
||||
)
|
||||
|
||||
block.add_module(
|
||||
name="conv_3x3",
|
||||
module=ConvLayer(
|
||||
in_channels=hidden_dim,
|
||||
out_channels=hidden_dim,
|
||||
stride=stride,
|
||||
kernel_size=3,
|
||||
groups=hidden_dim
|
||||
),
|
||||
)
|
||||
|
||||
block.add_module(
|
||||
name="red_1x1",
|
||||
module=ConvLayer(
|
||||
in_channels=hidden_dim,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
use_act=False,
|
||||
use_norm=True,
|
||||
),
|
||||
)
|
||||
|
||||
self.block = block
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.exp = expand_ratio
|
||||
self.stride = stride
|
||||
self.use_res_connect = (
|
||||
self.stride == 1 and in_channels == out_channels and skip_connection
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
|
||||
if self.use_res_connect:
|
||||
return x + self.block(x)
|
||||
else:
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class MobileViTBlock(nn.Module):
|
||||
"""
|
||||
This class defines the `MobileViT block <https://arxiv.org/abs/2110.02178?context=cs.LG>`_
|
||||
|
||||
Args:
|
||||
opts: command line arguments
|
||||
in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H, W)`
|
||||
transformer_dim (int): Input dimension to the transformer unit
|
||||
ffn_dim (int): Dimension of the FFN block
|
||||
n_transformer_blocks (int): Number of transformer blocks. Default: 2
|
||||
head_dim (int): Head dimension in the multi-head attention. Default: 32
|
||||
attn_dropout (float): Dropout in multi-head attention. Default: 0.0
|
||||
dropout (float): Dropout rate. Default: 0.0
|
||||
ffn_dropout (float): Dropout between FFN layers in transformer. Default: 0.0
|
||||
patch_h (int): Patch height for unfolding operation. Default: 8
|
||||
patch_w (int): Patch width for unfolding operation. Default: 8
|
||||
transformer_norm_layer (Optional[str]): Normalization layer in the transformer block. Default: layer_norm
|
||||
conv_ksize (int): Kernel size to learn local representations in MobileViT block. Default: 3
|
||||
no_fusion (Optional[bool]): Do not combine the input and output feature maps. Default: False
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
transformer_dim: int,
|
||||
ffn_dim: int,
|
||||
n_transformer_blocks: int = 2,
|
||||
head_dim: int = 32,
|
||||
attn_dropout: float = 0.0,
|
||||
dropout: float = 0.0,
|
||||
ffn_dropout: float = 0.0,
|
||||
patch_h: int = 8,
|
||||
patch_w: int = 8,
|
||||
conv_ksize: Optional[int] = 3,
|
||||
*args,
|
||||
**kwargs
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
conv_3x3_in = ConvLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=conv_ksize,
|
||||
stride=1
|
||||
)
|
||||
conv_1x1_in = ConvLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=transformer_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
use_norm=False,
|
||||
use_act=False
|
||||
)
|
||||
|
||||
conv_1x1_out = ConvLayer(
|
||||
in_channels=transformer_dim,
|
||||
out_channels=in_channels,
|
||||
kernel_size=1,
|
||||
stride=1
|
||||
)
|
||||
conv_3x3_out = ConvLayer(
|
||||
in_channels=2 * in_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=conv_ksize,
|
||||
stride=1
|
||||
)
|
||||
|
||||
self.local_rep = nn.Sequential()
|
||||
self.local_rep.add_module(name="conv_3x3", module=conv_3x3_in)
|
||||
self.local_rep.add_module(name="conv_1x1", module=conv_1x1_in)
|
||||
|
||||
assert transformer_dim % head_dim == 0
|
||||
num_heads = transformer_dim // head_dim
|
||||
|
||||
global_rep = [
|
||||
TransformerEncoder(
|
||||
embed_dim=transformer_dim,
|
||||
ffn_latent_dim=ffn_dim,
|
||||
num_heads=num_heads,
|
||||
attn_dropout=attn_dropout,
|
||||
dropout=dropout,
|
||||
ffn_dropout=ffn_dropout
|
||||
)
|
||||
for _ in range(n_transformer_blocks)
|
||||
]
|
||||
global_rep.append(nn.LayerNorm(transformer_dim))
|
||||
self.global_rep = nn.Sequential(*global_rep)
|
||||
|
||||
self.conv_proj = conv_1x1_out
|
||||
self.fusion = conv_3x3_out
|
||||
|
||||
self.patch_h = patch_h
|
||||
self.patch_w = patch_w
|
||||
self.patch_area = self.patch_w * self.patch_h
|
||||
|
||||
self.cnn_in_dim = in_channels
|
||||
self.cnn_out_dim = transformer_dim
|
||||
self.n_heads = num_heads
|
||||
self.ffn_dim = ffn_dim
|
||||
self.dropout = dropout
|
||||
self.attn_dropout = attn_dropout
|
||||
self.ffn_dropout = ffn_dropout
|
||||
self.n_blocks = n_transformer_blocks
|
||||
self.conv_ksize = conv_ksize
|
||||
|
||||
def unfolding(self, x: Tensor) -> Tuple[Tensor, Dict]:
|
||||
patch_w, patch_h = self.patch_w, self.patch_h
|
||||
patch_area = patch_w * patch_h
|
||||
batch_size, in_channels, orig_h, orig_w = x.shape
|
||||
|
||||
new_h = int(math.ceil(orig_h / self.patch_h) * self.patch_h)
|
||||
new_w = int(math.ceil(orig_w / self.patch_w) * self.patch_w)
|
||||
|
||||
interpolate = False
|
||||
if new_w != orig_w or new_h != orig_h:
|
||||
# Note: Padding can be done, but then it needs to be handled in attention function.
|
||||
x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=False)
|
||||
interpolate = True
|
||||
|
||||
# number of patches along width and height
|
||||
num_patch_w = new_w // patch_w # n_w
|
||||
num_patch_h = new_h // patch_h # n_h
|
||||
num_patches = num_patch_h * num_patch_w # N
|
||||
|
||||
# [B, C, H, W] -> [B * C * n_h, p_h, n_w, p_w]
|
||||
x = x.reshape(batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w)
|
||||
# [B * C * n_h, p_h, n_w, p_w] -> [B * C * n_h, n_w, p_h, p_w]
|
||||
x = x.transpose(1, 2)
|
||||
# [B * C * n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
|
||||
x = x.reshape(batch_size, in_channels, num_patches, patch_area)
|
||||
# [B, C, N, P] -> [B, P, N, C]
|
||||
x = x.transpose(1, 3)
|
||||
# [B, P, N, C] -> [BP, N, C]
|
||||
x = x.reshape(batch_size * patch_area, num_patches, -1)
|
||||
|
||||
info_dict = {
|
||||
"orig_size": (orig_h, orig_w),
|
||||
"batch_size": batch_size,
|
||||
"interpolate": interpolate,
|
||||
"total_patches": num_patches,
|
||||
"num_patches_w": num_patch_w,
|
||||
"num_patches_h": num_patch_h,
|
||||
}
|
||||
|
||||
return x, info_dict
|
||||
|
||||
def folding(self, x: Tensor, info_dict: Dict) -> Tensor:
|
||||
n_dim = x.dim()
|
||||
assert n_dim == 3, "Tensor should be of shape BPxNxC. Got: {}".format(
|
||||
x.shape
|
||||
)
|
||||
# [BP, N, C] --> [B, P, N, C]
|
||||
x = x.contiguous().view(
|
||||
info_dict["batch_size"], self.patch_area, info_dict["total_patches"], -1
|
||||
)
|
||||
|
||||
batch_size, pixels, num_patches, channels = x.size()
|
||||
num_patch_h = info_dict["num_patches_h"]
|
||||
num_patch_w = info_dict["num_patches_w"]
|
||||
|
||||
# [B, P, N, C] -> [B, C, N, P]
|
||||
x = x.transpose(1, 3)
|
||||
# [B, C, N, P] -> [B*C*n_h, n_w, p_h, p_w]
|
||||
x = x.reshape(batch_size * channels * num_patch_h, num_patch_w, self.patch_h, self.patch_w)
|
||||
# [B*C*n_h, n_w, p_h, p_w] -> [B*C*n_h, p_h, n_w, p_w]
|
||||
x = x.transpose(1, 2)
|
||||
# [B*C*n_h, p_h, n_w, p_w] -> [B, C, H, W]
|
||||
x = x.reshape(batch_size, channels, num_patch_h * self.patch_h, num_patch_w * self.patch_w)
|
||||
if info_dict["interpolate"]:
|
||||
x = F.interpolate(
|
||||
x,
|
||||
size=info_dict["orig_size"],
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
return x
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
res = x
|
||||
|
||||
fm = self.local_rep(x)
|
||||
|
||||
# convert feature map to patches
|
||||
patches, info_dict = self.unfolding(fm)
|
||||
|
||||
# learn global representations
|
||||
for transformer_layer in self.global_rep:
|
||||
patches = transformer_layer(patches)
|
||||
|
||||
# [B x Patch x Patches x C] -> [B x C x Patches x Patch]
|
||||
fm = self.folding(x=patches, info_dict=info_dict)
|
||||
|
||||
fm = self.conv_proj(fm)
|
||||
|
||||
fm = self.fusion(torch.cat((res, fm), dim=1))
|
||||
return fm
|
||||
|
||||
|
||||
class MobileViT(nn.Module):
|
||||
"""
|
||||
This class implements the `MobileViT architecture <https://arxiv.org/abs/2110.02178?context=cs.LG>`_
|
||||
"""
|
||||
def __init__(self, model_cfg: Dict, num_classes: int = 1000):
|
||||
super().__init__()
|
||||
|
||||
image_channels = 3
|
||||
out_channels = 16
|
||||
|
||||
self.conv_1 = ConvLayer(
|
||||
in_channels=image_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2
|
||||
)
|
||||
|
||||
self.layer_1, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer1"])
|
||||
self.layer_2, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer2"])
|
||||
self.layer_3, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer3"])
|
||||
self.layer_4, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer4"])
|
||||
self.layer_5, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer5"])
|
||||
|
||||
exp_channels = min(model_cfg["last_layer_exp_factor"] * out_channels, 960)
|
||||
self.conv_1x1_exp = ConvLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=exp_channels,
|
||||
kernel_size=1
|
||||
)
|
||||
|
||||
self.classifier = nn.Sequential()
|
||||
self.classifier.add_module(name="global_pool", module=nn.AdaptiveAvgPool2d(1))
|
||||
self.classifier.add_module(name="flatten", module=nn.Flatten())
|
||||
if 0.0 < model_cfg["cls_dropout"] < 1.0:
|
||||
self.classifier.add_module(name="dropout", module=nn.Dropout(p=model_cfg["cls_dropout"]))
|
||||
self.classifier.add_module(name="fc", module=nn.Linear(in_features=exp_channels, out_features=num_classes))
|
||||
|
||||
# weight init
|
||||
self.apply(self.init_parameters)
|
||||
|
||||
def _make_layer(self, input_channel, cfg: Dict) -> Tuple[nn.Sequential, int]:
|
||||
block_type = cfg.get("block_type", "mobilevit")
|
||||
if block_type.lower() == "mobilevit":
|
||||
return self._make_mit_layer(input_channel=input_channel, cfg=cfg)
|
||||
else:
|
||||
return self._make_mobilenet_layer(input_channel=input_channel, cfg=cfg)
|
||||
|
||||
@staticmethod
|
||||
def _make_mobilenet_layer(input_channel: int, cfg: Dict) -> Tuple[nn.Sequential, int]:
|
||||
output_channels = cfg.get("out_channels")
|
||||
num_blocks = cfg.get("num_blocks", 2)
|
||||
expand_ratio = cfg.get("expand_ratio", 4)
|
||||
block = []
|
||||
|
||||
for i in range(num_blocks):
|
||||
stride = cfg.get("stride", 1) if i == 0 else 1
|
||||
|
||||
layer = InvertedResidual(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channels,
|
||||
stride=stride,
|
||||
expand_ratio=expand_ratio
|
||||
)
|
||||
block.append(layer)
|
||||
input_channel = output_channels
|
||||
|
||||
return nn.Sequential(*block), input_channel
|
||||
|
||||
@staticmethod
|
||||
def _make_mit_layer(input_channel: int, cfg: Dict) -> [nn.Sequential, int]:
|
||||
stride = cfg.get("stride", 1)
|
||||
block = []
|
||||
|
||||
if stride == 2:
|
||||
layer = InvertedResidual(
|
||||
in_channels=input_channel,
|
||||
out_channels=cfg.get("out_channels"),
|
||||
stride=stride,
|
||||
expand_ratio=cfg.get("mv_expand_ratio", 4)
|
||||
)
|
||||
|
||||
block.append(layer)
|
||||
input_channel = cfg.get("out_channels")
|
||||
|
||||
transformer_dim = cfg["transformer_channels"]
|
||||
ffn_dim = cfg.get("ffn_dim")
|
||||
num_heads = cfg.get("num_heads", 4)
|
||||
head_dim = transformer_dim // num_heads
|
||||
|
||||
if transformer_dim % head_dim != 0:
|
||||
raise ValueError("Transformer input dimension should be divisible by head dimension. "
|
||||
"Got {} and {}.".format(transformer_dim, head_dim))
|
||||
|
||||
block.append(MobileViTBlock(
|
||||
in_channels=input_channel,
|
||||
transformer_dim=transformer_dim,
|
||||
ffn_dim=ffn_dim,
|
||||
n_transformer_blocks=cfg.get("transformer_blocks", 1),
|
||||
patch_h=cfg.get("patch_h", 2),
|
||||
patch_w=cfg.get("patch_w", 2),
|
||||
dropout=cfg.get("dropout", 0.1),
|
||||
ffn_dropout=cfg.get("ffn_dropout", 0.0),
|
||||
attn_dropout=cfg.get("attn_dropout", 0.1),
|
||||
head_dim=head_dim,
|
||||
conv_ksize=3
|
||||
))
|
||||
|
||||
return nn.Sequential(*block), input_channel
|
||||
|
||||
@staticmethod
|
||||
def init_parameters(m):
|
||||
if isinstance(m, nn.Conv2d):
|
||||
if m.weight is not None:
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
|
||||
if m.weight is not None:
|
||||
nn.init.ones_(m.weight)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, (nn.Linear,)):
|
||||
if m.weight is not None:
|
||||
nn.init.trunc_normal_(m.weight, mean=0.0, std=0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
else:
|
||||
pass
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.conv_1(x)
|
||||
x = self.layer_1(x)
|
||||
x = self.layer_2(x)
|
||||
|
||||
x = self.layer_3(x)
|
||||
x = self.layer_4(x)
|
||||
x = self.layer_5(x)
|
||||
x = self.conv_1x1_exp(x)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
def mobile_vit_xx_small(num_classes: int = 1000):
|
||||
# pretrain weight link
|
||||
# https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_xxs.pt
|
||||
config = get_config("xx_small")
|
||||
m = MobileViT(config, num_classes=num_classes)
|
||||
return m
|
||||
|
||||
|
||||
def mobile_vit_x_small(num_classes: int = 1000):
|
||||
# pretrain weight link
|
||||
# https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_xs.pt
|
||||
config = get_config("x_small")
|
||||
m = MobileViT(config, num_classes=num_classes)
|
||||
return m
|
||||
|
||||
|
||||
def mobile_vit_small(num_classes: int = 1000):
|
||||
# pretrain weight link
|
||||
# https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_s.pt
|
||||
config = get_config("small")
|
||||
m = MobileViT(config, num_classes=num_classes)
|
||||
return m
|
|
@ -0,0 +1,176 @@
|
|||
def get_config(mode: str = "xxs") -> dict:
|
||||
if mode == "xx_small":
|
||||
mv2_exp_mult = 2
|
||||
config = {
|
||||
"layer1": {
|
||||
"out_channels": 16,
|
||||
"expand_ratio": mv2_exp_mult,
|
||||
"num_blocks": 1,
|
||||
"stride": 1,
|
||||
"block_type": "mv2",
|
||||
},
|
||||
"layer2": {
|
||||
"out_channels": 24,
|
||||
"expand_ratio": mv2_exp_mult,
|
||||
"num_blocks": 3,
|
||||
"stride": 2,
|
||||
"block_type": "mv2",
|
||||
},
|
||||
"layer3": { # 28x28
|
||||
"out_channels": 48,
|
||||
"transformer_channels": 64,
|
||||
"ffn_dim": 128,
|
||||
"transformer_blocks": 2,
|
||||
"patch_h": 2, # 8,
|
||||
"patch_w": 2, # 8,
|
||||
"stride": 2,
|
||||
"mv_expand_ratio": mv2_exp_mult,
|
||||
"num_heads": 4,
|
||||
"block_type": "mobilevit",
|
||||
},
|
||||
"layer4": { # 14x14
|
||||
"out_channels": 64,
|
||||
"transformer_channels": 80,
|
||||
"ffn_dim": 160,
|
||||
"transformer_blocks": 4,
|
||||
"patch_h": 2, # 4,
|
||||
"patch_w": 2, # 4,
|
||||
"stride": 2,
|
||||
"mv_expand_ratio": mv2_exp_mult,
|
||||
"num_heads": 4,
|
||||
"block_type": "mobilevit",
|
||||
},
|
||||
"layer5": { # 7x7
|
||||
"out_channels": 80,
|
||||
"transformer_channels": 96,
|
||||
"ffn_dim": 192,
|
||||
"transformer_blocks": 3,
|
||||
"patch_h": 2,
|
||||
"patch_w": 2,
|
||||
"stride": 2,
|
||||
"mv_expand_ratio": mv2_exp_mult,
|
||||
"num_heads": 4,
|
||||
"block_type": "mobilevit",
|
||||
},
|
||||
"last_layer_exp_factor": 4,
|
||||
"cls_dropout": 0.1
|
||||
}
|
||||
elif mode == "x_small":
|
||||
mv2_exp_mult = 4
|
||||
config = {
|
||||
"layer1": {
|
||||
"out_channels": 32,
|
||||
"expand_ratio": mv2_exp_mult,
|
||||
"num_blocks": 1,
|
||||
"stride": 1,
|
||||
"block_type": "mv2",
|
||||
},
|
||||
"layer2": {
|
||||
"out_channels": 48,
|
||||
"expand_ratio": mv2_exp_mult,
|
||||
"num_blocks": 3,
|
||||
"stride": 2,
|
||||
"block_type": "mv2",
|
||||
},
|
||||
"layer3": { # 28x28
|
||||
"out_channels": 64,
|
||||
"transformer_channels": 96,
|
||||
"ffn_dim": 192,
|
||||
"transformer_blocks": 2,
|
||||
"patch_h": 2,
|
||||
"patch_w": 2,
|
||||
"stride": 2,
|
||||
"mv_expand_ratio": mv2_exp_mult,
|
||||
"num_heads": 4,
|
||||
"block_type": "mobilevit",
|
||||
},
|
||||
"layer4": { # 14x14
|
||||
"out_channels": 80,
|
||||
"transformer_channels": 120,
|
||||
"ffn_dim": 240,
|
||||
"transformer_blocks": 4,
|
||||
"patch_h": 2,
|
||||
"patch_w": 2,
|
||||
"stride": 2,
|
||||
"mv_expand_ratio": mv2_exp_mult,
|
||||
"num_heads": 4,
|
||||
"block_type": "mobilevit",
|
||||
},
|
||||
"layer5": { # 7x7
|
||||
"out_channels": 96,
|
||||
"transformer_channels": 144,
|
||||
"ffn_dim": 288,
|
||||
"transformer_blocks": 3,
|
||||
"patch_h": 2,
|
||||
"patch_w": 2,
|
||||
"stride": 2,
|
||||
"mv_expand_ratio": mv2_exp_mult,
|
||||
"num_heads": 4,
|
||||
"block_type": "mobilevit",
|
||||
},
|
||||
"last_layer_exp_factor": 4,
|
||||
"cls_dropout": 0.1
|
||||
}
|
||||
elif mode == "small":
|
||||
mv2_exp_mult = 4
|
||||
config = {
|
||||
"layer1": {
|
||||
"out_channels": 32,
|
||||
"expand_ratio": mv2_exp_mult,
|
||||
"num_blocks": 1,
|
||||
"stride": 1,
|
||||
"block_type": "mv2",
|
||||
},
|
||||
"layer2": {
|
||||
"out_channels": 64,
|
||||
"expand_ratio": mv2_exp_mult,
|
||||
"num_blocks": 3,
|
||||
"stride": 2,
|
||||
"block_type": "mv2",
|
||||
},
|
||||
"layer3": { # 28x28
|
||||
"out_channels": 96,
|
||||
"transformer_channels": 144,
|
||||
"ffn_dim": 288,
|
||||
"transformer_blocks": 2,
|
||||
"patch_h": 2,
|
||||
"patch_w": 2,
|
||||
"stride": 2,
|
||||
"mv_expand_ratio": mv2_exp_mult,
|
||||
"num_heads": 4,
|
||||
"block_type": "mobilevit",
|
||||
},
|
||||
"layer4": { # 14x14
|
||||
"out_channels": 128,
|
||||
"transformer_channels": 192,
|
||||
"ffn_dim": 384,
|
||||
"transformer_blocks": 4,
|
||||
"patch_h": 2,
|
||||
"patch_w": 2,
|
||||
"stride": 2,
|
||||
"mv_expand_ratio": mv2_exp_mult,
|
||||
"num_heads": 4,
|
||||
"block_type": "mobilevit",
|
||||
},
|
||||
"layer5": { # 7x7
|
||||
"out_channels": 160,
|
||||
"transformer_channels": 240,
|
||||
"ffn_dim": 480,
|
||||
"transformer_blocks": 3,
|
||||
"patch_h": 2,
|
||||
"patch_w": 2,
|
||||
"stride": 2,
|
||||
"mv_expand_ratio": mv2_exp_mult,
|
||||
"num_heads": 4,
|
||||
"block_type": "mobilevit",
|
||||
},
|
||||
"last_layer_exp_factor": 4,
|
||||
"cls_dropout": 0.1
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
for k in ["layer1", "layer2", "layer3", "layer4", "layer5"]:
|
||||
config[k].update({"dropout": 0.1, "ffn_dropout": 0.0, "attn_dropout": 0.0})
|
||||
|
||||
return config
|
|
@ -0,0 +1,38 @@
|
|||
from PIL import Image
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class MyDataSet(Dataset):
|
||||
"""自定义数据集"""
|
||||
|
||||
def __init__(self, images_path: list, images_class: list, transform=None):
|
||||
self.images_path = images_path
|
||||
self.images_class = images_class
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images_path)
|
||||
|
||||
def __getitem__(self, item):
|
||||
img = Image.open(self.images_path[item])
|
||||
# RGB为彩色图片,L为灰度图片
|
||||
if img.mode != 'RGB':
|
||||
# img = img.convert('RGB')
|
||||
raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
|
||||
label = self.images_class[item]
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
return img, label
|
||||
|
||||
@staticmethod
|
||||
def collate_fn(batch):
|
||||
# 官方实现的default_collate可以参考
|
||||
# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
|
||||
images, labels = tuple(zip(*batch))
|
||||
|
||||
images = torch.stack(images, dim=0)
|
||||
labels = torch.as_tensor(labels)
|
||||
return images, labels
|
|
@ -0,0 +1,61 @@
|
|||
import os
|
||||
import json
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from model import mobile_vit_small as create_model
|
||||
|
||||
|
||||
def main():
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
img_size = 224
|
||||
data_transform = transforms.Compose(
|
||||
[transforms.Resize(int(img_size * 1.14)),
|
||||
transforms.CenterCrop(img_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
|
||||
# load image
|
||||
img_path = "../tulip.jpg"
|
||||
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
|
||||
img = Image.open(img_path)
|
||||
plt.imshow(img)
|
||||
# [N, C, H, W]
|
||||
img = data_transform(img)
|
||||
# expand batch dimension
|
||||
img = torch.unsqueeze(img, dim=0)
|
||||
|
||||
# read class_indict
|
||||
json_path = './class_indices.json'
|
||||
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
|
||||
|
||||
with open(json_path, "r") as f:
|
||||
class_indict = json.load(f)
|
||||
|
||||
# create model
|
||||
model = create_model(num_classes=5).to(device)
|
||||
# load model weights
|
||||
model_weight_path = "./weights/best_model.pth"
|
||||
model.load_state_dict(torch.load(model_weight_path, map_location=device))
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
# predict class
|
||||
output = torch.squeeze(model(img.to(device))).cpu()
|
||||
predict = torch.softmax(output, dim=0)
|
||||
predict_cla = torch.argmax(predict).numpy()
|
||||
|
||||
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
|
||||
predict[predict_cla].numpy())
|
||||
plt.title(print_res)
|
||||
for i in range(len(predict)):
|
||||
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
|
||||
predict[i].numpy()))
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,56 @@
|
|||
import json
|
||||
import cv2
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from .model import mobile_vit_small as create_model
|
||||
|
||||
class EmotionPredictor:
|
||||
def __init__(self, model_path, class_indices_path, img_size=224):
|
||||
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
self.img_size = img_size
|
||||
self.data_transform = transforms.Compose([
|
||||
transforms.Resize(int(self.img_size * 1.14)),
|
||||
transforms.CenterCrop(self.img_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
# Load class indices
|
||||
with open(class_indices_path, "r",encoding="utf-8") as f:
|
||||
self.class_indict = json.load(f)
|
||||
# Load model
|
||||
self.model = self.load_model(model_path)
|
||||
|
||||
def load_model(self, model_path):
|
||||
|
||||
model = create_model(num_classes=9).to(self.device)
|
||||
model.load_state_dict(torch.load(model_path, map_location=self.device))
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def predict(self, np_image):
|
||||
# Convert numpy image to PIL image
|
||||
img = Image.fromarray(np_image).convert('RGB')
|
||||
|
||||
# Transform image
|
||||
img = self.data_transform(img)
|
||||
img = torch.unsqueeze(img, dim=0)
|
||||
|
||||
# Predict class
|
||||
with torch.no_grad():
|
||||
output = torch.squeeze(self.model(img.to(self.device))).cpu()
|
||||
probabilities = torch.softmax(output, dim=0)
|
||||
top_prob, top_catid = torch.topk(probabilities, 1)
|
||||
|
||||
top1 = {
|
||||
"name": self.class_indict[str(top_catid[0].item())],
|
||||
"score": top_prob[0].item(),
|
||||
"label": top_catid[0].item()
|
||||
}
|
||||
|
||||
return top1["name"]
|
||||
|
||||
# Example usage:
|
||||
# predictor = ImagePredictor(model_path="./weights/best_model.pth", class_indices_path="./class_indices.json")
|
||||
# result = predictor.predict("../tulip.jpg")
|
||||
# print(result)
|
|
@ -0,0 +1,135 @@
|
|||
import os
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torchvision import transforms
|
||||
|
||||
from my_dataset import MyDataSet
|
||||
from model import mobile_vit_xx_small as create_model
|
||||
from utils import read_split_data, train_one_epoch, evaluate
|
||||
|
||||
|
||||
def main(args):
|
||||
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
||||
|
||||
if os.path.exists("./weights") is False:
|
||||
os.makedirs("./weights")
|
||||
|
||||
tb_writer = SummaryWriter()
|
||||
|
||||
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)
|
||||
|
||||
img_size = 224
|
||||
data_transform = {
|
||||
"train": transforms.Compose([transforms.RandomResizedCrop(img_size),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
|
||||
"val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),
|
||||
transforms.CenterCrop(img_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
|
||||
|
||||
# 实例化训练数据集
|
||||
train_dataset = MyDataSet(images_path=train_images_path,
|
||||
images_class=train_images_label,
|
||||
transform=data_transform["train"])
|
||||
|
||||
# 实例化验证数据集
|
||||
val_dataset = MyDataSet(images_path=val_images_path,
|
||||
images_class=val_images_label,
|
||||
transform=data_transform["val"])
|
||||
|
||||
batch_size = args.batch_size
|
||||
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
|
||||
print('Using {} dataloader workers every process'.format(nw))
|
||||
train_loader = torch.utils.data.DataLoader(train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=True,
|
||||
num_workers=nw,
|
||||
collate_fn=train_dataset.collate_fn)
|
||||
|
||||
val_loader = torch.utils.data.DataLoader(val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
pin_memory=True,
|
||||
num_workers=nw,
|
||||
collate_fn=val_dataset.collate_fn)
|
||||
|
||||
model = create_model(num_classes=args.num_classes).to(device)
|
||||
|
||||
if args.weights != "":
|
||||
assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
|
||||
weights_dict = torch.load(args.weights, map_location=device)
|
||||
weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
|
||||
# 删除有关分类类别的权重
|
||||
for k in list(weights_dict.keys()):
|
||||
if "classifier" in k:
|
||||
del weights_dict[k]
|
||||
print(model.load_state_dict(weights_dict, strict=False))
|
||||
|
||||
if args.freeze_layers:
|
||||
for name, para in model.named_parameters():
|
||||
# 除head外,其他权重全部冻结
|
||||
if "classifier" not in name:
|
||||
para.requires_grad_(False)
|
||||
else:
|
||||
print("training {}".format(name))
|
||||
|
||||
pg = [p for p in model.parameters() if p.requires_grad]
|
||||
optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=1E-2)
|
||||
|
||||
best_acc = 0.
|
||||
for epoch in range(args.epochs):
|
||||
# train
|
||||
train_loss, train_acc = train_one_epoch(model=model,
|
||||
optimizer=optimizer,
|
||||
data_loader=train_loader,
|
||||
device=device,
|
||||
epoch=epoch)
|
||||
|
||||
# validate
|
||||
val_loss, val_acc = evaluate(model=model,
|
||||
data_loader=val_loader,
|
||||
device=device,
|
||||
epoch=epoch)
|
||||
|
||||
tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
|
||||
tb_writer.add_scalar(tags[0], train_loss, epoch)
|
||||
tb_writer.add_scalar(tags[1], train_acc, epoch)
|
||||
tb_writer.add_scalar(tags[2], val_loss, epoch)
|
||||
tb_writer.add_scalar(tags[3], val_acc, epoch)
|
||||
tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)
|
||||
|
||||
if val_acc > best_acc:
|
||||
best_acc = val_acc
|
||||
torch.save(model.state_dict(), "./weights/best_model.pth")
|
||||
|
||||
torch.save(model.state_dict(), "./weights/latest_model.pth")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--num_classes', type=int, default=5)
|
||||
parser.add_argument('--epochs', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=8)
|
||||
parser.add_argument('--lr', type=float, default=0.0002)
|
||||
|
||||
# 数据集所在根目录
|
||||
# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
|
||||
parser.add_argument('--data-path', type=str,
|
||||
default="/data/flower_photos")
|
||||
|
||||
# 预训练权重路径,如果不想载入就设置为空字符
|
||||
parser.add_argument('--weights', type=str, default='./mobilevit_xxs.pt',
|
||||
help='initial weights path')
|
||||
# 是否冻结权重
|
||||
parser.add_argument('--freeze-layers', type=bool, default=False)
|
||||
parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')
|
||||
|
||||
opt = parser.parse_args()
|
||||
|
||||
main(opt)
|
|
@ -0,0 +1,155 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""
|
||||
This layer applies a multi-head self- or cross-attention as described in
|
||||
`Attention is all you need <https://arxiv.org/abs/1706.03762>`_ paper
|
||||
|
||||
Args:
|
||||
embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`
|
||||
num_heads (int): Number of heads in multi-head attention
|
||||
attn_dropout (float): Attention dropout. Default: 0.0
|
||||
bias (bool): Use bias or not. Default: ``True``
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
|
||||
and :math:`C_{in}` is input embedding dim
|
||||
- Output: same shape as the input
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
attn_dropout: float = 0.0,
|
||||
bias: bool = True,
|
||||
*args,
|
||||
**kwargs
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if embed_dim % num_heads != 0:
|
||||
raise ValueError(
|
||||
"Embedding dim must be divisible by number of heads in {}. Got: embed_dim={} and num_heads={}".format(
|
||||
self.__class__.__name__, embed_dim, num_heads
|
||||
)
|
||||
)
|
||||
|
||||
self.qkv_proj = nn.Linear(in_features=embed_dim, out_features=3 * embed_dim, bias=bias)
|
||||
|
||||
self.attn_dropout = nn.Dropout(p=attn_dropout)
|
||||
self.out_proj = nn.Linear(in_features=embed_dim, out_features=embed_dim, bias=bias)
|
||||
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
self.num_heads = num_heads
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def forward(self, x_q: Tensor) -> Tensor:
|
||||
# [N, P, C]
|
||||
b_sz, n_patches, in_channels = x_q.shape
|
||||
|
||||
# self-attention
|
||||
# [N, P, C] -> [N, P, 3C] -> [N, P, 3, h, c] where C = hc
|
||||
qkv = self.qkv_proj(x_q).reshape(b_sz, n_patches, 3, self.num_heads, -1)
|
||||
|
||||
# [N, P, 3, h, c] -> [N, h, 3, P, C]
|
||||
qkv = qkv.transpose(1, 3).contiguous()
|
||||
|
||||
# [N, h, 3, P, C] -> [N, h, P, C] x 3
|
||||
query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
|
||||
|
||||
query = query * self.scaling
|
||||
|
||||
# [N h, P, c] -> [N, h, c, P]
|
||||
key = key.transpose(-1, -2)
|
||||
|
||||
# QK^T
|
||||
# [N, h, P, c] x [N, h, c, P] -> [N, h, P, P]
|
||||
attn = torch.matmul(query, key)
|
||||
attn = self.softmax(attn)
|
||||
attn = self.attn_dropout(attn)
|
||||
|
||||
# weighted sum
|
||||
# [N, h, P, P] x [N, h, P, c] -> [N, h, P, c]
|
||||
out = torch.matmul(attn, value)
|
||||
|
||||
# [N, h, P, c] -> [N, P, h, c] -> [N, P, C]
|
||||
out = out.transpose(1, 2).reshape(b_sz, n_patches, -1)
|
||||
out = self.out_proj(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
"""
|
||||
This class defines the pre-norm `Transformer encoder <https://arxiv.org/abs/1706.03762>`_
|
||||
Args:
|
||||
embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`
|
||||
ffn_latent_dim (int): Inner dimension of the FFN
|
||||
num_heads (int) : Number of heads in multi-head attention. Default: 8
|
||||
attn_dropout (float): Dropout rate for attention in multi-head attention. Default: 0.0
|
||||
dropout (float): Dropout rate. Default: 0.0
|
||||
ffn_dropout (float): Dropout between FFN layers. Default: 0.0
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
|
||||
and :math:`C_{in}` is input embedding dim
|
||||
- Output: same shape as the input
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
ffn_latent_dim: int,
|
||||
num_heads: Optional[int] = 8,
|
||||
attn_dropout: Optional[float] = 0.0,
|
||||
dropout: Optional[float] = 0.0,
|
||||
ffn_dropout: Optional[float] = 0.0,
|
||||
*args,
|
||||
**kwargs
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
|
||||
attn_unit = MultiHeadAttention(
|
||||
embed_dim,
|
||||
num_heads,
|
||||
attn_dropout=attn_dropout,
|
||||
bias=True
|
||||
)
|
||||
|
||||
self.pre_norm_mha = nn.Sequential(
|
||||
nn.LayerNorm(embed_dim),
|
||||
attn_unit,
|
||||
nn.Dropout(p=dropout)
|
||||
)
|
||||
|
||||
self.pre_norm_ffn = nn.Sequential(
|
||||
nn.LayerNorm(embed_dim),
|
||||
nn.Linear(in_features=embed_dim, out_features=ffn_latent_dim, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=ffn_dropout),
|
||||
nn.Linear(in_features=ffn_latent_dim, out_features=embed_dim, bias=True),
|
||||
nn.Dropout(p=dropout)
|
||||
)
|
||||
self.embed_dim = embed_dim
|
||||
self.ffn_dim = ffn_latent_dim
|
||||
self.ffn_dropout = ffn_dropout
|
||||
self.std_dropout = dropout
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# multi-head attention
|
||||
res = x
|
||||
x = self.pre_norm_mha(x)
|
||||
x = x + res
|
||||
|
||||
# feed forward network
|
||||
x = x + self.pre_norm_ffn(x)
|
||||
return x
|
|
@ -0,0 +1,56 @@
|
|||
import time
|
||||
import torch
|
||||
|
||||
batch_size = 8
|
||||
in_channels = 32
|
||||
patch_h = 2
|
||||
patch_w = 2
|
||||
num_patch_h = 16
|
||||
num_patch_w = 16
|
||||
num_patches = num_patch_h * num_patch_w
|
||||
patch_area = patch_h * patch_w
|
||||
|
||||
|
||||
def official(x: torch.Tensor):
|
||||
# [B, C, H, W] -> [B * C * n_h, p_h, n_w, p_w]
|
||||
x = x.reshape(batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w)
|
||||
# [B * C * n_h, p_h, n_w, p_w] -> [B * C * n_h, n_w, p_h, p_w]
|
||||
x = x.transpose(1, 2)
|
||||
# [B * C * n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
|
||||
x = x.reshape(batch_size, in_channels, num_patches, patch_area)
|
||||
# [B, C, N, P] -> [B, P, N, C]
|
||||
x = x.transpose(1, 3)
|
||||
# [B, P, N, C] -> [BP, N, C]
|
||||
x = x.reshape(batch_size * patch_area, num_patches, -1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def my_self(x: torch.Tensor):
|
||||
# [B, C, H, W] -> [B, C, n_h, p_h, n_w, p_w]
|
||||
x = x.reshape(batch_size, in_channels, num_patch_h, patch_h, num_patch_w, patch_w)
|
||||
# [B, C, n_h, p_h, n_w, p_w] -> [B, C, n_h, n_w, p_h, p_w]
|
||||
x = x.transpose(3, 4)
|
||||
# [B, C, n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
|
||||
x = x.reshape(batch_size, in_channels, num_patches, patch_area)
|
||||
# [B, C, N, P] -> [B, P, N, C]
|
||||
x = x.transpose(1, 3)
|
||||
# [B, P, N, C] -> [BP, N, C]
|
||||
x = x.reshape(batch_size * patch_area, num_patches, -1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
t = torch.randn(batch_size, in_channels, num_patch_h * patch_h, num_patch_w * patch_w)
|
||||
print(torch.equal(official(t), my_self(t)))
|
||||
|
||||
t1 = time.time()
|
||||
for _ in range(1000):
|
||||
official(t)
|
||||
print(f"official time: {time.time() - t1}")
|
||||
|
||||
t1 = time.time()
|
||||
for _ in range(1000):
|
||||
my_self(t)
|
||||
print(f"self time: {time.time() - t1}")
|
|
@ -0,0 +1,179 @@
|
|||
import os
|
||||
import sys
|
||||
import json
|
||||
import pickle
|
||||
import random
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def read_split_data(root: str, val_rate: float = 0.2):
|
||||
random.seed(0) # 保证随机结果可复现
|
||||
assert os.path.exists(root), "dataset root: {} does not exist.".format(root)
|
||||
|
||||
# 遍历文件夹,一个文件夹对应一个类别
|
||||
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
|
||||
# 排序,保证各平台顺序一致
|
||||
flower_class.sort()
|
||||
# 生成类别名称以及对应的数字索引
|
||||
class_indices = dict((k, v) for v, k in enumerate(flower_class))
|
||||
json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
|
||||
with open('class_indices.json', 'w') as json_file:
|
||||
json_file.write(json_str)
|
||||
|
||||
train_images_path = [] # 存储训练集的所有图片路径
|
||||
train_images_label = [] # 存储训练集图片对应索引信息
|
||||
val_images_path = [] # 存储验证集的所有图片路径
|
||||
val_images_label = [] # 存储验证集图片对应索引信息
|
||||
every_class_num = [] # 存储每个类别的样本总数
|
||||
supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型
|
||||
# 遍历每个文件夹下的文件
|
||||
for cla in flower_class:
|
||||
cla_path = os.path.join(root, cla)
|
||||
# 遍历获取supported支持的所有文件路径
|
||||
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
|
||||
if os.path.splitext(i)[-1] in supported]
|
||||
# 排序,保证各平台顺序一致
|
||||
images.sort()
|
||||
# 获取该类别对应的索引
|
||||
image_class = class_indices[cla]
|
||||
# 记录该类别的样本数量
|
||||
every_class_num.append(len(images))
|
||||
# 按比例随机采样验证样本
|
||||
val_path = random.sample(images, k=int(len(images) * val_rate))
|
||||
|
||||
for img_path in images:
|
||||
if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集
|
||||
val_images_path.append(img_path)
|
||||
val_images_label.append(image_class)
|
||||
else: # 否则存入训练集
|
||||
train_images_path.append(img_path)
|
||||
train_images_label.append(image_class)
|
||||
|
||||
print("{} images were found in the dataset.".format(sum(every_class_num)))
|
||||
print("{} images for training.".format(len(train_images_path)))
|
||||
print("{} images for validation.".format(len(val_images_path)))
|
||||
assert len(train_images_path) > 0, "number of training images must greater than 0."
|
||||
assert len(val_images_path) > 0, "number of validation images must greater than 0."
|
||||
|
||||
plot_image = False
|
||||
if plot_image:
|
||||
# 绘制每种类别个数柱状图
|
||||
plt.bar(range(len(flower_class)), every_class_num, align='center')
|
||||
# 将横坐标0,1,2,3,4替换为相应的类别名称
|
||||
plt.xticks(range(len(flower_class)), flower_class)
|
||||
# 在柱状图上添加数值标签
|
||||
for i, v in enumerate(every_class_num):
|
||||
plt.text(x=i, y=v + 5, s=str(v), ha='center')
|
||||
# 设置x坐标
|
||||
plt.xlabel('image class')
|
||||
# 设置y坐标
|
||||
plt.ylabel('number of images')
|
||||
# 设置柱状图的标题
|
||||
plt.title('flower class distribution')
|
||||
plt.show()
|
||||
|
||||
return train_images_path, train_images_label, val_images_path, val_images_label
|
||||
|
||||
|
||||
def plot_data_loader_image(data_loader):
|
||||
batch_size = data_loader.batch_size
|
||||
plot_num = min(batch_size, 4)
|
||||
|
||||
json_path = './class_indices.json'
|
||||
assert os.path.exists(json_path), json_path + " does not exist."
|
||||
json_file = open(json_path, 'r')
|
||||
class_indices = json.load(json_file)
|
||||
|
||||
for data in data_loader:
|
||||
images, labels = data
|
||||
for i in range(plot_num):
|
||||
# [C, H, W] -> [H, W, C]
|
||||
img = images[i].numpy().transpose(1, 2, 0)
|
||||
# 反Normalize操作
|
||||
img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
|
||||
label = labels[i].item()
|
||||
plt.subplot(1, plot_num, i+1)
|
||||
plt.xlabel(class_indices[str(label)])
|
||||
plt.xticks([]) # 去掉x轴的刻度
|
||||
plt.yticks([]) # 去掉y轴的刻度
|
||||
plt.imshow(img.astype('uint8'))
|
||||
plt.show()
|
||||
|
||||
|
||||
def write_pickle(list_info: list, file_name: str):
|
||||
with open(file_name, 'wb') as f:
|
||||
pickle.dump(list_info, f)
|
||||
|
||||
|
||||
def read_pickle(file_name: str) -> list:
|
||||
with open(file_name, 'rb') as f:
|
||||
info_list = pickle.load(f)
|
||||
return info_list
|
||||
|
||||
|
||||
def train_one_epoch(model, optimizer, data_loader, device, epoch):
|
||||
model.train()
|
||||
loss_function = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
|
||||
accu_loss = torch.zeros(1).to(device) # 累计损失
|
||||
accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
|
||||
optimizer.zero_grad()
|
||||
|
||||
sample_num = 0
|
||||
data_loader = tqdm(data_loader, file=sys.stdout)
|
||||
for step, data in enumerate(data_loader):
|
||||
images, labels = data
|
||||
sample_num += images.shape[0]
|
||||
|
||||
pred = model(images.to(device))
|
||||
pred_classes = torch.max(pred, dim=1)[1]
|
||||
accu_num += torch.eq(pred_classes, labels.to(device)).sum()
|
||||
|
||||
loss = loss_function(pred, labels.to(device))
|
||||
loss.backward()
|
||||
accu_loss += loss.detach()
|
||||
|
||||
data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
|
||||
accu_loss.item() / (step + 1),
|
||||
accu_num.item() / sample_num)
|
||||
|
||||
if not torch.isfinite(loss):
|
||||
print('WARNING: non-finite loss, ending training ', loss)
|
||||
sys.exit(1)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
return accu_loss.item() / (step + 1), accu_num.item() / sample_num
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(model, data_loader, device, epoch):
|
||||
loss_function = torch.nn.CrossEntropyLoss()
|
||||
|
||||
model.eval()
|
||||
|
||||
accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
|
||||
accu_loss = torch.zeros(1).to(device) # 累计损失
|
||||
|
||||
sample_num = 0
|
||||
data_loader = tqdm(data_loader, file=sys.stdout)
|
||||
for step, data in enumerate(data_loader):
|
||||
images, labels = data
|
||||
sample_num += images.shape[0]
|
||||
|
||||
pred = model(images.to(device))
|
||||
pred_classes = torch.max(pred, dim=1)[1]
|
||||
accu_num += torch.eq(pred_classes, labels.to(device)).sum()
|
||||
|
||||
loss = loss_function(pred, labels.to(device))
|
||||
accu_loss += loss
|
||||
|
||||
data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
|
||||
accu_loss.item() / (step + 1),
|
||||
accu_num.item() / sample_num)
|
||||
|
||||
return accu_loss.item() / (step + 1), accu_num.item() / sample_num
|
|
@ -0,0 +1,84 @@
|
|||
import cv2
|
||||
import dlib
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from predict_api import ImagePredictor
|
||||
|
||||
|
||||
def draw_chinese_text(image, text, position, color=(0, 255, 0)):
|
||||
# Convert cv2 image to PIL image
|
||||
image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||||
|
||||
# Create a blank image with alpha channel, same size as original image
|
||||
blank = Image.new('RGBA', image_pil.size, (0, 0, 0, 0))
|
||||
|
||||
# Create a draw object and draw text on the blank image
|
||||
draw = ImageDraw.Draw(blank)
|
||||
font = ImageFont.truetype("simhei.ttf", 20)
|
||||
draw.text(position, text, fill=color, font=font)
|
||||
|
||||
# Composite the original image with the blank image
|
||||
image_pil = Image.alpha_composite(image_pil.convert('RGBA'), blank)
|
||||
|
||||
# Convert PIL image back to cv2 image
|
||||
image = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
# Initialize face detector
|
||||
detector = dlib.get_frontal_face_detector()
|
||||
|
||||
# Initialize ImagePredictor
|
||||
predictor = ImagePredictor(model_path="./best.pth", class_indices_path="./class_indices.json")
|
||||
|
||||
# Open the webcam
|
||||
cap = cv2.VideoCapture(0)
|
||||
|
||||
while True:
|
||||
# Read a frame from the webcam
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
# Convert the frame to grayscale
|
||||
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Detect faces in the frame
|
||||
faces = detector(gray)
|
||||
|
||||
for rect in faces:
|
||||
# Get the coordinates of the face rectangle
|
||||
x = rect.left()
|
||||
y = rect.top()
|
||||
w = rect.width()
|
||||
h = rect.height()
|
||||
|
||||
# Crop the face from the frame
|
||||
face = frame[y:y+h, x:x+w]
|
||||
|
||||
# Predict the emotion of the face
|
||||
result = predictor.predict(face)
|
||||
|
||||
# Get the emotion with the highest score
|
||||
emotion = result["result"]["name"]
|
||||
|
||||
# Draw the rectangle around the face
|
||||
cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 255, 0), 2)
|
||||
|
||||
# Put the emotion text above the rectangle cv2
|
||||
# cv2.putText(frame, emotion, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
|
||||
|
||||
# Put the emotion text above the rectangle PIL
|
||||
frame = draw_chinese_text(frame, emotion, (x, y))
|
||||
|
||||
# Display the frame
|
||||
cv2.imshow("Emotion Recognition", frame)
|
||||
|
||||
# Break the loop if 'q' is pressed
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
|
||||
# Release the webcam and destroy all windows
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
|
@ -0,0 +1,187 @@
|
|||
import dlib
|
||||
import numpy as np
|
||||
import scipy.fftpack as fftpack
|
||||
from matplotlib import pyplot as plt
|
||||
from sklearn.decomposition import FastICA
|
||||
import cv2
|
||||
from scipy import signal
|
||||
|
||||
|
||||
class HeartRateMonitor:
|
||||
def __init__(self, fps, freqs_min, freqs_max):
|
||||
self.fps = fps
|
||||
self.freqs_min = freqs_min
|
||||
self.freqs_max = freqs_max
|
||||
self.all_hr_values = []
|
||||
self.bvp_signal = []
|
||||
|
||||
def get_channel_signal(self, ROI):
|
||||
blue = []
|
||||
green = []
|
||||
red = []
|
||||
for roi in ROI:
|
||||
b, g, r = cv2.split(roi)
|
||||
b = np.mean(np.sum(b)) / np.std(b)
|
||||
g = np.mean(np.sum(g)) / np.std(g)
|
||||
r = np.mean(np.sum(r)) / np.std(r)
|
||||
blue.append(b)
|
||||
green.append(g)
|
||||
red.append(r)
|
||||
return blue, green, red
|
||||
|
||||
def ICA(self, matrix, n_component, max_iter=200):
|
||||
matrix = matrix.T
|
||||
ica = FastICA(n_components=n_component, max_iter=max_iter)
|
||||
u = ica.fit_transform(matrix)
|
||||
return u.T
|
||||
|
||||
def fft_filter(self, signal):
|
||||
fft = fftpack.fft(signal, axis=0)
|
||||
frequencies = fftpack.fftfreq(signal.shape[0], d=1.0 / self.fps)
|
||||
bound_low = (np.abs(frequencies - self.freqs_min)).argmin()
|
||||
bound_high = (np.abs(frequencies - self.freqs_max)).argmin()
|
||||
fft[:bound_low] = 0
|
||||
fft[bound_high:-bound_high] = 0
|
||||
fft[-bound_low:] = 0
|
||||
return fft, frequencies
|
||||
|
||||
def find_heart_rate(self, fft, freqs):
|
||||
fft_maximums = []
|
||||
|
||||
for i in range(fft.shape[0]):
|
||||
if self.freqs_min <= freqs[i] <= self.freqs_max:
|
||||
fftMap = abs(fft[i])
|
||||
fft_maximums.append(fftMap.max())
|
||||
else:
|
||||
fft_maximums.append(0)
|
||||
|
||||
peaks, properties = signal.find_peaks(fft_maximums)
|
||||
max_peak = -1
|
||||
max_freq = 0
|
||||
|
||||
for peak in peaks:
|
||||
if fft_maximums[peak] > max_freq:
|
||||
max_freq = fft_maximums[peak]
|
||||
max_peak = peak
|
||||
|
||||
return freqs[max_peak] * 60
|
||||
|
||||
def fourier_transform(self, signal, N, fs):
|
||||
result = fftpack.fft(signal, N)
|
||||
result = np.abs(result)
|
||||
freqs = np.arange(N) / N
|
||||
freqs = freqs * fs
|
||||
return result[:N // 2], freqs[:N // 2]
|
||||
|
||||
def calculate_hrv(self, hr_values, window_size=5):
|
||||
num_values = int(window_size * self.fps)
|
||||
start_idx = max(0, len(hr_values) - num_values)
|
||||
recent_hr_values = hr_values[start_idx:]
|
||||
rr_intervals = np.array(recent_hr_values)
|
||||
|
||||
# 计算SDNN
|
||||
sdnn = np.std(rr_intervals)
|
||||
|
||||
# 计算RMSSD
|
||||
nn_diffs = np.diff(rr_intervals)
|
||||
rmssd = np.sqrt(np.mean(nn_diffs ** 2))
|
||||
|
||||
# 计算CV R-R
|
||||
mean_rr = np.mean(rr_intervals)
|
||||
cv_rr = sdnn / mean_rr if mean_rr != 0 else 0
|
||||
|
||||
return sdnn, rmssd, cv_rr
|
||||
|
||||
def process_roi(self, ROI):
|
||||
blue, green, red = self.get_channel_signal(ROI)
|
||||
matrix = np.array([blue, green, red])
|
||||
component = self.ICA(matrix, 3)
|
||||
|
||||
# 使用三个组件的平均值作为rPPG信号
|
||||
rppg_signal = np.mean(component, axis=0)
|
||||
# 将rPPG信号转换为BVP信号
|
||||
self.bvp_signal = self.rppg_to_bvp(rppg_signal)
|
||||
# self.plot_bvp_signal()
|
||||
hr_values = []
|
||||
for i in range(3):
|
||||
fft, freqs = self.fft_filter(component[i])
|
||||
heartrate = self.find_heart_rate(fft, freqs)
|
||||
hr_values.append(heartrate)
|
||||
avg_hr = sum(hr_values) / 3
|
||||
self.all_hr_values.append(avg_hr)
|
||||
sdnn, rmssd, cv_rr = self.calculate_hrv(self.all_hr_values, window_size=5)
|
||||
return avg_hr, sdnn, rmssd, cv_rr, self.bvp_signal
|
||||
|
||||
def rppg_to_bvp(self, rppg_signal):
|
||||
# 应用带通滤波器
|
||||
nyquist = 0.5 * self.fps
|
||||
low = self.freqs_min / nyquist
|
||||
high = self.freqs_max / nyquist
|
||||
b, a = signal.butter(3, [low, high], btype='band')
|
||||
bvp = signal.filtfilt(b, a, rppg_signal)
|
||||
|
||||
# 标准化
|
||||
bvp = (bvp - np.mean(bvp)) / np.std(bvp)
|
||||
|
||||
return bvp
|
||||
|
||||
def plot_bvp_signal(self):
|
||||
plt.figure(figsize=(12, 4))
|
||||
plt.plot(self.bvp_signal)
|
||||
plt.title('BVP Signal (Average of 3 ICA Components)')
|
||||
plt.xlabel('Samples')
|
||||
plt.ylabel('Amplitude')
|
||||
plt.grid(True)
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
ROI = []
|
||||
|
||||
freqs_min = 0.8
|
||||
freqs_max = 1.8
|
||||
heartrate = 0
|
||||
sdnn, rmssd, cv_rr = 0, 0, 0
|
||||
camera_code = 0
|
||||
capture = cv2.VideoCapture(camera_code)
|
||||
fps = capture.get(cv2.CAP_PROP_FPS)
|
||||
|
||||
hr_monitor = HeartRateMonitor(fps, freqs_min, freqs_max)
|
||||
|
||||
detector = dlib.get_frontal_face_detector()
|
||||
while capture.isOpened():
|
||||
ret, frame = capture.read()
|
||||
if not ret:
|
||||
continue
|
||||
dects = detector(frame)
|
||||
for face in dects:
|
||||
left = face.left()
|
||||
right = face.right()
|
||||
top = face.top()
|
||||
bottom = face.bottom()
|
||||
|
||||
h = bottom - top
|
||||
w = right - left
|
||||
roi = frame[top + h // 10 * 2:top + h // 10 * 7, left + w // 9 * 2:left + w // 9 * 8]
|
||||
|
||||
cv2.rectangle(frame, (left + w // 9 * 2, top + h // 10 * 2), (left + w // 9 * 8, top + h // 10 * 7),
|
||||
color=(0, 0, 255))
|
||||
cv2.rectangle(frame, (left, top), (left + w, top + h), color=(0, 0, 255))
|
||||
ROI.append(roi)
|
||||
if len(ROI) == 300:
|
||||
heartrate, sdnn, rmssd, cv_rr = hr_monitor.process_roi(ROI)
|
||||
for i in range(30):
|
||||
ROI.pop(0)
|
||||
cv2.putText(frame, '{:.1f}bps, CV R-R: {:.2f}'.format(heartrate, cv_rr), (50, 50), cv2.FONT_HERSHEY_SIMPLEX,
|
||||
1.2,
|
||||
(255, 0, 255), 2)
|
||||
cv2.putText(frame, 'SDNN: {:.2f}, RMSSD: {:.2f}'.format(sdnn, rmssd), (50, 80),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 1,
|
||||
(255, 0, 255), 2)
|
||||
cv2.imshow('frame', frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
|
||||
cv2.destroyAllWindows()
|
||||
capture.release()
|
|
@ -0,0 +1,58 @@
|
|||
# Video-based Heart Rate Monitoring
|
||||
|
||||
这个项目是一个基于视频的心率监测系统。它使用计算机视觉技术从人脸视频中提取心率信息。主要功能包括:
|
||||
|
||||
1. 检测人脸区域
|
||||
2. 从人脸区域提取RGB彩色通道信号
|
||||
3. 使用独立分量分析(ICA)从RGB信号中提取心率相关信号
|
||||
4. 使用FFT对信号进行频率分析,找出相应的心率值
|
||||
5. 计算心率变异性(HRV)指标,如SDNN、RMSSD和CV R-R
|
||||
|
||||
## 文件结构
|
||||
|
||||
- `HeartRateMonitor.py`: 实现心率监测算法的核心逻辑,以及算法演示程序。
|
||||
|
||||
## 使用方法
|
||||
|
||||
1. 确保已安装所需的Python库,包括`opencv-python`、`dlib`、`numpy`、`scipy`和`scikit-learn`
|
||||
2. 运行`HeartRateMonitor.py`脚本
|
||||
3. 脚本将打开默认摄像头,检测人脸区域
|
||||
4. 从人脸区域提取RGB彩色通道信号,使用ICA分离出心率信号
|
||||
5. 使用FFT分析心率信号,计算当前心率值
|
||||
6. 同时计算心率变异性指标SDNN、RMSSD和CV R-R
|
||||
7. 在视频画面上显示心率值和HRV指标
|
||||
|
||||
## 算法原理
|
||||
|
||||
### 心率信号提取
|
||||
|
||||
1. 从人脸ROI区域提取RGB三个通道的平均值和标准差
|
||||
2. 将RGB三个通道作为特征矩阵的三行输入ICA算法
|
||||
3. ICA算法将特征矩阵分解为3个独立分量
|
||||
4. 选择其中一个独立分量作为心率信号
|
||||
|
||||
### 心率计算
|
||||
|
||||
1. 对心率信号进行FFT变换得到频率域表示
|
||||
2. 根据设定的有效心率频率范围过滤FFT结果
|
||||
3. 在过滤后的FFT结果中找到最大值对应的频率,即为当前心率值(bpm)
|
||||
|
||||
### 心率变异性指标
|
||||
|
||||
1. 使用滑动窗口从最近的心率值序列中提取一段心率数据
|
||||
2. 计算该段数据的SDNN(标准差)、RMSSD(连续差分平方根值的均值)和CV R-R(R-R间期变异系数)
|
||||
3. 以上三个指标反映了心率的变异程度
|
||||
|
||||
## 参数说明
|
||||
|
||||
- `freqs_min`: 有效心率频率的下限(Hz)
|
||||
- `freqs_max`: 有效心率频率的上限(Hz)
|
||||
- `camera_code`: 使用的摄像头编号,0为默认摄像头
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 算法依赖人脸检测,如果人脸被遮挡或角度过大,将影响心率测量的准确性
|
||||
- 在光照条件较差的环境下,也可能影响测量精度
|
||||
- 目前只支持单个人脸的心率检测,多人情况下需要进一步改进
|
||||
- 算法的鲁棒性还有待提高,在特殊情况下可能会出现失效或测量偏差
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
# Video-based Respiration Rate Detection Algorithm
|
||||
|
||||
该项目是一个基于视频图像的呼吸频率检测算法的实现。它可以从视频中提取人体的呼吸曲线并计算呼吸频率。该算法使用了光流法、相关性引导的光流法、滤波、归一化等技术来提高检测精度。同时,它提供了多种呼吸频率计算方法供选择,包括FFT、Peak Counting、Crossing Point和Negative Feedback Crossover Point等。
|
||||
|
||||
## 文件结构
|
||||
|
||||
- `params.py`: 包含所有可配置的参数及其默认值。
|
||||
- `RespirationRateDetector.py`: 实现了呼吸频率检测算法的核心逻辑。
|
||||
- `demo.py`: 演示程序,从摄像头读取视频流并实时显示呼吸曲线和呼吸频率。
|
||||
|
||||
## 使用方法
|
||||
|
||||
1. 克隆该项目到本地。
|
||||
2. 安装所需的Python依赖包,OpenCV、NumPy、SciPy、Matplotlib。
|
||||
3. 根据需要在`params.py`中调整参数设置。
|
||||
4. 运行`demo.py`启动演示程序。
|
||||
|
||||
程序将打开一个窗口显示从摄像头捕获的视频流,并在另一个窗口中绘制实时呼吸曲线。同时,它还会在视频窗口上显示使用不同方法计算得到的呼吸频率值。
|
||||
|
||||
## 核心算法
|
||||
|
||||
该算法的核心步骤包括:
|
||||
|
||||
1. **光流法**:使用光流法跟踪视频中的特征点,并计算这些特征点的运动幅度和。
|
||||
2. **相关性引导的光流法**:通过计算每个特征点与呼吸曲线的相关性,筛选出与呼吸相关的特征点,以提高检测精度。
|
||||
3. **滤波**:对原始呼吸曲线进行带通滤波,去除高频和低频噪声。
|
||||
4. **归一化**:将滤波后的呼吸曲线进行归一化处理。
|
||||
5. **呼吸频率计算**:使用FFT、Peak Counting、Crossing Point和Negative Feedback Crossover Point等多种方法计算呼吸频率。
|
||||
|
||||
## 参数说明
|
||||
|
||||
`params.py`中包含了该算法的所有可配置参数及其默认值。主要参数包括:
|
||||
|
||||
- `--video-path`: 输入视频文件的路径。默认值为'./1.mp4'。
|
||||
|
||||
- `--FSS`: 是否启用特征点选择策略(Feature Point Selection Strategy)。默认为True。
|
||||
- `--CGOF`: 是否启用相关性引导的光流法(Correlation-Guided Optical Flow Method)。默认为True。
|
||||
- `--filter`: 是否对呼吸曲线进行滤波。默认为True。
|
||||
- `--Normalization`: 是否对呼吸曲线进行归一化。默认为True。
|
||||
- `--RR_Evaluation`: 是否计算呼吸频率。默认为True。
|
||||
|
||||
其他参数控制光流法、特征点选择策略、滤波和呼吸频率计算的具体设置。
|
||||
|
||||
- `--OFP-maxCorners`: 光流法中检测特征点的最大数量。默认为100。
|
||||
- `--OFP-qualityLevel`: 光流法中特征点检测的质量等级。默认为0.1。
|
||||
- `--OFP-minDistance`: 光流法中特征点之间的最小距离。默认为7。
|
||||
- `--OFP-mask`: 光流法中使用的mask,用于指定感兴趣区域。默认为None。
|
||||
- `--OFP-QualityLevelRV`: 当无法检测到足够数量的特征点时,降低质量等级的步长值。默认为0.05。
|
||||
- `--OFP-winSize`: 光流法中金字塔Lucas-Kanade光流估计器的窗口大小。默认为(15,15)。
|
||||
- `--OFP-maxLevel`: 光流法中的金字塔层数。默认为2。
|
||||
|
||||
- `--FSS-switch`: 是否启用特征点选择策略。
|
||||
- `--FSS-maxCorners`: 特征点选择策略中检测特征点的最大数量。默认为100。
|
||||
- `--FSS-qualityLevel`: 特征点选择策略中特征点检测的质量等级。默认为0.1。
|
||||
- `--FSS-minDistance`: 特征点选择策略中特征点之间的最小距离。默认为7。
|
||||
- `--FSS-mask`: 特征点选择策略中使用的mask。默认为None。
|
||||
- `--FSS-QualityLevelRV`: 当无法检测到足够数量的特征点时,降低质量等级的步长值。默认为0.05。
|
||||
- `--FSS-FPN`: 特征点选择策略中要选择的特征点数量。默认为5。
|
||||
|
||||
- `--CGOF-switch`: 是否启用相关性引导的光流法。
|
||||
|
||||
- `--Filter-switch`: 是否对呼吸曲线进行滤波。
|
||||
- `--Filter-type`: 滤波器的类型,可选'lowpass'、'highpass'、'bandpass'和'bandstop'。默认为'bandpass'。
|
||||
- `--Filter-order`: 滤波器的阶数。默认为3。
|
||||
- `--Filter-LowPass`: 带通滤波器的低通频率(次/分钟)。默认为2。
|
||||
- `--Filter-HighPass`: 带通滤波器的高通频率(次/分钟)。默认为40。
|
||||
|
||||
- `--Normalization-switch`: 是否对呼吸曲线进行归一化。
|
||||
|
||||
- `--RR-switch`: 是否计算呼吸频率。
|
||||
|
||||
- `--RR-Algorithm-PC-Height`: Peak Counting算法中使用的峰值高度阈值。默认为None。
|
||||
- `--RR-Algorithm-PC-Threshold`: Peak Counting算法中使用的峰值门限。默认为None。
|
||||
- `--RR-Algorithm-PC-MaxRR`: Peak Counting算法中呼吸频率的最大值(次/分钟)。默认为45。
|
||||
- `--RR-Algorithm-CP-shfit_distance`: Crossing Point算法中使用的移位距离。默认为15。
|
||||
- `--RR-Algorithm-NFCP-shfit_distance`: Negative Feedback Crossover Point算法中使用的移位距离。默认为15。
|
||||
- `--RR-Algorithm-NFCP-qualityLevel`: Negative Feedback Crossover Point算法中使用的质量等级。默认为0.6。
|
|
@ -0,0 +1,233 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
from scipy.fftpack import fft
|
||||
from scipy.signal import find_peaks
|
||||
|
||||
|
||||
class RespirationRateDetector:
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
|
||||
def FeaturePointSelectionStrategy(self, Image, FPN=5, QualityLevel=0.3):
|
||||
Image_gray = Image
|
||||
feature_params = dict(maxCorners=self.args.FSS_maxCorners,
|
||||
qualityLevel=QualityLevel,
|
||||
minDistance=self.args.FSS_minDistance)
|
||||
|
||||
p0 = cv2.goodFeaturesToTrack(Image_gray, mask=self.args.FSS_mask, **feature_params)
|
||||
|
||||
""" Robust checking """
|
||||
while (p0 is None):
|
||||
QualityLevel = QualityLevel - self.args.FSS_QualityLevelRV
|
||||
feature_params = dict(maxCorners=self.args.FSS_maxCorners,
|
||||
qualityLevel=QualityLevel,
|
||||
minDistance=self.args.FSS_minDistance)
|
||||
p0 = cv2.goodFeaturesToTrack(Image_gray, mask=None, **feature_params)
|
||||
|
||||
if len(p0) < FPN:
|
||||
FPN = len(p0)
|
||||
|
||||
h = Image_gray.shape[0] / 2
|
||||
w = Image_gray.shape[1] / 2
|
||||
|
||||
p1 = p0.copy()
|
||||
p1[:, :, 0] -= w
|
||||
p1[:, :, 1] -= h
|
||||
p1_1 = np.multiply(p1, p1)
|
||||
p1_2 = np.sum(p1_1, 2)
|
||||
p1_3 = np.sqrt(p1_2)
|
||||
p1_4 = p1_3[:, 0]
|
||||
p1_5 = np.argsort(p1_4)
|
||||
|
||||
FPMap = np.zeros((FPN, 1, 2), dtype=np.float32)
|
||||
for i in range(FPN):
|
||||
FPMap[i, :, :] = p0[p1_5[i], :, :]
|
||||
|
||||
return FPMap
|
||||
|
||||
def CorrelationGuidedOpticalFlowMethod(self, FeatureMtx_Amp, RespCurve):
|
||||
CGAmp_Mtx = FeatureMtx_Amp.T
|
||||
CGAmpAugmented_Mtx = np.zeros((CGAmp_Mtx.shape[0] + 1, CGAmp_Mtx.shape[1]))
|
||||
CGAmpAugmented_Mtx[0, :] = RespCurve
|
||||
CGAmpAugmented_Mtx[1:, :] = CGAmp_Mtx
|
||||
|
||||
Correlation_Mtx = np.corrcoef(CGAmpAugmented_Mtx)
|
||||
CM_mean = np.mean(abs(Correlation_Mtx[0, 1:]))
|
||||
Quality_num = (abs(Correlation_Mtx[0, 1:]) >= CM_mean).sum()
|
||||
QualityFeaturePoint_arg = (abs(Correlation_Mtx[0, 1:]) >= CM_mean).argsort()[0 - Quality_num:]
|
||||
|
||||
CGOF_Mtx = np.zeros((FeatureMtx_Amp.shape[0], Quality_num))
|
||||
|
||||
for i in range(Quality_num):
|
||||
CGOF_Mtx[:, i] = FeatureMtx_Amp[:, QualityFeaturePoint_arg[i]]
|
||||
|
||||
CGOF_Mtx_RespCurve = np.sum(CGOF_Mtx, 1) / Quality_num
|
||||
|
||||
return CGOF_Mtx_RespCurve
|
||||
|
||||
def ImproveOpticalFlow(self, frames, fs):
|
||||
feature_params = dict(maxCorners=self.args.OFP_maxCorners,
|
||||
qualityLevel=self.args.OFP_qualityLevel,
|
||||
minDistance=self.args.OFP_minDistance)
|
||||
|
||||
old_frame = frames[0]
|
||||
old_gray = cv2.cvtColor(old_frame, cv2.COLOR_BGR2GRAY)
|
||||
p0 = cv2.goodFeaturesToTrack(old_gray, mask=self.args.OFP_mask, **feature_params)
|
||||
|
||||
""" Robust Checking """
|
||||
while (p0 is None):
|
||||
self.args.OFP_qualityLevel = self.args.OFP_qualityLevel - self.args.OFP_QualityLevelRV
|
||||
feature_params = dict(maxCorners=self.args.OFP_maxCorners,
|
||||
qualityLevel=self.args.OFP_qualityLevel,
|
||||
minDistance=self.args.OFP_minDistance)
|
||||
p0 = cv2.goodFeaturesToTrack(old_gray, mask=None, **feature_params)
|
||||
|
||||
""" FeaturePoint Selection Strategy """
|
||||
if self.args.FSS:
|
||||
p0 = self.FeaturePointSelectionStrategy(Image=old_gray, FPN=self.args.FSS_FPN,
|
||||
QualityLevel=self.args.FSS_qualityLevel)
|
||||
else:
|
||||
p0 = cv2.goodFeaturesToTrack(old_gray, mask=None, **feature_params)
|
||||
|
||||
lk_params = dict(winSize=self.args.OFP_winSize, maxLevel=self.args.OFP_maxLevel)
|
||||
total_frame = len(frames)
|
||||
|
||||
FeatureMtx = np.zeros((total_frame, p0.shape[0], 2))
|
||||
FeatureMtx[0, :, 0] = p0[:, 0, 0].T
|
||||
FeatureMtx[0, :, 1] = p0[:, 0, 1].T
|
||||
frame_num = 1
|
||||
|
||||
while (frame_num < total_frame):
|
||||
frame_num += 1
|
||||
frame = frames[frame_num - 1]
|
||||
frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
||||
pl, st, err = cv2.calcOpticalFlowPyrLK(old_gray, frame_gray, p0, None, **lk_params)
|
||||
|
||||
old_gray = frame_gray.copy()
|
||||
p0 = pl.reshape(-1, 1, 2)
|
||||
FeatureMtx[frame_num - 1, :, 0] = p0[:, 0, 0].T
|
||||
FeatureMtx[frame_num - 1, :, 1] = p0[:, 0, 1].T
|
||||
|
||||
FeatureMtx_Amp = np.sqrt(FeatureMtx[:, :, 0] ** 2 + FeatureMtx[:, :, 1] ** 2)
|
||||
RespCurve = np.sum(FeatureMtx_Amp, 1) / p0.shape[0]
|
||||
|
||||
""" CCorrelation-Guided Optical Flow Method """
|
||||
if self.args.CGOF:
|
||||
RespCurve = self.CorrelationGuidedOpticalFlowMethod(FeatureMtx_Amp, RespCurve)
|
||||
|
||||
"""" Filter """
|
||||
if self.args.filter:
|
||||
original_signal = RespCurve
|
||||
#
|
||||
filter_order = self.args.Filter_order
|
||||
LowPass = self.args.Filter_LowPass / 60
|
||||
HighPass = self.args.Filter_HighPass / 60
|
||||
b, a = signal.butter(filter_order, [2 * LowPass / fs, 2 * HighPass / fs], self.args.Filter_type)
|
||||
filtedResp = signal.filtfilt(b, a, original_signal)
|
||||
else:
|
||||
filtedResp = RespCurve
|
||||
|
||||
""" Normalization """
|
||||
if self.args.Normalization:
|
||||
Resp_max = max(filtedResp)
|
||||
Resp_min = min(filtedResp)
|
||||
|
||||
Resp_norm = (filtedResp - Resp_min) / (Resp_max - Resp_min) - 0.5
|
||||
else:
|
||||
Resp_norm = filtedResp
|
||||
|
||||
return 1 - Resp_norm
|
||||
|
||||
def FFT(self, data, fs):
|
||||
fft_y = fft(data)
|
||||
maxFrequency = fs
|
||||
f = np.linspace(0, maxFrequency, len(data))
|
||||
abs_y = np.abs(fft_y)
|
||||
normalization_y = abs_y / len(data)
|
||||
normalization_half_y = normalization_y[range(int(len(data) / 2))]
|
||||
sorted_indices = np.argsort(normalization_half_y)
|
||||
RR = f[sorted_indices[-2]] * 60
|
||||
return RR
|
||||
|
||||
def PeakCounting(self, data, fs, Height=0.1, Threshold=0.2, MaxRR=30):
|
||||
Distance = 60 / MaxRR * fs
|
||||
peaks, _ = find_peaks(data, height=Height, threshold=Threshold, distance=Distance)
|
||||
RR = len(peaks) / (len(data) / fs) * 60
|
||||
return RR
|
||||
|
||||
def CrossingPoint(self, data, fs):
|
||||
shfit_distance = int(fs / 2)
|
||||
data_shift = np.zeros(data.shape) - 1
|
||||
data_shift[shfit_distance:] = data[:-shfit_distance]
|
||||
cross_curve = data - data_shift
|
||||
|
||||
zero_number = 0
|
||||
zero_index = []
|
||||
for i in range(len(cross_curve) - 1):
|
||||
if cross_curve[i] == 0:
|
||||
zero_number += 1
|
||||
zero_index.append(i)
|
||||
else:
|
||||
if cross_curve[i] * cross_curve[i + 1] < 0:
|
||||
zero_number += 1
|
||||
zero_index.append(i)
|
||||
|
||||
cw = zero_number
|
||||
N = len(data)
|
||||
RR1 = ((cw / 2) / (N / fs)) * 60
|
||||
|
||||
return RR1
|
||||
|
||||
def NegativeFeedbackCrossoverPointMethod(self, data, fs, QualityLevel=0.2):
|
||||
shfit_distance = int(fs / 2)
|
||||
data_shift = np.zeros(data.shape) - 1
|
||||
data_shift[shfit_distance:] = data[:-shfit_distance]
|
||||
cross_curve = data - data_shift
|
||||
|
||||
zero_number = 0
|
||||
zero_index = []
|
||||
for i in range(len(cross_curve) - 1):
|
||||
if cross_curve[i] == 0:
|
||||
zero_number += 1
|
||||
zero_index.append(i)
|
||||
else:
|
||||
if cross_curve[i] * cross_curve[i + 1] < 0:
|
||||
zero_number += 1
|
||||
zero_index.append(i)
|
||||
|
||||
cw = zero_number
|
||||
N = len(data)
|
||||
RR1 = ((cw / 2) / (N / fs)) * 60
|
||||
|
||||
if (len(zero_index) <= 1):
|
||||
RR2 = RR1
|
||||
else:
|
||||
time_span = 60 / RR1 / 2 * fs * QualityLevel
|
||||
zero_span = []
|
||||
for i in range(len(zero_index) - 1):
|
||||
zero_span.append(zero_index[i + 1] - zero_index[i])
|
||||
|
||||
while (min(zero_span) < time_span):
|
||||
doubt_point = np.argmin(zero_span)
|
||||
zero_index.pop(doubt_point)
|
||||
zero_index.pop(doubt_point)
|
||||
if len(zero_index) <= 1:
|
||||
break
|
||||
zero_span = []
|
||||
for i in range(len(zero_index) - 1):
|
||||
zero_span.append(zero_index[i + 1] - zero_index[i])
|
||||
|
||||
zero_number = len(zero_index)
|
||||
cw = zero_number
|
||||
RR2 = ((cw / 2) / (N / fs)) * 60
|
||||
|
||||
return RR2
|
||||
|
||||
def detect_respiration_rate(self, frames, fs):
|
||||
resp_curve = self.ImproveOpticalFlow(frames, fs)
|
||||
RR_FFT = self.FFT(resp_curve, fs)
|
||||
RR_PC = self.PeakCounting(resp_curve, fs)
|
||||
RR_CP = self.CrossingPoint(resp_curve, fs)
|
||||
RR_NFCP = self.NegativeFeedbackCrossoverPointMethod(resp_curve, fs)
|
||||
return resp_curve, RR_FFT, RR_PC, RR_CP, RR_NFCP
|
|
@ -0,0 +1,104 @@
|
|||
import queue
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from RespirationRateDetector import RespirationRateDetector
|
||||
from params import args
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def main():
|
||||
cap = cv2.VideoCapture(0) # 使用摄像头
|
||||
video_fs = cap.get(5)
|
||||
|
||||
detector = RespirationRateDetector(args)
|
||||
|
||||
frames = []
|
||||
|
||||
text = ["calculating..."]
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
# face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
|
||||
|
||||
resps = queue.Queue()
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
line, = ax.plot([], [])
|
||||
plt.ion()
|
||||
plt.show()
|
||||
|
||||
xdata = []
|
||||
ydata = []
|
||||
|
||||
last = 0
|
||||
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
|
||||
frames.append(frame)
|
||||
|
||||
if len(frames) == 300:
|
||||
|
||||
Resp, RR_FFT, RR_PC, RR_CP, RR_NFCP = detector.detect_respiration_rate(frames, video_fs)
|
||||
|
||||
Resp[0] = last
|
||||
|
||||
for res in Resp:
|
||||
resps.put(res)
|
||||
|
||||
last = Resp[-1]
|
||||
|
||||
text.clear()
|
||||
text.append('RR-FFT: {:.2f} bpm'.format(RR_FFT))
|
||||
text.append('RR-PC: {:.2f} bpm'.format(RR_PC))
|
||||
text.append('RR-CP: {:.2f} bpm'.format(RR_CP))
|
||||
text.append('RR-NFCP: {:.2f} bpm'.format(RR_NFCP))
|
||||
frames = []
|
||||
# 去除列表最前面的100个元素
|
||||
# frames=frames[50:]
|
||||
|
||||
if not resps.empty():
|
||||
|
||||
resp = resps.get()
|
||||
# 更新线的数据
|
||||
ydata.append(resp)
|
||||
|
||||
else:
|
||||
ydata.append(0)
|
||||
|
||||
if len(xdata) == 0:
|
||||
xdata.append(1)
|
||||
else:
|
||||
xdata.append(xdata[-1] + 1)
|
||||
|
||||
if len(xdata) > 600:
|
||||
xdata.pop(0)
|
||||
ydata.pop(0)
|
||||
|
||||
# 生成时间序列
|
||||
t = np.linspace(xdata[0] / video_fs, xdata[-1] / video_fs, len(ydata))
|
||||
|
||||
line.set_data(t, ydata) # 使用时间序列作为x轴
|
||||
|
||||
# 更新坐标轴的范围
|
||||
ax.set_xlim(t[0], t[-1])
|
||||
|
||||
ax.set_ylim(min(0, min(ydata)) - 0.5 * abs(min(ydata)), 1.5 * max(ydata))
|
||||
# 更新图表的显示
|
||||
plt.draw()
|
||||
plt.pause(0.01)
|
||||
|
||||
for i, t in enumerate(text):
|
||||
cv2.putText(frame, t, (10, 60 + i * 20), font, 0.6, (0, 255, 0), 2)
|
||||
cv2.imshow('Respiration Rate Detection', frame)
|
||||
key = cv2.waitKey(1) & 0xFF
|
||||
if key == ord('q'):
|
||||
break
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
plt.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,55 @@
|
|||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser('Lightweight Video-based Respiration Rate Detection Algorithm script', add_help=False)
|
||||
parser.add_argument('--video-path', default='./1.mp4', help='Video input path')
|
||||
|
||||
parser.add_argument('--FSS', default=True, type=bool, help='')
|
||||
parser.add_argument('--CGOF', default=True, type=bool, help='')
|
||||
parser.add_argument('--filter', default=True, type=bool, help='')
|
||||
parser.add_argument('--Normalization', default=True, type=bool, help='')
|
||||
parser.add_argument('--RR_Evaluation', default=True, type=bool, help='')
|
||||
|
||||
# # Optical flow parameters
|
||||
parser.add_argument('--OFP-maxCorners', default=100, type=int, help='')
|
||||
parser.add_argument('--OFP-qualityLevel', default=0.1, type=float, help='')
|
||||
parser.add_argument('--OFP-minDistance', default=7, type=int, help='')
|
||||
parser.add_argument('--OFP-mask', default=None, help='')
|
||||
parser.add_argument('--OFP-QualityLevelRV', default=0.05, type=float, help='QualityLeve reduction value')
|
||||
parser.add_argument('--OFP-winSize', default=(15, 15), help='')
|
||||
parser.add_argument('--OFP-maxLevel', default=2, type=int, help='')
|
||||
|
||||
# # FeaturePoint Selection Strategy parameters
|
||||
parser.add_argument('--FSS-switch', action='store_true', dest='FSS_switch')
|
||||
parser.add_argument('--FSS-maxCorners', default=100, type=int, help='')
|
||||
parser.add_argument('--FSS-qualityLevel', default=0.1, type=float, help='')
|
||||
parser.add_argument('--FSS-minDistance', default=7, type=int, help='')
|
||||
parser.add_argument('--FSS-mask', default=None, help='')
|
||||
parser.add_argument('--FSS-QualityLevelRV', default=0.05, type=float, help='QualityLeve reduction value')
|
||||
parser.add_argument('--FSS-FPN', default=5, type=int,
|
||||
help='The number of feature points for the feature point selection strategy')
|
||||
|
||||
# # CCorrelation-Guided Optical Flow Method parameters
|
||||
parser.add_argument('--CGOF-switch', action='store_true', dest='CGOF_switch')
|
||||
|
||||
# # Filter parameters
|
||||
parser.add_argument('--Filter-switch', action='store_true', dest='Filter_switch')
|
||||
parser.add_argument('--Filter-type', default='bandpass', help='')
|
||||
parser.add_argument('--Filter-order', default=3, type=int, help='')
|
||||
parser.add_argument('--Filter-LowPass', default=2, type=int, help='')
|
||||
parser.add_argument('--Filter-HighPass', default=40, type=int, help='')
|
||||
|
||||
# # Normalization parameters
|
||||
parser.add_argument('--Normalization-switch', action='store_true', dest='Normalization_switch')
|
||||
|
||||
# # RR Evaluation parameters
|
||||
parser.add_argument('--RR-switch', action='store_true', dest='RR_switch')
|
||||
|
||||
# # RR Algorithm parameters
|
||||
parser.add_argument('--RR-Algorithm-PC-Height', default=None, help='')
|
||||
parser.add_argument('--RR-Algorithm-PC-Threshold', default=None, help='')
|
||||
parser.add_argument('--RR-Algorithm-PC-MaxRR', default=45, type=int, help='')
|
||||
parser.add_argument('--RR-Algorithm-CP-shfit_distance', default=15, type=int, help='')
|
||||
parser.add_argument('--RR-Algorithm-NFCP-shfit_distance', default=15, type=int, help='')
|
||||
parser.add_argument('--RR-Algorithm-NFCP-qualityLevel', default=0.6, type=float, help='')
|
||||
|
||||
args = parser.parse_args()
|
|
@ -0,0 +1,34 @@
|
|||
# 基于视觉的皮肤病检测系统
|
||||
|
||||
该项目是一个基于图像的皮肤病检测系统。它使用MobileViT在皮肤图像数据集上进行训练,然后可以从摄像头输入的视频中检测人脸,并为每个检测到的人脸预测皮肤病类型,共支持24类。
|
||||
|
||||
## 核心文件
|
||||
|
||||
- `class_indices.json`: 包含皮肤病类型标签和对应数值编码的映射。
|
||||
- `predict_api.py`: 包含图像预测模型的加载、预处理和推理逻辑。
|
||||
- `video.py`: 视频处理和可视化的主要脚本。
|
||||
- `best300_model_0.7302241690286009.pth`: 训练的模型权重文件。
|
||||
|
||||
## 使用方法
|
||||
|
||||
1. 确保已安装所需的Python库,包括`opencv-python`、`torch`、`torchvision`、`Pillow`和`dlib`。
|
||||
2. 运行`video.py`脚本。
|
||||
3. 脚本将打开默认摄像头,开始人脸检测和皮肤病预测。
|
||||
4. 检测到的人脸周围会用矩形框标注,并显示预测的皮肤病类型和置信度分数。
|
||||
5. 按`q`键退出程序。
|
||||
|
||||
## 模型介绍
|
||||
|
||||
该项目使用MobileViT作为基础模型,对皮肤病图像数据集进行训练,以预测人脸图像的皮肤类型。模型输出包含24个值,分别对应各皮肤病类型的概率。
|
||||
|
||||
### 数据集介绍
|
||||
|
||||
该项目使用的皮肤病图像数据集来自网络开源数据,数据集包含20000张标注了皮肤病类型的人体皮肤图像。
|
||||
|
||||
## 算法流程
|
||||
|
||||
1. **人脸检测**: 使用Dlib库中的预训练人脸检测器在视频帧中检测人脸。
|
||||
2. **预处理**: 对检测到的人脸图像进行缩放、裁剪和标准化等预处理,以满足模型的输入要求。
|
||||
3. **推理**: 将预处理后的图像输入到预训练的MobileViT模型中,获得不同皮肤病类型的概率预测结果。
|
||||
4. **后处理**: 选取概率最高的类别作为最终预测结果。
|
||||
5. **可视化**: 在视频帧上绘制人脸矩形框,并显示预测的皮肤病类型和置信度分数。
|
|
@ -0,0 +1,28 @@
|
|||
{
|
||||
"0": "痤疮或酒渣鼻",
|
||||
"1": "光化性角化病基底细胞癌或其他恶性病变",
|
||||
"2": "过敏性皮炎",
|
||||
"3": "大疱性疾病",
|
||||
"4": "蜂窝织炎、脓疱病或其他细菌感染",
|
||||
"5": "湿疹",
|
||||
"6": "皮疹或药疹",
|
||||
"7": "脱发或其他头发疾病",
|
||||
"8": "健康",
|
||||
"9": "疱疹、HPV或其他性病",
|
||||
"10": "轻度疾病和色素沉着障碍",
|
||||
"11": "狼疮或其他结缔组织疾病",
|
||||
"12": "黑色素瘤皮肤癌痣或痣",
|
||||
"13": "指甲真菌或其他指甲疾病",
|
||||
"14": "毒藤或其他接触性皮炎",
|
||||
"15": "牛皮癣、扁平苔藓或相关疾病",
|
||||
"16": "疥疮、莱姆病或其他感染和叮咬",
|
||||
"17": "脂溢性角化病或其他良性肿瘤",
|
||||
"18": "全身性疾病",
|
||||
"19": "癣念珠菌病或其他真菌感染",
|
||||
"20": "荨麻疹",
|
||||
"21": "血管肿瘤",
|
||||
"22": "血管炎",
|
||||
"23": "疣、软疣或其他病毒感染"
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,562 @@
|
|||
"""
|
||||
original code from apple:
|
||||
https://github.com/apple/ml-cvnets/blob/main/cvnets/models/classification/mobilevit.py
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple, Union, Dict
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .transformer import TransformerEncoder
|
||||
from .model_config import get_config
|
||||
|
||||
|
||||
def make_divisible(
|
||||
v: Union[float, int],
|
||||
divisor: Optional[int] = 8,
|
||||
min_value: Optional[Union[float, int]] = None,
|
||||
) -> Union[float, int]:
|
||||
"""
|
||||
This function is taken from the original tf repo.
|
||||
It ensures that all layers have a channel number that is divisible by 8
|
||||
It can be seen here:
|
||||
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||
:param v:
|
||||
:param divisor:
|
||||
:param min_value:
|
||||
:return:
|
||||
"""
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class ConvLayer(nn.Module):
|
||||
"""
|
||||
Applies a 2D convolution over an input
|
||||
|
||||
Args:
|
||||
in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||
out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out})`
|
||||
kernel_size (Union[int, Tuple[int, int]]): Kernel size for convolution.
|
||||
stride (Union[int, Tuple[int, int]]): Stride for convolution. Default: 1
|
||||
groups (Optional[int]): Number of groups in convolution. Default: 1
|
||||
bias (Optional[bool]): Use bias. Default: ``False``
|
||||
use_norm (Optional[bool]): Use normalization layer after convolution. Default: ``True``
|
||||
use_act (Optional[bool]): Use activation layer after convolution (or convolution and normalization).
|
||||
Default: ``True``
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||
- Output: :math:`(N, C_{out}, H_{out}, W_{out})`
|
||||
|
||||
.. note::
|
||||
For depth-wise convolution, `groups=C_{in}=C_{out}`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Optional[Union[int, Tuple[int, int]]] = 1,
|
||||
groups: Optional[int] = 1,
|
||||
bias: Optional[bool] = False,
|
||||
use_norm: Optional[bool] = True,
|
||||
use_act: Optional[bool] = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride)
|
||||
|
||||
assert isinstance(kernel_size, Tuple)
|
||||
assert isinstance(stride, Tuple)
|
||||
|
||||
padding = (
|
||||
int((kernel_size[0] - 1) / 2),
|
||||
int((kernel_size[1] - 1) / 2),
|
||||
)
|
||||
|
||||
block = nn.Sequential()
|
||||
|
||||
conv_layer = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
groups=groups,
|
||||
padding=padding,
|
||||
bias=bias
|
||||
)
|
||||
|
||||
block.add_module(name="conv", module=conv_layer)
|
||||
|
||||
if use_norm:
|
||||
norm_layer = nn.BatchNorm2d(num_features=out_channels, momentum=0.1)
|
||||
block.add_module(name="norm", module=norm_layer)
|
||||
|
||||
if use_act:
|
||||
act_layer = nn.SiLU()
|
||||
block.add_module(name="act", module=act_layer)
|
||||
|
||||
self.block = block
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
"""
|
||||
This class implements the inverted residual block, as described in `MobileNetv2 <https://arxiv.org/abs/1801.04381>`_ paper
|
||||
|
||||
Args:
|
||||
in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||
out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out)`
|
||||
stride (int): Use convolutions with a stride. Default: 1
|
||||
expand_ratio (Union[int, float]): Expand the input channels by this factor in depth-wise conv
|
||||
skip_connection (Optional[bool]): Use skip-connection. Default: True
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||
- Output: :math:`(N, C_{out}, H_{out}, W_{out})`
|
||||
|
||||
.. note::
|
||||
If `in_channels =! out_channels` and `stride > 1`, we set `skip_connection=False`
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: int,
|
||||
expand_ratio: Union[int, float],
|
||||
skip_connection: Optional[bool] = True,
|
||||
) -> None:
|
||||
assert stride in [1, 2]
|
||||
hidden_dim = make_divisible(int(round(in_channels * expand_ratio)), 8)
|
||||
|
||||
super().__init__()
|
||||
|
||||
block = nn.Sequential()
|
||||
if expand_ratio != 1:
|
||||
block.add_module(
|
||||
name="exp_1x1",
|
||||
module=ConvLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=hidden_dim,
|
||||
kernel_size=1
|
||||
),
|
||||
)
|
||||
|
||||
block.add_module(
|
||||
name="conv_3x3",
|
||||
module=ConvLayer(
|
||||
in_channels=hidden_dim,
|
||||
out_channels=hidden_dim,
|
||||
stride=stride,
|
||||
kernel_size=3,
|
||||
groups=hidden_dim
|
||||
),
|
||||
)
|
||||
|
||||
block.add_module(
|
||||
name="red_1x1",
|
||||
module=ConvLayer(
|
||||
in_channels=hidden_dim,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
use_act=False,
|
||||
use_norm=True,
|
||||
),
|
||||
)
|
||||
|
||||
self.block = block
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.exp = expand_ratio
|
||||
self.stride = stride
|
||||
self.use_res_connect = (
|
||||
self.stride == 1 and in_channels == out_channels and skip_connection
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
|
||||
if self.use_res_connect:
|
||||
return x + self.block(x)
|
||||
else:
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class MobileViTBlock(nn.Module):
|
||||
"""
|
||||
This class defines the `MobileViT block <https://arxiv.org/abs/2110.02178?context=cs.LG>`_
|
||||
|
||||
Args:
|
||||
opts: command line arguments
|
||||
in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H, W)`
|
||||
transformer_dim (int): Input dimension to the transformer unit
|
||||
ffn_dim (int): Dimension of the FFN block
|
||||
n_transformer_blocks (int): Number of transformer blocks. Default: 2
|
||||
head_dim (int): Head dimension in the multi-head attention. Default: 32
|
||||
attn_dropout (float): Dropout in multi-head attention. Default: 0.0
|
||||
dropout (float): Dropout rate. Default: 0.0
|
||||
ffn_dropout (float): Dropout between FFN layers in transformer. Default: 0.0
|
||||
patch_h (int): Patch height for unfolding operation. Default: 8
|
||||
patch_w (int): Patch width for unfolding operation. Default: 8
|
||||
transformer_norm_layer (Optional[str]): Normalization layer in the transformer block. Default: layer_norm
|
||||
conv_ksize (int): Kernel size to learn local representations in MobileViT block. Default: 3
|
||||
no_fusion (Optional[bool]): Do not combine the input and output feature maps. Default: False
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
transformer_dim: int,
|
||||
ffn_dim: int,
|
||||
n_transformer_blocks: int = 2,
|
||||
head_dim: int = 32,
|
||||
attn_dropout: float = 0.0,
|
||||
dropout: float = 0.0,
|
||||
ffn_dropout: float = 0.0,
|
||||
patch_h: int = 8,
|
||||
patch_w: int = 8,
|
||||
conv_ksize: Optional[int] = 3,
|
||||
*args,
|
||||
**kwargs
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
conv_3x3_in = ConvLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=conv_ksize,
|
||||
stride=1
|
||||
)
|
||||
conv_1x1_in = ConvLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=transformer_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
use_norm=False,
|
||||
use_act=False
|
||||
)
|
||||
|
||||
conv_1x1_out = ConvLayer(
|
||||
in_channels=transformer_dim,
|
||||
out_channels=in_channels,
|
||||
kernel_size=1,
|
||||
stride=1
|
||||
)
|
||||
conv_3x3_out = ConvLayer(
|
||||
in_channels=2 * in_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=conv_ksize,
|
||||
stride=1
|
||||
)
|
||||
|
||||
self.local_rep = nn.Sequential()
|
||||
self.local_rep.add_module(name="conv_3x3", module=conv_3x3_in)
|
||||
self.local_rep.add_module(name="conv_1x1", module=conv_1x1_in)
|
||||
|
||||
assert transformer_dim % head_dim == 0
|
||||
num_heads = transformer_dim // head_dim
|
||||
|
||||
global_rep = [
|
||||
TransformerEncoder(
|
||||
embed_dim=transformer_dim,
|
||||
ffn_latent_dim=ffn_dim,
|
||||
num_heads=num_heads,
|
||||
attn_dropout=attn_dropout,
|
||||
dropout=dropout,
|
||||
ffn_dropout=ffn_dropout
|
||||
)
|
||||
for _ in range(n_transformer_blocks)
|
||||
]
|
||||
global_rep.append(nn.LayerNorm(transformer_dim))
|
||||
self.global_rep = nn.Sequential(*global_rep)
|
||||
|
||||
self.conv_proj = conv_1x1_out
|
||||
self.fusion = conv_3x3_out
|
||||
|
||||
self.patch_h = patch_h
|
||||
self.patch_w = patch_w
|
||||
self.patch_area = self.patch_w * self.patch_h
|
||||
|
||||
self.cnn_in_dim = in_channels
|
||||
self.cnn_out_dim = transformer_dim
|
||||
self.n_heads = num_heads
|
||||
self.ffn_dim = ffn_dim
|
||||
self.dropout = dropout
|
||||
self.attn_dropout = attn_dropout
|
||||
self.ffn_dropout = ffn_dropout
|
||||
self.n_blocks = n_transformer_blocks
|
||||
self.conv_ksize = conv_ksize
|
||||
|
||||
def unfolding(self, x: Tensor) -> Tuple[Tensor, Dict]:
|
||||
patch_w, patch_h = self.patch_w, self.patch_h
|
||||
patch_area = patch_w * patch_h
|
||||
batch_size, in_channels, orig_h, orig_w = x.shape
|
||||
|
||||
new_h = int(math.ceil(orig_h / self.patch_h) * self.patch_h)
|
||||
new_w = int(math.ceil(orig_w / self.patch_w) * self.patch_w)
|
||||
|
||||
interpolate = False
|
||||
if new_w != orig_w or new_h != orig_h:
|
||||
# Note: Padding can be done, but then it needs to be handled in attention function.
|
||||
x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=False)
|
||||
interpolate = True
|
||||
|
||||
# number of patches along width and height
|
||||
num_patch_w = new_w // patch_w # n_w
|
||||
num_patch_h = new_h // patch_h # n_h
|
||||
num_patches = num_patch_h * num_patch_w # N
|
||||
|
||||
# [B, C, H, W] -> [B * C * n_h, p_h, n_w, p_w]
|
||||
x = x.reshape(batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w)
|
||||
# [B * C * n_h, p_h, n_w, p_w] -> [B * C * n_h, n_w, p_h, p_w]
|
||||
x = x.transpose(1, 2)
|
||||
# [B * C * n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
|
||||
x = x.reshape(batch_size, in_channels, num_patches, patch_area)
|
||||
# [B, C, N, P] -> [B, P, N, C]
|
||||
x = x.transpose(1, 3)
|
||||
# [B, P, N, C] -> [BP, N, C]
|
||||
x = x.reshape(batch_size * patch_area, num_patches, -1)
|
||||
|
||||
info_dict = {
|
||||
"orig_size": (orig_h, orig_w),
|
||||
"batch_size": batch_size,
|
||||
"interpolate": interpolate,
|
||||
"total_patches": num_patches,
|
||||
"num_patches_w": num_patch_w,
|
||||
"num_patches_h": num_patch_h,
|
||||
}
|
||||
|
||||
return x, info_dict
|
||||
|
||||
def folding(self, x: Tensor, info_dict: Dict) -> Tensor:
|
||||
n_dim = x.dim()
|
||||
assert n_dim == 3, "Tensor should be of shape BPxNxC. Got: {}".format(
|
||||
x.shape
|
||||
)
|
||||
# [BP, N, C] --> [B, P, N, C]
|
||||
x = x.contiguous().view(
|
||||
info_dict["batch_size"], self.patch_area, info_dict["total_patches"], -1
|
||||
)
|
||||
|
||||
batch_size, pixels, num_patches, channels = x.size()
|
||||
num_patch_h = info_dict["num_patches_h"]
|
||||
num_patch_w = info_dict["num_patches_w"]
|
||||
|
||||
# [B, P, N, C] -> [B, C, N, P]
|
||||
x = x.transpose(1, 3)
|
||||
# [B, C, N, P] -> [B*C*n_h, n_w, p_h, p_w]
|
||||
x = x.reshape(batch_size * channels * num_patch_h, num_patch_w, self.patch_h, self.patch_w)
|
||||
# [B*C*n_h, n_w, p_h, p_w] -> [B*C*n_h, p_h, n_w, p_w]
|
||||
x = x.transpose(1, 2)
|
||||
# [B*C*n_h, p_h, n_w, p_w] -> [B, C, H, W]
|
||||
x = x.reshape(batch_size, channels, num_patch_h * self.patch_h, num_patch_w * self.patch_w)
|
||||
if info_dict["interpolate"]:
|
||||
x = F.interpolate(
|
||||
x,
|
||||
size=info_dict["orig_size"],
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
return x
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
res = x
|
||||
|
||||
fm = self.local_rep(x)
|
||||
|
||||
# convert feature map to patches
|
||||
patches, info_dict = self.unfolding(fm)
|
||||
|
||||
# learn global representations
|
||||
for transformer_layer in self.global_rep:
|
||||
patches = transformer_layer(patches)
|
||||
|
||||
# [B x Patch x Patches x C] -> [B x C x Patches x Patch]
|
||||
fm = self.folding(x=patches, info_dict=info_dict)
|
||||
|
||||
fm = self.conv_proj(fm)
|
||||
|
||||
fm = self.fusion(torch.cat((res, fm), dim=1))
|
||||
return fm
|
||||
|
||||
|
||||
class MobileViT(nn.Module):
|
||||
"""
|
||||
This class implements the `MobileViT architecture <https://arxiv.org/abs/2110.02178?context=cs.LG>`_
|
||||
"""
|
||||
def __init__(self, model_cfg: Dict, num_classes: int = 1000):
|
||||
super().__init__()
|
||||
|
||||
image_channels = 3
|
||||
out_channels = 16
|
||||
|
||||
self.conv_1 = ConvLayer(
|
||||
in_channels=image_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2
|
||||
)
|
||||
|
||||
self.layer_1, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer1"])
|
||||
self.layer_2, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer2"])
|
||||
self.layer_3, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer3"])
|
||||
self.layer_4, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer4"])
|
||||
self.layer_5, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer5"])
|
||||
|
||||
exp_channels = min(model_cfg["last_layer_exp_factor"] * out_channels, 960)
|
||||
self.conv_1x1_exp = ConvLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=exp_channels,
|
||||
kernel_size=1
|
||||
)
|
||||
|
||||
self.classifier = nn.Sequential()
|
||||
self.classifier.add_module(name="global_pool", module=nn.AdaptiveAvgPool2d(1))
|
||||
self.classifier.add_module(name="flatten", module=nn.Flatten())
|
||||
if 0.0 < model_cfg["cls_dropout"] < 1.0:
|
||||
self.classifier.add_module(name="dropout", module=nn.Dropout(p=model_cfg["cls_dropout"]))
|
||||
self.classifier.add_module(name="fc", module=nn.Linear(in_features=exp_channels, out_features=num_classes))
|
||||
|
||||
# weight init
|
||||
self.apply(self.init_parameters)
|
||||
|
||||
def _make_layer(self, input_channel, cfg: Dict) -> Tuple[nn.Sequential, int]:
|
||||
block_type = cfg.get("block_type", "mobilevit")
|
||||
if block_type.lower() == "mobilevit":
|
||||
return self._make_mit_layer(input_channel=input_channel, cfg=cfg)
|
||||
else:
|
||||
return self._make_mobilenet_layer(input_channel=input_channel, cfg=cfg)
|
||||
|
||||
@staticmethod
|
||||
def _make_mobilenet_layer(input_channel: int, cfg: Dict) -> Tuple[nn.Sequential, int]:
|
||||
output_channels = cfg.get("out_channels")
|
||||
num_blocks = cfg.get("num_blocks", 2)
|
||||
expand_ratio = cfg.get("expand_ratio", 4)
|
||||
block = []
|
||||
|
||||
for i in range(num_blocks):
|
||||
stride = cfg.get("stride", 1) if i == 0 else 1
|
||||
|
||||
layer = InvertedResidual(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channels,
|
||||
stride=stride,
|
||||
expand_ratio=expand_ratio
|
||||
)
|
||||
block.append(layer)
|
||||
input_channel = output_channels
|
||||
|
||||
return nn.Sequential(*block), input_channel
|
||||
|
||||
@staticmethod
|
||||
def _make_mit_layer(input_channel: int, cfg: Dict) -> [nn.Sequential, int]:
|
||||
stride = cfg.get("stride", 1)
|
||||
block = []
|
||||
|
||||
if stride == 2:
|
||||
layer = InvertedResidual(
|
||||
in_channels=input_channel,
|
||||
out_channels=cfg.get("out_channels"),
|
||||
stride=stride,
|
||||
expand_ratio=cfg.get("mv_expand_ratio", 4)
|
||||
)
|
||||
|
||||
block.append(layer)
|
||||
input_channel = cfg.get("out_channels")
|
||||
|
||||
transformer_dim = cfg["transformer_channels"]
|
||||
ffn_dim = cfg.get("ffn_dim")
|
||||
num_heads = cfg.get("num_heads", 4)
|
||||
head_dim = transformer_dim // num_heads
|
||||
|
||||
if transformer_dim % head_dim != 0:
|
||||
raise ValueError("Transformer input dimension should be divisible by head dimension. "
|
||||
"Got {} and {}.".format(transformer_dim, head_dim))
|
||||
|
||||
block.append(MobileViTBlock(
|
||||
in_channels=input_channel,
|
||||
transformer_dim=transformer_dim,
|
||||
ffn_dim=ffn_dim,
|
||||
n_transformer_blocks=cfg.get("transformer_blocks", 1),
|
||||
patch_h=cfg.get("patch_h", 2),
|
||||
patch_w=cfg.get("patch_w", 2),
|
||||
dropout=cfg.get("dropout", 0.1),
|
||||
ffn_dropout=cfg.get("ffn_dropout", 0.0),
|
||||
attn_dropout=cfg.get("attn_dropout", 0.1),
|
||||
head_dim=head_dim,
|
||||
conv_ksize=3
|
||||
))
|
||||
|
||||
return nn.Sequential(*block), input_channel
|
||||
|
||||
@staticmethod
|
||||
def init_parameters(m):
|
||||
if isinstance(m, nn.Conv2d):
|
||||
if m.weight is not None:
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
|
||||
if m.weight is not None:
|
||||
nn.init.ones_(m.weight)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, (nn.Linear,)):
|
||||
if m.weight is not None:
|
||||
nn.init.trunc_normal_(m.weight, mean=0.0, std=0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
else:
|
||||
pass
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.conv_1(x)
|
||||
x = self.layer_1(x)
|
||||
x = self.layer_2(x)
|
||||
|
||||
x = self.layer_3(x)
|
||||
x = self.layer_4(x)
|
||||
x = self.layer_5(x)
|
||||
x = self.conv_1x1_exp(x)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
def mobile_vit_xx_small(num_classes: int = 1000):
|
||||
# pretrain weight link
|
||||
# https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_xxs.pt
|
||||
config = get_config("xx_small")
|
||||
m = MobileViT(config, num_classes=num_classes)
|
||||
return m
|
||||
|
||||
|
||||
def mobile_vit_x_small(num_classes: int = 1000):
|
||||
# pretrain weight link
|
||||
# https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_xs.pt
|
||||
config = get_config("x_small")
|
||||
m = MobileViT(config, num_classes=num_classes)
|
||||
return m
|
||||
|
||||
|
||||
def mobile_vit_small(num_classes: int = 1000):
|
||||
# pretrain weight link
|
||||
# https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_s.pt
|
||||
config = get_config("small")
|
||||
m = MobileViT(config, num_classes=num_classes)
|
||||
return m
|
|
@ -0,0 +1,176 @@
|
|||
def get_config(mode: str = "xxs") -> dict:
|
||||
if mode == "xx_small":
|
||||
mv2_exp_mult = 2
|
||||
config = {
|
||||
"layer1": {
|
||||
"out_channels": 16,
|
||||
"expand_ratio": mv2_exp_mult,
|
||||
"num_blocks": 1,
|
||||
"stride": 1,
|
||||
"block_type": "mv2",
|
||||
},
|
||||
"layer2": {
|
||||
"out_channels": 24,
|
||||
"expand_ratio": mv2_exp_mult,
|
||||
"num_blocks": 3,
|
||||
"stride": 2,
|
||||
"block_type": "mv2",
|
||||
},
|
||||
"layer3": { # 28x28
|
||||
"out_channels": 48,
|
||||
"transformer_channels": 64,
|
||||
"ffn_dim": 128,
|
||||
"transformer_blocks": 2,
|
||||
"patch_h": 2, # 8,
|
||||
"patch_w": 2, # 8,
|
||||
"stride": 2,
|
||||
"mv_expand_ratio": mv2_exp_mult,
|
||||
"num_heads": 4,
|
||||
"block_type": "mobilevit",
|
||||
},
|
||||
"layer4": { # 14x14
|
||||
"out_channels": 64,
|
||||
"transformer_channels": 80,
|
||||
"ffn_dim": 160,
|
||||
"transformer_blocks": 4,
|
||||
"patch_h": 2, # 4,
|
||||
"patch_w": 2, # 4,
|
||||
"stride": 2,
|
||||
"mv_expand_ratio": mv2_exp_mult,
|
||||
"num_heads": 4,
|
||||
"block_type": "mobilevit",
|
||||
},
|
||||
"layer5": { # 7x7
|
||||
"out_channels": 80,
|
||||
"transformer_channels": 96,
|
||||
"ffn_dim": 192,
|
||||
"transformer_blocks": 3,
|
||||
"patch_h": 2,
|
||||
"patch_w": 2,
|
||||
"stride": 2,
|
||||
"mv_expand_ratio": mv2_exp_mult,
|
||||
"num_heads": 4,
|
||||
"block_type": "mobilevit",
|
||||
},
|
||||
"last_layer_exp_factor": 4,
|
||||
"cls_dropout": 0.1
|
||||
}
|
||||
elif mode == "x_small":
|
||||
mv2_exp_mult = 4
|
||||
config = {
|
||||
"layer1": {
|
||||
"out_channels": 32,
|
||||
"expand_ratio": mv2_exp_mult,
|
||||
"num_blocks": 1,
|
||||
"stride": 1,
|
||||
"block_type": "mv2",
|
||||
},
|
||||
"layer2": {
|
||||
"out_channels": 48,
|
||||
"expand_ratio": mv2_exp_mult,
|
||||
"num_blocks": 3,
|
||||
"stride": 2,
|
||||
"block_type": "mv2",
|
||||
},
|
||||
"layer3": { # 28x28
|
||||
"out_channels": 64,
|
||||
"transformer_channels": 96,
|
||||
"ffn_dim": 192,
|
||||
"transformer_blocks": 2,
|
||||
"patch_h": 2,
|
||||
"patch_w": 2,
|
||||
"stride": 2,
|
||||
"mv_expand_ratio": mv2_exp_mult,
|
||||
"num_heads": 4,
|
||||
"block_type": "mobilevit",
|
||||
},
|
||||
"layer4": { # 14x14
|
||||
"out_channels": 80,
|
||||
"transformer_channels": 120,
|
||||
"ffn_dim": 240,
|
||||
"transformer_blocks": 4,
|
||||
"patch_h": 2,
|
||||
"patch_w": 2,
|
||||
"stride": 2,
|
||||
"mv_expand_ratio": mv2_exp_mult,
|
||||
"num_heads": 4,
|
||||
"block_type": "mobilevit",
|
||||
},
|
||||
"layer5": { # 7x7
|
||||
"out_channels": 96,
|
||||
"transformer_channels": 144,
|
||||
"ffn_dim": 288,
|
||||
"transformer_blocks": 3,
|
||||
"patch_h": 2,
|
||||
"patch_w": 2,
|
||||
"stride": 2,
|
||||
"mv_expand_ratio": mv2_exp_mult,
|
||||
"num_heads": 4,
|
||||
"block_type": "mobilevit",
|
||||
},
|
||||
"last_layer_exp_factor": 4,
|
||||
"cls_dropout": 0.1
|
||||
}
|
||||
elif mode == "small":
|
||||
mv2_exp_mult = 4
|
||||
config = {
|
||||
"layer1": {
|
||||
"out_channels": 32,
|
||||
"expand_ratio": mv2_exp_mult,
|
||||
"num_blocks": 1,
|
||||
"stride": 1,
|
||||
"block_type": "mv2",
|
||||
},
|
||||
"layer2": {
|
||||
"out_channels": 64,
|
||||
"expand_ratio": mv2_exp_mult,
|
||||
"num_blocks": 3,
|
||||
"stride": 2,
|
||||
"block_type": "mv2",
|
||||
},
|
||||
"layer3": { # 28x28
|
||||
"out_channels": 96,
|
||||
"transformer_channels": 144,
|
||||
"ffn_dim": 288,
|
||||
"transformer_blocks": 2,
|
||||
"patch_h": 2,
|
||||
"patch_w": 2,
|
||||
"stride": 2,
|
||||
"mv_expand_ratio": mv2_exp_mult,
|
||||
"num_heads": 4,
|
||||
"block_type": "mobilevit",
|
||||
},
|
||||
"layer4": { # 14x14
|
||||
"out_channels": 128,
|
||||
"transformer_channels": 192,
|
||||
"ffn_dim": 384,
|
||||
"transformer_blocks": 4,
|
||||
"patch_h": 2,
|
||||
"patch_w": 2,
|
||||
"stride": 2,
|
||||
"mv_expand_ratio": mv2_exp_mult,
|
||||
"num_heads": 4,
|
||||
"block_type": "mobilevit",
|
||||
},
|
||||
"layer5": { # 7x7
|
||||
"out_channels": 160,
|
||||
"transformer_channels": 240,
|
||||
"ffn_dim": 480,
|
||||
"transformer_blocks": 3,
|
||||
"patch_h": 2,
|
||||
"patch_w": 2,
|
||||
"stride": 2,
|
||||
"mv_expand_ratio": mv2_exp_mult,
|
||||
"num_heads": 4,
|
||||
"block_type": "mobilevit",
|
||||
},
|
||||
"last_layer_exp_factor": 4,
|
||||
"cls_dropout": 0.1
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
for k in ["layer1", "layer2", "layer3", "layer4", "layer5"]:
|
||||
config[k].update({"dropout": 0.1, "ffn_dropout": 0.0, "attn_dropout": 0.0})
|
||||
|
||||
return config
|
|
@ -0,0 +1,37 @@
|
|||
from PIL import Image
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class MyDataSet(Dataset):
|
||||
"""自定义数据集"""
|
||||
|
||||
def __init__(self, images_path: list, images_class: list, transform=None):
|
||||
self.images_path = images_path
|
||||
self.images_class = images_class
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images_path)
|
||||
|
||||
def __getitem__(self, item):
|
||||
img = Image.open(self.images_path[item])
|
||||
# RGB为彩色图片,L为灰度图片
|
||||
if img.mode != 'RGB':
|
||||
raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
|
||||
label = self.images_class[item]
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
return img, label
|
||||
|
||||
@staticmethod
|
||||
def collate_fn(batch):
|
||||
# 官方实现的default_collate可以参考
|
||||
# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
|
||||
images, labels = tuple(zip(*batch))
|
||||
|
||||
images = torch.stack(images, dim=0)
|
||||
labels = torch.as_tensor(labels)
|
||||
return images, labels
|
|
@ -0,0 +1,64 @@
|
|||
import os
|
||||
import json
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from model import mobile_vit_small as create_model
|
||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
||||
|
||||
#设置plt支持中文
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei']
|
||||
|
||||
def main():
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
img_size = 224
|
||||
data_transform = transforms.Compose(
|
||||
[transforms.Resize(int(img_size * 1.14)),
|
||||
transforms.CenterCrop(img_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
|
||||
# load image
|
||||
img_path = r"E:\Download\data\train\Acne and Rosacea Photos\acne-closed-comedo-8.jpg"
|
||||
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
|
||||
img = Image.open(img_path)
|
||||
plt.imshow(img)
|
||||
# [N, C, H, W]
|
||||
img = data_transform(img)
|
||||
# expand batch dimension
|
||||
img = torch.unsqueeze(img, dim=0)
|
||||
|
||||
# read class_indict
|
||||
json_path = './class_indices.json'
|
||||
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
|
||||
|
||||
with open(json_path, "r",encoding="utf-8") as f:
|
||||
class_indict = json.load(f)
|
||||
|
||||
# create model
|
||||
model = create_model(num_classes=24).to(device)
|
||||
# load model weights
|
||||
model_weight_path = "./best300_model_0.7302241690286009.pth"
|
||||
model.load_state_dict(torch.load(model_weight_path, map_location=device))
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
# predict class
|
||||
output = torch.squeeze(model(img.to(device))).cpu()
|
||||
predict = torch.softmax(output, dim=0)
|
||||
predict_cla = torch.argmax(predict).numpy()
|
||||
|
||||
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
|
||||
predict[predict_cla].numpy())
|
||||
plt.title(print_res)
|
||||
for i in range(len(predict)):
|
||||
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
|
||||
predict[i].numpy()))
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,84 @@
|
|||
import os
|
||||
import json
|
||||
import uuid
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from .model import mobile_vit_small as create_model
|
||||
|
||||
class SkinDiseasePredictor:
|
||||
def __init__(self, model_path, class_indices_path, img_size=224):
|
||||
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
self.img_size = img_size
|
||||
self.data_transform = transforms.Compose([
|
||||
transforms.Resize(int(self.img_size * 1.14)),
|
||||
transforms.CenterCrop(self.img_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
# Load class indices
|
||||
with open(class_indices_path, "r",encoding="utf-8") as f:
|
||||
self.class_indict = json.load(f)
|
||||
# Load model
|
||||
self.model = self.load_model(model_path)
|
||||
|
||||
def load_model(self, model_path):
|
||||
|
||||
model = create_model(num_classes=24).to(self.device)
|
||||
model.load_state_dict(torch.load(model_path, map_location=self.device))
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def predict_img(self, image_path):
|
||||
# Load and transform image
|
||||
assert os.path.exists(image_path), f"file: '{image_path}' does not exist."
|
||||
img = Image.open(image_path).convert('RGB')
|
||||
img = self.data_transform(img)
|
||||
img = torch.unsqueeze(img, dim=0)
|
||||
|
||||
# Predict class
|
||||
with torch.no_grad():
|
||||
output = torch.squeeze(self.model(img.to(self.device))).cpu()
|
||||
probabilities = torch.softmax(output, dim=0)
|
||||
top_prob, top_catid = torch.topk(probabilities, 5)
|
||||
|
||||
# Top 5 results
|
||||
top5 = []
|
||||
for i in range(top_prob.size(0)):
|
||||
top5.append({
|
||||
"name": self.class_indict[str(top_catid[i].item())],
|
||||
"score": top_prob[i].item(),
|
||||
"label": top_catid[i].item()
|
||||
})
|
||||
|
||||
# Results dictionary
|
||||
|
||||
results = {"result": top5, "log_id": str(uuid.uuid1())}
|
||||
|
||||
return results
|
||||
def predict(self, np_image):
|
||||
# Convert numpy image to PIL image
|
||||
img = Image.fromarray(np_image).convert('RGB')
|
||||
|
||||
# Transform image
|
||||
img = self.data_transform(img)
|
||||
img = torch.unsqueeze(img, dim=0)
|
||||
|
||||
# Predict class
|
||||
with torch.no_grad():
|
||||
output = torch.squeeze(self.model(img.to(self.device))).cpu()
|
||||
probabilities = torch.softmax(output, dim=0)
|
||||
top_prob, top_catid = torch.topk(probabilities, 1)
|
||||
|
||||
top1 = {
|
||||
"name": self.class_indict[str(top_catid[0].item())],
|
||||
"score": top_prob[0].item(),
|
||||
"label": top_catid[0].item()
|
||||
}
|
||||
|
||||
return top1["name"] if top1["score"] > 0.5 else "健康"
|
||||
|
||||
# Example usage:
|
||||
# predictor = ImagePredictor(model_path="./weights/best_model.pth", class_indices_path="./class_indices.json")
|
||||
# result = predictor.predict("../tulip.jpg")
|
||||
# print(result)
|
|
@ -0,0 +1,135 @@
|
|||
import os
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torchvision import transforms
|
||||
|
||||
from my_dataset import MyDataSet
|
||||
from model import mobile_vit_xx_small as create_model
|
||||
from utils import read_split_data, train_one_epoch, evaluate
|
||||
|
||||
|
||||
def main(args):
|
||||
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
||||
|
||||
if os.path.exists("./weights") is False:
|
||||
os.makedirs("./weights")
|
||||
|
||||
tb_writer = SummaryWriter()
|
||||
|
||||
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)
|
||||
|
||||
img_size = 224
|
||||
data_transform = {
|
||||
"train": transforms.Compose([transforms.RandomResizedCrop(img_size),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
|
||||
"val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),
|
||||
transforms.CenterCrop(img_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
|
||||
|
||||
# 实例化训练数据集
|
||||
train_dataset = MyDataSet(images_path=train_images_path,
|
||||
images_class=train_images_label,
|
||||
transform=data_transform["train"])
|
||||
|
||||
# 实例化验证数据集
|
||||
val_dataset = MyDataSet(images_path=val_images_path,
|
||||
images_class=val_images_label,
|
||||
transform=data_transform["val"])
|
||||
|
||||
batch_size = args.batch_size
|
||||
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
|
||||
print('Using {} dataloader workers every process'.format(nw))
|
||||
train_loader = torch.utils.data.DataLoader(train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=True,
|
||||
num_workers=nw,
|
||||
collate_fn=train_dataset.collate_fn)
|
||||
|
||||
val_loader = torch.utils.data.DataLoader(val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
pin_memory=True,
|
||||
num_workers=nw,
|
||||
collate_fn=val_dataset.collate_fn)
|
||||
|
||||
model = create_model(num_classes=args.num_classes).to(device)
|
||||
|
||||
if args.weights != "":
|
||||
assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
|
||||
weights_dict = torch.load(args.weights, map_location=device)
|
||||
weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
|
||||
# 删除有关分类类别的权重
|
||||
for k in list(weights_dict.keys()):
|
||||
if "classifier" in k:
|
||||
del weights_dict[k]
|
||||
print(model.load_state_dict(weights_dict, strict=False))
|
||||
|
||||
if args.freeze_layers:
|
||||
for name, para in model.named_parameters():
|
||||
# 除head外,其他权重全部冻结
|
||||
if "classifier" not in name:
|
||||
para.requires_grad_(False)
|
||||
else:
|
||||
print("training {}".format(name))
|
||||
|
||||
pg = [p for p in model.parameters() if p.requires_grad]
|
||||
optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=1E-2)
|
||||
|
||||
best_acc = 0.
|
||||
for epoch in range(args.epochs):
|
||||
# train
|
||||
train_loss, train_acc = train_one_epoch(model=model,
|
||||
optimizer=optimizer,
|
||||
data_loader=train_loader,
|
||||
device=device,
|
||||
epoch=epoch)
|
||||
|
||||
# validate
|
||||
val_loss, val_acc = evaluate(model=model,
|
||||
data_loader=val_loader,
|
||||
device=device,
|
||||
epoch=epoch)
|
||||
|
||||
tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
|
||||
tb_writer.add_scalar(tags[0], train_loss, epoch)
|
||||
tb_writer.add_scalar(tags[1], train_acc, epoch)
|
||||
tb_writer.add_scalar(tags[2], val_loss, epoch)
|
||||
tb_writer.add_scalar(tags[3], val_acc, epoch)
|
||||
tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)
|
||||
|
||||
if val_acc > best_acc:
|
||||
best_acc = val_acc
|
||||
torch.save(model.state_dict(), "./weights/best_model.pth")
|
||||
|
||||
torch.save(model.state_dict(), "./weights/latest_model.pth")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--num_classes', type=int, default=5)
|
||||
parser.add_argument('--epochs', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=8)
|
||||
parser.add_argument('--lr', type=float, default=0.0002)
|
||||
|
||||
# 数据集所在根目录
|
||||
# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
|
||||
parser.add_argument('--data-path', type=str,
|
||||
default="/data/flower_photos")
|
||||
|
||||
# 预训练权重路径,如果不想载入就设置为空字符
|
||||
parser.add_argument('--weights', type=str, default='./mobilevit_xxs.pt',
|
||||
help='initial weights path')
|
||||
# 是否冻结权重
|
||||
parser.add_argument('--freeze-layers', type=bool, default=False)
|
||||
parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')
|
||||
|
||||
opt = parser.parse_args()
|
||||
|
||||
main(opt)
|
|
@ -0,0 +1,155 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""
|
||||
This layer applies a multi-head self- or cross-attention as described in
|
||||
`Attention is all you need <https://arxiv.org/abs/1706.03762>`_ paper
|
||||
|
||||
Args:
|
||||
embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`
|
||||
num_heads (int): Number of heads in multi-head attention
|
||||
attn_dropout (float): Attention dropout. Default: 0.0
|
||||
bias (bool): Use bias or not. Default: ``True``
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
|
||||
and :math:`C_{in}` is input embedding dim
|
||||
- Output: same shape as the input
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
attn_dropout: float = 0.0,
|
||||
bias: bool = True,
|
||||
*args,
|
||||
**kwargs
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if embed_dim % num_heads != 0:
|
||||
raise ValueError(
|
||||
"Embedding dim must be divisible by number of heads in {}. Got: embed_dim={} and num_heads={}".format(
|
||||
self.__class__.__name__, embed_dim, num_heads
|
||||
)
|
||||
)
|
||||
|
||||
self.qkv_proj = nn.Linear(in_features=embed_dim, out_features=3 * embed_dim, bias=bias)
|
||||
|
||||
self.attn_dropout = nn.Dropout(p=attn_dropout)
|
||||
self.out_proj = nn.Linear(in_features=embed_dim, out_features=embed_dim, bias=bias)
|
||||
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
self.num_heads = num_heads
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def forward(self, x_q: Tensor) -> Tensor:
|
||||
# [N, P, C]
|
||||
b_sz, n_patches, in_channels = x_q.shape
|
||||
|
||||
# self-attention
|
||||
# [N, P, C] -> [N, P, 3C] -> [N, P, 3, h, c] where C = hc
|
||||
qkv = self.qkv_proj(x_q).reshape(b_sz, n_patches, 3, self.num_heads, -1)
|
||||
|
||||
# [N, P, 3, h, c] -> [N, h, 3, P, C]
|
||||
qkv = qkv.transpose(1, 3).contiguous()
|
||||
|
||||
# [N, h, 3, P, C] -> [N, h, P, C] x 3
|
||||
query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
|
||||
|
||||
query = query * self.scaling
|
||||
|
||||
# [N h, P, c] -> [N, h, c, P]
|
||||
key = key.transpose(-1, -2)
|
||||
|
||||
# QK^T
|
||||
# [N, h, P, c] x [N, h, c, P] -> [N, h, P, P]
|
||||
attn = torch.matmul(query, key)
|
||||
attn = self.softmax(attn)
|
||||
attn = self.attn_dropout(attn)
|
||||
|
||||
# weighted sum
|
||||
# [N, h, P, P] x [N, h, P, c] -> [N, h, P, c]
|
||||
out = torch.matmul(attn, value)
|
||||
|
||||
# [N, h, P, c] -> [N, P, h, c] -> [N, P, C]
|
||||
out = out.transpose(1, 2).reshape(b_sz, n_patches, -1)
|
||||
out = self.out_proj(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
"""
|
||||
This class defines the pre-norm `Transformer encoder <https://arxiv.org/abs/1706.03762>`_
|
||||
Args:
|
||||
embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`
|
||||
ffn_latent_dim (int): Inner dimension of the FFN
|
||||
num_heads (int) : Number of heads in multi-head attention. Default: 8
|
||||
attn_dropout (float): Dropout rate for attention in multi-head attention. Default: 0.0
|
||||
dropout (float): Dropout rate. Default: 0.0
|
||||
ffn_dropout (float): Dropout between FFN layers. Default: 0.0
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
|
||||
and :math:`C_{in}` is input embedding dim
|
||||
- Output: same shape as the input
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
ffn_latent_dim: int,
|
||||
num_heads: Optional[int] = 8,
|
||||
attn_dropout: Optional[float] = 0.0,
|
||||
dropout: Optional[float] = 0.0,
|
||||
ffn_dropout: Optional[float] = 0.0,
|
||||
*args,
|
||||
**kwargs
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
|
||||
attn_unit = MultiHeadAttention(
|
||||
embed_dim,
|
||||
num_heads,
|
||||
attn_dropout=attn_dropout,
|
||||
bias=True
|
||||
)
|
||||
|
||||
self.pre_norm_mha = nn.Sequential(
|
||||
nn.LayerNorm(embed_dim),
|
||||
attn_unit,
|
||||
nn.Dropout(p=dropout)
|
||||
)
|
||||
|
||||
self.pre_norm_ffn = nn.Sequential(
|
||||
nn.LayerNorm(embed_dim),
|
||||
nn.Linear(in_features=embed_dim, out_features=ffn_latent_dim, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=ffn_dropout),
|
||||
nn.Linear(in_features=ffn_latent_dim, out_features=embed_dim, bias=True),
|
||||
nn.Dropout(p=dropout)
|
||||
)
|
||||
self.embed_dim = embed_dim
|
||||
self.ffn_dim = ffn_latent_dim
|
||||
self.ffn_dropout = ffn_dropout
|
||||
self.std_dropout = dropout
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# multi-head attention
|
||||
res = x
|
||||
x = self.pre_norm_mha(x)
|
||||
x = x + res
|
||||
|
||||
# feed forward network
|
||||
x = x + self.pre_norm_ffn(x)
|
||||
return x
|
|
@ -0,0 +1,56 @@
|
|||
import time
|
||||
import torch
|
||||
|
||||
batch_size = 8
|
||||
in_channels = 32
|
||||
patch_h = 2
|
||||
patch_w = 2
|
||||
num_patch_h = 16
|
||||
num_patch_w = 16
|
||||
num_patches = num_patch_h * num_patch_w
|
||||
patch_area = patch_h * patch_w
|
||||
|
||||
|
||||
def official(x: torch.Tensor):
|
||||
# [B, C, H, W] -> [B * C * n_h, p_h, n_w, p_w]
|
||||
x = x.reshape(batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w)
|
||||
# [B * C * n_h, p_h, n_w, p_w] -> [B * C * n_h, n_w, p_h, p_w]
|
||||
x = x.transpose(1, 2)
|
||||
# [B * C * n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
|
||||
x = x.reshape(batch_size, in_channels, num_patches, patch_area)
|
||||
# [B, C, N, P] -> [B, P, N, C]
|
||||
x = x.transpose(1, 3)
|
||||
# [B, P, N, C] -> [BP, N, C]
|
||||
x = x.reshape(batch_size * patch_area, num_patches, -1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def my_self(x: torch.Tensor):
|
||||
# [B, C, H, W] -> [B, C, n_h, p_h, n_w, p_w]
|
||||
x = x.reshape(batch_size, in_channels, num_patch_h, patch_h, num_patch_w, patch_w)
|
||||
# [B, C, n_h, p_h, n_w, p_w] -> [B, C, n_h, n_w, p_h, p_w]
|
||||
x = x.transpose(3, 4)
|
||||
# [B, C, n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
|
||||
x = x.reshape(batch_size, in_channels, num_patches, patch_area)
|
||||
# [B, C, N, P] -> [B, P, N, C]
|
||||
x = x.transpose(1, 3)
|
||||
# [B, P, N, C] -> [BP, N, C]
|
||||
x = x.reshape(batch_size * patch_area, num_patches, -1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
t = torch.randn(batch_size, in_channels, num_patch_h * patch_h, num_patch_w * patch_w)
|
||||
print(torch.equal(official(t), my_self(t)))
|
||||
|
||||
t1 = time.time()
|
||||
for _ in range(1000):
|
||||
official(t)
|
||||
print(f"official time: {time.time() - t1}")
|
||||
|
||||
t1 = time.time()
|
||||
for _ in range(1000):
|
||||
my_self(t)
|
||||
print(f"self time: {time.time() - t1}")
|
|
@ -0,0 +1,179 @@
|
|||
import os
|
||||
import sys
|
||||
import json
|
||||
import pickle
|
||||
import random
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def read_split_data(root: str, val_rate: float = 0.2):
|
||||
random.seed(0) # 保证随机结果可复现
|
||||
assert os.path.exists(root), "dataset root: {} does not exist.".format(root)
|
||||
|
||||
# 遍历文件夹,一个文件夹对应一个类别
|
||||
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
|
||||
# 排序,保证各平台顺序一致
|
||||
flower_class.sort()
|
||||
# 生成类别名称以及对应的数字索引
|
||||
class_indices = dict((k, v) for v, k in enumerate(flower_class))
|
||||
json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
|
||||
with open('class_indices.json', 'w') as json_file:
|
||||
json_file.write(json_str)
|
||||
|
||||
train_images_path = [] # 存储训练集的所有图片路径
|
||||
train_images_label = [] # 存储训练集图片对应索引信息
|
||||
val_images_path = [] # 存储验证集的所有图片路径
|
||||
val_images_label = [] # 存储验证集图片对应索引信息
|
||||
every_class_num = [] # 存储每个类别的样本总数
|
||||
supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型
|
||||
# 遍历每个文件夹下的文件
|
||||
for cla in flower_class:
|
||||
cla_path = os.path.join(root, cla)
|
||||
# 遍历获取supported支持的所有文件路径
|
||||
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
|
||||
if os.path.splitext(i)[-1] in supported]
|
||||
# 排序,保证各平台顺序一致
|
||||
images.sort()
|
||||
# 获取该类别对应的索引
|
||||
image_class = class_indices[cla]
|
||||
# 记录该类别的样本数量
|
||||
every_class_num.append(len(images))
|
||||
# 按比例随机采样验证样本
|
||||
val_path = random.sample(images, k=int(len(images) * val_rate))
|
||||
|
||||
for img_path in images:
|
||||
if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集
|
||||
val_images_path.append(img_path)
|
||||
val_images_label.append(image_class)
|
||||
else: # 否则存入训练集
|
||||
train_images_path.append(img_path)
|
||||
train_images_label.append(image_class)
|
||||
|
||||
print("{} images were found in the dataset.".format(sum(every_class_num)))
|
||||
print("{} images for training.".format(len(train_images_path)))
|
||||
print("{} images for validation.".format(len(val_images_path)))
|
||||
assert len(train_images_path) > 0, "number of training images must greater than 0."
|
||||
assert len(val_images_path) > 0, "number of validation images must greater than 0."
|
||||
|
||||
plot_image = False
|
||||
if plot_image:
|
||||
# 绘制每种类别个数柱状图
|
||||
plt.bar(range(len(flower_class)), every_class_num, align='center')
|
||||
# 将横坐标0,1,2,3,4替换为相应的类别名称
|
||||
plt.xticks(range(len(flower_class)), flower_class)
|
||||
# 在柱状图上添加数值标签
|
||||
for i, v in enumerate(every_class_num):
|
||||
plt.text(x=i, y=v + 5, s=str(v), ha='center')
|
||||
# 设置x坐标
|
||||
plt.xlabel('image class')
|
||||
# 设置y坐标
|
||||
plt.ylabel('number of images')
|
||||
# 设置柱状图的标题
|
||||
plt.title('flower class distribution')
|
||||
plt.show()
|
||||
|
||||
return train_images_path, train_images_label, val_images_path, val_images_label
|
||||
|
||||
|
||||
def plot_data_loader_image(data_loader):
|
||||
batch_size = data_loader.batch_size
|
||||
plot_num = min(batch_size, 4)
|
||||
|
||||
json_path = './class_indices.json'
|
||||
assert os.path.exists(json_path), json_path + " does not exist."
|
||||
json_file = open(json_path, 'r')
|
||||
class_indices = json.load(json_file)
|
||||
|
||||
for data in data_loader:
|
||||
images, labels = data
|
||||
for i in range(plot_num):
|
||||
# [C, H, W] -> [H, W, C]
|
||||
img = images[i].numpy().transpose(1, 2, 0)
|
||||
# 反Normalize操作
|
||||
img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
|
||||
label = labels[i].item()
|
||||
plt.subplot(1, plot_num, i+1)
|
||||
plt.xlabel(class_indices[str(label)])
|
||||
plt.xticks([]) # 去掉x轴的刻度
|
||||
plt.yticks([]) # 去掉y轴的刻度
|
||||
plt.imshow(img.astype('uint8'))
|
||||
plt.show()
|
||||
|
||||
|
||||
def write_pickle(list_info: list, file_name: str):
|
||||
with open(file_name, 'wb') as f:
|
||||
pickle.dump(list_info, f)
|
||||
|
||||
|
||||
def read_pickle(file_name: str) -> list:
|
||||
with open(file_name, 'rb') as f:
|
||||
info_list = pickle.load(f)
|
||||
return info_list
|
||||
|
||||
|
||||
def train_one_epoch(model, optimizer, data_loader, device, epoch):
|
||||
model.train()
|
||||
loss_function = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
|
||||
accu_loss = torch.zeros(1).to(device) # 累计损失
|
||||
accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
|
||||
optimizer.zero_grad()
|
||||
|
||||
sample_num = 0
|
||||
data_loader = tqdm(data_loader, file=sys.stdout)
|
||||
for step, data in enumerate(data_loader):
|
||||
images, labels = data
|
||||
sample_num += images.shape[0]
|
||||
|
||||
pred = model(images.to(device))
|
||||
pred_classes = torch.max(pred, dim=1)[1]
|
||||
accu_num += torch.eq(pred_classes, labels.to(device)).sum()
|
||||
|
||||
loss = loss_function(pred, labels.to(device))
|
||||
loss.backward()
|
||||
accu_loss += loss.detach()
|
||||
|
||||
data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
|
||||
accu_loss.item() / (step + 1),
|
||||
accu_num.item() / sample_num)
|
||||
|
||||
if not torch.isfinite(loss):
|
||||
print('WARNING: non-finite loss, ending training ', loss)
|
||||
sys.exit(1)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
return accu_loss.item() / (step + 1), accu_num.item() / sample_num
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(model, data_loader, device, epoch):
|
||||
loss_function = torch.nn.CrossEntropyLoss()
|
||||
|
||||
model.eval()
|
||||
|
||||
accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
|
||||
accu_loss = torch.zeros(1).to(device) # 累计损失
|
||||
|
||||
sample_num = 0
|
||||
data_loader = tqdm(data_loader, file=sys.stdout)
|
||||
for step, data in enumerate(data_loader):
|
||||
images, labels = data
|
||||
sample_num += images.shape[0]
|
||||
|
||||
pred = model(images.to(device))
|
||||
pred_classes = torch.max(pred, dim=1)[1]
|
||||
accu_num += torch.eq(pred_classes, labels.to(device)).sum()
|
||||
|
||||
loss = loss_function(pred, labels.to(device))
|
||||
accu_loss += loss
|
||||
|
||||
data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
|
||||
accu_loss.item() / (step + 1),
|
||||
accu_num.item() / sample_num)
|
||||
|
||||
return accu_loss.item() / (step + 1), accu_num.item() / sample_num
|
|
@ -0,0 +1,51 @@
|
|||
import cv2
|
||||
import dlib
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from predict_api import ImagePredictor
|
||||
|
||||
# Initialize camera and face detector
|
||||
cap = cv2.VideoCapture(0)
|
||||
detector = dlib.get_frontal_face_detector()
|
||||
|
||||
# Initialize ImagePredictor
|
||||
predictor = ImagePredictor(model_path="best300_model_0.7302241690286009.pth", class_indices_path="./class_indices.json")
|
||||
|
||||
while True:
|
||||
# Capture frame-by-frame
|
||||
ret, frame = cap.read()
|
||||
|
||||
# Convert the image from BGR color (which OpenCV uses) to RGB color
|
||||
rgb_image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Perform face detection
|
||||
faces = detector(rgb_image)
|
||||
|
||||
# Loop through each face in this frame
|
||||
for rect in faces:
|
||||
# Get the bounding box coordinates
|
||||
x1, y1, x2, y2 = rect.left(), rect.top(), rect.right(), rect.bottom()
|
||||
|
||||
# Crop the face from the frame
|
||||
face_image = rgb_image[y1:y2, x1:x2]
|
||||
|
||||
# Use ImagePredictor to predict the class of this face
|
||||
result = predictor.predict(face_image)
|
||||
|
||||
# Draw a rectangle around the face
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
|
||||
# Display the class name and score
|
||||
cv2.putText(frame, f"{result['result'][0]['name']}: {round(result['result'][0]['score'],4)}", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36,255,12), 2)
|
||||
|
||||
# Display the resulting frame
|
||||
cv2.imshow('Video', frame)
|
||||
|
||||
# Exit loop if 'q' is pressed
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
|
||||
# When everything is done, release the capture
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
|
@ -0,0 +1,34 @@
|
|||
# 基于视觉的皮肤类型检测系统
|
||||
|
||||
该项目是一个基于图像的皮肤类型检测系统。它使用MobileViT在皮肤图像数据集上进行训练,然后可以从摄像头输入的视频中检测人脸,并为每个检测到的人脸预测皮肤类型(干性、正常或油性)。
|
||||
|
||||
## 核心文件
|
||||
|
||||
- `class_indices.json`: 包含皮肤类型标签和对应数值编码的映射。
|
||||
- `predict_api.py`: 包含图像预测模型的加载、预处理和推理逻辑。
|
||||
- `video.py`: 视频处理和可视化的主要脚本。
|
||||
- `best_model_'0.8998410174880763'.pth`: 在皮肤图像数据集上训练的模型权重文件。
|
||||
|
||||
## 使用方法
|
||||
|
||||
1. 确保已安装所需的Python库,包括`opencv-python`、`torch`、`torchvision`、`Pillow`和`dlib`。
|
||||
2. 运行`video.py`脚本。
|
||||
3. 脚本将打开默认摄像头,开始人脸检测和皮肤类型预测。
|
||||
4. 检测到的人脸周围会用矩形框标注,并显示预测的皮肤类型和置信度分数。
|
||||
5. 按`q`键退出程序。
|
||||
|
||||
## 模型介绍
|
||||
|
||||
该项目使用MobileViT作为基础模型,对皮肤图像数据集进行训练,以预测人脸图像的皮肤类型。模型输出包含3个值,分别对应干性、正常和油性皮肤类型的概率。
|
||||
|
||||
### 数据集介绍
|
||||
|
||||
该项目使用的皮肤图像数据集来自Kaggle平台,数据集包含3152张标注了皮肤类型(干性、正常或油性)的人脸图像。
|
||||
|
||||
## 算法流程
|
||||
|
||||
1. **人脸检测**: 使用Dlib库中的预训练人脸检测器在视频帧中检测人脸。
|
||||
2. **预处理**: 对检测到的人脸图像进行缩放、裁剪和标准化等预处理,以满足模型的输入要求。
|
||||
3. **推理**: 将预处理后的图像输入到预训练的Mobile-ViT模型中,获得不同皮肤类型的概率预测结果。
|
||||
4. **后处理**: 选取概率最高的类别作为最终预测结果。
|
||||
5. **可视化**: 在视频帧上绘制人脸矩形框,并显示预测的皮肤类型和置信度分数。
|
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"0": "干性皮肤",
|
||||
"1": "正常皮肤",
|
||||
"2": "油性皮肤"
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,562 @@
|
|||
"""
|
||||
original code from apple:
|
||||
https://github.com/apple/ml-cvnets/blob/main/cvnets/models/classification/mobilevit.py
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple, Union, Dict
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .transformer import TransformerEncoder
|
||||
from .model_config import get_config
|
||||
|
||||
|
||||
def make_divisible(
|
||||
v: Union[float, int],
|
||||
divisor: Optional[int] = 8,
|
||||
min_value: Optional[Union[float, int]] = None,
|
||||
) -> Union[float, int]:
|
||||
"""
|
||||
This function is taken from the original tf repo.
|
||||
It ensures that all layers have a channel number that is divisible by 8
|
||||
It can be seen here:
|
||||
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||
:param v:
|
||||
:param divisor:
|
||||
:param min_value:
|
||||
:return:
|
||||
"""
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class ConvLayer(nn.Module):
|
||||
"""
|
||||
Applies a 2D convolution over an input
|
||||
|
||||
Args:
|
||||
in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||
out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out})`
|
||||
kernel_size (Union[int, Tuple[int, int]]): Kernel size for convolution.
|
||||
stride (Union[int, Tuple[int, int]]): Stride for convolution. Default: 1
|
||||
groups (Optional[int]): Number of groups in convolution. Default: 1
|
||||
bias (Optional[bool]): Use bias. Default: ``False``
|
||||
use_norm (Optional[bool]): Use normalization layer after convolution. Default: ``True``
|
||||
use_act (Optional[bool]): Use activation layer after convolution (or convolution and normalization).
|
||||
Default: ``True``
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||
- Output: :math:`(N, C_{out}, H_{out}, W_{out})`
|
||||
|
||||
.. note::
|
||||
For depth-wise convolution, `groups=C_{in}=C_{out}`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Optional[Union[int, Tuple[int, int]]] = 1,
|
||||
groups: Optional[int] = 1,
|
||||
bias: Optional[bool] = False,
|
||||
use_norm: Optional[bool] = True,
|
||||
use_act: Optional[bool] = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride)
|
||||
|
||||
assert isinstance(kernel_size, Tuple)
|
||||
assert isinstance(stride, Tuple)
|
||||
|
||||
padding = (
|
||||
int((kernel_size[0] - 1) / 2),
|
||||
int((kernel_size[1] - 1) / 2),
|
||||
)
|
||||
|
||||
block = nn.Sequential()
|
||||
|
||||
conv_layer = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
groups=groups,
|
||||
padding=padding,
|
||||
bias=bias
|
||||
)
|
||||
|
||||
block.add_module(name="conv", module=conv_layer)
|
||||
|
||||
if use_norm:
|
||||
norm_layer = nn.BatchNorm2d(num_features=out_channels, momentum=0.1)
|
||||
block.add_module(name="norm", module=norm_layer)
|
||||
|
||||
if use_act:
|
||||
act_layer = nn.SiLU()
|
||||
block.add_module(name="act", module=act_layer)
|
||||
|
||||
self.block = block
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
"""
|
||||
This class implements the inverted residual block, as described in `MobileNetv2 <https://arxiv.org/abs/1801.04381>`_ paper
|
||||
|
||||
Args:
|
||||
in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||
out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out)`
|
||||
stride (int): Use convolutions with a stride. Default: 1
|
||||
expand_ratio (Union[int, float]): Expand the input channels by this factor in depth-wise conv
|
||||
skip_connection (Optional[bool]): Use skip-connection. Default: True
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||
- Output: :math:`(N, C_{out}, H_{out}, W_{out})`
|
||||
|
||||
.. note::
|
||||
If `in_channels =! out_channels` and `stride > 1`, we set `skip_connection=False`
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: int,
|
||||
expand_ratio: Union[int, float],
|
||||
skip_connection: Optional[bool] = True,
|
||||
) -> None:
|
||||
assert stride in [1, 2]
|
||||
hidden_dim = make_divisible(int(round(in_channels * expand_ratio)), 8)
|
||||
|
||||
super().__init__()
|
||||
|
||||
block = nn.Sequential()
|
||||
if expand_ratio != 1:
|
||||
block.add_module(
|
||||
name="exp_1x1",
|
||||
module=ConvLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=hidden_dim,
|
||||
kernel_size=1
|
||||
),
|
||||
)
|
||||
|
||||
block.add_module(
|
||||
name="conv_3x3",
|
||||
module=ConvLayer(
|
||||
in_channels=hidden_dim,
|
||||
out_channels=hidden_dim,
|
||||
stride=stride,
|
||||
kernel_size=3,
|
||||
groups=hidden_dim
|
||||
),
|
||||
)
|
||||
|
||||
block.add_module(
|
||||
name="red_1x1",
|
||||
module=ConvLayer(
|
||||
in_channels=hidden_dim,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
use_act=False,
|
||||
use_norm=True,
|
||||
),
|
||||
)
|
||||
|
||||
self.block = block
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.exp = expand_ratio
|
||||
self.stride = stride
|
||||
self.use_res_connect = (
|
||||
self.stride == 1 and in_channels == out_channels and skip_connection
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
|
||||
if self.use_res_connect:
|
||||
return x + self.block(x)
|
||||
else:
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class MobileViTBlock(nn.Module):
|
||||
"""
|
||||
This class defines the `MobileViT block <https://arxiv.org/abs/2110.02178?context=cs.LG>`_
|
||||
|
||||
Args:
|
||||
opts: command line arguments
|
||||
in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H, W)`
|
||||
transformer_dim (int): Input dimension to the transformer unit
|
||||
ffn_dim (int): Dimension of the FFN block
|
||||
n_transformer_blocks (int): Number of transformer blocks. Default: 2
|
||||
head_dim (int): Head dimension in the multi-head attention. Default: 32
|
||||
attn_dropout (float): Dropout in multi-head attention. Default: 0.0
|
||||
dropout (float): Dropout rate. Default: 0.0
|
||||
ffn_dropout (float): Dropout between FFN layers in transformer. Default: 0.0
|
||||
patch_h (int): Patch height for unfolding operation. Default: 8
|
||||
patch_w (int): Patch width for unfolding operation. Default: 8
|
||||
transformer_norm_layer (Optional[str]): Normalization layer in the transformer block. Default: layer_norm
|
||||
conv_ksize (int): Kernel size to learn local representations in MobileViT block. Default: 3
|
||||
no_fusion (Optional[bool]): Do not combine the input and output feature maps. Default: False
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
transformer_dim: int,
|
||||
ffn_dim: int,
|
||||
n_transformer_blocks: int = 2,
|
||||
head_dim: int = 32,
|
||||
attn_dropout: float = 0.0,
|
||||
dropout: float = 0.0,
|
||||
ffn_dropout: float = 0.0,
|
||||
patch_h: int = 8,
|
||||
patch_w: int = 8,
|
||||
conv_ksize: Optional[int] = 3,
|
||||
*args,
|
||||
**kwargs
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
conv_3x3_in = ConvLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=conv_ksize,
|
||||
stride=1
|
||||
)
|
||||
conv_1x1_in = ConvLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=transformer_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
use_norm=False,
|
||||
use_act=False
|
||||
)
|
||||
|
||||
conv_1x1_out = ConvLayer(
|
||||
in_channels=transformer_dim,
|
||||
out_channels=in_channels,
|
||||
kernel_size=1,
|
||||
stride=1
|
||||
)
|
||||
conv_3x3_out = ConvLayer(
|
||||
in_channels=2 * in_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=conv_ksize,
|
||||
stride=1
|
||||
)
|
||||
|
||||
self.local_rep = nn.Sequential()
|
||||
self.local_rep.add_module(name="conv_3x3", module=conv_3x3_in)
|
||||
self.local_rep.add_module(name="conv_1x1", module=conv_1x1_in)
|
||||
|
||||
assert transformer_dim % head_dim == 0
|
||||
num_heads = transformer_dim // head_dim
|
||||
|
||||
global_rep = [
|
||||
TransformerEncoder(
|
||||
embed_dim=transformer_dim,
|
||||
ffn_latent_dim=ffn_dim,
|
||||
num_heads=num_heads,
|
||||
attn_dropout=attn_dropout,
|
||||
dropout=dropout,
|
||||
ffn_dropout=ffn_dropout
|
||||
)
|
||||
for _ in range(n_transformer_blocks)
|
||||
]
|
||||
global_rep.append(nn.LayerNorm(transformer_dim))
|
||||
self.global_rep = nn.Sequential(*global_rep)
|
||||
|
||||
self.conv_proj = conv_1x1_out
|
||||
self.fusion = conv_3x3_out
|
||||
|
||||
self.patch_h = patch_h
|
||||
self.patch_w = patch_w
|
||||
self.patch_area = self.patch_w * self.patch_h
|
||||
|
||||
self.cnn_in_dim = in_channels
|
||||
self.cnn_out_dim = transformer_dim
|
||||
self.n_heads = num_heads
|
||||
self.ffn_dim = ffn_dim
|
||||
self.dropout = dropout
|
||||
self.attn_dropout = attn_dropout
|
||||
self.ffn_dropout = ffn_dropout
|
||||
self.n_blocks = n_transformer_blocks
|
||||
self.conv_ksize = conv_ksize
|
||||
|
||||
def unfolding(self, x: Tensor) -> Tuple[Tensor, Dict]:
|
||||
patch_w, patch_h = self.patch_w, self.patch_h
|
||||
patch_area = patch_w * patch_h
|
||||
batch_size, in_channels, orig_h, orig_w = x.shape
|
||||
|
||||
new_h = int(math.ceil(orig_h / self.patch_h) * self.patch_h)
|
||||
new_w = int(math.ceil(orig_w / self.patch_w) * self.patch_w)
|
||||
|
||||
interpolate = False
|
||||
if new_w != orig_w or new_h != orig_h:
|
||||
# Note: Padding can be done, but then it needs to be handled in attention function.
|
||||
x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=False)
|
||||
interpolate = True
|
||||
|
||||
# number of patches along width and height
|
||||
num_patch_w = new_w // patch_w # n_w
|
||||
num_patch_h = new_h // patch_h # n_h
|
||||
num_patches = num_patch_h * num_patch_w # N
|
||||
|
||||
# [B, C, H, W] -> [B * C * n_h, p_h, n_w, p_w]
|
||||
x = x.reshape(batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w)
|
||||
# [B * C * n_h, p_h, n_w, p_w] -> [B * C * n_h, n_w, p_h, p_w]
|
||||
x = x.transpose(1, 2)
|
||||
# [B * C * n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
|
||||
x = x.reshape(batch_size, in_channels, num_patches, patch_area)
|
||||
# [B, C, N, P] -> [B, P, N, C]
|
||||
x = x.transpose(1, 3)
|
||||
# [B, P, N, C] -> [BP, N, C]
|
||||
x = x.reshape(batch_size * patch_area, num_patches, -1)
|
||||
|
||||
info_dict = {
|
||||
"orig_size": (orig_h, orig_w),
|
||||
"batch_size": batch_size,
|
||||
"interpolate": interpolate,
|
||||
"total_patches": num_patches,
|
||||
"num_patches_w": num_patch_w,
|
||||
"num_patches_h": num_patch_h,
|
||||
}
|
||||
|
||||
return x, info_dict
|
||||
|
||||
def folding(self, x: Tensor, info_dict: Dict) -> Tensor:
|
||||
n_dim = x.dim()
|
||||
assert n_dim == 3, "Tensor should be of shape BPxNxC. Got: {}".format(
|
||||
x.shape
|
||||
)
|
||||
# [BP, N, C] --> [B, P, N, C]
|
||||
x = x.contiguous().view(
|
||||
info_dict["batch_size"], self.patch_area, info_dict["total_patches"], -1
|
||||
)
|
||||
|
||||
batch_size, pixels, num_patches, channels = x.size()
|
||||
num_patch_h = info_dict["num_patches_h"]
|
||||
num_patch_w = info_dict["num_patches_w"]
|
||||
|
||||
# [B, P, N, C] -> [B, C, N, P]
|
||||
x = x.transpose(1, 3)
|
||||
# [B, C, N, P] -> [B*C*n_h, n_w, p_h, p_w]
|
||||
x = x.reshape(batch_size * channels * num_patch_h, num_patch_w, self.patch_h, self.patch_w)
|
||||
# [B*C*n_h, n_w, p_h, p_w] -> [B*C*n_h, p_h, n_w, p_w]
|
||||
x = x.transpose(1, 2)
|
||||
# [B*C*n_h, p_h, n_w, p_w] -> [B, C, H, W]
|
||||
x = x.reshape(batch_size, channels, num_patch_h * self.patch_h, num_patch_w * self.patch_w)
|
||||
if info_dict["interpolate"]:
|
||||
x = F.interpolate(
|
||||
x,
|
||||
size=info_dict["orig_size"],
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
return x
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
res = x
|
||||
|
||||
fm = self.local_rep(x)
|
||||
|
||||
# convert feature map to patches
|
||||
patches, info_dict = self.unfolding(fm)
|
||||
|
||||
# learn global representations
|
||||
for transformer_layer in self.global_rep:
|
||||
patches = transformer_layer(patches)
|
||||
|
||||
# [B x Patch x Patches x C] -> [B x C x Patches x Patch]
|
||||
fm = self.folding(x=patches, info_dict=info_dict)
|
||||
|
||||
fm = self.conv_proj(fm)
|
||||
|
||||
fm = self.fusion(torch.cat((res, fm), dim=1))
|
||||
return fm
|
||||
|
||||
|
||||
class MobileViT(nn.Module):
|
||||
"""
|
||||
This class implements the `MobileViT architecture <https://arxiv.org/abs/2110.02178?context=cs.LG>`_
|
||||
"""
|
||||
def __init__(self, model_cfg: Dict, num_classes: int = 1000):
|
||||
super().__init__()
|
||||
|
||||
image_channels = 3
|
||||
out_channels = 16
|
||||
|
||||
self.conv_1 = ConvLayer(
|
||||
in_channels=image_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2
|
||||
)
|
||||
|
||||
self.layer_1, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer1"])
|
||||
self.layer_2, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer2"])
|
||||
self.layer_3, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer3"])
|
||||
self.layer_4, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer4"])
|
||||
self.layer_5, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer5"])
|
||||
|
||||
exp_channels = min(model_cfg["last_layer_exp_factor"] * out_channels, 960)
|
||||
self.conv_1x1_exp = ConvLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=exp_channels,
|
||||
kernel_size=1
|
||||
)
|
||||
|
||||
self.classifier = nn.Sequential()
|
||||
self.classifier.add_module(name="global_pool", module=nn.AdaptiveAvgPool2d(1))
|
||||
self.classifier.add_module(name="flatten", module=nn.Flatten())
|
||||
if 0.0 < model_cfg["cls_dropout"] < 1.0:
|
||||
self.classifier.add_module(name="dropout", module=nn.Dropout(p=model_cfg["cls_dropout"]))
|
||||
self.classifier.add_module(name="fc", module=nn.Linear(in_features=exp_channels, out_features=num_classes))
|
||||
|
||||
# weight init
|
||||
self.apply(self.init_parameters)
|
||||
|
||||
def _make_layer(self, input_channel, cfg: Dict) -> Tuple[nn.Sequential, int]:
|
||||
block_type = cfg.get("block_type", "mobilevit")
|
||||
if block_type.lower() == "mobilevit":
|
||||
return self._make_mit_layer(input_channel=input_channel, cfg=cfg)
|
||||
else:
|
||||
return self._make_mobilenet_layer(input_channel=input_channel, cfg=cfg)
|
||||
|
||||
@staticmethod
|
||||
def _make_mobilenet_layer(input_channel: int, cfg: Dict) -> Tuple[nn.Sequential, int]:
|
||||
output_channels = cfg.get("out_channels")
|
||||
num_blocks = cfg.get("num_blocks", 2)
|
||||
expand_ratio = cfg.get("expand_ratio", 4)
|
||||
block = []
|
||||
|
||||
for i in range(num_blocks):
|
||||
stride = cfg.get("stride", 1) if i == 0 else 1
|
||||
|
||||
layer = InvertedResidual(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channels,
|
||||
stride=stride,
|
||||
expand_ratio=expand_ratio
|
||||
)
|
||||
block.append(layer)
|
||||
input_channel = output_channels
|
||||
|
||||
return nn.Sequential(*block), input_channel
|
||||
|
||||
@staticmethod
|
||||
def _make_mit_layer(input_channel: int, cfg: Dict) -> [nn.Sequential, int]:
|
||||
stride = cfg.get("stride", 1)
|
||||
block = []
|
||||
|
||||
if stride == 2:
|
||||
layer = InvertedResidual(
|
||||
in_channels=input_channel,
|
||||
out_channels=cfg.get("out_channels"),
|
||||
stride=stride,
|
||||
expand_ratio=cfg.get("mv_expand_ratio", 4)
|
||||
)
|
||||
|
||||
block.append(layer)
|
||||
input_channel = cfg.get("out_channels")
|
||||
|
||||
transformer_dim = cfg["transformer_channels"]
|
||||
ffn_dim = cfg.get("ffn_dim")
|
||||
num_heads = cfg.get("num_heads", 4)
|
||||
head_dim = transformer_dim // num_heads
|
||||
|
||||
if transformer_dim % head_dim != 0:
|
||||
raise ValueError("Transformer input dimension should be divisible by head dimension. "
|
||||
"Got {} and {}.".format(transformer_dim, head_dim))
|
||||
|
||||
block.append(MobileViTBlock(
|
||||
in_channels=input_channel,
|
||||
transformer_dim=transformer_dim,
|
||||
ffn_dim=ffn_dim,
|
||||
n_transformer_blocks=cfg.get("transformer_blocks", 1),
|
||||
patch_h=cfg.get("patch_h", 2),
|
||||
patch_w=cfg.get("patch_w", 2),
|
||||
dropout=cfg.get("dropout", 0.1),
|
||||
ffn_dropout=cfg.get("ffn_dropout", 0.0),
|
||||
attn_dropout=cfg.get("attn_dropout", 0.1),
|
||||
head_dim=head_dim,
|
||||
conv_ksize=3
|
||||
))
|
||||
|
||||
return nn.Sequential(*block), input_channel
|
||||
|
||||
@staticmethod
|
||||
def init_parameters(m):
|
||||
if isinstance(m, nn.Conv2d):
|
||||
if m.weight is not None:
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
|
||||
if m.weight is not None:
|
||||
nn.init.ones_(m.weight)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, (nn.Linear,)):
|
||||
if m.weight is not None:
|
||||
nn.init.trunc_normal_(m.weight, mean=0.0, std=0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
else:
|
||||
pass
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.conv_1(x)
|
||||
x = self.layer_1(x)
|
||||
x = self.layer_2(x)
|
||||
|
||||
x = self.layer_3(x)
|
||||
x = self.layer_4(x)
|
||||
x = self.layer_5(x)
|
||||
x = self.conv_1x1_exp(x)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
def mobile_vit_xx_small(num_classes: int = 1000):
|
||||
# pretrain weight link
|
||||
# https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_xxs.pt
|
||||
config = get_config("xx_small")
|
||||
m = MobileViT(config, num_classes=num_classes)
|
||||
return m
|
||||
|
||||
|
||||
def mobile_vit_x_small(num_classes: int = 1000):
|
||||
# pretrain weight link
|
||||
# https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_xs.pt
|
||||
config = get_config("x_small")
|
||||
m = MobileViT(config, num_classes=num_classes)
|
||||
return m
|
||||
|
||||
|
||||
def mobile_vit_small(num_classes: int = 1000):
|
||||
# pretrain weight link
|
||||
# https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_s.pt
|
||||
config = get_config("small")
|
||||
m = MobileViT(config, num_classes=num_classes)
|
||||
return m
|
|
@ -0,0 +1,176 @@
|
|||
def get_config(mode: str = "xxs") -> dict:
|
||||
if mode == "xx_small":
|
||||
mv2_exp_mult = 2
|
||||
config = {
|
||||
"layer1": {
|
||||
"out_channels": 16,
|
||||
"expand_ratio": mv2_exp_mult,
|
||||
"num_blocks": 1,
|
||||
"stride": 1,
|
||||
"block_type": "mv2",
|
||||
},
|
||||
"layer2": {
|
||||
"out_channels": 24,
|
||||
"expand_ratio": mv2_exp_mult,
|
||||
"num_blocks": 3,
|
||||
"stride": 2,
|
||||
"block_type": "mv2",
|
||||
},
|
||||
"layer3": { # 28x28
|
||||
"out_channels": 48,
|
||||
"transformer_channels": 64,
|
||||
"ffn_dim": 128,
|
||||
"transformer_blocks": 2,
|
||||
"patch_h": 2, # 8,
|
||||
"patch_w": 2, # 8,
|
||||
"stride": 2,
|
||||
"mv_expand_ratio": mv2_exp_mult,
|
||||
"num_heads": 4,
|
||||
"block_type": "mobilevit",
|
||||
},
|
||||
"layer4": { # 14x14
|
||||
"out_channels": 64,
|
||||
"transformer_channels": 80,
|
||||
"ffn_dim": 160,
|
||||
"transformer_blocks": 4,
|
||||
"patch_h": 2, # 4,
|
||||
"patch_w": 2, # 4,
|
||||
"stride": 2,
|
||||
"mv_expand_ratio": mv2_exp_mult,
|
||||
"num_heads": 4,
|
||||
"block_type": "mobilevit",
|
||||
},
|
||||
"layer5": { # 7x7
|
||||
"out_channels": 80,
|
||||
"transformer_channels": 96,
|
||||
"ffn_dim": 192,
|
||||
"transformer_blocks": 3,
|
||||
"patch_h": 2,
|
||||
"patch_w": 2,
|
||||
"stride": 2,
|
||||
"mv_expand_ratio": mv2_exp_mult,
|
||||
"num_heads": 4,
|
||||
"block_type": "mobilevit",
|
||||
},
|
||||
"last_layer_exp_factor": 4,
|
||||
"cls_dropout": 0.1
|
||||
}
|
||||
elif mode == "x_small":
|
||||
mv2_exp_mult = 4
|
||||
config = {
|
||||
"layer1": {
|
||||
"out_channels": 32,
|
||||
"expand_ratio": mv2_exp_mult,
|
||||
"num_blocks": 1,
|
||||
"stride": 1,
|
||||
"block_type": "mv2",
|
||||
},
|
||||
"layer2": {
|
||||
"out_channels": 48,
|
||||
"expand_ratio": mv2_exp_mult,
|
||||
"num_blocks": 3,
|
||||
"stride": 2,
|
||||
"block_type": "mv2",
|
||||
},
|
||||
"layer3": { # 28x28
|
||||
"out_channels": 64,
|
||||
"transformer_channels": 96,
|
||||
"ffn_dim": 192,
|
||||
"transformer_blocks": 2,
|
||||
"patch_h": 2,
|
||||
"patch_w": 2,
|
||||
"stride": 2,
|
||||
"mv_expand_ratio": mv2_exp_mult,
|
||||
"num_heads": 4,
|
||||
"block_type": "mobilevit",
|
||||
},
|
||||
"layer4": { # 14x14
|
||||
"out_channels": 80,
|
||||
"transformer_channels": 120,
|
||||
"ffn_dim": 240,
|
||||
"transformer_blocks": 4,
|
||||
"patch_h": 2,
|
||||
"patch_w": 2,
|
||||
"stride": 2,
|
||||
"mv_expand_ratio": mv2_exp_mult,
|
||||
"num_heads": 4,
|
||||
"block_type": "mobilevit",
|
||||
},
|
||||
"layer5": { # 7x7
|
||||
"out_channels": 96,
|
||||
"transformer_channels": 144,
|
||||
"ffn_dim": 288,
|
||||
"transformer_blocks": 3,
|
||||
"patch_h": 2,
|
||||
"patch_w": 2,
|
||||
"stride": 2,
|
||||
"mv_expand_ratio": mv2_exp_mult,
|
||||
"num_heads": 4,
|
||||
"block_type": "mobilevit",
|
||||
},
|
||||
"last_layer_exp_factor": 4,
|
||||
"cls_dropout": 0.1
|
||||
}
|
||||
elif mode == "small":
|
||||
mv2_exp_mult = 4
|
||||
config = {
|
||||
"layer1": {
|
||||
"out_channels": 32,
|
||||
"expand_ratio": mv2_exp_mult,
|
||||
"num_blocks": 1,
|
||||
"stride": 1,
|
||||
"block_type": "mv2",
|
||||
},
|
||||
"layer2": {
|
||||
"out_channels": 64,
|
||||
"expand_ratio": mv2_exp_mult,
|
||||
"num_blocks": 3,
|
||||
"stride": 2,
|
||||
"block_type": "mv2",
|
||||
},
|
||||
"layer3": { # 28x28
|
||||
"out_channels": 96,
|
||||
"transformer_channels": 144,
|
||||
"ffn_dim": 288,
|
||||
"transformer_blocks": 2,
|
||||
"patch_h": 2,
|
||||
"patch_w": 2,
|
||||
"stride": 2,
|
||||
"mv_expand_ratio": mv2_exp_mult,
|
||||
"num_heads": 4,
|
||||
"block_type": "mobilevit",
|
||||
},
|
||||
"layer4": { # 14x14
|
||||
"out_channels": 128,
|
||||
"transformer_channels": 192,
|
||||
"ffn_dim": 384,
|
||||
"transformer_blocks": 4,
|
||||
"patch_h": 2,
|
||||
"patch_w": 2,
|
||||
"stride": 2,
|
||||
"mv_expand_ratio": mv2_exp_mult,
|
||||
"num_heads": 4,
|
||||
"block_type": "mobilevit",
|
||||
},
|
||||
"layer5": { # 7x7
|
||||
"out_channels": 160,
|
||||
"transformer_channels": 240,
|
||||
"ffn_dim": 480,
|
||||
"transformer_blocks": 3,
|
||||
"patch_h": 2,
|
||||
"patch_w": 2,
|
||||
"stride": 2,
|
||||
"mv_expand_ratio": mv2_exp_mult,
|
||||
"num_heads": 4,
|
||||
"block_type": "mobilevit",
|
||||
},
|
||||
"last_layer_exp_factor": 4,
|
||||
"cls_dropout": 0.1
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
for k in ["layer1", "layer2", "layer3", "layer4", "layer5"]:
|
||||
config[k].update({"dropout": 0.1, "ffn_dropout": 0.0, "attn_dropout": 0.0})
|
||||
|
||||
return config
|
|
@ -0,0 +1,37 @@
|
|||
from PIL import Image
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class MyDataSet(Dataset):
|
||||
"""自定义数据集"""
|
||||
|
||||
def __init__(self, images_path: list, images_class: list, transform=None):
|
||||
self.images_path = images_path
|
||||
self.images_class = images_class
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images_path)
|
||||
|
||||
def __getitem__(self, item):
|
||||
img = Image.open(self.images_path[item])
|
||||
# RGB为彩色图片,L为灰度图片
|
||||
if img.mode != 'RGB':
|
||||
raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
|
||||
label = self.images_class[item]
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
return img, label
|
||||
|
||||
@staticmethod
|
||||
def collate_fn(batch):
|
||||
# 官方实现的default_collate可以参考
|
||||
# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
|
||||
images, labels = tuple(zip(*batch))
|
||||
|
||||
images = torch.stack(images, dim=0)
|
||||
labels = torch.as_tensor(labels)
|
||||
return images, labels
|
|
@ -0,0 +1,64 @@
|
|||
import os
|
||||
import json
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from model import mobile_vit_small as create_model
|
||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
||||
|
||||
#设置plt支持中文
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei']
|
||||
|
||||
def main():
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
img_size = 224
|
||||
data_transform = transforms.Compose(
|
||||
[transforms.Resize(int(img_size * 1.14)),
|
||||
transforms.CenterCrop(img_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
|
||||
# load image
|
||||
img_path = r"E:\Download\data\train\Acne and Rosacea Photos\acne-closed-comedo-8.jpg"
|
||||
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
|
||||
img = Image.open(img_path)
|
||||
plt.imshow(img)
|
||||
# [N, C, H, W]
|
||||
img = data_transform(img)
|
||||
# expand batch dimension
|
||||
img = torch.unsqueeze(img, dim=0)
|
||||
|
||||
# read class_indict
|
||||
json_path = './class_indices.json'
|
||||
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
|
||||
|
||||
with open(json_path, "r",encoding="utf-8") as f:
|
||||
class_indict = json.load(f)
|
||||
|
||||
# create model
|
||||
model = create_model(num_classes=24).to(device)
|
||||
# load model weights
|
||||
model_weight_path = "./best300_model_0.7302241690286009.pth"
|
||||
model.load_state_dict(torch.load(model_weight_path, map_location=device))
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
# predict class
|
||||
output = torch.squeeze(model(img.to(device))).cpu()
|
||||
predict = torch.softmax(output, dim=0)
|
||||
predict_cla = torch.argmax(predict).numpy()
|
||||
|
||||
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
|
||||
predict[predict_cla].numpy())
|
||||
plt.title(print_res)
|
||||
for i in range(len(predict)):
|
||||
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
|
||||
predict[i].numpy()))
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,84 @@
|
|||
import os
|
||||
import json
|
||||
import uuid
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from .model import mobile_vit_small as create_model
|
||||
|
||||
class SkinTypePredictor:
|
||||
def __init__(self, model_path, class_indices_path, img_size=224):
|
||||
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
self.img_size = img_size
|
||||
self.data_transform = transforms.Compose([
|
||||
transforms.Resize(int(self.img_size * 1.14)),
|
||||
transforms.CenterCrop(self.img_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
# Load class indices
|
||||
with open(class_indices_path, "r",encoding="utf-8") as f:
|
||||
self.class_indict = json.load(f)
|
||||
# Load model
|
||||
self.model = self.load_model(model_path)
|
||||
|
||||
def load_model(self, model_path):
|
||||
|
||||
model = create_model(num_classes=3).to(self.device)
|
||||
model.load_state_dict(torch.load(model_path, map_location=self.device))
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def predict_img(self, image_path):
|
||||
# Load and transform image
|
||||
assert os.path.exists(image_path), f"file: '{image_path}' does not exist."
|
||||
img = Image.open(image_path).convert('RGB')
|
||||
img = self.data_transform(img)
|
||||
img = torch.unsqueeze(img, dim=0)
|
||||
|
||||
# Predict class
|
||||
with torch.no_grad():
|
||||
output = torch.squeeze(self.model(img.to(self.device))).cpu()
|
||||
probabilities = torch.softmax(output, dim=0)
|
||||
top_prob, top_catid = torch.topk(probabilities, 5)
|
||||
|
||||
# Top 5 results
|
||||
top5 = []
|
||||
for i in range(top_prob.size(0)):
|
||||
top5.append({
|
||||
"name": self.class_indict[str(top_catid[i].item())],
|
||||
"score": top_prob[i].item(),
|
||||
"label": top_catid[i].item()
|
||||
})
|
||||
|
||||
# Results dictionary
|
||||
|
||||
results = {"result": top5, "log_id": str(uuid.uuid1())}
|
||||
|
||||
return results
|
||||
def predict(self, np_image):
|
||||
# Convert numpy image to PIL image
|
||||
img = Image.fromarray(np_image).convert('RGB')
|
||||
|
||||
# Transform image
|
||||
img = self.data_transform(img)
|
||||
img = torch.unsqueeze(img, dim=0)
|
||||
|
||||
# Predict class
|
||||
with torch.no_grad():
|
||||
output = torch.squeeze(self.model(img.to(self.device))).cpu()
|
||||
probabilities = torch.softmax(output, dim=0)
|
||||
top_prob, top_catid = torch.topk(probabilities, 1)
|
||||
|
||||
top1 = {
|
||||
"name": self.class_indict[str(top_catid[0].item())],
|
||||
"score": top_prob[0].item(),
|
||||
"label": top_catid[0].item()
|
||||
}
|
||||
|
||||
|
||||
return top1["name"]
|
||||
# Example usage:
|
||||
# predictor = SkinTypePredictor(model_path="./weights/best_model.pth", class_indices_path="./class_indices.json")
|
||||
# result = predictor.predict("../tulip.jpg")
|
||||
# print(result)
|
|
@ -0,0 +1,135 @@
|
|||
import os
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torchvision import transforms
|
||||
|
||||
from my_dataset import MyDataSet
|
||||
from model import mobile_vit_xx_small as create_model
|
||||
from utils import read_split_data, train_one_epoch, evaluate
|
||||
|
||||
|
||||
def main(args):
|
||||
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
||||
|
||||
if os.path.exists("./weights") is False:
|
||||
os.makedirs("./weights")
|
||||
|
||||
tb_writer = SummaryWriter()
|
||||
|
||||
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)
|
||||
|
||||
img_size = 224
|
||||
data_transform = {
|
||||
"train": transforms.Compose([transforms.RandomResizedCrop(img_size),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
|
||||
"val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),
|
||||
transforms.CenterCrop(img_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
|
||||
|
||||
# 实例化训练数据集
|
||||
train_dataset = MyDataSet(images_path=train_images_path,
|
||||
images_class=train_images_label,
|
||||
transform=data_transform["train"])
|
||||
|
||||
# 实例化验证数据集
|
||||
val_dataset = MyDataSet(images_path=val_images_path,
|
||||
images_class=val_images_label,
|
||||
transform=data_transform["val"])
|
||||
|
||||
batch_size = args.batch_size
|
||||
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
|
||||
print('Using {} dataloader workers every process'.format(nw))
|
||||
train_loader = torch.utils.data.DataLoader(train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=True,
|
||||
num_workers=nw,
|
||||
collate_fn=train_dataset.collate_fn)
|
||||
|
||||
val_loader = torch.utils.data.DataLoader(val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
pin_memory=True,
|
||||
num_workers=nw,
|
||||
collate_fn=val_dataset.collate_fn)
|
||||
|
||||
model = create_model(num_classes=args.num_classes).to(device)
|
||||
|
||||
if args.weights != "":
|
||||
assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
|
||||
weights_dict = torch.load(args.weights, map_location=device)
|
||||
weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
|
||||
# 删除有关分类类别的权重
|
||||
for k in list(weights_dict.keys()):
|
||||
if "classifier" in k:
|
||||
del weights_dict[k]
|
||||
print(model.load_state_dict(weights_dict, strict=False))
|
||||
|
||||
if args.freeze_layers:
|
||||
for name, para in model.named_parameters():
|
||||
# 除head外,其他权重全部冻结
|
||||
if "classifier" not in name:
|
||||
para.requires_grad_(False)
|
||||
else:
|
||||
print("training {}".format(name))
|
||||
|
||||
pg = [p for p in model.parameters() if p.requires_grad]
|
||||
optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=1E-2)
|
||||
|
||||
best_acc = 0.
|
||||
for epoch in range(args.epochs):
|
||||
# train
|
||||
train_loss, train_acc = train_one_epoch(model=model,
|
||||
optimizer=optimizer,
|
||||
data_loader=train_loader,
|
||||
device=device,
|
||||
epoch=epoch)
|
||||
|
||||
# validate
|
||||
val_loss, val_acc = evaluate(model=model,
|
||||
data_loader=val_loader,
|
||||
device=device,
|
||||
epoch=epoch)
|
||||
|
||||
tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
|
||||
tb_writer.add_scalar(tags[0], train_loss, epoch)
|
||||
tb_writer.add_scalar(tags[1], train_acc, epoch)
|
||||
tb_writer.add_scalar(tags[2], val_loss, epoch)
|
||||
tb_writer.add_scalar(tags[3], val_acc, epoch)
|
||||
tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)
|
||||
|
||||
if val_acc > best_acc:
|
||||
best_acc = val_acc
|
||||
torch.save(model.state_dict(), "./weights/best_model.pth")
|
||||
|
||||
torch.save(model.state_dict(), "./weights/latest_model.pth")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--num_classes', type=int, default=5)
|
||||
parser.add_argument('--epochs', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=8)
|
||||
parser.add_argument('--lr', type=float, default=0.0002)
|
||||
|
||||
# 数据集所在根目录
|
||||
# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
|
||||
parser.add_argument('--data-path', type=str,
|
||||
default="/data/flower_photos")
|
||||
|
||||
# 预训练权重路径,如果不想载入就设置为空字符
|
||||
parser.add_argument('--weights', type=str, default='./mobilevit_xxs.pt',
|
||||
help='initial weights path')
|
||||
# 是否冻结权重
|
||||
parser.add_argument('--freeze-layers', type=bool, default=False)
|
||||
parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')
|
||||
|
||||
opt = parser.parse_args()
|
||||
|
||||
main(opt)
|
|
@ -0,0 +1,155 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""
|
||||
This layer applies a multi-head self- or cross-attention as described in
|
||||
`Attention is all you need <https://arxiv.org/abs/1706.03762>`_ paper
|
||||
|
||||
Args:
|
||||
embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`
|
||||
num_heads (int): Number of heads in multi-head attention
|
||||
attn_dropout (float): Attention dropout. Default: 0.0
|
||||
bias (bool): Use bias or not. Default: ``True``
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
|
||||
and :math:`C_{in}` is input embedding dim
|
||||
- Output: same shape as the input
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
attn_dropout: float = 0.0,
|
||||
bias: bool = True,
|
||||
*args,
|
||||
**kwargs
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if embed_dim % num_heads != 0:
|
||||
raise ValueError(
|
||||
"Embedding dim must be divisible by number of heads in {}. Got: embed_dim={} and num_heads={}".format(
|
||||
self.__class__.__name__, embed_dim, num_heads
|
||||
)
|
||||
)
|
||||
|
||||
self.qkv_proj = nn.Linear(in_features=embed_dim, out_features=3 * embed_dim, bias=bias)
|
||||
|
||||
self.attn_dropout = nn.Dropout(p=attn_dropout)
|
||||
self.out_proj = nn.Linear(in_features=embed_dim, out_features=embed_dim, bias=bias)
|
||||
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
self.num_heads = num_heads
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def forward(self, x_q: Tensor) -> Tensor:
|
||||
# [N, P, C]
|
||||
b_sz, n_patches, in_channels = x_q.shape
|
||||
|
||||
# self-attention
|
||||
# [N, P, C] -> [N, P, 3C] -> [N, P, 3, h, c] where C = hc
|
||||
qkv = self.qkv_proj(x_q).reshape(b_sz, n_patches, 3, self.num_heads, -1)
|
||||
|
||||
# [N, P, 3, h, c] -> [N, h, 3, P, C]
|
||||
qkv = qkv.transpose(1, 3).contiguous()
|
||||
|
||||
# [N, h, 3, P, C] -> [N, h, P, C] x 3
|
||||
query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
|
||||
|
||||
query = query * self.scaling
|
||||
|
||||
# [N h, P, c] -> [N, h, c, P]
|
||||
key = key.transpose(-1, -2)
|
||||
|
||||
# QK^T
|
||||
# [N, h, P, c] x [N, h, c, P] -> [N, h, P, P]
|
||||
attn = torch.matmul(query, key)
|
||||
attn = self.softmax(attn)
|
||||
attn = self.attn_dropout(attn)
|
||||
|
||||
# weighted sum
|
||||
# [N, h, P, P] x [N, h, P, c] -> [N, h, P, c]
|
||||
out = torch.matmul(attn, value)
|
||||
|
||||
# [N, h, P, c] -> [N, P, h, c] -> [N, P, C]
|
||||
out = out.transpose(1, 2).reshape(b_sz, n_patches, -1)
|
||||
out = self.out_proj(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
"""
|
||||
This class defines the pre-norm `Transformer encoder <https://arxiv.org/abs/1706.03762>`_
|
||||
Args:
|
||||
embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`
|
||||
ffn_latent_dim (int): Inner dimension of the FFN
|
||||
num_heads (int) : Number of heads in multi-head attention. Default: 8
|
||||
attn_dropout (float): Dropout rate for attention in multi-head attention. Default: 0.0
|
||||
dropout (float): Dropout rate. Default: 0.0
|
||||
ffn_dropout (float): Dropout between FFN layers. Default: 0.0
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
|
||||
and :math:`C_{in}` is input embedding dim
|
||||
- Output: same shape as the input
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
ffn_latent_dim: int,
|
||||
num_heads: Optional[int] = 8,
|
||||
attn_dropout: Optional[float] = 0.0,
|
||||
dropout: Optional[float] = 0.0,
|
||||
ffn_dropout: Optional[float] = 0.0,
|
||||
*args,
|
||||
**kwargs
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
|
||||
attn_unit = MultiHeadAttention(
|
||||
embed_dim,
|
||||
num_heads,
|
||||
attn_dropout=attn_dropout,
|
||||
bias=True
|
||||
)
|
||||
|
||||
self.pre_norm_mha = nn.Sequential(
|
||||
nn.LayerNorm(embed_dim),
|
||||
attn_unit,
|
||||
nn.Dropout(p=dropout)
|
||||
)
|
||||
|
||||
self.pre_norm_ffn = nn.Sequential(
|
||||
nn.LayerNorm(embed_dim),
|
||||
nn.Linear(in_features=embed_dim, out_features=ffn_latent_dim, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=ffn_dropout),
|
||||
nn.Linear(in_features=ffn_latent_dim, out_features=embed_dim, bias=True),
|
||||
nn.Dropout(p=dropout)
|
||||
)
|
||||
self.embed_dim = embed_dim
|
||||
self.ffn_dim = ffn_latent_dim
|
||||
self.ffn_dropout = ffn_dropout
|
||||
self.std_dropout = dropout
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# multi-head attention
|
||||
res = x
|
||||
x = self.pre_norm_mha(x)
|
||||
x = x + res
|
||||
|
||||
# feed forward network
|
||||
x = x + self.pre_norm_ffn(x)
|
||||
return x
|
|
@ -0,0 +1,56 @@
|
|||
import time
|
||||
import torch
|
||||
|
||||
batch_size = 8
|
||||
in_channels = 32
|
||||
patch_h = 2
|
||||
patch_w = 2
|
||||
num_patch_h = 16
|
||||
num_patch_w = 16
|
||||
num_patches = num_patch_h * num_patch_w
|
||||
patch_area = patch_h * patch_w
|
||||
|
||||
|
||||
def official(x: torch.Tensor):
|
||||
# [B, C, H, W] -> [B * C * n_h, p_h, n_w, p_w]
|
||||
x = x.reshape(batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w)
|
||||
# [B * C * n_h, p_h, n_w, p_w] -> [B * C * n_h, n_w, p_h, p_w]
|
||||
x = x.transpose(1, 2)
|
||||
# [B * C * n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
|
||||
x = x.reshape(batch_size, in_channels, num_patches, patch_area)
|
||||
# [B, C, N, P] -> [B, P, N, C]
|
||||
x = x.transpose(1, 3)
|
||||
# [B, P, N, C] -> [BP, N, C]
|
||||
x = x.reshape(batch_size * patch_area, num_patches, -1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def my_self(x: torch.Tensor):
|
||||
# [B, C, H, W] -> [B, C, n_h, p_h, n_w, p_w]
|
||||
x = x.reshape(batch_size, in_channels, num_patch_h, patch_h, num_patch_w, patch_w)
|
||||
# [B, C, n_h, p_h, n_w, p_w] -> [B, C, n_h, n_w, p_h, p_w]
|
||||
x = x.transpose(3, 4)
|
||||
# [B, C, n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
|
||||
x = x.reshape(batch_size, in_channels, num_patches, patch_area)
|
||||
# [B, C, N, P] -> [B, P, N, C]
|
||||
x = x.transpose(1, 3)
|
||||
# [B, P, N, C] -> [BP, N, C]
|
||||
x = x.reshape(batch_size * patch_area, num_patches, -1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
t = torch.randn(batch_size, in_channels, num_patch_h * patch_h, num_patch_w * patch_w)
|
||||
print(torch.equal(official(t), my_self(t)))
|
||||
|
||||
t1 = time.time()
|
||||
for _ in range(1000):
|
||||
official(t)
|
||||
print(f"official time: {time.time() - t1}")
|
||||
|
||||
t1 = time.time()
|
||||
for _ in range(1000):
|
||||
my_self(t)
|
||||
print(f"self time: {time.time() - t1}")
|
|
@ -0,0 +1,179 @@
|
|||
import os
|
||||
import sys
|
||||
import json
|
||||
import pickle
|
||||
import random
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def read_split_data(root: str, val_rate: float = 0.2):
|
||||
random.seed(0) # 保证随机结果可复现
|
||||
assert os.path.exists(root), "dataset root: {} does not exist.".format(root)
|
||||
|
||||
# 遍历文件夹,一个文件夹对应一个类别
|
||||
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
|
||||
# 排序,保证各平台顺序一致
|
||||
flower_class.sort()
|
||||
# 生成类别名称以及对应的数字索引
|
||||
class_indices = dict((k, v) for v, k in enumerate(flower_class))
|
||||
json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
|
||||
with open('class_indices.json', 'w') as json_file:
|
||||
json_file.write(json_str)
|
||||
|
||||
train_images_path = [] # 存储训练集的所有图片路径
|
||||
train_images_label = [] # 存储训练集图片对应索引信息
|
||||
val_images_path = [] # 存储验证集的所有图片路径
|
||||
val_images_label = [] # 存储验证集图片对应索引信息
|
||||
every_class_num = [] # 存储每个类别的样本总数
|
||||
supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型
|
||||
# 遍历每个文件夹下的文件
|
||||
for cla in flower_class:
|
||||
cla_path = os.path.join(root, cla)
|
||||
# 遍历获取supported支持的所有文件路径
|
||||
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
|
||||
if os.path.splitext(i)[-1] in supported]
|
||||
# 排序,保证各平台顺序一致
|
||||
images.sort()
|
||||
# 获取该类别对应的索引
|
||||
image_class = class_indices[cla]
|
||||
# 记录该类别的样本数量
|
||||
every_class_num.append(len(images))
|
||||
# 按比例随机采样验证样本
|
||||
val_path = random.sample(images, k=int(len(images) * val_rate))
|
||||
|
||||
for img_path in images:
|
||||
if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集
|
||||
val_images_path.append(img_path)
|
||||
val_images_label.append(image_class)
|
||||
else: # 否则存入训练集
|
||||
train_images_path.append(img_path)
|
||||
train_images_label.append(image_class)
|
||||
|
||||
print("{} images were found in the dataset.".format(sum(every_class_num)))
|
||||
print("{} images for training.".format(len(train_images_path)))
|
||||
print("{} images for validation.".format(len(val_images_path)))
|
||||
assert len(train_images_path) > 0, "number of training images must greater than 0."
|
||||
assert len(val_images_path) > 0, "number of validation images must greater than 0."
|
||||
|
||||
plot_image = False
|
||||
if plot_image:
|
||||
# 绘制每种类别个数柱状图
|
||||
plt.bar(range(len(flower_class)), every_class_num, align='center')
|
||||
# 将横坐标0,1,2,3,4替换为相应的类别名称
|
||||
plt.xticks(range(len(flower_class)), flower_class)
|
||||
# 在柱状图上添加数值标签
|
||||
for i, v in enumerate(every_class_num):
|
||||
plt.text(x=i, y=v + 5, s=str(v), ha='center')
|
||||
# 设置x坐标
|
||||
plt.xlabel('image class')
|
||||
# 设置y坐标
|
||||
plt.ylabel('number of images')
|
||||
# 设置柱状图的标题
|
||||
plt.title('flower class distribution')
|
||||
plt.show()
|
||||
|
||||
return train_images_path, train_images_label, val_images_path, val_images_label
|
||||
|
||||
|
||||
def plot_data_loader_image(data_loader):
|
||||
batch_size = data_loader.batch_size
|
||||
plot_num = min(batch_size, 4)
|
||||
|
||||
json_path = './class_indices.json'
|
||||
assert os.path.exists(json_path), json_path + " does not exist."
|
||||
json_file = open(json_path, 'r')
|
||||
class_indices = json.load(json_file)
|
||||
|
||||
for data in data_loader:
|
||||
images, labels = data
|
||||
for i in range(plot_num):
|
||||
# [C, H, W] -> [H, W, C]
|
||||
img = images[i].numpy().transpose(1, 2, 0)
|
||||
# 反Normalize操作
|
||||
img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
|
||||
label = labels[i].item()
|
||||
plt.subplot(1, plot_num, i+1)
|
||||
plt.xlabel(class_indices[str(label)])
|
||||
plt.xticks([]) # 去掉x轴的刻度
|
||||
plt.yticks([]) # 去掉y轴的刻度
|
||||
plt.imshow(img.astype('uint8'))
|
||||
plt.show()
|
||||
|
||||
|
||||
def write_pickle(list_info: list, file_name: str):
|
||||
with open(file_name, 'wb') as f:
|
||||
pickle.dump(list_info, f)
|
||||
|
||||
|
||||
def read_pickle(file_name: str) -> list:
|
||||
with open(file_name, 'rb') as f:
|
||||
info_list = pickle.load(f)
|
||||
return info_list
|
||||
|
||||
|
||||
def train_one_epoch(model, optimizer, data_loader, device, epoch):
|
||||
model.train()
|
||||
loss_function = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
|
||||
accu_loss = torch.zeros(1).to(device) # 累计损失
|
||||
accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
|
||||
optimizer.zero_grad()
|
||||
|
||||
sample_num = 0
|
||||
data_loader = tqdm(data_loader, file=sys.stdout)
|
||||
for step, data in enumerate(data_loader):
|
||||
images, labels = data
|
||||
sample_num += images.shape[0]
|
||||
|
||||
pred = model(images.to(device))
|
||||
pred_classes = torch.max(pred, dim=1)[1]
|
||||
accu_num += torch.eq(pred_classes, labels.to(device)).sum()
|
||||
|
||||
loss = loss_function(pred, labels.to(device))
|
||||
loss.backward()
|
||||
accu_loss += loss.detach()
|
||||
|
||||
data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
|
||||
accu_loss.item() / (step + 1),
|
||||
accu_num.item() / sample_num)
|
||||
|
||||
if not torch.isfinite(loss):
|
||||
print('WARNING: non-finite loss, ending training ', loss)
|
||||
sys.exit(1)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
return accu_loss.item() / (step + 1), accu_num.item() / sample_num
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(model, data_loader, device, epoch):
|
||||
loss_function = torch.nn.CrossEntropyLoss()
|
||||
|
||||
model.eval()
|
||||
|
||||
accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
|
||||
accu_loss = torch.zeros(1).to(device) # 累计损失
|
||||
|
||||
sample_num = 0
|
||||
data_loader = tqdm(data_loader, file=sys.stdout)
|
||||
for step, data in enumerate(data_loader):
|
||||
images, labels = data
|
||||
sample_num += images.shape[0]
|
||||
|
||||
pred = model(images.to(device))
|
||||
pred_classes = torch.max(pred, dim=1)[1]
|
||||
accu_num += torch.eq(pred_classes, labels.to(device)).sum()
|
||||
|
||||
loss = loss_function(pred, labels.to(device))
|
||||
accu_loss += loss
|
||||
|
||||
data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
|
||||
accu_loss.item() / (step + 1),
|
||||
accu_num.item() / sample_num)
|
||||
|
||||
return accu_loss.item() / (step + 1), accu_num.item() / sample_num
|
|
@ -0,0 +1,51 @@
|
|||
import cv2
|
||||
import dlib
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from predict_api import ImagePredictor
|
||||
|
||||
# Initialize camera and face detector
|
||||
cap = cv2.VideoCapture(0)
|
||||
detector = dlib.get_frontal_face_detector()
|
||||
|
||||
# Initialize ImagePredictor
|
||||
predictor = ImagePredictor(model_path="best_model_'0.8998410174880763'.pth", class_indices_path="./class_indices.json")
|
||||
|
||||
while True:
|
||||
# Capture frame-by-frame
|
||||
ret, frame = cap.read()
|
||||
|
||||
# Convert the image from BGR color (which OpenCV uses) to RGB color
|
||||
rgb_image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Perform face detection
|
||||
faces = detector(rgb_image)
|
||||
|
||||
# Loop through each face in this frame
|
||||
for rect in faces:
|
||||
# Get the bounding box coordinates
|
||||
x1, y1, x2, y2 = rect.left(), rect.top(), rect.right(), rect.bottom()
|
||||
|
||||
# Crop the face from the frame
|
||||
face_image = rgb_image[y1:y2, x1:x2]
|
||||
|
||||
# Use ImagePredictor to predict the class of this face
|
||||
result = predictor.predict(face_image)
|
||||
|
||||
# Draw a rectangle around the face
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
|
||||
# Display the class name and score
|
||||
cv2.putText(frame, f"{result['result'][0]['name']}: {round(result['result'][0]['score'],4)}", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36,255,12), 2)
|
||||
|
||||
# Display the resulting frame
|
||||
cv2.imshow('Video', frame)
|
||||
|
||||
# Exit loop if 'q' is pressed
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
|
||||
# When everything is done, release the capture
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
|
@ -0,0 +1,112 @@
|
|||
import queue
|
||||
|
||||
import numpy as np
|
||||
from PyQt5.QtCore import QObject, pyqtSignal
|
||||
from multiprocessing import Process, Queue
|
||||
from apis.age.AgeGenderPredictor import AgeGenderPredictor
|
||||
from apis.hr.HeartRateMonitor import HeartRateMonitor
|
||||
from apis.st.predict_api import SkinTypePredictor
|
||||
from apis.sd.predict_api import SkinDiseasePredictor
|
||||
from apis.emotion.predict_api import EmotionPredictor
|
||||
from apis.bp.BPApi import BPModel
|
||||
from apis.rr.RespirationRateDetector import RespirationRateDetector
|
||||
from apis.rr.params import args
|
||||
|
||||
|
||||
def run_ai(input_queue, output_queue, fps_queue):
|
||||
age_predictor = AgeGenderPredictor('weights/age.pth')
|
||||
st_predictor = SkinTypePredictor('weights/st.pth', class_indices_path="labels/st.json")
|
||||
sd_predictor = SkinDiseasePredictor('weights/sd.pth', class_indices_path="labels/sd.json")
|
||||
emotion_predictor = EmotionPredictor('weights/emotion.pth', class_indices_path="labels/emotion.json")
|
||||
bp_predictor = BPModel(model_path=r'weights/bp.pth', fps=30)
|
||||
rr_detector = RespirationRateDetector(args)
|
||||
hr_detector = HeartRateMonitor(30, 0.8, 1.8)
|
||||
|
||||
while True:
|
||||
|
||||
frames = input_queue.get()
|
||||
fps = fps_queue.get()
|
||||
|
||||
if frames is None:
|
||||
break
|
||||
|
||||
if fps is None:
|
||||
fps = 30
|
||||
else:
|
||||
# print(fps)
|
||||
bp_predictor.fps = fps
|
||||
hr_detector.fps = fps
|
||||
|
||||
results = {}
|
||||
|
||||
sbp_outputs, dbp_outputs = bp_predictor.predict(frames[-300:])
|
||||
|
||||
gender, age, age_group = age_predictor.predict(frames[-1])
|
||||
st = st_predictor.predict(frames[-1])
|
||||
sd = sd_predictor.predict(frames[-1])
|
||||
emotion = emotion_predictor.predict(frames[-1])
|
||||
|
||||
Resp, RR_FFT, RR_PC, RR_CP, RR_NFCP = rr_detector.detect_respiration_rate(frames[-300:], fps)
|
||||
heartrate, sdnn, rmssd, cv_rr, bvp = hr_detector.process_roi(frames[-300:])
|
||||
|
||||
RR = int(round((RR_FFT + RR_PC + RR_CP + RR_NFCP) / 4))
|
||||
|
||||
results['gender'] = gender
|
||||
results['age'] = int(age)
|
||||
results['age_group'] = age_group
|
||||
results['skin_type'] = st
|
||||
results['skin_disease'] = sd
|
||||
results['emotion'] = emotion
|
||||
results['sbp'] = int(round(sbp_outputs, 0))
|
||||
results['dbp'] = int(round(dbp_outputs, 0))
|
||||
results['rr'] = RR
|
||||
results['hr'] = int(heartrate)
|
||||
results['sdnn'] = round(sdnn, 1)
|
||||
|
||||
# 判断rmssd是否为nan
|
||||
results['rmssd'] = round(rmssd, 1) if not np.isnan(rmssd) else 0.0
|
||||
results['cvrr'] = round(cv_rr, 1)
|
||||
results['resp'] = Resp
|
||||
results['bvp'] = bvp
|
||||
|
||||
# 添加其他生理指标的计算...
|
||||
|
||||
def clear_queue(q):
|
||||
while not q.empty():
|
||||
try:
|
||||
q.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
clear_queue(input_queue)
|
||||
output_queue.put(results)
|
||||
|
||||
|
||||
class AIProcess(QObject):
|
||||
results_ready = pyqtSignal(dict)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.input_queue = Queue()
|
||||
self.output_queue = Queue()
|
||||
self.fps_queue = Queue()
|
||||
|
||||
self.process = Process(target=run_ai, args=(self.input_queue, self.output_queue, self.fps_queue))
|
||||
|
||||
def start(self):
|
||||
self.process.start()
|
||||
|
||||
def process_frames(self, frames):
|
||||
self.input_queue.put(frames)
|
||||
|
||||
def update_fps(self, fps):
|
||||
self.fps_queue.put(fps)
|
||||
|
||||
def stop(self):
|
||||
self.input_queue.put(None)
|
||||
self.process.join()
|
||||
|
||||
def get_results(self):
|
||||
if not self.output_queue.empty():
|
||||
return self.output_queue.get()
|
||||
return None
|
|
@ -0,0 +1,47 @@
|
|||
from PyQt5.QtCore import QThread, pyqtSignal
|
||||
import core.face_detection as ftm
|
||||
from numpy import ndarray
|
||||
import cv2
|
||||
|
||||
|
||||
class CameraThread(QThread):
|
||||
new_frame = pyqtSignal(ndarray)
|
||||
storage_frame = pyqtSignal(ndarray)
|
||||
fps_signal = pyqtSignal(float)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.cap = cv2.VideoCapture(0)
|
||||
self.fps = self.cap.get(cv2.CAP_PROP_FPS)
|
||||
|
||||
self.face_detector = ftm.FaceDetector()
|
||||
self.running = True
|
||||
|
||||
def run(self):
|
||||
while self.running:
|
||||
|
||||
ret, frame = self.cap.read()
|
||||
if ret:
|
||||
rgb_image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
img, bboxs = self.face_detector.find_faces(rgb_image)
|
||||
|
||||
face_img = None
|
||||
max_roi = 0
|
||||
for bbox in bboxs:
|
||||
try:
|
||||
x, y, w, h = bbox[1]
|
||||
if w * h > max_roi:
|
||||
max_roi = w * h
|
||||
face_img = img[y:y + h, x:x + w]
|
||||
except:
|
||||
pass
|
||||
|
||||
self.new_frame.emit(img)
|
||||
if face_img is not None:
|
||||
self.fps_signal.emit(self.fps)
|
||||
self.storage_frame.emit(cv2.resize(face_img, (224, 224)))
|
||||
|
||||
def stop(self):
|
||||
self.running = False
|
||||
self.wait()
|
||||
self.cap.release()
|
|
@ -0,0 +1,58 @@
|
|||
import cv2
|
||||
|
||||
class FaceDetector:
|
||||
def __init__(self, cascade_path='haarcascade_frontalface_default.xml'):
|
||||
self.face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + cascade_path)
|
||||
|
||||
def find_faces(self, img, draw=True):
|
||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
faces = self.face_cascade.detectMultiScale(gray, 1.3, 5)
|
||||
|
||||
bboxs = []
|
||||
for i, (x, y, w, h) in enumerate(faces):
|
||||
bbox = (x, y, w, h)
|
||||
bboxs.append([i, bbox, None]) # 'None' 用于保持与原代码结构一致,因为OpenCV不提供置信度分数
|
||||
|
||||
if draw:
|
||||
img = self.fancy_draw(img, bbox)
|
||||
|
||||
return img, bboxs
|
||||
|
||||
def fancy_draw(self, img, bbox, l=30, t=5, rt=1):
|
||||
x, y, w, h = bbox
|
||||
x1, y1 = x + w, y + h
|
||||
|
||||
# 绘制矩形
|
||||
cv2.rectangle(img, bbox, (0, 255, 255), rt)
|
||||
|
||||
# 绘制角落
|
||||
# 左上角
|
||||
cv2.line(img, (x, y), (x+l, y), (255, 0, 255), t)
|
||||
cv2.line(img, (x, y), (x, y+l), (255, 0, 255), t)
|
||||
# 右上角
|
||||
cv2.line(img, (x1, y), (x1-l, y), (255, 0, 255), t)
|
||||
cv2.line(img, (x1, y), (x1, y+l), (255, 0, 255), t)
|
||||
# 左下角
|
||||
cv2.line(img, (x, y1), (x+l, y1), (255, 0, 255), t)
|
||||
cv2.line(img, (x, y1), (x, y1-l), (255, 0, 255), t)
|
||||
# 右下角
|
||||
cv2.line(img, (x1, y1), (x1-l, y1), (255, 0, 255), t)
|
||||
cv2.line(img, (x1, y1), (x1, y1-l), (255, 0, 255), t)
|
||||
|
||||
return img
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
detector = FaceDetector()
|
||||
cap = cv2.VideoCapture(0) # 使用默认摄像头
|
||||
|
||||
while True:
|
||||
success, img = cap.read()
|
||||
img, bboxs = detector.find_faces(img)
|
||||
|
||||
cv2.imshow("Image", img)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
|
@ -0,0 +1,28 @@
|
|||
from PyQt5.QtCore import QThread, pyqtSignal
|
||||
from collections import deque
|
||||
import numpy as np
|
||||
|
||||
|
||||
class StorageThread(QThread):
|
||||
frames_ready = pyqtSignal(np.ndarray)
|
||||
|
||||
def __init__(self, max_frames=300):
|
||||
super().__init__()
|
||||
self.frame_buffer = deque(maxlen=max_frames)
|
||||
self.running = True
|
||||
|
||||
def run(self):
|
||||
while self.running:
|
||||
if len(self.frame_buffer) >= self.frame_buffer.maxlen:
|
||||
frames = np.array(list(self.frame_buffer))
|
||||
self.frames_ready.emit(frames)
|
||||
self.frame_buffer.clear()
|
||||
# print('Frames ready:', frames.shape)
|
||||
self.msleep(100) # 休眠100ms
|
||||
|
||||
def store_frame(self, frame):
|
||||
self.frame_buffer.append(frame)
|
||||
|
||||
def stop(self):
|
||||
self.running = False
|
||||
self.wait()
|
|
@ -0,0 +1,25 @@
|
|||
from PyQt5.QtCore import QThread, pyqtSignal
|
||||
|
||||
class UpdateThread(QThread):
|
||||
update_ui_signal = pyqtSignal(dict)
|
||||
|
||||
def __init__(self, ai_process):
|
||||
super().__init__()
|
||||
self.ai_process = ai_process
|
||||
self.running = True
|
||||
|
||||
def run(self):
|
||||
while self.running:
|
||||
results = self.ai_process.get_results()
|
||||
if results:
|
||||
self.update_ui_signal.emit(results)
|
||||
self.msleep(100) # 休眠100ms
|
||||
|
||||
def update_ui(self, results):
|
||||
print(results)
|
||||
self.update_ui_signal.emit(results)
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
self.running = False
|
||||
self.wait()
|
After Width: | Height: | Size: 4.6 KiB |
After Width: | Height: | Size: 5.7 KiB |
After Width: | Height: | Size: 4.8 KiB |
After Width: | Height: | Size: 4.7 KiB |
After Width: | Height: | Size: 4.7 KiB |
After Width: | Height: | Size: 4.8 KiB |
After Width: | Height: | Size: 4.9 KiB |
After Width: | Height: | Size: 4.8 KiB |
After Width: | Height: | Size: 5.6 KiB |
|
@ -0,0 +1 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?><svg width="48" height="48" viewBox="0 0 48 48" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M5.00372 42.2311C5.00372 42.6557 5.35807 42.9999 5.79521 42.9999L42.2023 43C42.6394 43 42.9938 42.6558 42.9938 42.2313V41.3131C43.012 41.0364 43.049 39.6555 42.1388 38.1289C41.5648 37.1663 40.7318 36.3347 39.6628 35.6573C38.3696 34.8378 36.7245 34.244 34.7347 33.8865C34.72 33.8846 33.2446 33.689 31.7331 33.303C29.101 32.6307 28.8709 32.0357 28.8694 32.0299C28.8539 31.9711 28.8315 31.9146 28.8028 31.8615C28.7813 31.7505 28.7281 31.3328 28.8298 30.2136C29.088 27.371 30.6128 25.691 31.838 24.3412C32.2244 23.9155 32.5893 23.5134 32.8704 23.1191C34.0827 21.4181 34.1952 19.4839 34.2003 19.364C34.2003 19.1211 34.1724 18.9214 34.1127 18.7363C33.9937 18.3659 33.7698 18.1351 33.6063 17.9666L33.6052 17.9654C33.564 17.923 33.5251 17.8828 33.4933 17.8459C33.4812 17.8318 33.449 17.7945 33.4783 17.603C33.5859 16.8981 33.6505 16.3079 33.6815 15.7456C33.7367 14.7438 33.7798 13.2456 33.5214 11.7875C33.4895 11.5385 33.4347 11.2755 33.3494 10.9622C33.0764 9.95814 32.6378 9.09971 32.0284 8.39124C31.9236 8.27722 29.3756 5.5928 21.9788 5.04201C20.956 4.96586 19.9449 5.00688 18.9496 5.05775C18.7097 5.06961 18.3812 5.08589 18.0738 5.16554C17.3101 5.36337 17.1063 5.84743 17.0528 6.11834C16.9641 6.56708 17.12 6.91615 17.2231 7.14718L17.2231 7.1472L17.2231 7.14723C17.2381 7.18072 17.2566 7.22213 17.2243 7.32997C17.0526 7.59588 16.7825 7.83561 16.5071 8.06273C16.4275 8.13038 14.5727 9.72968 14.4707 11.8189C14.1957 13.4078 14.2165 15.8834 14.5417 17.5944C14.5606 17.6889 14.5885 17.8288 14.5432 17.9233L14.5432 17.9233C14.1935 18.2367 13.7971 18.5919 13.7981 19.4024C13.8023 19.4839 13.9148 21.4181 15.1272 23.1191C15.408 23.5131 15.7726 23.9149 16.1587 24.3403L16.1596 24.3412L16.1596 24.3413C17.3848 25.6911 18.9095 27.371 19.1678 30.2135C19.2694 31.3327 19.2162 31.7505 19.1947 31.8614C19.166 31.9145 19.1436 31.971 19.1282 32.0298C19.1266 32.0356 18.8974 32.6287 16.2772 33.2996C14.7656 33.6867 13.2775 33.8845 13.2331 33.8909C11.2994 34.2173 9.66438 34.7963 8.37351 35.6115C7.30813 36.2844 6.47354 37.1175 5.89289 38.0877C4.96517 39.6379 4.99025 41.0497 5.00372 41.3074V42.2311Z" fill="none" stroke="#ffffff" stroke-width="4" stroke-linejoin="round"/></svg>
|
After Width: | Height: | Size: 2.2 KiB |
|
@ -0,0 +1 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?><svg width="48" height="48" viewBox="0 0 48 48" fill="none" xmlns="http://www.w3.org/2000/svg"><path fill-rule="evenodd" clip-rule="evenodd" d="M24 43.6C32.432 43.6 39.5606 36.9192 41.8935 29.2497C42.4179 27.5255 46 27.5255 46 23.8C46 20.0745 42.3839 19.8601 41.7987 18.048C39.3724 10.5346 32.3209 4 24 4C15.6745 4 8.61973 10.5407 6.19725 18.0606C5.61467 19.8691 2 20.0091 2 23.8C2 27.5909 5.59225 27.5909 6.1349 29.3421C8.4967 36.9639 15.6018 43.6 24 43.6Z" fill="none" stroke="#ffffff" stroke-width="4"/><path d="M41.7987 18.048C39.3724 10.5346 32.3209 4 24 4" stroke="#ffffff" stroke-width="4" stroke-linecap="round"/><path d="M19.1002 21.5998C19.1002 22.4261 18.876 23.1516 18.5398 23.6559C18.2013 24.1637 17.7885 24.3998 17.4002 24.3998C17.0119 24.3998 16.5991 24.1637 16.2606 23.6559C15.9244 23.1516 15.7002 22.4261 15.7002 21.5998C15.7002 20.7735 15.9244 20.048 16.2606 19.5437C16.5991 19.0359 17.0119 18.7998 17.4002 18.7998C17.7885 18.7998 18.2013 19.0359 18.5398 19.5437C18.876 20.048 19.1002 20.7735 19.1002 21.5998Z" fill="#ffffff" stroke="#ffffff"/><path d="M32.2999 21.5998C32.2999 22.4261 32.0757 23.1516 31.7395 23.6559C31.401 24.1637 30.9882 24.3998 30.5999 24.3998C30.2116 24.3998 29.7988 24.1637 29.4603 23.6559C29.1241 23.1516 28.8999 22.4261 28.8999 21.5998C28.8999 20.7735 29.1241 20.048 29.4603 19.5437C29.7988 19.0359 30.2116 18.7998 30.5999 18.7998C30.9882 18.7998 31.401 19.0359 31.7395 19.5437C32.0757 20.048 32.2999 20.7735 32.2999 21.5998Z" fill="#ffffff" stroke="#ffffff"/><path fill-rule="evenodd" clip-rule="evenodd" d="M18.498 31.7505C20.4289 33.0501 22.266 33.6999 24.0094 33.6999C25.7509 33.6999 27.4776 33.0515 29.1894 31.7547" fill="#ffffff"/><path d="M18.498 31.7505C20.4289 33.0501 22.266 33.6999 24.0094 33.6999C25.7509 33.6999 27.4776 33.0515 29.1894 31.7547" stroke="#ffffff" stroke-width="4" stroke-linecap="round"/><path d="M31.7283 6.2002C31.9964 8.13368 31.4067 9.54651 29.9593 10.4387C28.5119 11.3309 26.1602 11.749 22.9043 11.693" stroke="#ffffff" stroke-width="4" stroke-linecap="round"/></svg>
|
After Width: | Height: | Size: 2.0 KiB |
After Width: | Height: | Size: 3.8 KiB |
After Width: | Height: | Size: 3.6 KiB |
After Width: | Height: | Size: 5.9 KiB |
|
@ -0,0 +1,11 @@
|
|||
{
|
||||
"0": "ANGRY",
|
||||
"1": "CONFUSED",
|
||||
"2": "DISGUST",
|
||||
"3": "FEAR",
|
||||
"4": "HAPPY",
|
||||
"5": "NATURAL",
|
||||
"6": "SAD",
|
||||
"7": "SHY",
|
||||
"8": "SURPRISED"
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
{
|
||||
"0": "痤疮或酒渣鼻",
|
||||
"1": "光化性角化病基底细胞癌或其他恶性病变",
|
||||
"2": "过敏性皮炎",
|
||||
"3": "大疱性疾病",
|
||||
"4": "蜂窝织炎、脓疱病或其他细菌感染",
|
||||
"5": "湿疹",
|
||||
"6": "皮疹或药疹",
|
||||
"7": "脱发或其他头发疾病",
|
||||
"8": "健康",
|
||||
"9": "疱疹、HPV或其他性病",
|
||||
"10": "轻度疾病和色素沉着障碍",
|
||||
"11": "狼疮或其他结缔组织疾病",
|
||||
"12": "黑色素瘤皮肤癌痣或痣",
|
||||
"13": "指甲真菌或其他指甲疾病",
|
||||
"14": "毒藤或其他接触性皮炎",
|
||||
"15": "牛皮癣、扁平苔藓或相关疾病",
|
||||
"16": "疥疮、莱姆病或其他感染和叮咬",
|
||||
"17": "脂溢性角化病或其他良性肿瘤",
|
||||
"18": "全身性疾病",
|
||||
"19": "癣念珠菌病或其他真菌感染",
|
||||
"20": "荨麻疹",
|
||||
"21": "血管肿瘤",
|
||||
"22": "血管炎",
|
||||
"23": "疣、软疣或其他病毒感染"
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"0": "干性皮肤",
|
||||
"1": "正常皮肤",
|
||||
"2": "油性皮肤"
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
import sys
|
||||
from PyQt5.QtWidgets import QApplication
|
||||
from ui.main_window import MainWindow
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
if __name__ == '__main__':
|
||||
app = QApplication(sys.argv)
|
||||
main_window = MainWindow()
|
||||
main_window.show()
|
||||
sys.exit(app.exec_())
|
|
@ -0,0 +1,87 @@
|
|||
# 健康监测系统
|
||||
|
||||
## 项目概述
|
||||
|
||||
这是一个基于计算机视觉和人工智能的健康监测系统。该系统利用摄像头捕获用户图像,并通过多个AI模型分析各种健康指标,如年龄、性别、血压、情绪状态、心率、呼吸率以及皮肤状况。
|
||||
|
||||
## 系统架构
|
||||
|
||||
```
|
||||
project/
|
||||
│
|
||||
├── main.py # 主程序入口
|
||||
├── ui/
|
||||
│ ├── main_window.py # 主窗口UI逻辑代码
|
||||
│ └── ui.py # 主窗口UI布局代码
|
||||
├── images/ # UI相关图片
|
||||
├── labels/ # 模型识别分类标签
|
||||
├── weights/ # 模型权重
|
||||
├── core/
|
||||
│ ├── camera_thread.py # 摄像头读取线程
|
||||
│ ├── storage_thread.py # 帧序列存储线程
|
||||
│ ├── update_thread.py # 结果更新线程
|
||||
│ └── api_process.py # AI处理进程
|
||||
└── apis/
|
||||
├── age/ # 年龄性别预测模型
|
||||
├── bp/ # 血压预测模型
|
||||
├── emotion/ # 情绪检测模型
|
||||
├── hr/ # 心率检测模型
|
||||
├── rr/ # 呼吸检测模型
|
||||
├── sd/ # 皮肤疾病检测模型
|
||||
└── st/ # 皮肤类型检测模型
|
||||
```
|
||||
|
||||
## 系统流程
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[主线程: GUI] --> B[线程1: 摄像头读取]
|
||||
A --> C[线程2: 帧序列存储]
|
||||
A --> D[进程1: AI处理]
|
||||
A --> E[线程3: 结果更新]
|
||||
B -->|帧数据| A
|
||||
B -->|帧数据| C
|
||||
C -->|帧序列| D
|
||||
D -->|处理结果| E
|
||||
E -->|更新UI| A
|
||||
```
|
||||
|
||||
|
||||
|
||||
## 主要组件
|
||||
|
||||
1. **主线程(GUI线程)**
|
||||
- 运行PyQt5的事件循环和处理用户交互
|
||||
- 显示摄像头画面和AI处理结果
|
||||
- 协调其他线程和进程的工作
|
||||
|
||||
2. **摄像头读取线程**
|
||||
- 专门负责从摄像头读取画面
|
||||
- 使用QTimer定时触发画面捕获
|
||||
- 通过信号机制将捕获的帧传递给主线程显示和存储线程
|
||||
|
||||
3. **帧序列存储线程**
|
||||
- 接收来自摄像头读取线程的帧
|
||||
- 将帧序列存储到磁盘或内存缓冲区
|
||||
- 实现循环缓冲,只保留最近的N帧
|
||||
|
||||
4. **AI处理进程**
|
||||
- 使用多进程以充分利用多核CPU
|
||||
- 定期从存储的帧序列中获取数据
|
||||
- 运行多个AI算法来计算各项生理指标
|
||||
- 将处理结果通过进程间通信(如Queue)发送回主进程
|
||||
|
||||
5. **结果更新线程**
|
||||
- 接收来自AI处理进程的结果
|
||||
- 通过信号机制将结果传递给主线程进行UI更新
|
||||
|
||||
## 功能模块
|
||||
|
||||
- 年龄和性别预测
|
||||
- 血压预测
|
||||
- 情绪检测
|
||||
- 心率检测
|
||||
- 呼吸率检测
|
||||
- 皮肤疾病检测
|
||||
- 皮肤类型检测
|
||||
|
|
@ -0,0 +1,118 @@
|
|||
import os
|
||||
|
||||
from PyQt5.QtWidgets import QMainWindow
|
||||
from PyQt5.QtCore import pyqtSlot, QTimer, QDateTime
|
||||
from PyQt5 import QtGui
|
||||
|
||||
from core.camera_thread import CameraThread
|
||||
from core.storage_thread import StorageThread
|
||||
from core.api_process import AIProcess
|
||||
from core.update_thread import UpdateThread
|
||||
from ui.ui import Ui_MainWindow
|
||||
|
||||
|
||||
class MainWindow(QMainWindow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ui = Ui_MainWindow()
|
||||
self.ui.setupUi(self)
|
||||
|
||||
self.images_path = "images"
|
||||
|
||||
self.pic_dict = {
|
||||
"Male": "male.png",
|
||||
"Female": "female.png",
|
||||
"ANGRY": "ANGRY.png",
|
||||
"DISGUST": "DISGUST.png",
|
||||
"FEAR": "FEAR.png",
|
||||
"HAPPY": "HAPPY.png",
|
||||
"NATURAL": "NATURAL.png",
|
||||
"SAD": "SAD.png",
|
||||
"SURPRISED": "SURPRISED.png",
|
||||
"CONFUSED": "CONFUSED.png",
|
||||
"SHY": "SHY.png",
|
||||
}
|
||||
|
||||
for key in self.pic_dict.keys():
|
||||
self.pic_dict[key] = os.path.join(self.images_path, self.pic_dict[key])
|
||||
|
||||
self.camera_thread = CameraThread()
|
||||
self.storage_thread = StorageThread()
|
||||
self.ai_process = AIProcess()
|
||||
self.update_thread = UpdateThread(self.ai_process)
|
||||
|
||||
self.camera_thread.new_frame.connect(self.ui.update_image)
|
||||
self.camera_thread.storage_frame.connect(self.storage_thread.store_frame)
|
||||
|
||||
self.camera_thread.fps_signal.connect(self.ai_process.update_fps)
|
||||
self.storage_thread.frames_ready.connect(self.ai_process.process_frames)
|
||||
self.ai_process.results_ready.connect(self.update_thread.update_ui)
|
||||
self.update_thread.update_ui_signal.connect(self.update_ui)
|
||||
|
||||
self.camera_thread.start()
|
||||
self.storage_thread.start()
|
||||
self.ai_process.start()
|
||||
self.update_thread.start()
|
||||
|
||||
# 设置定时器更新时间显示
|
||||
self.timer = QTimer(self)
|
||||
self.timer.timeout.connect(self.update_time)
|
||||
self.timer.start(1000) # 每秒更新一次
|
||||
|
||||
@pyqtSlot(dict)
|
||||
def update_ui(self, results):
|
||||
# print(results)
|
||||
# print("update ui")
|
||||
Resp = results.get('resp', [])
|
||||
if Resp is not None:
|
||||
Resp[0] = self.ui.last_resp
|
||||
|
||||
for res in Resp:
|
||||
self.ui.resps.put(res)
|
||||
self.ui.last_resp = Resp[-1]
|
||||
|
||||
Bvp = results.get('bvp', [])
|
||||
if Bvp is not None:
|
||||
Bvp[0] = self.ui.last_bvp
|
||||
|
||||
for bvp in Bvp:
|
||||
self.ui.bvps.put(bvp)
|
||||
self.ui.last_bvp = Bvp[-1]
|
||||
|
||||
# 更新UI上的各项指标
|
||||
self.ui.sbp_value.setText(str(results.get('sbp', '0')))
|
||||
self.ui.dbp_value.setText(str(results.get('dbp', '0')))
|
||||
self.ui.hr_value.setText(str(results.get('hr', '0')))
|
||||
self.ui.cvrr_value.setText(str(results.get('cvrr', '0')))
|
||||
self.ui.sdnn_value.setText(str(results.get('sdnn', '0')))
|
||||
self.ui.rmssd_value.setText(str(results.get('rmssd', '0')))
|
||||
self.ui.emotion_value.setText(results.get('emotion', '0'))
|
||||
self.ui.rr_value.setText(str(results.get('rr', '0')))
|
||||
self.ui.spo2_value.setText(str(results.get('spo2', '0')))
|
||||
|
||||
if results.get('skin_disease', '健康') != '健康':
|
||||
self.ui.skin_diseance_value.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:18pt;")
|
||||
else:
|
||||
self.ui.skin_diseance_value.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:28pt;")
|
||||
self.ui.skin_diseance_value.setText(results.get('skin_disease', '健康'))
|
||||
|
||||
self.ui.skin_ctype_value.setText(results.get('skin_type', '正常皮肤'))
|
||||
self.ui.age_value.setProperty("value", results.get('age', 30))
|
||||
self.ui.sex_value.setPixmap(QtGui.QPixmap(self.pic_dict.get(results.get('gender', 'Male'))))
|
||||
|
||||
self.ui.emotion_value.setText(results.get('emotion', 'NATURAL'))
|
||||
self.ui.emotion_icon.setPixmap(QtGui.QPixmap(self.pic_dict.get(results.get('emotion', 'NATURAL'))))
|
||||
|
||||
def update_time(self):
|
||||
current_time = QDateTime.currentDateTime()
|
||||
time_display = current_time.toString('yyyy年MM月dd日 hh:mm:ss dddd')
|
||||
self.ui.top.setText(time_display)
|
||||
|
||||
def closeEvent(self, event):
|
||||
self.camera_thread.stop()
|
||||
self.storage_thread.stop()
|
||||
self.ai_process.stop()
|
||||
self.update_thread.stop()
|
||||
super().closeEvent(event)
|
|
@ -0,0 +1,666 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Form implementation generated from reading ui file 'demo.ui'
|
||||
#
|
||||
# Created by: PyQt5 UI code generator 5.15.9
|
||||
#
|
||||
# WARNING: Any manual changes made to this file will be lost when pyuic5 is
|
||||
# run again. Do not edit this file unless you know what you are doing.
|
||||
|
||||
import queue
|
||||
from PyQt5 import QtCore, QtGui, QtWidgets
|
||||
from PyQt5.QtCore import QTimer, Qt
|
||||
from PyQt5.QtGui import QImage, QPixmap
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FC
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Ui_MainWindow(object):
|
||||
|
||||
def mousePressEvent(self, event):
|
||||
if event.button() == Qt.LeftButton:
|
||||
self.m_drag = True
|
||||
self.m_DragPosition = event.globalPos() - self.window.pos()
|
||||
event.accept()
|
||||
|
||||
def mouseMoveEvent(self, QMouseEvent):
|
||||
if Qt.LeftButton and self.m_drag:
|
||||
self.window.move(QMouseEvent.globalPos() - self.m_DragPosition)
|
||||
QMouseEvent.accept()
|
||||
|
||||
def mouseReleaseEvent(self, QMouseEvent):
|
||||
self.m_drag = False
|
||||
|
||||
def setupUi(self, MainWindow):
|
||||
MainWindow.setObjectName("MainWindow")
|
||||
MainWindow.resize(1920, 1080)
|
||||
MainWindow.setMinimumSize(QtCore.QSize(1920, 1080))
|
||||
MainWindow.setMaximumSize(QtCore.QSize(1920, 1080))
|
||||
# 设置窗体无边框
|
||||
MainWindow.setWindowFlags(Qt.Window | Qt.FramelessWindowHint)
|
||||
MainWindow.setAttribute(Qt.WA_TranslucentBackground)
|
||||
|
||||
self.window = MainWindow
|
||||
MainWindow.setMouseTracking(True)
|
||||
MainWindow.mousePressEvent = self.mousePressEvent
|
||||
MainWindow.mouseMoveEvent = self.mouseMoveEvent
|
||||
MainWindow.mouseReleaseEvent = self.mouseReleaseEvent
|
||||
|
||||
font = QtGui.QFont()
|
||||
font.setFamily("Microsoft YaHei")
|
||||
font.setPointSize(16)
|
||||
font.setBold(True)
|
||||
font.setItalic(False)
|
||||
font.setWeight(75)
|
||||
MainWindow.setFont(font)
|
||||
MainWindow.setStyleSheet("color:rgb(255,255,255); \n"
|
||||
"font: 16pt \"Microsoft YaHei\";\n"
|
||||
"font-weight:600;\n"
|
||||
"background:rgb(13, 29, 66)\n"
|
||||
"")
|
||||
self.centralwidget = QtWidgets.QWidget(MainWindow)
|
||||
self.centralwidget.setObjectName("centralwidget")
|
||||
self.top = QtWidgets.QLabel(self.centralwidget)
|
||||
self.top.setGeometry(QtCore.QRect(5, 5, 1910, 60))
|
||||
font = QtGui.QFont()
|
||||
font.setFamily("Microsoft YaHei")
|
||||
font.setPointSize(18)
|
||||
font.setBold(True)
|
||||
font.setItalic(False)
|
||||
font.setWeight(75)
|
||||
self.top.setFont(font)
|
||||
self.top.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"padding-right:20;font-size:18pt;\n"
|
||||
"border:2px solid grey;\n"
|
||||
"border-radius: 5px;")
|
||||
self.top.setFrameShape(QtWidgets.QFrame.Panel)
|
||||
self.top.setScaledContents(False)
|
||||
self.top.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignTrailing | QtCore.Qt.AlignVCenter)
|
||||
self.top.setWordWrap(False)
|
||||
self.top.setObjectName("top")
|
||||
self.bvp = QtWidgets.QLabel(self.centralwidget)
|
||||
self.bvp.setGeometry(QtCore.QRect(5, 809, 510, 261))
|
||||
self.bvp.setStyleSheet("background:rgba(13, 29, 66, 0);\n"
|
||||
"border:2px solid grey;\n"
|
||||
"border-radius: 5px;")
|
||||
self.bvp.setFrameShape(QtWidgets.QFrame.Panel)
|
||||
self.bvp.setMidLineWidth(0)
|
||||
self.bvp.setText("")
|
||||
self.bvp.setObjectName("bvp")
|
||||
self.emotion = QtWidgets.QLabel(self.centralwidget)
|
||||
self.emotion.setGeometry(QtCore.QRect(1405, 70, 510, 200))
|
||||
self.emotion.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"border:2px solid grey;\n"
|
||||
"border-radius: 5px;")
|
||||
self.emotion.setFrameShape(QtWidgets.QFrame.Box)
|
||||
self.emotion.setMidLineWidth(0)
|
||||
self.emotion.setText("")
|
||||
self.emotion.setObjectName("emotion")
|
||||
self.skin = QtWidgets.QLabel(self.centralwidget)
|
||||
self.skin.setGeometry(QtCore.QRect(1405, 280, 510, 230))
|
||||
self.skin.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"border:2px solid grey;\n"
|
||||
"border-radius: 5px;")
|
||||
self.skin.setFrameShape(QtWidgets.QFrame.Panel)
|
||||
self.skin.setMidLineWidth(0)
|
||||
self.skin.setText("")
|
||||
self.skin.setObjectName("skin")
|
||||
self.resp = QtWidgets.QLabel(self.centralwidget)
|
||||
self.resp.setGeometry(QtCore.QRect(1405, 709, 510, 361))
|
||||
self.resp.setStyleSheet("background:rgba(13, 29, 66 ,0);\n"
|
||||
"border:2px solid grey;\n"
|
||||
"border-radius: 5px;")
|
||||
self.resp.setFrameShape(QtWidgets.QFrame.Panel)
|
||||
self.resp.setMidLineWidth(0)
|
||||
self.resp.setText("")
|
||||
self.resp.setObjectName("resp")
|
||||
self.rr = QtWidgets.QLabel(self.centralwidget)
|
||||
self.rr.setGeometry(QtCore.QRect(1405, 520, 510, 181))
|
||||
self.rr.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"border:2px solid grey;\n"
|
||||
"border-radius: 5px;")
|
||||
self.rr.setFrameShape(QtWidgets.QFrame.Panel)
|
||||
self.rr.setFrameShadow(QtWidgets.QFrame.Plain)
|
||||
self.rr.setLineWidth(1)
|
||||
self.rr.setMidLineWidth(0)
|
||||
self.rr.setText("")
|
||||
self.rr.setObjectName("rr")
|
||||
self.age_region = QtWidgets.QLabel(self.centralwidget)
|
||||
self.age_region.setGeometry(QtCore.QRect(5, 70, 510, 200))
|
||||
self.age_region.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"border:2px solid grey;\n"
|
||||
"border-radius: 5px;")
|
||||
self.age_region.setFrameShape(QtWidgets.QFrame.Panel)
|
||||
self.age_region.setMidLineWidth(0)
|
||||
self.age_region.setText("")
|
||||
self.age_region.setObjectName("age_region")
|
||||
self.hr = QtWidgets.QLabel(self.centralwidget)
|
||||
self.hr.setGeometry(QtCore.QRect(5, 500, 510, 301))
|
||||
self.hr.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"border:2px solid grey;\n"
|
||||
"border-radius: 5px;")
|
||||
self.hr.setFrameShape(QtWidgets.QFrame.Panel)
|
||||
self.hr.setMidLineWidth(0)
|
||||
self.hr.setText("")
|
||||
self.hr.setObjectName("hr")
|
||||
self.bp = QtWidgets.QLabel(self.centralwidget)
|
||||
self.bp.setGeometry(QtCore.QRect(5, 280, 510, 211))
|
||||
self.bp.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"border:2px solid grey;\n"
|
||||
"border-radius: 5px;")
|
||||
self.bp.setFrameShape(QtWidgets.QFrame.Panel)
|
||||
self.bp.setMidLineWidth(0)
|
||||
self.bp.setText("")
|
||||
self.bp.setObjectName("bp")
|
||||
self.image = QtWidgets.QLabel(self.centralwidget)
|
||||
self.image.setGeometry(QtCore.QRect(520, 70, 881, 1001))
|
||||
self.image.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"border:2px solid grey;\n"
|
||||
"border-radius: 5px;margin:0;padding:0;")
|
||||
self.image.setFrameShape(QtWidgets.QFrame.Panel)
|
||||
self.image.setText("")
|
||||
# self.image.setPixmap(QtGui.QPixmap("../../au/featDemo/1.png"))
|
||||
self.image.setObjectName("image")
|
||||
self.title = QtWidgets.QLabel(self.centralwidget)
|
||||
self.title.setGeometry(QtCore.QRect(750, 5, 421, 60))
|
||||
font = QtGui.QFont()
|
||||
font.setFamily("Microsoft YaHei")
|
||||
font.setPointSize(28)
|
||||
font.setBold(True)
|
||||
font.setItalic(False)
|
||||
font.setWeight(75)
|
||||
self.title.setFont(font)
|
||||
self.title.setStyleSheet("font-size:28pt;\n"
|
||||
"background:none;")
|
||||
self.title.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.title.setObjectName("title")
|
||||
self.sex_title = QtWidgets.QLabel(self.centralwidget)
|
||||
self.sex_title.setGeometry(QtCore.QRect(50, 90, 71, 41))
|
||||
self.sex_title.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:18pt;")
|
||||
self.sex_title.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.sex_title.setObjectName("sex_title")
|
||||
self.sex_value = QtWidgets.QLabel(self.centralwidget)
|
||||
self.sex_value.setGeometry(QtCore.QRect(50, 160, 71, 71))
|
||||
self.sex_value.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"")
|
||||
self.sex_value.setText("")
|
||||
self.sex_value.setPixmap(QtGui.QPixmap("images/female.png"))
|
||||
self.sex_value.setScaledContents(True)
|
||||
self.sex_value.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.sex_value.setObjectName("sex_value")
|
||||
self.adult = QtWidgets.QLabel(self.centralwidget)
|
||||
self.adult.setGeometry(QtCore.QRect(310, 170, 51, 51))
|
||||
self.adult.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"")
|
||||
self.adult.setText("")
|
||||
self.adult.setPixmap(QtGui.QPixmap("images/adult.svg"))
|
||||
self.adult.setScaledContents(True)
|
||||
self.adult.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.adult.setObjectName("adult")
|
||||
self.age_title = QtWidgets.QLabel(self.centralwidget)
|
||||
self.age_title.setGeometry(QtCore.QRect(300, 90, 71, 41))
|
||||
self.age_title.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:18pt;")
|
||||
self.age_title.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.age_title.setObjectName("age_title")
|
||||
self.age_value = QtWidgets.QProgressBar(self.centralwidget)
|
||||
self.age_value.setGeometry(QtCore.QRect(190, 229, 291, 5))
|
||||
self.age_value.setStyleSheet(".QProgressBar{\n"
|
||||
"background:rgb(255, 255, 255);\n"
|
||||
"border-radius:2%;\n"
|
||||
"border:1px solid grey;}\n"
|
||||
".QProgressBar::chunk {\n"
|
||||
" background-color:rgb(13, 29, 66)\n"
|
||||
"}\n"
|
||||
"")
|
||||
self.age_value.setMaximum(70)
|
||||
self.age_value.setProperty("value", 42)
|
||||
self.age_value.setTextVisible(False)
|
||||
self.age_value.setOrientation(QtCore.Qt.Horizontal)
|
||||
self.age_value.setObjectName("age_value")
|
||||
self.baby = QtWidgets.QLabel(self.centralwidget)
|
||||
self.baby.setGeometry(QtCore.QRect(170, 170, 51, 51))
|
||||
self.baby.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"")
|
||||
self.baby.setText("")
|
||||
self.baby.setPixmap(QtGui.QPixmap("images/baby.svg"))
|
||||
self.baby.setScaledContents(True)
|
||||
self.baby.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.baby.setObjectName("baby")
|
||||
self.old = QtWidgets.QLabel(self.centralwidget)
|
||||
self.old.setGeometry(QtCore.QRect(450, 170, 51, 51))
|
||||
self.old.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"")
|
||||
self.old.setText("")
|
||||
self.old.setPixmap(QtGui.QPixmap("images/old.png"))
|
||||
self.old.setScaledContents(True)
|
||||
self.old.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.old.setObjectName("old")
|
||||
self.logo = QtWidgets.QLabel(self.centralwidget)
|
||||
self.logo.setGeometry(QtCore.QRect(1310, 80, 81, 61))
|
||||
self.logo.setStyleSheet("background:none;")
|
||||
self.logo.setText("")
|
||||
self.logo.setPixmap(QtGui.QPixmap("../../../../../公共安全.png"))
|
||||
self.logo.setScaledContents(True)
|
||||
self.logo.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.logo.setObjectName("logo")
|
||||
self.sbp_title = QtWidgets.QLabel(self.centralwidget)
|
||||
self.sbp_title.setGeometry(QtCore.QRect(105, 300, 71, 41))
|
||||
self.sbp_title.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:18pt;")
|
||||
self.sbp_title.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.sbp_title.setObjectName("sbp_title")
|
||||
self.dbp_title = QtWidgets.QLabel(self.centralwidget)
|
||||
self.dbp_title.setGeometry(QtCore.QRect(340, 300, 71, 41))
|
||||
self.dbp_title.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:18pt;")
|
||||
self.dbp_title.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.dbp_title.setObjectName("dbp_title")
|
||||
self.sbp_type = QtWidgets.QLabel(self.centralwidget)
|
||||
self.sbp_type.setGeometry(QtCore.QRect(70, 420, 141, 41))
|
||||
self.sbp_type.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:24pt;")
|
||||
self.sbp_type.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.sbp_type.setObjectName("sbp_type")
|
||||
self.sbp_value = QtWidgets.QLabel(self.centralwidget)
|
||||
self.sbp_value.setGeometry(QtCore.QRect(70, 360, 141, 51))
|
||||
self.sbp_value.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:40pt;")
|
||||
self.sbp_value.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.sbp_value.setObjectName("sbp_value")
|
||||
self.dbp_value = QtWidgets.QLabel(self.centralwidget)
|
||||
self.dbp_value.setGeometry(QtCore.QRect(305, 360, 141, 51))
|
||||
self.dbp_value.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:40pt;")
|
||||
self.dbp_value.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.dbp_value.setObjectName("dbp_value")
|
||||
self.dbp_type = QtWidgets.QLabel(self.centralwidget)
|
||||
self.dbp_type.setGeometry(QtCore.QRect(305, 420, 141, 41))
|
||||
self.dbp_type.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:24pt;")
|
||||
self.dbp_type.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.dbp_type.setObjectName("dbp_type")
|
||||
self.hr_title = QtWidgets.QLabel(self.centralwidget)
|
||||
self.hr_title.setGeometry(QtCore.QRect(100, 520, 71, 41))
|
||||
self.hr_title.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:18pt;")
|
||||
self.hr_title.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.hr_title.setObjectName("hr_title")
|
||||
self.hr_value = QtWidgets.QLabel(self.centralwidget)
|
||||
self.hr_value.setGeometry(QtCore.QRect(65, 560, 141, 51))
|
||||
self.hr_value.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:40pt;")
|
||||
self.hr_value.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.hr_value.setObjectName("hr_value")
|
||||
self.hr_type = QtWidgets.QLabel(self.centralwidget)
|
||||
self.hr_type.setGeometry(QtCore.QRect(65, 610, 141, 41))
|
||||
self.hr_type.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:24pt;")
|
||||
self.hr_type.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.hr_type.setObjectName("hr_type")
|
||||
self.cvrr_title = QtWidgets.QLabel(self.centralwidget)
|
||||
self.cvrr_title.setGeometry(QtCore.QRect(330, 520, 91, 41))
|
||||
self.cvrr_title.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:18pt;")
|
||||
self.cvrr_title.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.cvrr_title.setObjectName("cvrr_title")
|
||||
self.cvrr_value = QtWidgets.QLabel(self.centralwidget)
|
||||
self.cvrr_value.setGeometry(QtCore.QRect(305, 560, 141, 51))
|
||||
self.cvrr_value.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:40pt;")
|
||||
self.cvrr_value.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.cvrr_value.setObjectName("cvrr_value")
|
||||
self.rmssd_title = QtWidgets.QLabel(self.centralwidget)
|
||||
self.rmssd_title.setGeometry(QtCore.QRect(330, 680, 91, 41))
|
||||
self.rmssd_title.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:18pt;")
|
||||
self.rmssd_title.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.rmssd_title.setObjectName("rmssd_title")
|
||||
self.sdnn_title = QtWidgets.QLabel(self.centralwidget)
|
||||
self.sdnn_title.setGeometry(QtCore.QRect(90, 680, 91, 41))
|
||||
self.sdnn_title.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:18pt;")
|
||||
self.sdnn_title.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.sdnn_title.setObjectName("sdnn_title")
|
||||
self.sdnn_value = QtWidgets.QLabel(self.centralwidget)
|
||||
self.sdnn_value.setGeometry(QtCore.QRect(65, 720, 141, 51))
|
||||
self.sdnn_value.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:40pt;")
|
||||
self.sdnn_value.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.sdnn_value.setObjectName("sdnn_value")
|
||||
self.rmssd_value = QtWidgets.QLabel(self.centralwidget)
|
||||
self.rmssd_value.setGeometry(QtCore.QRect(305, 720, 141, 51))
|
||||
self.rmssd_value.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:40pt;")
|
||||
self.rmssd_value.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.rmssd_value.setObjectName("rmssd_value")
|
||||
self.emotion_icon = QtWidgets.QLabel(self.centralwidget)
|
||||
self.emotion_icon.setGeometry(QtCore.QRect(1500, 150, 71, 71))
|
||||
self.emotion_icon.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"")
|
||||
self.emotion_icon.setText("")
|
||||
self.emotion_icon.setPixmap(QtGui.QPixmap("images/happy.png"))
|
||||
self.emotion_icon.setScaledContents(True)
|
||||
self.emotion_icon.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.emotion_icon.setObjectName("emotion_icon")
|
||||
self.emotion_value = QtWidgets.QLabel(self.centralwidget)
|
||||
self.emotion_value.setGeometry(QtCore.QRect(1610, 150, 231, 71))
|
||||
font = QtGui.QFont()
|
||||
font.setFamily("Microsoft YaHei")
|
||||
font.setPointSize(36)
|
||||
font.setBold(True)
|
||||
font.setItalic(False)
|
||||
font.setWeight(75)
|
||||
self.emotion_value.setFont(font)
|
||||
self.emotion_value.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:36pt;")
|
||||
self.emotion_value.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.emotion_value.setObjectName("emotion_value")
|
||||
self.emotion_title = QtWidgets.QLabel(self.centralwidget)
|
||||
self.emotion_title.setGeometry(QtCore.QRect(1610, 90, 100, 40))
|
||||
self.emotion_title.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:18pt;")
|
||||
self.emotion_title.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.emotion_title.setObjectName("emotion_title")
|
||||
self.skin_title = QtWidgets.QLabel(self.centralwidget)
|
||||
self.skin_title.setGeometry(QtCore.QRect(1610, 300, 100, 40))
|
||||
self.skin_title.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:18pt;")
|
||||
self.skin_title.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.skin_title.setObjectName("skin_title")
|
||||
self.skin_ctype_value = QtWidgets.QLabel(self.centralwidget)
|
||||
self.skin_ctype_value.setGeometry(QtCore.QRect(1555, 350, 211, 71))
|
||||
font = QtGui.QFont()
|
||||
font.setFamily("Microsoft YaHei")
|
||||
font.setPointSize(36)
|
||||
font.setBold(True)
|
||||
font.setItalic(False)
|
||||
font.setWeight(75)
|
||||
self.skin_ctype_value.setFont(font)
|
||||
self.skin_ctype_value.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:36pt;")
|
||||
self.skin_ctype_value.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.skin_ctype_value.setObjectName("skin_ctype_value")
|
||||
self.skin_diseance_value = QtWidgets.QLabel(self.centralwidget)
|
||||
self.skin_diseance_value.setGeometry(QtCore.QRect(1415, 420, 491, 71))
|
||||
font = QtGui.QFont()
|
||||
font.setFamily("Microsoft YaHei")
|
||||
font.setPointSize(28)
|
||||
font.setBold(True)
|
||||
font.setItalic(False)
|
||||
font.setWeight(75)
|
||||
self.skin_diseance_value.setFont(font)
|
||||
self.skin_diseance_value.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:28pt;")
|
||||
self.skin_diseance_value.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.skin_diseance_value.setObjectName("skin_diseance_value")
|
||||
self.rr_title = QtWidgets.QLabel(self.centralwidget)
|
||||
self.rr_title.setGeometry(QtCore.QRect(1495, 540, 111, 41))
|
||||
self.rr_title.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:18pt;")
|
||||
self.rr_title.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.rr_title.setObjectName("rr_title")
|
||||
self.rr_type = QtWidgets.QLabel(self.centralwidget)
|
||||
self.rr_type.setGeometry(QtCore.QRect(1480, 630, 141, 41))
|
||||
self.rr_type.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:24pt;")
|
||||
self.rr_type.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.rr_type.setObjectName("rr_type")
|
||||
self.rr_value = QtWidgets.QLabel(self.centralwidget)
|
||||
self.rr_value.setGeometry(QtCore.QRect(1480, 580, 141, 51))
|
||||
self.rr_value.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:40pt;")
|
||||
self.rr_value.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.rr_value.setObjectName("rr_value")
|
||||
self.spo2_value = QtWidgets.QLabel(self.centralwidget)
|
||||
self.spo2_value.setGeometry(QtCore.QRect(1695, 580, 141, 51))
|
||||
self.spo2_value.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:40pt;")
|
||||
self.spo2_value.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.spo2_value.setObjectName("spo2_value")
|
||||
self.spo2_type = QtWidgets.QLabel(self.centralwidget)
|
||||
self.spo2_type.setGeometry(QtCore.QRect(1695, 630, 141, 41))
|
||||
self.spo2_type.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:24pt;")
|
||||
self.spo2_type.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.spo2_type.setObjectName("spo2_type")
|
||||
self.spo2_title = QtWidgets.QLabel(self.centralwidget)
|
||||
self.spo2_title.setGeometry(QtCore.QRect(1710, 540, 111, 41))
|
||||
self.spo2_title.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:18pt;")
|
||||
self.spo2_title.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.spo2_title.setObjectName("spo2_title")
|
||||
self.rPPG_title = QtWidgets.QLabel(self.centralwidget)
|
||||
self.rPPG_title.setGeometry(QtCore.QRect(190, 820, 151, 41))
|
||||
self.rPPG_title.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:18pt;")
|
||||
self.rPPG_title.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.rPPG_title.setObjectName("rPPG_title")
|
||||
self.resp_title = QtWidgets.QLabel(self.centralwidget)
|
||||
self.resp_title.setGeometry(QtCore.QRect(1620, 720, 100, 40))
|
||||
self.resp_title.setStyleSheet("background:rgb(13, 29, 66);\n"
|
||||
"font-size:18pt;")
|
||||
self.resp_title.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.resp_title.setObjectName("resp_title")
|
||||
self.horizontalLayoutWidget = QtWidgets.QWidget(self.centralwidget)
|
||||
self.horizontalLayoutWidget.setGeometry(QtCore.QRect(-50, 870, 601, 191))
|
||||
self.horizontalLayoutWidget.setObjectName("horizontalLayoutWidget")
|
||||
self.rppg_view = QtWidgets.QHBoxLayout(self.horizontalLayoutWidget)
|
||||
self.rppg_view.setContentsMargins(0, 0, 0, 0)
|
||||
self.rppg_view.setObjectName("rppg_view")
|
||||
self.horizontalLayoutWidget_2 = QtWidgets.QWidget(self.centralwidget)
|
||||
self.horizontalLayoutWidget_2.setGeometry(QtCore.QRect(1350, 770, 611, 291))
|
||||
self.horizontalLayoutWidget_2.setObjectName("horizontalLayoutWidget_2")
|
||||
self.resp_view = QtWidgets.QHBoxLayout(self.horizontalLayoutWidget_2)
|
||||
self.resp_view.setContentsMargins(0, 0, 0, 0)
|
||||
self.resp_view.setObjectName("resp_view")
|
||||
self.top.raise_()
|
||||
self.emotion.raise_()
|
||||
self.skin.raise_()
|
||||
self.rr.raise_()
|
||||
self.age_region.raise_()
|
||||
self.hr.raise_()
|
||||
self.bp.raise_()
|
||||
self.title.raise_()
|
||||
self.sex_title.raise_()
|
||||
self.sex_value.raise_()
|
||||
self.adult.raise_()
|
||||
self.age_title.raise_()
|
||||
self.age_value.raise_()
|
||||
self.baby.raise_()
|
||||
self.old.raise_()
|
||||
self.sbp_title.raise_()
|
||||
self.dbp_title.raise_()
|
||||
self.sbp_type.raise_()
|
||||
self.sbp_value.raise_()
|
||||
self.dbp_value.raise_()
|
||||
self.dbp_type.raise_()
|
||||
self.hr_title.raise_()
|
||||
self.hr_value.raise_()
|
||||
self.hr_type.raise_()
|
||||
self.cvrr_title.raise_()
|
||||
self.cvrr_value.raise_()
|
||||
self.rmssd_title.raise_()
|
||||
self.sdnn_title.raise_()
|
||||
self.sdnn_value.raise_()
|
||||
self.rmssd_value.raise_()
|
||||
self.emotion_icon.raise_()
|
||||
self.emotion_value.raise_()
|
||||
self.emotion_title.raise_()
|
||||
self.skin_title.raise_()
|
||||
self.skin_ctype_value.raise_()
|
||||
self.skin_diseance_value.raise_()
|
||||
self.rr_title.raise_()
|
||||
self.rr_type.raise_()
|
||||
self.rr_value.raise_()
|
||||
self.spo2_value.raise_()
|
||||
self.spo2_type.raise_()
|
||||
self.spo2_title.raise_()
|
||||
self.horizontalLayoutWidget.raise_()
|
||||
self.horizontalLayoutWidget_2.raise_()
|
||||
self.bvp.raise_()
|
||||
self.rPPG_title.raise_()
|
||||
self.resp.raise_()
|
||||
self.resp_title.raise_()
|
||||
self.image.raise_()
|
||||
self.logo.raise_()
|
||||
|
||||
# 解决无法显示中文
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei']
|
||||
# 解决无法显示负号
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
# 设置画布部分
|
||||
self.rppg_fig = plt.figure(facecolor='#0d1d42', figsize=(5, 1))
|
||||
self.rppg_ax = self.rppg_fig.add_subplot(111)
|
||||
|
||||
self.resp_fig = plt.figure(facecolor='#0d1d42', figsize=(5, 1))
|
||||
self.resp_ax = self.resp_fig.add_subplot(111)
|
||||
|
||||
# 隐藏边框
|
||||
self.resp_ax.set_frame_on(False)
|
||||
self.rppg_ax.set_frame_on(False)
|
||||
|
||||
self.rppg_canvas = FC(self.rppg_fig)
|
||||
|
||||
self.resp_canvas = FC(self.resp_fig)
|
||||
|
||||
self.rppg_view.addWidget(self.rppg_canvas)
|
||||
self.resp_view.addWidget(self.resp_canvas)
|
||||
self.rppg_canvas.raise_()
|
||||
self.resp_canvas.raise_()
|
||||
|
||||
self.resps = queue.Queue()
|
||||
self.bvps = queue.Queue()
|
||||
|
||||
self.last_resp = 0
|
||||
self.last_bvp = 0
|
||||
|
||||
self.rr_x = []
|
||||
self.rr_y = []
|
||||
|
||||
self.bvp_x = []
|
||||
self.bvp_y = []
|
||||
#
|
||||
# # 设置定时器更新曲线图
|
||||
self.timer = QTimer(MainWindow)
|
||||
self.timer.timeout.connect(self.update_resp_plot)
|
||||
self.timer.start(100)
|
||||
|
||||
self.timer1 = QTimer(MainWindow)
|
||||
self.timer1.timeout.connect(self.update_bvp_plot)
|
||||
self.timer1.start(100)
|
||||
|
||||
MainWindow.setCentralWidget(self.centralwidget)
|
||||
|
||||
self.retranslateUi(MainWindow)
|
||||
QtCore.QMetaObject.connectSlotsByName(MainWindow)
|
||||
|
||||
def retranslateUi(self, MainWindow):
|
||||
_translate = QtCore.QCoreApplication.translate
|
||||
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
|
||||
self.top.setText(_translate("MainWindow", "2024年06月11日 15:11:11 星期二"))
|
||||
self.title.setText(_translate("MainWindow", "体检仪"))
|
||||
self.sex_title.setText(_translate("MainWindow", "性别"))
|
||||
self.age_title.setText(_translate("MainWindow", "年龄"))
|
||||
self.sbp_title.setText(_translate("MainWindow", "收缩压"))
|
||||
self.dbp_title.setText(_translate("MainWindow", "舒张压"))
|
||||
self.sbp_type.setText(_translate("MainWindow", "mmHg"))
|
||||
self.sbp_value.setText(_translate("MainWindow", "140"))
|
||||
self.dbp_value.setText(_translate("MainWindow", "80"))
|
||||
self.dbp_type.setText(_translate("MainWindow", "mmHg"))
|
||||
self.hr_title.setText(_translate("MainWindow", "心率"))
|
||||
self.hr_value.setText(_translate("MainWindow", "60"))
|
||||
self.hr_type.setText(_translate("MainWindow", "bpm"))
|
||||
self.cvrr_title.setText(_translate("MainWindow", "CV R-R"))
|
||||
self.cvrr_value.setText(_translate("MainWindow", "1.1"))
|
||||
self.rmssd_title.setText(_translate("MainWindow", "RMSSD"))
|
||||
self.sdnn_title.setText(_translate("MainWindow", "SDNN"))
|
||||
self.sdnn_value.setText(_translate("MainWindow", "6.6"))
|
||||
self.rmssd_value.setText(_translate("MainWindow", "8.3"))
|
||||
self.emotion_value.setText(_translate("MainWindow", "HAPPY"))
|
||||
self.emotion_title.setText(_translate("MainWindow", "表情"))
|
||||
self.skin_title.setText(_translate("MainWindow", "皮肤属性"))
|
||||
self.skin_ctype_value.setText(_translate("MainWindow", "油性皮肤"))
|
||||
self.skin_diseance_value.setText(_translate("MainWindow", "健康"))
|
||||
self.rr_title.setText(_translate("MainWindow", "呼吸频率"))
|
||||
self.rr_type.setText(_translate("MainWindow", "bpm"))
|
||||
self.rr_value.setText(_translate("MainWindow", "20"))
|
||||
self.spo2_value.setText(_translate("MainWindow", "20"))
|
||||
self.spo2_type.setText(_translate("MainWindow", "%"))
|
||||
self.spo2_title.setText(_translate("MainWindow", "血氧"))
|
||||
self.rPPG_title.setText(_translate("MainWindow", "BVP波动"))
|
||||
self.resp_title.setText(_translate("MainWindow", "呼吸波动"))
|
||||
|
||||
def update_image(self, image):
|
||||
|
||||
height, width = image.shape[:2]
|
||||
|
||||
target_width = width * self.image.height() / height
|
||||
|
||||
pixmap = QImage(image, width, height, QImage.Format_RGB888)
|
||||
|
||||
pixmap = QPixmap.fromImage(pixmap).scaled(target_width, self.image.height())
|
||||
|
||||
self.image.setAlignment(Qt.AlignCenter)
|
||||
self.image.setPixmap(pixmap)
|
||||
|
||||
def update_resp_plot(self):
|
||||
|
||||
if not self.resps.empty():
|
||||
resp = self.resps.get()
|
||||
# 更新线的数据
|
||||
self.rr_y.append(resp)
|
||||
else:
|
||||
self.rr_y.append(0)
|
||||
|
||||
if len(self.rr_x) == 0:
|
||||
self.rr_x.append(1)
|
||||
else:
|
||||
self.rr_x.append(self.rr_x[-1] + 1)
|
||||
|
||||
if len(self.rr_x) > 300:
|
||||
self.rr_x.pop(0)
|
||||
self.rr_y.pop(0)
|
||||
|
||||
# 生成时间序列
|
||||
t = np.linspace(self.rr_x[0] / 30, self.rr_x[-1] / 30, len(self.rr_y))
|
||||
|
||||
self.resp_ax.clear()
|
||||
|
||||
self.resp_ax.set_ylim(max(-1, min(self.rr_y) - 0.5), max(2, max(self.rr_y) + 0.5))
|
||||
self.resp_ax.plot(t, self.rr_y, color='#33FFFF')
|
||||
|
||||
self.resp_ax.axis('off')
|
||||
|
||||
self.resp_canvas.draw()
|
||||
|
||||
def update_bvp_plot(self):
|
||||
|
||||
if not self.bvps.empty():
|
||||
bvp = self.bvps.get()
|
||||
# 更新线的数据
|
||||
self.bvp_y.append(bvp)
|
||||
else:
|
||||
self.bvp_y.append(0)
|
||||
|
||||
if len(self.bvp_x) == 0:
|
||||
self.bvp_x.append(1)
|
||||
else:
|
||||
self.bvp_x.append(self.bvp_x[-1] + 1)
|
||||
|
||||
if len(self.bvp_x) > 300:
|
||||
self.bvp_x.pop(0)
|
||||
self.bvp_y.pop(0)
|
||||
|
||||
# 生成时间序列
|
||||
t = np.linspace(self.bvp_x[0] / 30, self.bvp_x[-1] / 30, len(self.bvp_y))
|
||||
|
||||
self.rppg_ax.clear()
|
||||
|
||||
self.rppg_ax.set_ylim(min(-4, min(self.bvp_y) - 0.5), max(4, max(self.bvp_y) + 0.5))
|
||||
self.rppg_ax.plot(t, self.bvp_y, color='#FFFF00')
|
||||
|
||||
self.rppg_ax.axis('off')
|
||||
|
||||
self.rppg_canvas.draw()
|