Face/anti.py

151 lines
4.7 KiB
Python
Raw Normal View History

2024-07-29 11:24:25 +08:00
import os
import cv2
import numpy as np
import argparse
import warnings
import time
import torch
import torch.nn.functional as F
from src.generate_patches import CropImage
from src.model_lib.MiniFASNet import MiniFASNetV1, MiniFASNetV2,MiniFASNetV1SE,MiniFASNetV2SE
from src.data_io import transform as trans
from src.utility import get_kernel, parse_model_name
warnings.filterwarnings('ignore')
MODEL_MAPPING = {
'MiniFASNetV1': MiniFASNetV1,
'MiniFASNetV2': MiniFASNetV2,
'MiniFASNetV1SE':MiniFASNetV1SE,
'MiniFASNetV2SE':MiniFASNetV2SE
}
class AntiSpoofPredict():
def __init__(self, cpu_or_cuda):
super(AntiSpoofPredict, self).__init__()
self.device = torch.device("cuda" if cpu_or_cuda == "cuda" else "cpu")
def predict(self, img, model):
test_transform = trans.Compose([
trans.ToTensor(),
])
img = test_transform(img)
img = img.unsqueeze(0).to(self.device)
with torch.no_grad():
result = model.forward(img)
result = F.softmax(result).cpu().numpy()
return result
def load_anti_model(model_dir,cpu_or_cuda):
model_list = []
for model_path in os.listdir(model_dir):
model_list.append(_load_model(os.path.join(model_dir, model_path), cpu_or_cuda))
return model_list
def _load_model(model_path,cpu_or_cuda):
# define model
device = torch.device("cuda" if cpu_or_cuda == "cuda" else "cpu")
model_name = os.path.basename(model_path)
h_input, w_input, model_type, _ = parse_model_name(model_name)
kernel_size = get_kernel(h_input, w_input, )
model = MODEL_MAPPING[model_type](conv6_kernel=kernel_size).to(device)
# load model weight
state_dict = torch.load(model_path, map_location=device)
keys = iter(state_dict)
first_layer_name = keys.__next__()
if first_layer_name.find('module.') >= 0:
from collections import OrderedDict
new_state_dict = OrderedDict()
for key, value in state_dict.items():
name_key = key[7:]
new_state_dict[name_key] = value
model.load_state_dict(new_state_dict)
else:
model.load_state_dict(state_dict)
model.eval()
return model
# 因为安卓端APK获取的视频流宽高比为3:4,为了与之一致所以将宽高比限制为3:4
def check_image(image):
height, width, channel = image.shape
if width/height != 3/4:
print("Image is not appropriate!!!\nHeight/Width should be 4/3.")
return False
else:
return True
# 人脸活体检测
def anti_spoofing(image_name, model_dir, cpu_or_cuda, bbox, model_list):
model_test = AntiSpoofPredict(cpu_or_cuda)
image_cropper = CropImage()
image = cv2.imdecode(np.fromfile(image_name, dtype=np.uint8), cv2.IMREAD_COLOR)
h, w = image.shape[:2]
factor = h / w
if (w > 1000):
image = cv2.resize(image, (600, int(600 * factor)))
# result = check_image(image)
# if result is False:
# return
# image_bbox = model_test.get_bbox(image)
image_bbox = bbox
prediction = np.zeros((1, 3))
test_speed = 0
# sum the prediction from single model's result
for index, model_name in enumerate(os.listdir(model_dir)):
h_input, w_input, model_type, scale = parse_model_name(model_name)
param = {
"org_img": image,
"bbox": image_bbox,
"scale": scale,
"out_w": w_input,
"out_h": h_input,
"crop": True,
}
if scale is None:
param["crop"] = False
img = image_cropper.crop(**param)
start = time.time()
prediction += model_test.predict(img, model_list[index])
test_speed += time.time()-start
label = np.argmax(prediction)
# print(prediction)
# cv2.rectangle(
# image,
# (image_bbox[0], image_bbox[1]),
# (image_bbox[0] + image_bbox[2], image_bbox[1] + image_bbox[3]),
# (225,0,0), 2)
# cv2.imshow("out",image)
# cv2.waitKey(0)
value = prediction[0][1]/2
if value > 0.915:
return "real face", '{:.10f}'.format(value)
else:
return "fake face", '{:.10f}'.format(value)
if __name__ == "__main__":
desc = "test"
parser = argparse.ArgumentParser(description=desc)
parser.add_argument(
"--device_id",
type=int,
default=0,
help="which gpu id, [0/1/2/3]")
parser.add_argument(
"--model_dir",
type=str,
default="./resources/anti_spoof_models",
help="model_lib used to test")
parser.add_argument(
"--image_name",
type=str,
default="000_0.bmp",
help="image used to test")
args = parser.parse_args()
# anti_spoofing(args.image_name, args.model_dir, args.device_id)