351 lines
12 KiB
Python
351 lines
12 KiB
Python
|
import os
|
||
|
import time
|
||
|
#include
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import nltk
|
||
|
# load model and options
|
||
|
from data import get_test_loader
|
||
|
from evaluation import AverageMeter, LogCollector, shard_xattn, i2t, t2i
|
||
|
#from extract_features import feature
|
||
|
from extract_features import feature
|
||
|
from model import SCAN
|
||
|
from vocab import deserialize_vocab
|
||
|
|
||
|
def encode_img_caps(model, data_loader, log_step=100, logging=print):
|
||
|
"""Encode all images and captions loadable by `data_loader`
|
||
|
"""
|
||
|
batch_time = AverageMeter()
|
||
|
val_logger = LogCollector()
|
||
|
|
||
|
# switch to evaluate mode
|
||
|
model.eval()
|
||
|
|
||
|
end = time.time()
|
||
|
|
||
|
# np array to keep all the embeddings
|
||
|
img_embs = None
|
||
|
cap_embs = None
|
||
|
cap_lens = None
|
||
|
|
||
|
max_n_word = 0
|
||
|
for i, (images, captions, lengths, ids) in enumerate(data_loader):
|
||
|
max_n_word = max(max_n_word, max(lengths))
|
||
|
# lengths = lengths.cpu().numpy().tolist()
|
||
|
# l = [len(l) for l in lengths]
|
||
|
# max_n_word = max(max_n_word, max(l))
|
||
|
|
||
|
with torch.no_grad():
|
||
|
for i, (images, captions, lengths, ids) in enumerate(data_loader):
|
||
|
# make sure val logger is used
|
||
|
model.logger = val_logger
|
||
|
lengths = lengths.cpu().numpy().tolist()
|
||
|
images = images.cuda()
|
||
|
captions = captions.cuda()
|
||
|
# pos = pos.cuda()
|
||
|
# compute the embeddings
|
||
|
img_emb, img_mean, cap_emb, cap_len, cap_mean = model.module.forward_emb(images, captions, lengths)
|
||
|
# img_emb, cap_emb, cap_len = model.forward_emb(images, captions, pos, lengths)
|
||
|
# print(img_emb)
|
||
|
if img_embs is None:
|
||
|
if img_emb.dim() == 3:
|
||
|
img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1), img_emb.size(2)))
|
||
|
else:
|
||
|
img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1)))
|
||
|
cap_embs = np.zeros((len(data_loader.dataset), max_n_word, cap_emb.size(2)))
|
||
|
img_means = np.zeros((len(data_loader.dataset), img_mean.size(1)))
|
||
|
# tags = np.zeros((len(data_loader.dataset), max_n_word))
|
||
|
cap_lens = [0] * len(data_loader.dataset)
|
||
|
cap_means = np.zeros((len(data_loader.dataset), cap_mean.size(1)))
|
||
|
# cache embeddings
|
||
|
# print(img_embs.shape,type(ids))
|
||
|
# print(img_emb.shape)
|
||
|
img_embs[ids] = img_emb.data.cpu().numpy().copy()
|
||
|
img_means[ids] = img_mean.data.cpu().numpy().copy()
|
||
|
cap_means[ids] = cap_mean.data.cpu().numpy().copy()
|
||
|
cap_embs[ids, :cap_emb.size(1), :] = cap_emb.data.cpu().numpy().copy()
|
||
|
for j, nid in enumerate(ids):
|
||
|
cap_lens[nid] = cap_len[j]
|
||
|
|
||
|
# measure elapsed time
|
||
|
batch_time.update(time.time() - end)
|
||
|
end = time.time()
|
||
|
|
||
|
if i % log_step == 0:
|
||
|
logging('Test: [{0}/{1}]\t'
|
||
|
'{e_log}\t'
|
||
|
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
||
|
.format(
|
||
|
i, len(data_loader), batch_time=batch_time,
|
||
|
e_log=str(model.logger)))
|
||
|
del images, captions
|
||
|
return img_embs, img_means, cap_embs, cap_lens, cap_means
|
||
|
|
||
|
def get_cap(cap_str, vocab):
|
||
|
caption = cap_str
|
||
|
# Convert caption (string) to word ids.
|
||
|
tokens = nltk.tokenize.word_tokenize(
|
||
|
caption.encode('utf-8').decode('utf-8'))
|
||
|
caption = []
|
||
|
caption.append(vocab('<start>'))
|
||
|
caption.extend([vocab(str(token).lower()) for token in tokens])
|
||
|
caption.append(vocab('<end>'))
|
||
|
# assert(len(caption) - 2== len(new_tags))
|
||
|
target = torch.Tensor(caption)
|
||
|
target = torch.unsqueeze(target,0).long()
|
||
|
# new_tags = torch.Tensor(new_tags)
|
||
|
return target
|
||
|
|
||
|
def encode_cap(model, cap_str, vocab):
|
||
|
"""Encode all images and captions loadable by `data_loader`
|
||
|
"""
|
||
|
batch_time = AverageMeter()
|
||
|
|
||
|
# switch to evaluate mode
|
||
|
model.eval()
|
||
|
|
||
|
end = time.time()
|
||
|
|
||
|
# np array to keep all the embeddings
|
||
|
cap_e = None
|
||
|
cap_m = None
|
||
|
|
||
|
with torch.no_grad():
|
||
|
captions = get_cap(cap_str,vocab)
|
||
|
lengths = []
|
||
|
lengths.append(len(captions[0]))
|
||
|
captions = captions.cuda()
|
||
|
#print(captions)
|
||
|
|
||
|
# compute the embeddings
|
||
|
cap_emb, cap_len, cap_mean = model.module.txt_emb(captions, lengths)
|
||
|
if cap_e is None:
|
||
|
cap_e = cap_emb.data.cpu().numpy().copy()
|
||
|
cap_m = cap_mean.data.cpu().numpy().copy()
|
||
|
|
||
|
# measure elapsed time
|
||
|
batch_time.update(time.time() - end)
|
||
|
end = time.time()
|
||
|
|
||
|
return cap_e, cap_len, cap_m
|
||
|
|
||
|
def encode_image(model, image_feat):
|
||
|
"""Encode all images and captions loadable by `data_loader`
|
||
|
"""
|
||
|
batch_time = AverageMeter()
|
||
|
|
||
|
# switch to evaluate mode
|
||
|
model.eval()
|
||
|
|
||
|
# np array to keep all the embeddings
|
||
|
img_e = None
|
||
|
img_m = None
|
||
|
img_id = None
|
||
|
# image_feat = torch.from_numpy(image_feat)
|
||
|
# image_feat = image_feat.cuda()
|
||
|
# print(image_feat)
|
||
|
|
||
|
# print(image_feat)
|
||
|
end = time.time()
|
||
|
|
||
|
with torch.no_grad():
|
||
|
tmp = torch.unsqueeze(torch.from_numpy(image_feat), 0)
|
||
|
tmp = tmp.cuda()
|
||
|
# print(tmp.data)
|
||
|
img_emb, img_mean = model.module.image_emb(tmp)
|
||
|
# measure elapsed time
|
||
|
batch_time.update(time.time() - end)
|
||
|
end = time.time()
|
||
|
|
||
|
if img_e is None:
|
||
|
img_e = img_emb.data.cpu().numpy().copy()
|
||
|
img_m = img_mean.data.cpu().numpy().copy()
|
||
|
|
||
|
return img_e, img_m
|
||
|
|
||
|
def encode_data(model, data_loader, log_step=10, logging=print):
|
||
|
"""Encode all images and captions loadable by `data_loader`
|
||
|
"""
|
||
|
batch_time = AverageMeter()
|
||
|
val_logger = LogCollector()
|
||
|
|
||
|
# switch to evaluate mode
|
||
|
model.eval()
|
||
|
|
||
|
end = time.time()
|
||
|
|
||
|
# np array to keep all the embeddings
|
||
|
img_embs = None
|
||
|
cap_embs = None
|
||
|
cap_lens = None
|
||
|
img_e = None
|
||
|
img_m = None
|
||
|
img_id = None
|
||
|
|
||
|
max_n_word = 0
|
||
|
for i, (images, captions, lengths, ids) in enumerate(data_loader):
|
||
|
max_n_word = max(max_n_word, max(lengths))
|
||
|
# lengths = lengths.cpu().numpy().tolist()
|
||
|
# l = [len(l) for l in lengths]
|
||
|
# max_n_word = max(max_n_word, max(l))
|
||
|
|
||
|
tmp = np.load('data.npy')
|
||
|
tmp = torch.unsqueeze(torch.from_numpy(tmp), 0)
|
||
|
tmp = tmp.cuda()
|
||
|
#print(tmp.data)
|
||
|
|
||
|
with torch.no_grad():
|
||
|
for i, (images, captions, lengths, ids) in enumerate(data_loader):
|
||
|
# make sure val logger is used
|
||
|
model.logger = val_logger
|
||
|
lengths = lengths.cpu().numpy().tolist()
|
||
|
# tmp = images[0]
|
||
|
# tmp = torch.unsqueeze(tmp,0)
|
||
|
# tmp = tmp.cuda()
|
||
|
|
||
|
#print(tmp.shape)
|
||
|
images = images.cuda()
|
||
|
captions = captions.cuda()
|
||
|
# pos = pos.cuda()
|
||
|
# compute the embeddings
|
||
|
#print(images.shape,captions.shape)
|
||
|
img_emb, img_mean, cap_emb, cap_len, cap_mean = model.module.forward_emb(tmp, captions, lengths)
|
||
|
|
||
|
#return img_emb, img_mean, cap_emb, cap_len, cap_mean,
|
||
|
# img_emb, cap_emb, cap_len = model.forward_emb(images, captions, pos, lengths)
|
||
|
# print(img_emb)
|
||
|
if img_embs is None:
|
||
|
if img_emb.dim() == 3:
|
||
|
img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1), img_emb.size(2)))
|
||
|
else:
|
||
|
img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1)))
|
||
|
cap_embs = np.zeros((len(data_loader.dataset), max_n_word, cap_emb.size(2)))
|
||
|
img_means = np.zeros((len(data_loader.dataset), img_mean.size(1)))
|
||
|
# tags = np.zeros((len(data_loader.dataset), max_n_word))
|
||
|
cap_lens = [0] * len(data_loader.dataset)
|
||
|
cap_means = np.zeros((len(data_loader.dataset), cap_mean.size(1)))
|
||
|
# cache embeddings
|
||
|
# print(img_embs.shape,type(ids))
|
||
|
# print(img_emb.shape)
|
||
|
img_embs[ids] = img_emb.data.cpu().numpy().copy()
|
||
|
img_means[ids] = img_mean.data.cpu().numpy().copy()
|
||
|
cap_means[ids] = cap_mean.data.cpu().numpy().copy()
|
||
|
cap_embs[ids, :cap_emb.size(1), :] = cap_emb.data.cpu().numpy().copy()
|
||
|
for j, nid in enumerate(ids):
|
||
|
cap_lens[nid] = cap_len[j]
|
||
|
|
||
|
# measure elapsed time
|
||
|
batch_time.update(time.time() - end)
|
||
|
end = time.time()
|
||
|
|
||
|
if i % log_step == 0:
|
||
|
logging('Test: [{0}/{1}]\t'
|
||
|
'{e_log}\t'
|
||
|
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
||
|
.format(
|
||
|
i, len(data_loader), batch_time=batch_time,
|
||
|
e_log=str(model.logger)))
|
||
|
del images, captions
|
||
|
if img_e is None:
|
||
|
img_e = img_emb.data.cpu().numpy().copy()
|
||
|
img_m = img_mean.data.cpu().numpy().copy()
|
||
|
img_id = ids[0]
|
||
|
print("11111111111111111111111111111111111111111111111111111")
|
||
|
# print(tmp.data)
|
||
|
break
|
||
|
return img_e, img_m, cap_embs, cap_lens, cap_means,img_id
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
model_path = "./runs/test/model_best.pth.tar"
|
||
|
data_path = "./data/"
|
||
|
image_path = "./image/ride.jpg"
|
||
|
checkpoint = torch.load(model_path)
|
||
|
opt = checkpoint['opt']
|
||
|
print(opt)
|
||
|
|
||
|
caps_list = []
|
||
|
with open("test_caps.txt", "r") as f:
|
||
|
for line in f.readlines():
|
||
|
line = line.strip("\n")
|
||
|
caps_list.append(line)
|
||
|
print(len(caps_list))
|
||
|
image_list = []
|
||
|
with open("result.txt", "r") as f:
|
||
|
for line in f.readlines():
|
||
|
line = line.strip("\n")
|
||
|
image_list.append(line.split("#")[0])
|
||
|
# print(len(image_list))
|
||
|
# print(image_list[:10])
|
||
|
id_list = []
|
||
|
with open("test_ids.txt", "r") as f:
|
||
|
for line in f.readlines():
|
||
|
line = line.strip("\n")
|
||
|
id_list.append(line)
|
||
|
# print(len(id_list))
|
||
|
# print(id_list[:10])
|
||
|
|
||
|
|
||
|
if data_path is not None:
|
||
|
opt.data_path = data_path
|
||
|
|
||
|
# load vocabulary used by the model
|
||
|
vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
|
||
|
word2idx = vocab.word2idx
|
||
|
opt.vocab_size = len(vocab)
|
||
|
|
||
|
model = SCAN(word2idx, opt)
|
||
|
model = torch.nn.DataParallel(model)
|
||
|
model.cuda()
|
||
|
|
||
|
# load model state
|
||
|
model.load_state_dict(checkpoint['model'])
|
||
|
|
||
|
print('Loading dataset')
|
||
|
data_loader = get_test_loader("test", opt.data_name, vocab,
|
||
|
opt.batch_size, 0, opt)
|
||
|
|
||
|
print('Computing results...')
|
||
|
#img_embs, img_means, cap_embs, cap_lens, cap_means, img_id= encode_data(model, data_loader)
|
||
|
img_embs, img_means, cap_embs, cap_lens, cap_means = encode_img_caps(model, data_loader)
|
||
|
print(img_embs.shape, cap_embs.shape)
|
||
|
|
||
|
test_str = "A little boy is playing football on the football field"
|
||
|
# target = get_cap(, vocab)
|
||
|
# print(target)
|
||
|
|
||
|
# cap_emb, cap_len, cap_mean = encode_cap(model, caps_list[60], vocab)
|
||
|
# cap_emb, cap_len, cap_mean = encode_cap(model, test_str, vocab)
|
||
|
# print(cap_emb.shape,len(cap_len),cap_mean.shape)
|
||
|
# img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), 5)])
|
||
|
# sims = shard_xattn(model, img_embs, img_means, cap_emb, cap_len, cap_mean, opt, shard_size=1024)
|
||
|
# sims = sims.T
|
||
|
# print(sims.shape)
|
||
|
#
|
||
|
# inds = np.argsort(sims[0])[::-1]
|
||
|
# print(inds[:10])
|
||
|
# for i in inds[:10]:
|
||
|
# print(image_list[5*int(id_list[i])])
|
||
|
|
||
|
#image_feat = np.load('data.npy')
|
||
|
image_feat = feature(image_path)
|
||
|
print(image_feat.shape)
|
||
|
img_emb, img_mean = encode_image(model, image_feat)
|
||
|
print(img_emb.shape)
|
||
|
# print('Images: %d, Captions: %d' %
|
||
|
# (img_embs.shape[0] / 5, cap_embs.shape[0]))
|
||
|
# sims = shard_xattn(model, img_embs, img_means, cap_embs, cap_lens, cap_means, opt, shard_size=512)
|
||
|
# print(sims.shape)
|
||
|
# r, rt = i2t(img_embs, cap_embs, cap_lens, sims, return_ranks=True)
|
||
|
# print("Image to text: %.1f %.1f %.1f %.1f %.1f" % r)
|
||
|
# ri, rti = t2i(img_embs, cap_embs, cap_lens, sims, return_ranks=True)
|
||
|
# print("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri)
|
||
|
|
||
|
sims = shard_xattn(model, img_emb, img_mean, cap_embs, cap_lens, cap_means, opt, shard_size=2048)
|
||
|
print(sims.shape)
|
||
|
inds = np.argsort(sims[0])[::-1]
|
||
|
#inds = inds.astype("int32")
|
||
|
print(inds[:10])
|
||
|
print(inds.dtype)
|
||
|
for i in inds[:10]:
|
||
|
print(caps_list[i])
|