161 lines
5.0 KiB
Python
161 lines
5.0 KiB
Python
|
from flask import Flask, request, jsonify
|
|||
|
import os
|
|||
|
import time
|
|||
|
import numpy as np
|
|||
|
import torch
|
|||
|
import nltk
|
|||
|
# load model and options
|
|||
|
from werkzeug.exceptions import RequestEntityTooLarge
|
|||
|
from werkzeug.utils import secure_filename
|
|||
|
|
|||
|
from data import get_test_loader
|
|||
|
from evaluation import AverageMeter, LogCollector, shard_xattn, i2t, t2i
|
|||
|
#from extract_features import feature
|
|||
|
from extract_features import feature, loadModel, imgFeature
|
|||
|
from model import SCAN
|
|||
|
from test_one import encode_img_caps, encode_image, encode_cap
|
|||
|
from vocab import deserialize_vocab
|
|||
|
|
|||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
|
|||
|
|
|||
|
app = Flask(__name__)
|
|||
|
|
|||
|
ALLOWED_IMG = set(['png', 'jpg', 'jpeg', 'bmp', 'PNG', 'JPG', 'JPEG'])
|
|||
|
# 限制上传的图片最大为10M
|
|||
|
ALLOWED_IMG_SIZE = 10 * 1024 * 1024
|
|||
|
ALLOWED_FILE = set(['zip'])
|
|||
|
|
|||
|
model_path = "./runs/test/model_best.pth.tar"
|
|||
|
data_path = "./data/"
|
|||
|
image_path = "./image/ride.jpg"
|
|||
|
checkpoint = torch.load(model_path)
|
|||
|
opt = checkpoint['opt']
|
|||
|
print(opt)
|
|||
|
|
|||
|
caps_list = []
|
|||
|
with open("test_caps.txt", "r") as f:
|
|||
|
for line in f.readlines():
|
|||
|
line = line.strip("\n")
|
|||
|
caps_list.append(line)
|
|||
|
print(len(caps_list))
|
|||
|
image_list = []
|
|||
|
with open("result.txt", "r") as f:
|
|||
|
for line in f.readlines():
|
|||
|
line = line.strip("\n")
|
|||
|
image_list.append(line.split("#")[0])
|
|||
|
# print(len(image_list))
|
|||
|
# print(image_list[:10])
|
|||
|
id_list = []
|
|||
|
with open("test_ids.txt", "r") as f:
|
|||
|
for line in f.readlines():
|
|||
|
line = line.strip("\n")
|
|||
|
id_list.append(line)
|
|||
|
# print(len(id_list))
|
|||
|
# print(id_list[:10])
|
|||
|
|
|||
|
|
|||
|
if data_path is not None:
|
|||
|
opt.data_path = data_path
|
|||
|
|
|||
|
# load vocabulary used by the model
|
|||
|
vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
|
|||
|
word2idx = vocab.word2idx
|
|||
|
opt.vocab_size = len(vocab)
|
|||
|
|
|||
|
model = SCAN(word2idx, opt)
|
|||
|
model = torch.nn.DataParallel(model)
|
|||
|
model.cuda()
|
|||
|
|
|||
|
# load model state
|
|||
|
model.load_state_dict(checkpoint['model'])
|
|||
|
|
|||
|
print('Loading dataset')
|
|||
|
loadTime = time.time()
|
|||
|
data_loader = get_test_loader("test", opt.data_name, vocab,
|
|||
|
opt.batch_size, 0, opt)
|
|||
|
print('loadTime:{:.4f}'.format(time.time() - loadTime))
|
|||
|
print('Computing results...')
|
|||
|
#img_embs, img_means, cap_embs, cap_lens, cap_means, img_id= encode_data(model, data_loader)
|
|||
|
img_embs, img_means, cap_embs, cap_lens, cap_means = encode_img_caps(model, data_loader)
|
|||
|
print(img_embs.shape, cap_embs.shape)
|
|||
|
img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), 5)])
|
|||
|
featureModel , cfg = loadModel()
|
|||
|
|
|||
|
# 检查上传文件格式
|
|||
|
def check_file_format(file_name, format):
|
|||
|
if '.' in file_name:
|
|||
|
file_format = file_name.rsplit('.')[1]
|
|||
|
if file_format in format:
|
|||
|
return True
|
|||
|
return False
|
|||
|
|
|||
|
# 检查img大小,大于10M抛出异常
|
|||
|
def check_img_size(img_path):
|
|||
|
fsize = os.path.getsize(img_path)
|
|||
|
if fsize > ALLOWED_IMG_SIZE:
|
|||
|
raise RequestEntityTooLarge
|
|||
|
|
|||
|
def create_response(status, textList = None):
|
|||
|
# res为总的json结构体
|
|||
|
res = {}
|
|||
|
res['status'] = status
|
|||
|
#res['message'] = message
|
|||
|
if(textList != None):
|
|||
|
res['textList'] = textList
|
|||
|
return jsonify(res)
|
|||
|
|
|||
|
@app.route('/')
|
|||
|
def hello():
|
|||
|
return 'Hello, World!'
|
|||
|
|
|||
|
@app.route('/text2image', methods=['POST'])
|
|||
|
def text2image():
|
|||
|
str = request.form.get("string")
|
|||
|
cap_emb, cap_len, cap_mean = encode_cap(model, str, vocab)
|
|||
|
print(cap_emb.shape, len(cap_len), cap_mean.shape)
|
|||
|
sims = shard_xattn(model, img_embs, img_means, cap_emb, cap_len, cap_mean, opt, shard_size=1024)
|
|||
|
sims = sims.T
|
|||
|
print(sims.shape)
|
|||
|
|
|||
|
inds = np.argsort(sims[0])[::-1]
|
|||
|
print(inds[:10])
|
|||
|
imgList = []
|
|||
|
for i in inds[:10]:
|
|||
|
print(image_list[5 * int(id_list[5*i])])
|
|||
|
#print(image_list[5 * int(id_list[i])])
|
|||
|
imgList.append(image_list[5 * int(id_list[5*i])])
|
|||
|
return create_response(1, imgList)
|
|||
|
|
|||
|
@app.route('/image2text', methods=['POST'])
|
|||
|
def image2text():
|
|||
|
try:
|
|||
|
f = request.files['file_name']
|
|||
|
if f and check_file_format(f.filename, ALLOWED_IMG):
|
|||
|
img_path = './image/' + secure_filename(f.filename)
|
|||
|
f.save(img_path)
|
|||
|
check_img_size(img_path)
|
|||
|
tic = time.time()
|
|||
|
image_feat = imgFeature(featureModel, cfg, img_path)
|
|||
|
print(image_feat.shape)
|
|||
|
img_emb, img_mean = encode_image(model, image_feat)
|
|||
|
print(img_emb.shape)
|
|||
|
sims = shard_xattn(model, img_emb, img_mean, cap_embs, cap_lens, cap_means, opt, shard_size=10240)
|
|||
|
print(sims.shape)
|
|||
|
inds = np.argsort(sims[0])[::-1]
|
|||
|
# inds = inds.astype("int32")
|
|||
|
print(inds[:10])
|
|||
|
print(inds.dtype)
|
|||
|
textList = []
|
|||
|
for i in inds[:10]:
|
|||
|
textList.append(caps_list[i])
|
|||
|
print(caps_list[i])
|
|||
|
print('time: {:.4f}'.format(time.time() - tic))
|
|||
|
return create_response(1, textList)
|
|||
|
else:
|
|||
|
return create_response('png jpg jpeg bmp are allowed')
|
|||
|
except RequestEntityTooLarge:
|
|||
|
return create_response('image size should be less than 10M')
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
app.run(host="0.0.0.0", port=5001)
|