50 lines
1.5 KiB
Python
50 lines
1.5 KiB
Python
import numpy as np
|
|
import mxnet as mx
|
|
|
|
|
|
# 加载性别年龄模型
|
|
def get_model(ctx, image_size, model_str, layer):
|
|
_vec = model_str.split(',')
|
|
assert len(_vec) == 2
|
|
prefix = _vec[0]
|
|
epoch = int(_vec[1])
|
|
print('loading', prefix, epoch)
|
|
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
|
|
all_layers = sym.get_internals()
|
|
sym = all_layers[layer + '_output']
|
|
model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
|
|
model.bind(data_shapes=[('data', (1, 3, image_size[0], image_size[1]))])
|
|
model.set_params(arg_params, aux_params)
|
|
return model
|
|
|
|
|
|
class GenderModel:
|
|
def __init__(self, args):
|
|
self.args = args
|
|
if args.gpu >= 0:
|
|
ctx = mx.gpu(args.gpu)
|
|
else:
|
|
ctx = mx.cpu()
|
|
_vec = args.image_size.split(',')
|
|
assert len(_vec) == 2
|
|
image_size = (int(_vec[0]), int(_vec[1]))
|
|
self.model = None
|
|
if len(args.model) > 0:
|
|
self.model = get_model(ctx, image_size, args.model, 'fc1')
|
|
|
|
self.det_minsize = 50
|
|
self.det_threshold = [0.6, 0.7, 0.8]
|
|
# self.det_factor = 0.9
|
|
self.image_size = image_size
|
|
|
|
def get_ga(self, data):
|
|
# print(data)
|
|
self.model.forward(data, is_train=False)
|
|
ret = self.model.get_outputs()[0].asnumpy()
|
|
g = ret[:, 0:2].flatten()
|
|
gender = np.argmax(g)
|
|
a = ret[:, 2:202].reshape((100, 2))
|
|
a = np.argmax(a, axis=1)
|
|
age = int(sum(a))
|
|
return gender, age
|