151 lines
4.7 KiB
Python
151 lines
4.7 KiB
Python
|
import os
|
|||
|
import cv2
|
|||
|
import numpy as np
|
|||
|
import argparse
|
|||
|
import warnings
|
|||
|
import time
|
|||
|
import torch
|
|||
|
import torch.nn.functional as F
|
|||
|
|
|||
|
from src.generate_patches import CropImage
|
|||
|
from src.model_lib.MiniFASNet import MiniFASNetV1, MiniFASNetV2,MiniFASNetV1SE,MiniFASNetV2SE
|
|||
|
from src.data_io import transform as trans
|
|||
|
from src.utility import get_kernel, parse_model_name
|
|||
|
warnings.filterwarnings('ignore')
|
|||
|
|
|||
|
MODEL_MAPPING = {
|
|||
|
'MiniFASNetV1': MiniFASNetV1,
|
|||
|
'MiniFASNetV2': MiniFASNetV2,
|
|||
|
'MiniFASNetV1SE':MiniFASNetV1SE,
|
|||
|
'MiniFASNetV2SE':MiniFASNetV2SE
|
|||
|
}
|
|||
|
|
|||
|
class AntiSpoofPredict():
|
|||
|
def __init__(self, cpu_or_cuda):
|
|||
|
super(AntiSpoofPredict, self).__init__()
|
|||
|
self.device = torch.device("cuda" if cpu_or_cuda == "cuda" else "cpu")
|
|||
|
|
|||
|
def predict(self, img, model):
|
|||
|
test_transform = trans.Compose([
|
|||
|
trans.ToTensor(),
|
|||
|
])
|
|||
|
img = test_transform(img)
|
|||
|
img = img.unsqueeze(0).to(self.device)
|
|||
|
with torch.no_grad():
|
|||
|
result = model.forward(img)
|
|||
|
result = F.softmax(result).cpu().numpy()
|
|||
|
return result
|
|||
|
|
|||
|
def load_anti_model(model_dir,cpu_or_cuda):
|
|||
|
model_list = []
|
|||
|
for model_path in os.listdir(model_dir):
|
|||
|
model_list.append(_load_model(os.path.join(model_dir, model_path), cpu_or_cuda))
|
|||
|
return model_list
|
|||
|
|
|||
|
def _load_model(model_path,cpu_or_cuda):
|
|||
|
# define model
|
|||
|
device = torch.device("cuda" if cpu_or_cuda == "cuda" else "cpu")
|
|||
|
model_name = os.path.basename(model_path)
|
|||
|
h_input, w_input, model_type, _ = parse_model_name(model_name)
|
|||
|
kernel_size = get_kernel(h_input, w_input, )
|
|||
|
model = MODEL_MAPPING[model_type](conv6_kernel=kernel_size).to(device)
|
|||
|
|
|||
|
# load model weight
|
|||
|
state_dict = torch.load(model_path, map_location=device)
|
|||
|
keys = iter(state_dict)
|
|||
|
first_layer_name = keys.__next__()
|
|||
|
if first_layer_name.find('module.') >= 0:
|
|||
|
from collections import OrderedDict
|
|||
|
new_state_dict = OrderedDict()
|
|||
|
for key, value in state_dict.items():
|
|||
|
name_key = key[7:]
|
|||
|
new_state_dict[name_key] = value
|
|||
|
model.load_state_dict(new_state_dict)
|
|||
|
else:
|
|||
|
model.load_state_dict(state_dict)
|
|||
|
model.eval()
|
|||
|
return model
|
|||
|
|
|||
|
|
|||
|
# 因为安卓端APK获取的视频流宽高比为3:4,为了与之一致,所以将宽高比限制为3:4
|
|||
|
def check_image(image):
|
|||
|
height, width, channel = image.shape
|
|||
|
if width/height != 3/4:
|
|||
|
print("Image is not appropriate!!!\nHeight/Width should be 4/3.")
|
|||
|
return False
|
|||
|
else:
|
|||
|
return True
|
|||
|
|
|||
|
# 人脸活体检测
|
|||
|
def anti_spoofing(image_name, model_dir, cpu_or_cuda, bbox, model_list):
|
|||
|
model_test = AntiSpoofPredict(cpu_or_cuda)
|
|||
|
image_cropper = CropImage()
|
|||
|
image = cv2.imdecode(np.fromfile(image_name, dtype=np.uint8), cv2.IMREAD_COLOR)
|
|||
|
h, w = image.shape[:2]
|
|||
|
factor = h / w
|
|||
|
if (w > 1000):
|
|||
|
image = cv2.resize(image, (600, int(600 * factor)))
|
|||
|
# result = check_image(image)
|
|||
|
# if result is False:
|
|||
|
# return
|
|||
|
# image_bbox = model_test.get_bbox(image)
|
|||
|
image_bbox = bbox
|
|||
|
prediction = np.zeros((1, 3))
|
|||
|
test_speed = 0
|
|||
|
# sum the prediction from single model's result
|
|||
|
for index, model_name in enumerate(os.listdir(model_dir)):
|
|||
|
h_input, w_input, model_type, scale = parse_model_name(model_name)
|
|||
|
param = {
|
|||
|
"org_img": image,
|
|||
|
"bbox": image_bbox,
|
|||
|
"scale": scale,
|
|||
|
"out_w": w_input,
|
|||
|
"out_h": h_input,
|
|||
|
"crop": True,
|
|||
|
}
|
|||
|
if scale is None:
|
|||
|
param["crop"] = False
|
|||
|
img = image_cropper.crop(**param)
|
|||
|
|
|||
|
start = time.time()
|
|||
|
prediction += model_test.predict(img, model_list[index])
|
|||
|
test_speed += time.time()-start
|
|||
|
|
|||
|
label = np.argmax(prediction)
|
|||
|
# print(prediction)
|
|||
|
# cv2.rectangle(
|
|||
|
# image,
|
|||
|
# (image_bbox[0], image_bbox[1]),
|
|||
|
# (image_bbox[0] + image_bbox[2], image_bbox[1] + image_bbox[3]),
|
|||
|
# (225,0,0), 2)
|
|||
|
# cv2.imshow("out",image)
|
|||
|
# cv2.waitKey(0)
|
|||
|
value = prediction[0][1]/2
|
|||
|
if value > 0.915:
|
|||
|
return "real face", '{:.10f}'.format(value)
|
|||
|
else:
|
|||
|
return "fake face", '{:.10f}'.format(value)
|
|||
|
|
|||
|
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
desc = "test"
|
|||
|
parser = argparse.ArgumentParser(description=desc)
|
|||
|
parser.add_argument(
|
|||
|
"--device_id",
|
|||
|
type=int,
|
|||
|
default=0,
|
|||
|
help="which gpu id, [0/1/2/3]")
|
|||
|
parser.add_argument(
|
|||
|
"--model_dir",
|
|||
|
type=str,
|
|||
|
default="./resources/anti_spoof_models",
|
|||
|
help="model_lib used to test")
|
|||
|
parser.add_argument(
|
|||
|
"--image_name",
|
|||
|
type=str,
|
|||
|
default="000_0.bmp",
|
|||
|
help="image used to test")
|
|||
|
args = parser.parse_args()
|
|||
|
# anti_spoofing(args.image_name, args.model_dir, args.device_id)
|