Face/gender_model.py

50 lines
1.5 KiB
Python
Raw Normal View History

2024-07-29 11:24:25 +08:00
import numpy as np
import mxnet as mx
# 加载性别年龄模型
def get_model(ctx, image_size, model_str, layer):
_vec = model_str.split(',')
assert len(_vec) == 2
prefix = _vec[0]
epoch = int(_vec[1])
print('loading', prefix, epoch)
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
all_layers = sym.get_internals()
sym = all_layers[layer + '_output']
model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
model.bind(data_shapes=[('data', (1, 3, image_size[0], image_size[1]))])
model.set_params(arg_params, aux_params)
return model
class GenderModel:
def __init__(self, args):
self.args = args
if args.gpu >= 0:
ctx = mx.gpu(args.gpu)
else:
ctx = mx.cpu()
_vec = args.image_size.split(',')
assert len(_vec) == 2
image_size = (int(_vec[0]), int(_vec[1]))
self.model = None
if len(args.model) > 0:
self.model = get_model(ctx, image_size, args.model, 'fc1')
self.det_minsize = 50
self.det_threshold = [0.6, 0.7, 0.8]
# self.det_factor = 0.9
self.image_size = image_size
def get_ga(self, data):
# print(data)
self.model.forward(data, is_train=False)
ret = self.model.get_outputs()[0].asnumpy()
g = ret[:, 0:2].flatten()
gender = np.argmax(g)
a = ret[:, 2:202].reshape((100, 2))
a = np.argmax(a, axis=1)
age = int(sum(a))
return gender, age