450 lines
17 KiB
Python
450 lines
17 KiB
Python
import time
|
||
|
||
import faiss
|
||
from flask import Flask, render_template, request, jsonify, send_from_directory
|
||
from markupsafe import escape, escape_silent
|
||
from werkzeug.utils import secure_filename
|
||
|
||
from anti import anti_spoofing, load_anti_model
|
||
from face_api import load_arcface_model, load_npy, findOne, load_image, face_verification, findAll, add_one_to_database, \
|
||
get_claster_tmp_file_embedding, cluster, detect_video
|
||
from gender_age import set_gender_conf, gender_age, load_gender_model
|
||
from retinaface_detect import load_retinaface_model, detect_one, set_retinaface_conf
|
||
from werkzeug.exceptions import RequestEntityTooLarge
|
||
import zipfile
|
||
import os
|
||
import shutil
|
||
import re
|
||
import numpy as np
|
||
import torch
|
||
|
||
ALLOWED_IMG = set(['png', 'jpg', 'jpeg', 'bmp', 'PNG', 'JPG', 'JPEG'])
|
||
# 限制上传的图片最大为10M
|
||
ALLOWED_IMG_SIZE = 10 * 1024 * 1024
|
||
ALLOWED_FILE = set(['zip'])
|
||
ALLOWED_VIDEO = set(['mp4'])
|
||
app = Flask(__name__)
|
||
|
||
# 限制上传的文件最大为100M
|
||
app.config['MAX_CONTENT_LENGTH'] = 100 * 1024 * 1024
|
||
# 使用jsonify,避免中文乱码
|
||
app.config['JSON_AS_ASCII'] = False
|
||
|
||
# 设置使用CPU或者GPU(传入cuda)
|
||
cpu_or_cuda = "cuda" if torch.cuda.is_available() else "cpu"
|
||
# 加载人脸识别模型
|
||
arcface_model = load_arcface_model("./model/backbone100.pth", cpu_or_cuda=cpu_or_cuda)
|
||
# 加载人脸检测模型
|
||
retinaface_args = set_retinaface_conf(cpu_or_cuda=cpu_or_cuda)
|
||
retinaface_model = load_retinaface_model(retinaface_args)
|
||
# 加载性别年龄识别模型
|
||
gender_args = set_gender_conf()
|
||
gender_model = load_gender_model(gender_args, 'fc1')
|
||
anti_spoofing_model_path = "model/anti_spoof_models"
|
||
anti_model = load_anti_model(anti_spoofing_model_path, cpu_or_cuda)
|
||
|
||
|
||
# 读取人脸库
|
||
|
||
|
||
@app.route('/')
|
||
def index():
|
||
return "model"
|
||
|
||
|
||
@app.route('/hello')
|
||
@app.route('/hello/<name>')
|
||
def hello(name=None):
|
||
return render_template('hello.html', name=name)
|
||
|
||
|
||
@app.route('/user', methods=['GET'])
|
||
def show_user_name():
|
||
return request.args.get('username', '')
|
||
|
||
|
||
# 创建返回的json数据
|
||
# 函数参数用是否=None判断,函数中定义的data,result用true,false判断
|
||
def create_response(status, name=None, distance=None, verification=None, gender=None, age=None, num=None, anti=None,
|
||
score=None, box_and_point=None, addfile_names=None,fail_names=None,database_name=None,msg=None,
|
||
delete_names=None,not_exist_names=None):
|
||
# res为总的json结构体
|
||
res = {}
|
||
res['status'] = status
|
||
|
||
data = {}
|
||
try:
|
||
data["box_and_point"] = box_and_point.tolist()
|
||
except AttributeError:
|
||
pass
|
||
if anti != None and score != None:
|
||
liveness = {}
|
||
liveness["spoofing"] = anti
|
||
liveness['score'] = score
|
||
data['liveness'] = liveness
|
||
if distance!=None:
|
||
data['distance'] = float(distance)
|
||
if verification!=None:
|
||
data['verification'] = verification
|
||
if num!=None:
|
||
data['number'] = num
|
||
if gender!=None:
|
||
data['gender'] = gender
|
||
if age!=None:
|
||
data['age'] = age
|
||
if name!=None:
|
||
data['name'] = name
|
||
if data:
|
||
res['data'] = data
|
||
|
||
# 数据库增删接口返回数据
|
||
result = {}
|
||
if msg!=None:
|
||
res['msg'] = msg
|
||
if database_name!=None:
|
||
result['database_name'] = database_name
|
||
# 增加人脸
|
||
if addfile_names!=None or fail_names!=None:
|
||
result['success_names'] = addfile_names
|
||
result['fail_names'] = fail_names
|
||
# 删除人脸
|
||
if delete_names!=None or not_exist_names!=None:
|
||
result['delete_names'] = delete_names
|
||
result['not_exist_names'] = not_exist_names
|
||
if result:
|
||
res['result'] = result
|
||
|
||
return jsonify(res)
|
||
|
||
|
||
# 创建cluster接口返回的json数据
|
||
def create_cluster_response(status, all_cluster):
|
||
res = {}
|
||
data = {}
|
||
for index, cluster in enumerate(all_cluster):
|
||
data['cluster' + str(index)] = cluster
|
||
res['data'] = data
|
||
res['status'] = status
|
||
return res
|
||
|
||
|
||
# 检查上传文件格式
|
||
def check_file_format(file_name, format):
|
||
if '.' in file_name:
|
||
file_format = file_name.rsplit('.')[1]
|
||
if file_format in format:
|
||
return True
|
||
return False
|
||
|
||
|
||
# 检查img大小,大于10M抛出异常
|
||
def check_img_size(img_path):
|
||
fsize = os.path.getsize(img_path)
|
||
if fsize > ALLOWED_IMG_SIZE:
|
||
raise RequestEntityTooLarge
|
||
|
||
|
||
# 解压zip文件存到某路径:
|
||
def unzip(zip_src, dst_dir):
|
||
f = zipfile.is_zipfile(zip_src)
|
||
if f:
|
||
fz = zipfile.ZipFile(zip_src, 'r')
|
||
for file in fz.namelist():
|
||
fz.extract(file, dst_dir)
|
||
return True
|
||
else:
|
||
return False
|
||
|
||
|
||
# 解压文件
|
||
def un_zip(file_path, output_path):
|
||
zip_file = zipfile.ZipFile(file_path)
|
||
if os.path.isdir(output_path):
|
||
pass
|
||
else:
|
||
os.mkdir(output_path)
|
||
zip_file.extractall(output_path)
|
||
# for names in zip_file.namelist():
|
||
# zip_file.extract(names,output_path)
|
||
zip_file.close()
|
||
|
||
|
||
# 人脸识别、性别年龄识别
|
||
@app.route('/recognition', methods=['POST'])
|
||
def recognition():
|
||
try:
|
||
f = request.files['file_name']
|
||
if f and check_file_format(f.filename, ALLOWED_IMG):
|
||
img_path = './img/recognition/' + secure_filename(f.filename)
|
||
f.save(img_path)
|
||
check_img_size(img_path)
|
||
# img3 = load_image('./file/'+secure_filename(f.filename))
|
||
# img3 = torch.from_numpy(img3)
|
||
tic = time.time()
|
||
img3, box_and_point = detect_one(img_path, retinaface_model, retinaface_args)
|
||
print('detect time: {:.4f}'.format(time.time() - tic))
|
||
if len(img3) == 0:
|
||
return create_response('no face')
|
||
elif len(img3) > 1:
|
||
namelist = findAll(img3, arcface_model, index, database_name_list, cpu_or_cuda)
|
||
gender_list, age_list = [], []
|
||
# gender_list, age_list = gender_age(img3, gender_model)
|
||
res = create_response('success', namelist, gender=gender_list, age=age_list,
|
||
box_and_point=box_and_point)
|
||
else:
|
||
b = box_and_point[0]
|
||
w = b[2] - b[0]
|
||
h = b[3] - b[1]
|
||
b[2] = w
|
||
b[3] = h
|
||
label, value = anti_spoofing(img_path, anti_spoofing_model_path, cpu_or_cuda, np.array(b[:4], int),
|
||
anti_model)
|
||
# print(index,database_name_list)
|
||
name, distance = findOne(img3, arcface_model, index, database_name_list, cpu_or_cuda)
|
||
gender_list, age_list = [], []
|
||
# gender_list, age_list = gender_age(img3, gender_model)
|
||
res = create_response('success', name, gender=gender_list, age=age_list, distance=distance,
|
||
anti=label, score=value, box_and_point=box_and_point)
|
||
return res
|
||
else:
|
||
return create_response('png jpg jpeg bmp are allowed')
|
||
except RequestEntityTooLarge:
|
||
return create_response('image size should be less than 10M')
|
||
|
||
|
||
# 两张图片比对
|
||
@app.route('/compare', methods=['POST'])
|
||
def compare_file():
|
||
try:
|
||
file1 = request.files['file1_name']
|
||
file2 = request.files['file2_name']
|
||
if file1 and check_file_format(file1.filename, ALLOWED_IMG) and file2 and check_file_format(file2.filename,
|
||
ALLOWED_IMG):
|
||
img1_path = './img/compare/' + secure_filename(file1.filename)
|
||
img2_path = './img/compare/' + secure_filename(file2.filename)
|
||
file1.save(img1_path)
|
||
file2.save(img2_path)
|
||
check_img_size(img1_path)
|
||
check_img_size(img2_path)
|
||
img1, box_and_point1 = detect_one(img1_path, retinaface_model,
|
||
retinaface_args)
|
||
img2, box_and_point2 = detect_one(img2_path, retinaface_model, retinaface_args)
|
||
if len(img1) == 1 and len(img2) == 1:
|
||
result,distance = face_verification(img1, img2, arcface_model, cpu_or_cuda)
|
||
print(result,distance)
|
||
return create_response('success', verification=result,distance=distance)
|
||
else:
|
||
return create_response('image contains no face or more than 1 face')
|
||
else:
|
||
return create_response('png jpg jpeg bmp are allowed')
|
||
except RequestEntityTooLarge:
|
||
return create_response('image size should be less than 10M')
|
||
|
||
|
||
# 数据库增加人脸,可实现向“现有/新建”数据库增加“单张/多张”人脸
|
||
# 增和改
|
||
@app.route('/databaseAdd', methods=['POST'])
|
||
def DB_add_face():
|
||
try:
|
||
# 上传人脸图片(>=1)
|
||
# key都为file_list,value为不同的值可实现批量上传图片
|
||
upload_files = request.files.getlist("file_list")
|
||
# '',[],{},0都可以视为False
|
||
if not upload_files:
|
||
msg = "上传文件为空"
|
||
return create_response(0,msg=msg)
|
||
database_name = request.form.get("database_name")
|
||
database_path = "./Database/" + database_name + ".npy"
|
||
if not os.path.exists(database_path):
|
||
msg = "数据库不存在"
|
||
return create_response(0,msg=msg)
|
||
# 数据库中已存在的人名
|
||
names = load_npy(database_path).keys()
|
||
# print(names)
|
||
|
||
# 这是服务器上用于暂存上传图片的文件夹,每次上传前重建,使用后删除
|
||
# 后面可根据需要改为定期删除
|
||
file_temp_path = './img/uploadNew/'
|
||
if not os.path.exists(file_temp_path):
|
||
os.makedirs(file_temp_path)
|
||
|
||
# 正则表达式用于提取文件名中的中文,用于.npy中的keys
|
||
r = re.compile('[\u4e00-\u9fa5]+')
|
||
# 分别存取添加成功或失败的名字
|
||
success_names = []
|
||
fail_names = {}
|
||
# 添加失败的两种情况:格式错误或已经存在
|
||
format_wrong = []
|
||
alreadyExist = []
|
||
# 分别处理每一张图片,先判断格式对不对,再判断是否存在
|
||
for file in upload_files:
|
||
filename = file.filename
|
||
name = r.findall(filename)[0]
|
||
if file and check_file_format(filename, ALLOWED_IMG):
|
||
if name in names:
|
||
alreadyExist.append(name)
|
||
continue
|
||
save_path = file_temp_path + filename
|
||
file.save(save_path)
|
||
check_img_size(save_path)
|
||
img_file, box_and_point = detect_one(save_path, retinaface_model, retinaface_args)
|
||
add_one_to_database(img=img_file, model=arcface_model, name=name, database_path=database_path,
|
||
cpu_or_cuda=cpu_or_cuda)
|
||
success_names.append(name)
|
||
else:
|
||
format_wrong.append(name)
|
||
continue
|
||
shutil.rmtree(file_temp_path)
|
||
# 如果有错误情况
|
||
if format_wrong or alreadyExist:
|
||
status = 0
|
||
else:
|
||
status = 1
|
||
fail_names['formatWrong'] = format_wrong
|
||
fail_names['alreadyExist'] = alreadyExist
|
||
|
||
return create_response(status=status,addfile_names=success_names,fail_names=fail_names,database_name=database_name,msg="新增人脸操作执行完成")
|
||
except RequestEntityTooLarge:
|
||
return create_response(0,msg='image size should be less than 10M')
|
||
|
||
|
||
# 数据库删除人脸,可实现在现有数据库中删除’单/多‘张人脸
|
||
@app.route('/databaseDelete', methods=['POST'])
|
||
def DB_delete_face():
|
||
try:
|
||
delete_names = request.form.getlist("delete_names")
|
||
database_name = request.form.get("database_name")
|
||
database_path = "./Database/" + database_name + ".npy"
|
||
if not os.path.exists(database_path):
|
||
msg = "数据库不存在"
|
||
return create_response(0,msg=msg)
|
||
if not delete_names:
|
||
msg = "delete_names参数为空"
|
||
return create_response(0,msg=msg)
|
||
k_v = load_npy(database_path)
|
||
print(k_v.keys())
|
||
success_list = []
|
||
fail_list = []
|
||
for name in delete_names:
|
||
if name in k_v.keys():
|
||
del k_v[name]
|
||
success_list.append(name)
|
||
else:
|
||
fail_list.append(name)
|
||
continue
|
||
np.save(database_path, k_v)
|
||
status = 1
|
||
if fail_list:
|
||
status = 0
|
||
return create_response(status=status,delete_names=success_list,not_exist_names=fail_list,database_name=database_name,
|
||
msg="删除人脸操作完成")
|
||
except RequestEntityTooLarge:
|
||
return create_response(0,'image size should be less than 10M')
|
||
|
||
|
||
# 以图搜图接口:
|
||
# 上传图片压缩包建图片库
|
||
@app.route('/uploadZip', methods=['POST'])
|
||
def upload_Zip():
|
||
try:
|
||
zip = request.files['zip_name']
|
||
dst_dir = './img/search/'
|
||
if unzip(zip, dst_dir):
|
||
return create_response('upload zip success')
|
||
else:
|
||
return create_response('upload zip file please')
|
||
except RequestEntityTooLarge:
|
||
return create_response('image size should be less than 10M')
|
||
|
||
|
||
# 以图搜图
|
||
@app.route('/imgSearchImg', methods=['POST'])
|
||
def img_search_img():
|
||
searchfile = './img/search/face'
|
||
try:
|
||
file = request.files['img_name']
|
||
if file and check_file_format(file.filename, ALLOWED_IMG):
|
||
img_path = './img/search/' + secure_filename(file.filename)
|
||
file.save(img_path)
|
||
check_img_size(img_path)
|
||
img, box_and_point = detect_one(img_path, retinaface_model,
|
||
retinaface_args)
|
||
if len(img) == 1:
|
||
Onename = []
|
||
num = 0
|
||
for filenames in os.listdir(searchfile):
|
||
imgpath = os.path.join(searchfile, filenames)
|
||
imgdata, box_and_point = detect_one(imgpath, retinaface_model, retinaface_args)
|
||
result = face_verification(img, imgdata, arcface_model, cpu_or_cuda)
|
||
isOne, distance = result.split(' ', -1)[0], result.split(' ', -1)[1]
|
||
if isOne == 'same':
|
||
Onename.append(filenames)
|
||
num += 1
|
||
return create_response('success', name=Onename, num=num)
|
||
else:
|
||
return create_response('image contains no face or more than 1 face')
|
||
else:
|
||
return create_response('png jpg jpeg bmp are allowed')
|
||
except RequestEntityTooLarge:
|
||
return create_response('image size should be less than 10M')
|
||
|
||
|
||
# 人脸聚类接口
|
||
@app.route('/cluster', methods=['POST'])
|
||
def zip_cluster():
|
||
try:
|
||
f = request.files['file_name']
|
||
if f and check_file_format(f.filename, ALLOWED_FILE):
|
||
zip_name = secure_filename(f.filename)
|
||
f.save('./img/cluster_tmp_file/' + zip_name)
|
||
un_zip('./img/cluster_tmp_file/' + zip_name, './img/cluster_tmp_file/')
|
||
emb_list, name_list = get_claster_tmp_file_embedding("./img/cluster_tmp_file/" + zip_name.rsplit('.')[0],
|
||
retinaface_model,
|
||
retinaface_args, arcface_model, cpu_or_cuda)
|
||
return create_cluster_response("success", cluster(emb_list, name_list))
|
||
else:
|
||
return create_response('zip are allowed')
|
||
except RequestEntityTooLarge:
|
||
return create_response('file size should be less than 100M')
|
||
|
||
|
||
# 视频识别接口
|
||
@app.route('/videorecognition', methods=['POST'])
|
||
def video_recognition():
|
||
try:
|
||
f = request.files['file_name']
|
||
if f and check_file_format(f.filename, ALLOWED_VIDEO):
|
||
video_name = secure_filename(f.filename)
|
||
f.save('./video/' + video_name)
|
||
detect_video('./video/' + video_name, './videoout/' + video_name, retinaface_model, arcface_model, k_v,
|
||
retinaface_args)
|
||
return create_response("success")
|
||
else:
|
||
return create_response('mp4 are allowed')
|
||
except RequestEntityTooLarge:
|
||
return create_response('file size should be less than 100M')
|
||
|
||
|
||
@app.route('/download/<string:filename>', methods=['GET'])
|
||
def download(filename):
|
||
if os.path.isfile(os.path.join('./videoout/', filename)):
|
||
return send_from_directory('./videoout/', filename, as_attachment=True)
|
||
else:
|
||
return create_response("Download failed")
|
||
|
||
|
||
if __name__ == '__main__':
|
||
k_v = load_npy("./Database/student.npy")
|
||
database_name_list = list(k_v.keys())
|
||
vector_list = np.array(list(k_v.values()))
|
||
print(vector_list.shape)
|
||
#print(database_name_list)
|
||
nlist = 50
|
||
quantizer = faiss.IndexFlatL2(512) # the other index
|
||
index = faiss.IndexIVFFlat(quantizer, 512, nlist, faiss.METRIC_L2)
|
||
index.train(vector_list)
|
||
# index = faiss.IndexFlatL2(512)
|
||
index.add(vector_list)
|
||
index.nprobe = 50
|
||
app.run(host="0.0.0.0", port=5000)
|