Graduation_Project/LHL/test_one.py

351 lines
12 KiB
Python

import os
import time
#include
import numpy as np
import torch
import nltk
# load model and options
from data import get_test_loader
from evaluation import AverageMeter, LogCollector, shard_xattn, i2t, t2i
#from extract_features import feature
from extract_features import feature
from model import SCAN
from vocab import deserialize_vocab
def encode_img_caps(model, data_loader, log_step=100, logging=print):
"""Encode all images and captions loadable by `data_loader`
"""
batch_time = AverageMeter()
val_logger = LogCollector()
# switch to evaluate mode
model.eval()
end = time.time()
# np array to keep all the embeddings
img_embs = None
cap_embs = None
cap_lens = None
max_n_word = 0
for i, (images, captions, lengths, ids) in enumerate(data_loader):
max_n_word = max(max_n_word, max(lengths))
# lengths = lengths.cpu().numpy().tolist()
# l = [len(l) for l in lengths]
# max_n_word = max(max_n_word, max(l))
with torch.no_grad():
for i, (images, captions, lengths, ids) in enumerate(data_loader):
# make sure val logger is used
model.logger = val_logger
lengths = lengths.cpu().numpy().tolist()
images = images.cuda()
captions = captions.cuda()
# pos = pos.cuda()
# compute the embeddings
img_emb, img_mean, cap_emb, cap_len, cap_mean = model.module.forward_emb(images, captions, lengths)
# img_emb, cap_emb, cap_len = model.forward_emb(images, captions, pos, lengths)
# print(img_emb)
if img_embs is None:
if img_emb.dim() == 3:
img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1), img_emb.size(2)))
else:
img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1)))
cap_embs = np.zeros((len(data_loader.dataset), max_n_word, cap_emb.size(2)))
img_means = np.zeros((len(data_loader.dataset), img_mean.size(1)))
# tags = np.zeros((len(data_loader.dataset), max_n_word))
cap_lens = [0] * len(data_loader.dataset)
cap_means = np.zeros((len(data_loader.dataset), cap_mean.size(1)))
# cache embeddings
# print(img_embs.shape,type(ids))
# print(img_emb.shape)
img_embs[ids] = img_emb.data.cpu().numpy().copy()
img_means[ids] = img_mean.data.cpu().numpy().copy()
cap_means[ids] = cap_mean.data.cpu().numpy().copy()
cap_embs[ids, :cap_emb.size(1), :] = cap_emb.data.cpu().numpy().copy()
for j, nid in enumerate(ids):
cap_lens[nid] = cap_len[j]
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % log_step == 0:
logging('Test: [{0}/{1}]\t'
'{e_log}\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
.format(
i, len(data_loader), batch_time=batch_time,
e_log=str(model.logger)))
del images, captions
return img_embs, img_means, cap_embs, cap_lens, cap_means
def get_cap(cap_str, vocab):
caption = cap_str
# Convert caption (string) to word ids.
tokens = nltk.tokenize.word_tokenize(
caption.encode('utf-8').decode('utf-8'))
caption = []
caption.append(vocab('<start>'))
caption.extend([vocab(str(token).lower()) for token in tokens])
caption.append(vocab('<end>'))
# assert(len(caption) - 2== len(new_tags))
target = torch.Tensor(caption)
target = torch.unsqueeze(target,0).long()
# new_tags = torch.Tensor(new_tags)
return target
def encode_cap(model, cap_str, vocab):
"""Encode all images and captions loadable by `data_loader`
"""
batch_time = AverageMeter()
# switch to evaluate mode
model.eval()
end = time.time()
# np array to keep all the embeddings
cap_e = None
cap_m = None
with torch.no_grad():
captions = get_cap(cap_str,vocab)
lengths = []
lengths.append(len(captions[0]))
captions = captions.cuda()
#print(captions)
# compute the embeddings
cap_emb, cap_len, cap_mean = model.module.txt_emb(captions, lengths)
if cap_e is None:
cap_e = cap_emb.data.cpu().numpy().copy()
cap_m = cap_mean.data.cpu().numpy().copy()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
return cap_e, cap_len, cap_m
def encode_image(model, image_feat):
"""Encode all images and captions loadable by `data_loader`
"""
batch_time = AverageMeter()
# switch to evaluate mode
model.eval()
# np array to keep all the embeddings
img_e = None
img_m = None
img_id = None
# image_feat = torch.from_numpy(image_feat)
# image_feat = image_feat.cuda()
# print(image_feat)
# print(image_feat)
end = time.time()
with torch.no_grad():
tmp = torch.unsqueeze(torch.from_numpy(image_feat), 0)
tmp = tmp.cuda()
# print(tmp.data)
img_emb, img_mean = model.module.image_emb(tmp)
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if img_e is None:
img_e = img_emb.data.cpu().numpy().copy()
img_m = img_mean.data.cpu().numpy().copy()
return img_e, img_m
def encode_data(model, data_loader, log_step=10, logging=print):
"""Encode all images and captions loadable by `data_loader`
"""
batch_time = AverageMeter()
val_logger = LogCollector()
# switch to evaluate mode
model.eval()
end = time.time()
# np array to keep all the embeddings
img_embs = None
cap_embs = None
cap_lens = None
img_e = None
img_m = None
img_id = None
max_n_word = 0
for i, (images, captions, lengths, ids) in enumerate(data_loader):
max_n_word = max(max_n_word, max(lengths))
# lengths = lengths.cpu().numpy().tolist()
# l = [len(l) for l in lengths]
# max_n_word = max(max_n_word, max(l))
tmp = np.load('data.npy')
tmp = torch.unsqueeze(torch.from_numpy(tmp), 0)
tmp = tmp.cuda()
#print(tmp.data)
with torch.no_grad():
for i, (images, captions, lengths, ids) in enumerate(data_loader):
# make sure val logger is used
model.logger = val_logger
lengths = lengths.cpu().numpy().tolist()
# tmp = images[0]
# tmp = torch.unsqueeze(tmp,0)
# tmp = tmp.cuda()
#print(tmp.shape)
images = images.cuda()
captions = captions.cuda()
# pos = pos.cuda()
# compute the embeddings
#print(images.shape,captions.shape)
img_emb, img_mean, cap_emb, cap_len, cap_mean = model.module.forward_emb(tmp, captions, lengths)
#return img_emb, img_mean, cap_emb, cap_len, cap_mean,
# img_emb, cap_emb, cap_len = model.forward_emb(images, captions, pos, lengths)
# print(img_emb)
if img_embs is None:
if img_emb.dim() == 3:
img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1), img_emb.size(2)))
else:
img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1)))
cap_embs = np.zeros((len(data_loader.dataset), max_n_word, cap_emb.size(2)))
img_means = np.zeros((len(data_loader.dataset), img_mean.size(1)))
# tags = np.zeros((len(data_loader.dataset), max_n_word))
cap_lens = [0] * len(data_loader.dataset)
cap_means = np.zeros((len(data_loader.dataset), cap_mean.size(1)))
# cache embeddings
# print(img_embs.shape,type(ids))
# print(img_emb.shape)
img_embs[ids] = img_emb.data.cpu().numpy().copy()
img_means[ids] = img_mean.data.cpu().numpy().copy()
cap_means[ids] = cap_mean.data.cpu().numpy().copy()
cap_embs[ids, :cap_emb.size(1), :] = cap_emb.data.cpu().numpy().copy()
for j, nid in enumerate(ids):
cap_lens[nid] = cap_len[j]
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % log_step == 0:
logging('Test: [{0}/{1}]\t'
'{e_log}\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
.format(
i, len(data_loader), batch_time=batch_time,
e_log=str(model.logger)))
del images, captions
if img_e is None:
img_e = img_emb.data.cpu().numpy().copy()
img_m = img_mean.data.cpu().numpy().copy()
img_id = ids[0]
print("11111111111111111111111111111111111111111111111111111")
# print(tmp.data)
break
return img_e, img_m, cap_embs, cap_lens, cap_means,img_id
if __name__ == '__main__':
model_path = "./runs/test/model_best.pth.tar"
data_path = "./data/"
image_path = "./image/ride.jpg"
checkpoint = torch.load(model_path)
opt = checkpoint['opt']
print(opt)
caps_list = []
with open("test_caps.txt", "r") as f:
for line in f.readlines():
line = line.strip("\n")
caps_list.append(line)
print(len(caps_list))
image_list = []
with open("result.txt", "r") as f:
for line in f.readlines():
line = line.strip("\n")
image_list.append(line.split("#")[0])
# print(len(image_list))
# print(image_list[:10])
id_list = []
with open("test_ids.txt", "r") as f:
for line in f.readlines():
line = line.strip("\n")
id_list.append(line)
# print(len(id_list))
# print(id_list[:10])
if data_path is not None:
opt.data_path = data_path
# load vocabulary used by the model
vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
word2idx = vocab.word2idx
opt.vocab_size = len(vocab)
model = SCAN(word2idx, opt)
model = torch.nn.DataParallel(model)
model.cuda()
# load model state
model.load_state_dict(checkpoint['model'])
print('Loading dataset')
data_loader = get_test_loader("test", opt.data_name, vocab,
opt.batch_size, 0, opt)
print('Computing results...')
#img_embs, img_means, cap_embs, cap_lens, cap_means, img_id= encode_data(model, data_loader)
img_embs, img_means, cap_embs, cap_lens, cap_means = encode_img_caps(model, data_loader)
print(img_embs.shape, cap_embs.shape)
test_str = "A little boy is playing football on the football field"
# target = get_cap(, vocab)
# print(target)
# cap_emb, cap_len, cap_mean = encode_cap(model, caps_list[60], vocab)
# cap_emb, cap_len, cap_mean = encode_cap(model, test_str, vocab)
# print(cap_emb.shape,len(cap_len),cap_mean.shape)
# img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), 5)])
# sims = shard_xattn(model, img_embs, img_means, cap_emb, cap_len, cap_mean, opt, shard_size=1024)
# sims = sims.T
# print(sims.shape)
#
# inds = np.argsort(sims[0])[::-1]
# print(inds[:10])
# for i in inds[:10]:
# print(image_list[5*int(id_list[i])])
#image_feat = np.load('data.npy')
image_feat = feature(image_path)
print(image_feat.shape)
img_emb, img_mean = encode_image(model, image_feat)
print(img_emb.shape)
# print('Images: %d, Captions: %d' %
# (img_embs.shape[0] / 5, cap_embs.shape[0]))
# sims = shard_xattn(model, img_embs, img_means, cap_embs, cap_lens, cap_means, opt, shard_size=512)
# print(sims.shape)
# r, rt = i2t(img_embs, cap_embs, cap_lens, sims, return_ranks=True)
# print("Image to text: %.1f %.1f %.1f %.1f %.1f" % r)
# ri, rti = t2i(img_embs, cap_embs, cap_lens, sims, return_ranks=True)
# print("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri)
sims = shard_xattn(model, img_emb, img_mean, cap_embs, cap_lens, cap_means, opt, shard_size=2048)
print(sims.shape)
inds = np.argsort(sims[0])[::-1]
#inds = inds.astype("int32")
print(inds[:10])
print(inds.dtype)
for i in inds[:10]:
print(caps_list[i])