Face/gender_age.py

99 lines
2.9 KiB
Python

import datetime
import mxnet as mx
import numpy as np
from retinaface_detect import detect_one, load_retinaface_model, set_retinaface_conf
# 年龄性别配置
class ConfGenderModel(object):
def __init__(self, image_size, image, model, gpu, det):
self.image_size = image_size
self.image = image
self.gpu = gpu
self.model = model
self.det = det
# 实例化一个配置
def set_gender_conf():
args = ConfGenderModel(image_size='112,112',
image=r'C:\Users\ASUS\Desktop\man.png',
gpu=-1,
model='model/model,0',
det=0)
return args
# 加载性别年龄模型
def load_gender_model(args, layer):
if args.gpu >= 0:
ctx = mx.gpu(args.gpu)
else:
ctx = mx.cpu()
_vec = args.image_size.split(',')
assert len(_vec) == 2
image_size = (int(_vec[0]), int(_vec[1]))
_vec = args.model.split(',')
assert len(_vec) == 2
prefix = _vec[0]
epoch = int(_vec[1])
print('loading', prefix, epoch)
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
all_layers = sym.get_internals()
sym = all_layers[layer + '_output']
model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
model.bind(data_shapes=[('data', (1, 3, image_size[0], image_size[1]))])
model.set_params(arg_params, aux_params)
return model
# 前向推理
def get_ga(model, img):
# print(data)
model.forward(img, is_train=False)
ret = model.get_outputs()[0].asnumpy()
g = ret[:, 0:2].flatten()
gender = np.argmax(g)
a = ret[:, 2:202].reshape((100, 2))
a = np.argmax(a, axis=1)
age = int(sum(a))
return gender, age
# 预测人脸列表中每个人的性别年龄
def gender_age(img_list, gender_model):
gender_list = []
age_list = []
if len(img_list) == 0:
print("find no face")
else:
time_now = datetime.datetime.now()
img_list *= 127.5
img_list += 127.5
for img in img_list:
img = np.expand_dims(img, axis=0)
img = mx.nd.array(img)
img = mx.io.DataBatch(data=(img,))
gender, age = get_ga(gender_model, img)
if gender == 1:
gender_list.append("man")
else:
gender_list.append('woman')
age_list.append(age)
time_now2 = datetime.datetime.now()
diff = time_now2 - time_now
print('time cost', diff.total_seconds())
return gender_list,age_list
if __name__ == "__main__":
args = set_gender_conf()
retinaface_args = set_retinaface_conf()
gender_model = load_gender_model(args, 'fc1')
retinaface_model = load_retinaface_model(retinaface_args)
img_list, box_and_point = detect_one(args.image, retinaface_model,retinaface_args)
gender_list, age_list = gender_age(img_list, gender_model)
print(gender_list)