Face/anti.py

151 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import cv2
import numpy as np
import argparse
import warnings
import time
import torch
import torch.nn.functional as F
from src.generate_patches import CropImage
from src.model_lib.MiniFASNet import MiniFASNetV1, MiniFASNetV2,MiniFASNetV1SE,MiniFASNetV2SE
from src.data_io import transform as trans
from src.utility import get_kernel, parse_model_name
warnings.filterwarnings('ignore')
MODEL_MAPPING = {
'MiniFASNetV1': MiniFASNetV1,
'MiniFASNetV2': MiniFASNetV2,
'MiniFASNetV1SE':MiniFASNetV1SE,
'MiniFASNetV2SE':MiniFASNetV2SE
}
class AntiSpoofPredict():
def __init__(self, cpu_or_cuda):
super(AntiSpoofPredict, self).__init__()
self.device = torch.device("cuda" if cpu_or_cuda == "cuda" else "cpu")
def predict(self, img, model):
test_transform = trans.Compose([
trans.ToTensor(),
])
img = test_transform(img)
img = img.unsqueeze(0).to(self.device)
with torch.no_grad():
result = model.forward(img)
result = F.softmax(result).cpu().numpy()
return result
def load_anti_model(model_dir,cpu_or_cuda):
model_list = []
for model_path in os.listdir(model_dir):
model_list.append(_load_model(os.path.join(model_dir, model_path), cpu_or_cuda))
return model_list
def _load_model(model_path,cpu_or_cuda):
# define model
device = torch.device("cuda" if cpu_or_cuda == "cuda" else "cpu")
model_name = os.path.basename(model_path)
h_input, w_input, model_type, _ = parse_model_name(model_name)
kernel_size = get_kernel(h_input, w_input, )
model = MODEL_MAPPING[model_type](conv6_kernel=kernel_size).to(device)
# load model weight
state_dict = torch.load(model_path, map_location=device)
keys = iter(state_dict)
first_layer_name = keys.__next__()
if first_layer_name.find('module.') >= 0:
from collections import OrderedDict
new_state_dict = OrderedDict()
for key, value in state_dict.items():
name_key = key[7:]
new_state_dict[name_key] = value
model.load_state_dict(new_state_dict)
else:
model.load_state_dict(state_dict)
model.eval()
return model
# 因为安卓端APK获取的视频流宽高比为3:4,为了与之一致所以将宽高比限制为3:4
def check_image(image):
height, width, channel = image.shape
if width/height != 3/4:
print("Image is not appropriate!!!\nHeight/Width should be 4/3.")
return False
else:
return True
# 人脸活体检测
def anti_spoofing(image_name, model_dir, cpu_or_cuda, bbox, model_list):
model_test = AntiSpoofPredict(cpu_or_cuda)
image_cropper = CropImage()
image = cv2.imdecode(np.fromfile(image_name, dtype=np.uint8), cv2.IMREAD_COLOR)
h, w = image.shape[:2]
factor = h / w
if (w > 1000):
image = cv2.resize(image, (600, int(600 * factor)))
# result = check_image(image)
# if result is False:
# return
# image_bbox = model_test.get_bbox(image)
image_bbox = bbox
prediction = np.zeros((1, 3))
test_speed = 0
# sum the prediction from single model's result
for index, model_name in enumerate(os.listdir(model_dir)):
h_input, w_input, model_type, scale = parse_model_name(model_name)
param = {
"org_img": image,
"bbox": image_bbox,
"scale": scale,
"out_w": w_input,
"out_h": h_input,
"crop": True,
}
if scale is None:
param["crop"] = False
img = image_cropper.crop(**param)
start = time.time()
prediction += model_test.predict(img, model_list[index])
test_speed += time.time()-start
label = np.argmax(prediction)
# print(prediction)
# cv2.rectangle(
# image,
# (image_bbox[0], image_bbox[1]),
# (image_bbox[0] + image_bbox[2], image_bbox[1] + image_bbox[3]),
# (225,0,0), 2)
# cv2.imshow("out",image)
# cv2.waitKey(0)
value = prediction[0][1]/2
if value > 0.915:
return "real face", '{:.10f}'.format(value)
else:
return "fake face", '{:.10f}'.format(value)
if __name__ == "__main__":
desc = "test"
parser = argparse.ArgumentParser(description=desc)
parser.add_argument(
"--device_id",
type=int,
default=0,
help="which gpu id, [0/1/2/3]")
parser.add_argument(
"--model_dir",
type=str,
default="./resources/anti_spoof_models",
help="model_lib used to test")
parser.add_argument(
"--image_name",
type=str,
default="000_0.bmp",
help="image used to test")
args = parser.parse_args()
# anti_spoofing(args.image_name, args.model_dir, args.device_id)