Graduation_Project/QN/RecipeRetrieval/utils/tools.py

97 lines
2.9 KiB
Python

import spacy
from pprint import pprint
from transformers import BertModel, BertTokenizer
import torch
"""
Noun -> Verb: tag=NN, head=verb
Adj -> Noun: tag=JJ, head=Noun
Num -> Noun: tag=CD, head=Noun
of -> Noun: tag=NN, head=of
Quan -> of: word=of, tag=IN, head=Quan
"""
class Parser(object):
def __init__(self) -> None:
# For POS tagging and dependency parsing
self.pos_model = spacy.load("en_core_web_sm")
# For word embedding
# Load pre-trained model tokenizer (vocabulary)
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Load pre-trained model (weights)
self.model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True)
self.model.eval()
def parse(self, text:str):
# For POS tagging and dependency parsing
outputs = self.pos_model(text)
sentence = list()
edges = []
tag_ls = list()
dep_dict = dict()
for item in outputs:
idx = item.i
word = item.text
tag = item.tag_
dep = item.dep_
head = item.head
sentence.append(word)
print(f"word={word}, tag={tag}, head={head}, dep={dep}")
# POS tagging
if "VB" in tag:
tag_ls.append("VB")
continue
elif "NN" in tag:
tag_ls.append("NN")
# Add noun to dependency dict
dep_dict[idx] = list()
continue
elif "JJ" in tag:
tag_ls.append("JJ")
continue
elif "CD" in tag:
tag_ls.append("CD")
continue
else:
tag_ls.append(tag)
# Dependency parsing
# For word embedding
# Tokenize input
tokenized_text = self.tokenizer.tokenize(text)
indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
# Predict hidden states features for each layer
with torch.no_grad():
outputs = self.model(tokens_tensor)
hidden_states = outputs[2]
# `token_vecs` is a tensor with shape [txt_len x 768]
token_vecs = hidden_states[-2][0]
return sentence, tag_ls, edges, token_vecs
if __name__ == "__main__":
parser = Parser()
text = "Rolling the rice into a ball about the size of a large tomato, add 2 cups of sweet sugar, and five bottles of water."
sentence, tag_ls, edges, token_vecs = parser.parse(text)
print(token_vecs.shape)
# for i in range(len(sentence)):
# print(sentence[i], tag_ls[i])