151 lines
4.7 KiB
Python
151 lines
4.7 KiB
Python
import os
|
||
import cv2
|
||
import numpy as np
|
||
import argparse
|
||
import warnings
|
||
import time
|
||
import torch
|
||
import torch.nn.functional as F
|
||
|
||
from src.generate_patches import CropImage
|
||
from src.model_lib.MiniFASNet import MiniFASNetV1, MiniFASNetV2,MiniFASNetV1SE,MiniFASNetV2SE
|
||
from src.data_io import transform as trans
|
||
from src.utility import get_kernel, parse_model_name
|
||
warnings.filterwarnings('ignore')
|
||
|
||
MODEL_MAPPING = {
|
||
'MiniFASNetV1': MiniFASNetV1,
|
||
'MiniFASNetV2': MiniFASNetV2,
|
||
'MiniFASNetV1SE':MiniFASNetV1SE,
|
||
'MiniFASNetV2SE':MiniFASNetV2SE
|
||
}
|
||
|
||
class AntiSpoofPredict():
|
||
def __init__(self, cpu_or_cuda):
|
||
super(AntiSpoofPredict, self).__init__()
|
||
self.device = torch.device("cuda" if cpu_or_cuda == "cuda" else "cpu")
|
||
|
||
def predict(self, img, model):
|
||
test_transform = trans.Compose([
|
||
trans.ToTensor(),
|
||
])
|
||
img = test_transform(img)
|
||
img = img.unsqueeze(0).to(self.device)
|
||
with torch.no_grad():
|
||
result = model.forward(img)
|
||
result = F.softmax(result).cpu().numpy()
|
||
return result
|
||
|
||
def load_anti_model(model_dir,cpu_or_cuda):
|
||
model_list = []
|
||
for model_path in os.listdir(model_dir):
|
||
model_list.append(_load_model(os.path.join(model_dir, model_path), cpu_or_cuda))
|
||
return model_list
|
||
|
||
def _load_model(model_path,cpu_or_cuda):
|
||
# define model
|
||
device = torch.device("cuda" if cpu_or_cuda == "cuda" else "cpu")
|
||
model_name = os.path.basename(model_path)
|
||
h_input, w_input, model_type, _ = parse_model_name(model_name)
|
||
kernel_size = get_kernel(h_input, w_input, )
|
||
model = MODEL_MAPPING[model_type](conv6_kernel=kernel_size).to(device)
|
||
|
||
# load model weight
|
||
state_dict = torch.load(model_path, map_location=device)
|
||
keys = iter(state_dict)
|
||
first_layer_name = keys.__next__()
|
||
if first_layer_name.find('module.') >= 0:
|
||
from collections import OrderedDict
|
||
new_state_dict = OrderedDict()
|
||
for key, value in state_dict.items():
|
||
name_key = key[7:]
|
||
new_state_dict[name_key] = value
|
||
model.load_state_dict(new_state_dict)
|
||
else:
|
||
model.load_state_dict(state_dict)
|
||
model.eval()
|
||
return model
|
||
|
||
|
||
# 因为安卓端APK获取的视频流宽高比为3:4,为了与之一致,所以将宽高比限制为3:4
|
||
def check_image(image):
|
||
height, width, channel = image.shape
|
||
if width/height != 3/4:
|
||
print("Image is not appropriate!!!\nHeight/Width should be 4/3.")
|
||
return False
|
||
else:
|
||
return True
|
||
|
||
# 人脸活体检测
|
||
def anti_spoofing(image_name, model_dir, cpu_or_cuda, bbox, model_list):
|
||
model_test = AntiSpoofPredict(cpu_or_cuda)
|
||
image_cropper = CropImage()
|
||
image = cv2.imdecode(np.fromfile(image_name, dtype=np.uint8), cv2.IMREAD_COLOR)
|
||
h, w = image.shape[:2]
|
||
factor = h / w
|
||
if (w > 1000):
|
||
image = cv2.resize(image, (600, int(600 * factor)))
|
||
# result = check_image(image)
|
||
# if result is False:
|
||
# return
|
||
# image_bbox = model_test.get_bbox(image)
|
||
image_bbox = bbox
|
||
prediction = np.zeros((1, 3))
|
||
test_speed = 0
|
||
# sum the prediction from single model's result
|
||
for index, model_name in enumerate(os.listdir(model_dir)):
|
||
h_input, w_input, model_type, scale = parse_model_name(model_name)
|
||
param = {
|
||
"org_img": image,
|
||
"bbox": image_bbox,
|
||
"scale": scale,
|
||
"out_w": w_input,
|
||
"out_h": h_input,
|
||
"crop": True,
|
||
}
|
||
if scale is None:
|
||
param["crop"] = False
|
||
img = image_cropper.crop(**param)
|
||
|
||
start = time.time()
|
||
prediction += model_test.predict(img, model_list[index])
|
||
test_speed += time.time()-start
|
||
|
||
label = np.argmax(prediction)
|
||
# print(prediction)
|
||
# cv2.rectangle(
|
||
# image,
|
||
# (image_bbox[0], image_bbox[1]),
|
||
# (image_bbox[0] + image_bbox[2], image_bbox[1] + image_bbox[3]),
|
||
# (225,0,0), 2)
|
||
# cv2.imshow("out",image)
|
||
# cv2.waitKey(0)
|
||
value = prediction[0][1]/2
|
||
if value > 0.915:
|
||
return "real face", '{:.10f}'.format(value)
|
||
else:
|
||
return "fake face", '{:.10f}'.format(value)
|
||
|
||
|
||
|
||
if __name__ == "__main__":
|
||
desc = "test"
|
||
parser = argparse.ArgumentParser(description=desc)
|
||
parser.add_argument(
|
||
"--device_id",
|
||
type=int,
|
||
default=0,
|
||
help="which gpu id, [0/1/2/3]")
|
||
parser.add_argument(
|
||
"--model_dir",
|
||
type=str,
|
||
default="./resources/anti_spoof_models",
|
||
help="model_lib used to test")
|
||
parser.add_argument(
|
||
"--image_name",
|
||
type=str,
|
||
default="000_0.bmp",
|
||
help="image used to test")
|
||
args = parser.parse_args()
|
||
# anti_spoofing(args.image_name, args.model_dir, args.device_id)
|