YOLOv3-model-pruning/testdict.py

15 lines
502 B
Python
Raw Normal View History

2024-06-25 14:07:50 +08:00
import torch
from models import *
class opt():
model_def = "config/ds_1w_prune_0.5_yolov3-ds-person.cfg"
data_config = "config/smallperson.data" # person.data
model = 'checkpoints/ds_1w_prune_0.5_yolov3_ckpt_99_05181112.pth' # checkpoints/yolov3_ckpt.pth'
# 加载模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Darknet(opt.model_def).to(device) #加载模型
model.load_state_dict(torch.load(opt.model)) #加载权重
# 打印参数
print(model)