Graduation_Project/LHL/app.py

161 lines
5.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from flask import Flask, request, jsonify
import os
import time
import numpy as np
import torch
import nltk
# load model and options
from werkzeug.exceptions import RequestEntityTooLarge
from werkzeug.utils import secure_filename
from data import get_test_loader
from evaluation import AverageMeter, LogCollector, shard_xattn, i2t, t2i
#from extract_features import feature
from extract_features import feature, loadModel, imgFeature
from model import SCAN
from test_one import encode_img_caps, encode_image, encode_cap
from vocab import deserialize_vocab
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
app = Flask(__name__)
ALLOWED_IMG = set(['png', 'jpg', 'jpeg', 'bmp', 'PNG', 'JPG', 'JPEG'])
# 限制上传的图片最大为10M
ALLOWED_IMG_SIZE = 10 * 1024 * 1024
ALLOWED_FILE = set(['zip'])
model_path = "./runs/test/model_best.pth.tar"
data_path = "./data/"
image_path = "./image/ride.jpg"
checkpoint = torch.load(model_path)
opt = checkpoint['opt']
print(opt)
caps_list = []
with open("test_caps.txt", "r") as f:
for line in f.readlines():
line = line.strip("\n")
caps_list.append(line)
print(len(caps_list))
image_list = []
with open("result.txt", "r") as f:
for line in f.readlines():
line = line.strip("\n")
image_list.append(line.split("#")[0])
# print(len(image_list))
# print(image_list[:10])
id_list = []
with open("test_ids.txt", "r") as f:
for line in f.readlines():
line = line.strip("\n")
id_list.append(line)
# print(len(id_list))
# print(id_list[:10])
if data_path is not None:
opt.data_path = data_path
# load vocabulary used by the model
vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
word2idx = vocab.word2idx
opt.vocab_size = len(vocab)
model = SCAN(word2idx, opt)
model = torch.nn.DataParallel(model)
model.cuda()
# load model state
model.load_state_dict(checkpoint['model'])
print('Loading dataset')
loadTime = time.time()
data_loader = get_test_loader("test", opt.data_name, vocab,
opt.batch_size, 0, opt)
print('loadTime:{:.4f}'.format(time.time() - loadTime))
print('Computing results...')
#img_embs, img_means, cap_embs, cap_lens, cap_means, img_id= encode_data(model, data_loader)
img_embs, img_means, cap_embs, cap_lens, cap_means = encode_img_caps(model, data_loader)
print(img_embs.shape, cap_embs.shape)
img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), 5)])
featureModel , cfg = loadModel()
# 检查上传文件格式
def check_file_format(file_name, format):
if '.' in file_name:
file_format = file_name.rsplit('.')[1]
if file_format in format:
return True
return False
# 检查img大小大于10M抛出异常
def check_img_size(img_path):
fsize = os.path.getsize(img_path)
if fsize > ALLOWED_IMG_SIZE:
raise RequestEntityTooLarge
def create_response(status, textList = None):
# res为总的json结构体
res = {}
res['status'] = status
#res['message'] = message
if(textList != None):
res['textList'] = textList
return jsonify(res)
@app.route('/')
def hello():
return 'Hello, World!'
@app.route('/text2image', methods=['POST'])
def text2image():
str = request.form.get("string")
cap_emb, cap_len, cap_mean = encode_cap(model, str, vocab)
print(cap_emb.shape, len(cap_len), cap_mean.shape)
sims = shard_xattn(model, img_embs, img_means, cap_emb, cap_len, cap_mean, opt, shard_size=1024)
sims = sims.T
print(sims.shape)
inds = np.argsort(sims[0])[::-1]
print(inds[:10])
imgList = []
for i in inds[:10]:
print(image_list[5 * int(id_list[5*i])])
#print(image_list[5 * int(id_list[i])])
imgList.append(image_list[5 * int(id_list[5*i])])
return create_response(1, imgList)
@app.route('/image2text', methods=['POST'])
def image2text():
try:
f = request.files['file_name']
if f and check_file_format(f.filename, ALLOWED_IMG):
img_path = './image/' + secure_filename(f.filename)
f.save(img_path)
check_img_size(img_path)
tic = time.time()
image_feat = imgFeature(featureModel, cfg, img_path)
print(image_feat.shape)
img_emb, img_mean = encode_image(model, image_feat)
print(img_emb.shape)
sims = shard_xattn(model, img_emb, img_mean, cap_embs, cap_lens, cap_means, opt, shard_size=10240)
print(sims.shape)
inds = np.argsort(sims[0])[::-1]
# inds = inds.astype("int32")
print(inds[:10])
print(inds.dtype)
textList = []
for i in inds[:10]:
textList.append(caps_list[i])
print(caps_list[i])
print('time: {:.4f}'.format(time.time() - tic))
return create_response(1, textList)
else:
return create_response('png jpg jpeg bmp are allowed')
except RequestEntityTooLarge:
return create_response('image size should be less than 10M')
if __name__ == '__main__':
app.run(host="0.0.0.0", port=5001)