99 lines
2.9 KiB
Python
99 lines
2.9 KiB
Python
|
import datetime
|
||
|
import mxnet as mx
|
||
|
import numpy as np
|
||
|
from retinaface_detect import detect_one, load_retinaface_model, set_retinaface_conf
|
||
|
|
||
|
|
||
|
# 年龄性别配置
|
||
|
class ConfGenderModel(object):
|
||
|
def __init__(self, image_size, image, model, gpu, det):
|
||
|
self.image_size = image_size
|
||
|
self.image = image
|
||
|
self.gpu = gpu
|
||
|
self.model = model
|
||
|
self.det = det
|
||
|
|
||
|
|
||
|
# 实例化一个配置
|
||
|
def set_gender_conf():
|
||
|
args = ConfGenderModel(image_size='112,112',
|
||
|
image=r'C:\Users\ASUS\Desktop\man.png',
|
||
|
gpu=-1,
|
||
|
model='model/model,0',
|
||
|
det=0)
|
||
|
return args
|
||
|
|
||
|
|
||
|
# 加载性别年龄模型
|
||
|
def load_gender_model(args, layer):
|
||
|
if args.gpu >= 0:
|
||
|
ctx = mx.gpu(args.gpu)
|
||
|
else:
|
||
|
ctx = mx.cpu()
|
||
|
_vec = args.image_size.split(',')
|
||
|
assert len(_vec) == 2
|
||
|
image_size = (int(_vec[0]), int(_vec[1]))
|
||
|
|
||
|
_vec = args.model.split(',')
|
||
|
assert len(_vec) == 2
|
||
|
prefix = _vec[0]
|
||
|
epoch = int(_vec[1])
|
||
|
print('loading', prefix, epoch)
|
||
|
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
|
||
|
all_layers = sym.get_internals()
|
||
|
sym = all_layers[layer + '_output']
|
||
|
model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
|
||
|
model.bind(data_shapes=[('data', (1, 3, image_size[0], image_size[1]))])
|
||
|
model.set_params(arg_params, aux_params)
|
||
|
return model
|
||
|
|
||
|
|
||
|
# 前向推理
|
||
|
def get_ga(model, img):
|
||
|
# print(data)
|
||
|
model.forward(img, is_train=False)
|
||
|
ret = model.get_outputs()[0].asnumpy()
|
||
|
g = ret[:, 0:2].flatten()
|
||
|
gender = np.argmax(g)
|
||
|
a = ret[:, 2:202].reshape((100, 2))
|
||
|
a = np.argmax(a, axis=1)
|
||
|
age = int(sum(a))
|
||
|
return gender, age
|
||
|
|
||
|
|
||
|
# 预测人脸列表中每个人的性别年龄
|
||
|
def gender_age(img_list, gender_model):
|
||
|
gender_list = []
|
||
|
age_list = []
|
||
|
if len(img_list) == 0:
|
||
|
print("find no face")
|
||
|
else:
|
||
|
time_now = datetime.datetime.now()
|
||
|
img_list *= 127.5
|
||
|
img_list += 127.5
|
||
|
|
||
|
for img in img_list:
|
||
|
img = np.expand_dims(img, axis=0)
|
||
|
img = mx.nd.array(img)
|
||
|
img = mx.io.DataBatch(data=(img,))
|
||
|
gender, age = get_ga(gender_model, img)
|
||
|
if gender == 1:
|
||
|
gender_list.append("man")
|
||
|
else:
|
||
|
gender_list.append('woman')
|
||
|
age_list.append(age)
|
||
|
time_now2 = datetime.datetime.now()
|
||
|
diff = time_now2 - time_now
|
||
|
print('time cost', diff.total_seconds())
|
||
|
return gender_list,age_list
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
args = set_gender_conf()
|
||
|
retinaface_args = set_retinaface_conf()
|
||
|
gender_model = load_gender_model(args, 'fc1')
|
||
|
retinaface_model = load_retinaface_model(retinaface_args)
|
||
|
img_list, box_and_point = detect_one(args.image, retinaface_model,retinaface_args)
|
||
|
gender_list, age_list = gender_age(img_list, gender_model)
|
||
|
print(gender_list)
|