15 lines
502 B
Python
15 lines
502 B
Python
|
import torch
|
||
|
from models import *
|
||
|
|
||
|
class opt():
|
||
|
model_def = "config/ds_1w_prune_0.5_yolov3-ds-person.cfg"
|
||
|
data_config = "config/smallperson.data" # person.data
|
||
|
model = 'checkpoints/ds_1w_prune_0.5_yolov3_ckpt_99_05181112.pth' # checkpoints/yolov3_ckpt.pth'
|
||
|
|
||
|
# 加载模型
|
||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
model = Darknet(opt.model_def).to(device) #加载模型
|
||
|
model.load_state_dict(torch.load(opt.model)) #加载权重
|
||
|
|
||
|
# 打印参数
|
||
|
print(model)
|