diff --git a/WZM/vocab.py b/WZM/vocab.py new file mode 100644 index 0000000..737fd24 --- /dev/null +++ b/WZM/vocab.py @@ -0,0 +1,117 @@ +import nltk +from collections import Counter +import argparse +import os +import json + +annotations = { + 'coco_splits': ['train_caps.txt', 'val_caps.txt', 'test_caps.txt'], + 'flickr30k_splits': ['train_caps.txt', 'val_caps.txt', 'test_caps.txt'], + 'rsicd_precomp': ['train_caps.txt', 'test_caps.txt'], + 'rsitmd_precomp': ['train_caps.txt', 'test_caps.txt'], + 'ucm_precomp': ['train_caps.txt', 'val_caps.txt'], + 'sydney_precomp': ['train_caps.txt', 'val_caps.txt'], + } + + +class Vocabulary(object): + """Simple vocabulary wrapper.""" + + def __init__(self): + self.word2idx = {} + self.idx2word = {} + self.idx = 0 + + def add_word(self, word): + if word not in self.word2idx: + self.word2idx[word] = self.idx + self.idx2word[self.idx] = word + self.idx += 1 + + def __call__(self, word): + if word not in self.word2idx: + return self.word2idx[''] + return self.word2idx[word] + + def __len__(self): + return len(self.word2idx) + + +def serialize_vocab(vocab, dest): + d = {} + d['word2idx'] = vocab.word2idx + d['idx2word'] = vocab.idx2word + d['idx'] = vocab.idx + with open(dest, "w") as f: + json.dump(d, f) + + +def deserialize_vocab(src): + with open(src) as f: + d = json.load(f) + vocab = Vocabulary() + vocab.word2idx = d['word2idx'] + vocab.idx2word = d['idx2word'] + vocab.idx = d['idx'] + return vocab + + +def from_txt(txt): + captions = [] + with open(txt, 'rb') as f: + for line in f: + captions.append(line.strip()) + return captions + + +def build_vocab(data_path, data_name, caption_file, threshold): + """Build a simple vocabulary wrapper.""" + + stopword_list = list(set(nltk.corpus.stopwords.words('english'))) + counter = Counter() + for path in caption_file[data_name]: + full_path = os.path.join(os.path.join(data_path, data_name), path) + captions = from_txt(full_path) + + for i, caption in enumerate(captions): + tokens = nltk.tokenize.word_tokenize( + caption.lower().decode('utf-8')) + punctuations = [',', '.', ':', ';', '?', '(', ')', '[', ']', '&', '!', '*', '@', '#', '$', '%'] + tokens = [k for k in tokens if k not in punctuations] + tokens = [k for k in tokens if k not in stopword_list] + counter.update(tokens) + + if i % 1000 == 0: + print("[%d/%d] tokenized the captions." % (i, len(captions))) + + # Discard if the occurrence of the word is less than min_word_cnt. + words = [word for word, cnt in counter.items() if cnt >= threshold] + + # Create a vocab wrapper and add some special tokens. + vocab = Vocabulary() + vocab.add_word('') + # vocab.add_word('') + # vocab.add_word('') + vocab.add_word('') + + # Add words to the vocabulary. + for i, word in enumerate(words): + vocab.add_word(word) + # vocab.add_word('') + return vocab + + +def main(data_path, data_name): + vocab = build_vocab(data_path, data_name, caption_file=annotations, threshold=4) + serialize_vocab(vocab, 'vocab/%s_vocab.json' % data_name) + print("Saved vocabulary file to ", 'vocab/%s_vocab.json' %(data_name)) + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--data_path', default='data') + parser.add_argument('--data_name', default='rsitmd_precomp', + help='{coco,f30k}') + opt = parser.parse_args() + main(opt.data_path, opt.data_name)