Graduation_Project/LHL/app.py

161 lines
5.0 KiB
Python
Raw Normal View History

2024-06-25 11:50:04 +08:00
from flask import Flask, request, jsonify
import os
import time
import numpy as np
import torch
import nltk
# load model and options
from werkzeug.exceptions import RequestEntityTooLarge
from werkzeug.utils import secure_filename
from data import get_test_loader
from evaluation import AverageMeter, LogCollector, shard_xattn, i2t, t2i
#from extract_features import feature
from extract_features import feature, loadModel, imgFeature
from model import SCAN
from test_one import encode_img_caps, encode_image, encode_cap
from vocab import deserialize_vocab
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
app = Flask(__name__)
ALLOWED_IMG = set(['png', 'jpg', 'jpeg', 'bmp', 'PNG', 'JPG', 'JPEG'])
# 限制上传的图片最大为10M
ALLOWED_IMG_SIZE = 10 * 1024 * 1024
ALLOWED_FILE = set(['zip'])
model_path = "./runs/test/model_best.pth.tar"
data_path = "./data/"
image_path = "./image/ride.jpg"
checkpoint = torch.load(model_path)
opt = checkpoint['opt']
print(opt)
caps_list = []
with open("test_caps.txt", "r") as f:
for line in f.readlines():
line = line.strip("\n")
caps_list.append(line)
print(len(caps_list))
image_list = []
with open("result.txt", "r") as f:
for line in f.readlines():
line = line.strip("\n")
image_list.append(line.split("#")[0])
# print(len(image_list))
# print(image_list[:10])
id_list = []
with open("test_ids.txt", "r") as f:
for line in f.readlines():
line = line.strip("\n")
id_list.append(line)
# print(len(id_list))
# print(id_list[:10])
if data_path is not None:
opt.data_path = data_path
# load vocabulary used by the model
vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
word2idx = vocab.word2idx
opt.vocab_size = len(vocab)
model = SCAN(word2idx, opt)
model = torch.nn.DataParallel(model)
model.cuda()
# load model state
model.load_state_dict(checkpoint['model'])
print('Loading dataset')
loadTime = time.time()
data_loader = get_test_loader("test", opt.data_name, vocab,
opt.batch_size, 0, opt)
print('loadTime:{:.4f}'.format(time.time() - loadTime))
print('Computing results...')
#img_embs, img_means, cap_embs, cap_lens, cap_means, img_id= encode_data(model, data_loader)
img_embs, img_means, cap_embs, cap_lens, cap_means = encode_img_caps(model, data_loader)
print(img_embs.shape, cap_embs.shape)
img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), 5)])
featureModel , cfg = loadModel()
# 检查上传文件格式
def check_file_format(file_name, format):
if '.' in file_name:
file_format = file_name.rsplit('.')[1]
if file_format in format:
return True
return False
# 检查img大小大于10M抛出异常
def check_img_size(img_path):
fsize = os.path.getsize(img_path)
if fsize > ALLOWED_IMG_SIZE:
raise RequestEntityTooLarge
def create_response(status, textList = None):
# res为总的json结构体
res = {}
res['status'] = status
#res['message'] = message
if(textList != None):
res['textList'] = textList
return jsonify(res)
@app.route('/')
def hello():
return 'Hello, World!'
@app.route('/text2image', methods=['POST'])
def text2image():
str = request.form.get("string")
cap_emb, cap_len, cap_mean = encode_cap(model, str, vocab)
print(cap_emb.shape, len(cap_len), cap_mean.shape)
sims = shard_xattn(model, img_embs, img_means, cap_emb, cap_len, cap_mean, opt, shard_size=1024)
sims = sims.T
print(sims.shape)
inds = np.argsort(sims[0])[::-1]
print(inds[:10])
imgList = []
for i in inds[:10]:
print(image_list[5 * int(id_list[5*i])])
#print(image_list[5 * int(id_list[i])])
imgList.append(image_list[5 * int(id_list[5*i])])
return create_response(1, imgList)
@app.route('/image2text', methods=['POST'])
def image2text():
try:
f = request.files['file_name']
if f and check_file_format(f.filename, ALLOWED_IMG):
img_path = './image/' + secure_filename(f.filename)
f.save(img_path)
check_img_size(img_path)
tic = time.time()
image_feat = imgFeature(featureModel, cfg, img_path)
print(image_feat.shape)
img_emb, img_mean = encode_image(model, image_feat)
print(img_emb.shape)
sims = shard_xattn(model, img_emb, img_mean, cap_embs, cap_lens, cap_means, opt, shard_size=10240)
print(sims.shape)
inds = np.argsort(sims[0])[::-1]
# inds = inds.astype("int32")
print(inds[:10])
print(inds.dtype)
textList = []
for i in inds[:10]:
textList.append(caps_list[i])
print(caps_list[i])
print('time: {:.4f}'.format(time.time() - tic))
return create_response(1, textList)
else:
return create_response('png jpg jpeg bmp are allowed')
except RequestEntityTooLarge:
return create_response('image size should be less than 10M')
if __name__ == '__main__':
app.run(host="0.0.0.0", port=5001)