from flask import Flask, request, jsonify import os import time import numpy as np import torch import nltk # load model and options from werkzeug.exceptions import RequestEntityTooLarge from werkzeug.utils import secure_filename from data import get_test_loader from evaluation import AverageMeter, LogCollector, shard_xattn, i2t, t2i #from extract_features import feature from extract_features import feature, loadModel, imgFeature from model import SCAN from test_one import encode_img_caps, encode_image, encode_cap from vocab import deserialize_vocab os.environ["CUDA_VISIBLE_DEVICES"] = "2" app = Flask(__name__) ALLOWED_IMG = set(['png', 'jpg', 'jpeg', 'bmp', 'PNG', 'JPG', 'JPEG']) # 限制上传的图片最大为10M ALLOWED_IMG_SIZE = 10 * 1024 * 1024 ALLOWED_FILE = set(['zip']) model_path = "./runs/test/model_best.pth.tar" data_path = "./data/" image_path = "./image/ride.jpg" checkpoint = torch.load(model_path) opt = checkpoint['opt'] print(opt) caps_list = [] with open("test_caps.txt", "r") as f: for line in f.readlines(): line = line.strip("\n") caps_list.append(line) print(len(caps_list)) image_list = [] with open("result.txt", "r") as f: for line in f.readlines(): line = line.strip("\n") image_list.append(line.split("#")[0]) # print(len(image_list)) # print(image_list[:10]) id_list = [] with open("test_ids.txt", "r") as f: for line in f.readlines(): line = line.strip("\n") id_list.append(line) # print(len(id_list)) # print(id_list[:10]) if data_path is not None: opt.data_path = data_path # load vocabulary used by the model vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name)) word2idx = vocab.word2idx opt.vocab_size = len(vocab) model = SCAN(word2idx, opt) model = torch.nn.DataParallel(model) model.cuda() # load model state model.load_state_dict(checkpoint['model']) print('Loading dataset') loadTime = time.time() data_loader = get_test_loader("test", opt.data_name, vocab, opt.batch_size, 0, opt) print('loadTime:{:.4f}'.format(time.time() - loadTime)) print('Computing results...') #img_embs, img_means, cap_embs, cap_lens, cap_means, img_id= encode_data(model, data_loader) img_embs, img_means, cap_embs, cap_lens, cap_means = encode_img_caps(model, data_loader) print(img_embs.shape, cap_embs.shape) img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), 5)]) featureModel , cfg = loadModel() # 检查上传文件格式 def check_file_format(file_name, format): if '.' in file_name: file_format = file_name.rsplit('.')[1] if file_format in format: return True return False # 检查img大小,大于10M抛出异常 def check_img_size(img_path): fsize = os.path.getsize(img_path) if fsize > ALLOWED_IMG_SIZE: raise RequestEntityTooLarge def create_response(status, textList = None): # res为总的json结构体 res = {} res['status'] = status #res['message'] = message if(textList != None): res['textList'] = textList return jsonify(res) @app.route('/') def hello(): return 'Hello, World!' @app.route('/text2image', methods=['POST']) def text2image(): str = request.form.get("string") cap_emb, cap_len, cap_mean = encode_cap(model, str, vocab) print(cap_emb.shape, len(cap_len), cap_mean.shape) sims = shard_xattn(model, img_embs, img_means, cap_emb, cap_len, cap_mean, opt, shard_size=1024) sims = sims.T print(sims.shape) inds = np.argsort(sims[0])[::-1] print(inds[:10]) imgList = [] for i in inds[:10]: print(image_list[5 * int(id_list[5*i])]) #print(image_list[5 * int(id_list[i])]) imgList.append(image_list[5 * int(id_list[5*i])]) return create_response(1, imgList) @app.route('/image2text', methods=['POST']) def image2text(): try: f = request.files['file_name'] if f and check_file_format(f.filename, ALLOWED_IMG): img_path = './image/' + secure_filename(f.filename) f.save(img_path) check_img_size(img_path) tic = time.time() image_feat = imgFeature(featureModel, cfg, img_path) print(image_feat.shape) img_emb, img_mean = encode_image(model, image_feat) print(img_emb.shape) sims = shard_xattn(model, img_emb, img_mean, cap_embs, cap_lens, cap_means, opt, shard_size=10240) print(sims.shape) inds = np.argsort(sims[0])[::-1] # inds = inds.astype("int32") print(inds[:10]) print(inds.dtype) textList = [] for i in inds[:10]: textList.append(caps_list[i]) print(caps_list[i]) print('time: {:.4f}'.format(time.time() - tic)) return create_response(1, textList) else: return create_response('png jpg jpeg bmp are allowed') except RequestEntityTooLarge: return create_response('image size should be less than 10M') if __name__ == '__main__': app.run(host="0.0.0.0", port=5001)