import datetime import mxnet as mx import numpy as np from retinaface_detect import detect_one, load_retinaface_model, set_retinaface_conf # 年龄性别配置 class ConfGenderModel(object): def __init__(self, image_size, image, model, gpu, det): self.image_size = image_size self.image = image self.gpu = gpu self.model = model self.det = det # 实例化一个配置 def set_gender_conf(): args = ConfGenderModel(image_size='112,112', image=r'C:\Users\ASUS\Desktop\man.png', gpu=-1, model='model/model,0', det=0) return args # 加载性别年龄模型 def load_gender_model(args, layer): if args.gpu >= 0: ctx = mx.gpu(args.gpu) else: ctx = mx.cpu() _vec = args.image_size.split(',') assert len(_vec) == 2 image_size = (int(_vec[0]), int(_vec[1])) _vec = args.model.split(',') assert len(_vec) == 2 prefix = _vec[0] epoch = int(_vec[1]) print('loading', prefix, epoch) sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) all_layers = sym.get_internals() sym = all_layers[layer + '_output'] model = mx.mod.Module(symbol=sym, context=ctx, label_names=None) model.bind(data_shapes=[('data', (1, 3, image_size[0], image_size[1]))]) model.set_params(arg_params, aux_params) return model # 前向推理 def get_ga(model, img): # print(data) model.forward(img, is_train=False) ret = model.get_outputs()[0].asnumpy() g = ret[:, 0:2].flatten() gender = np.argmax(g) a = ret[:, 2:202].reshape((100, 2)) a = np.argmax(a, axis=1) age = int(sum(a)) return gender, age # 预测人脸列表中每个人的性别年龄 def gender_age(img_list, gender_model): gender_list = [] age_list = [] if len(img_list) == 0: print("find no face") else: time_now = datetime.datetime.now() img_list *= 127.5 img_list += 127.5 for img in img_list: img = np.expand_dims(img, axis=0) img = mx.nd.array(img) img = mx.io.DataBatch(data=(img,)) gender, age = get_ga(gender_model, img) if gender == 1: gender_list.append("man") else: gender_list.append('woman') age_list.append(age) time_now2 = datetime.datetime.now() diff = time_now2 - time_now print('time cost', diff.total_seconds()) return gender_list,age_list if __name__ == "__main__": args = set_gender_conf() retinaface_args = set_retinaface_conf() gender_model = load_gender_model(args, 'fc1') retinaface_model = load_retinaface_model(retinaface_args) img_list, box_and_point = detect_one(args.image, retinaface_model,retinaface_args) gender_list, age_list = gender_age(img_list, gender_model) print(gender_list)