161 lines
5.0 KiB
Python
161 lines
5.0 KiB
Python
from flask import Flask, request, jsonify
|
||
import os
|
||
import time
|
||
import numpy as np
|
||
import torch
|
||
import nltk
|
||
# load model and options
|
||
from werkzeug.exceptions import RequestEntityTooLarge
|
||
from werkzeug.utils import secure_filename
|
||
|
||
from data import get_test_loader
|
||
from evaluation import AverageMeter, LogCollector, shard_xattn, i2t, t2i
|
||
#from extract_features import feature
|
||
from extract_features import feature, loadModel, imgFeature
|
||
from model import SCAN
|
||
from test_one import encode_img_caps, encode_image, encode_cap
|
||
from vocab import deserialize_vocab
|
||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
|
||
|
||
app = Flask(__name__)
|
||
|
||
ALLOWED_IMG = set(['png', 'jpg', 'jpeg', 'bmp', 'PNG', 'JPG', 'JPEG'])
|
||
# 限制上传的图片最大为10M
|
||
ALLOWED_IMG_SIZE = 10 * 1024 * 1024
|
||
ALLOWED_FILE = set(['zip'])
|
||
|
||
model_path = "./runs/test/model_best.pth.tar"
|
||
data_path = "./data/"
|
||
image_path = "./image/ride.jpg"
|
||
checkpoint = torch.load(model_path)
|
||
opt = checkpoint['opt']
|
||
print(opt)
|
||
|
||
caps_list = []
|
||
with open("test_caps.txt", "r") as f:
|
||
for line in f.readlines():
|
||
line = line.strip("\n")
|
||
caps_list.append(line)
|
||
print(len(caps_list))
|
||
image_list = []
|
||
with open("result.txt", "r") as f:
|
||
for line in f.readlines():
|
||
line = line.strip("\n")
|
||
image_list.append(line.split("#")[0])
|
||
# print(len(image_list))
|
||
# print(image_list[:10])
|
||
id_list = []
|
||
with open("test_ids.txt", "r") as f:
|
||
for line in f.readlines():
|
||
line = line.strip("\n")
|
||
id_list.append(line)
|
||
# print(len(id_list))
|
||
# print(id_list[:10])
|
||
|
||
|
||
if data_path is not None:
|
||
opt.data_path = data_path
|
||
|
||
# load vocabulary used by the model
|
||
vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
|
||
word2idx = vocab.word2idx
|
||
opt.vocab_size = len(vocab)
|
||
|
||
model = SCAN(word2idx, opt)
|
||
model = torch.nn.DataParallel(model)
|
||
model.cuda()
|
||
|
||
# load model state
|
||
model.load_state_dict(checkpoint['model'])
|
||
|
||
print('Loading dataset')
|
||
loadTime = time.time()
|
||
data_loader = get_test_loader("test", opt.data_name, vocab,
|
||
opt.batch_size, 0, opt)
|
||
print('loadTime:{:.4f}'.format(time.time() - loadTime))
|
||
print('Computing results...')
|
||
#img_embs, img_means, cap_embs, cap_lens, cap_means, img_id= encode_data(model, data_loader)
|
||
img_embs, img_means, cap_embs, cap_lens, cap_means = encode_img_caps(model, data_loader)
|
||
print(img_embs.shape, cap_embs.shape)
|
||
img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), 5)])
|
||
featureModel , cfg = loadModel()
|
||
|
||
# 检查上传文件格式
|
||
def check_file_format(file_name, format):
|
||
if '.' in file_name:
|
||
file_format = file_name.rsplit('.')[1]
|
||
if file_format in format:
|
||
return True
|
||
return False
|
||
|
||
# 检查img大小,大于10M抛出异常
|
||
def check_img_size(img_path):
|
||
fsize = os.path.getsize(img_path)
|
||
if fsize > ALLOWED_IMG_SIZE:
|
||
raise RequestEntityTooLarge
|
||
|
||
def create_response(status, textList = None):
|
||
# res为总的json结构体
|
||
res = {}
|
||
res['status'] = status
|
||
#res['message'] = message
|
||
if(textList != None):
|
||
res['textList'] = textList
|
||
return jsonify(res)
|
||
|
||
@app.route('/')
|
||
def hello():
|
||
return 'Hello, World!'
|
||
|
||
@app.route('/text2image', methods=['POST'])
|
||
def text2image():
|
||
str = request.form.get("string")
|
||
cap_emb, cap_len, cap_mean = encode_cap(model, str, vocab)
|
||
print(cap_emb.shape, len(cap_len), cap_mean.shape)
|
||
sims = shard_xattn(model, img_embs, img_means, cap_emb, cap_len, cap_mean, opt, shard_size=1024)
|
||
sims = sims.T
|
||
print(sims.shape)
|
||
|
||
inds = np.argsort(sims[0])[::-1]
|
||
print(inds[:10])
|
||
imgList = []
|
||
for i in inds[:10]:
|
||
print(image_list[5 * int(id_list[5*i])])
|
||
#print(image_list[5 * int(id_list[i])])
|
||
imgList.append(image_list[5 * int(id_list[5*i])])
|
||
return create_response(1, imgList)
|
||
|
||
@app.route('/image2text', methods=['POST'])
|
||
def image2text():
|
||
try:
|
||
f = request.files['file_name']
|
||
if f and check_file_format(f.filename, ALLOWED_IMG):
|
||
img_path = './image/' + secure_filename(f.filename)
|
||
f.save(img_path)
|
||
check_img_size(img_path)
|
||
tic = time.time()
|
||
image_feat = imgFeature(featureModel, cfg, img_path)
|
||
print(image_feat.shape)
|
||
img_emb, img_mean = encode_image(model, image_feat)
|
||
print(img_emb.shape)
|
||
sims = shard_xattn(model, img_emb, img_mean, cap_embs, cap_lens, cap_means, opt, shard_size=10240)
|
||
print(sims.shape)
|
||
inds = np.argsort(sims[0])[::-1]
|
||
# inds = inds.astype("int32")
|
||
print(inds[:10])
|
||
print(inds.dtype)
|
||
textList = []
|
||
for i in inds[:10]:
|
||
textList.append(caps_list[i])
|
||
print(caps_list[i])
|
||
print('time: {:.4f}'.format(time.time() - tic))
|
||
return create_response(1, textList)
|
||
else:
|
||
return create_response('png jpg jpeg bmp are allowed')
|
||
except RequestEntityTooLarge:
|
||
return create_response('image size should be less than 10M')
|
||
|
||
if __name__ == '__main__':
|
||
app.run(host="0.0.0.0", port=5001)
|