97 lines
2.9 KiB
Python
97 lines
2.9 KiB
Python
import spacy
|
|
from pprint import pprint
|
|
from transformers import BertModel, BertTokenizer
|
|
import torch
|
|
|
|
"""
|
|
Noun -> Verb: tag=NN, head=verb
|
|
Adj -> Noun: tag=JJ, head=Noun
|
|
Num -> Noun: tag=CD, head=Noun
|
|
of -> Noun: tag=NN, head=of
|
|
Quan -> of: word=of, tag=IN, head=Quan
|
|
"""
|
|
|
|
|
|
class Parser(object):
|
|
def __init__(self) -> None:
|
|
# For POS tagging and dependency parsing
|
|
self.pos_model = spacy.load("en_core_web_sm")
|
|
|
|
# For word embedding
|
|
# Load pre-trained model tokenizer (vocabulary)
|
|
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
|
# Load pre-trained model (weights)
|
|
self.model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True)
|
|
self.model.eval()
|
|
|
|
def parse(self, text:str):
|
|
# For POS tagging and dependency parsing
|
|
outputs = self.pos_model(text)
|
|
|
|
sentence = list()
|
|
edges = []
|
|
|
|
tag_ls = list()
|
|
dep_dict = dict()
|
|
|
|
for item in outputs:
|
|
idx = item.i
|
|
word = item.text
|
|
tag = item.tag_
|
|
dep = item.dep_
|
|
head = item.head
|
|
|
|
|
|
sentence.append(word)
|
|
print(f"word={word}, tag={tag}, head={head}, dep={dep}")
|
|
|
|
|
|
# POS tagging
|
|
if "VB" in tag:
|
|
tag_ls.append("VB")
|
|
continue
|
|
elif "NN" in tag:
|
|
tag_ls.append("NN")
|
|
# Add noun to dependency dict
|
|
dep_dict[idx] = list()
|
|
continue
|
|
elif "JJ" in tag:
|
|
tag_ls.append("JJ")
|
|
continue
|
|
elif "CD" in tag:
|
|
tag_ls.append("CD")
|
|
continue
|
|
else:
|
|
tag_ls.append(tag)
|
|
|
|
# Dependency parsing
|
|
|
|
|
|
# For word embedding
|
|
# Tokenize input
|
|
tokenized_text = self.tokenizer.tokenize(text)
|
|
indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
|
|
|
|
# Convert inputs to PyTorch tensors
|
|
tokens_tensor = torch.tensor([indexed_tokens])
|
|
# Predict hidden states features for each layer
|
|
with torch.no_grad():
|
|
outputs = self.model(tokens_tensor)
|
|
hidden_states = outputs[2]
|
|
|
|
# `token_vecs` is a tensor with shape [txt_len x 768]
|
|
token_vecs = hidden_states[-2][0]
|
|
|
|
return sentence, tag_ls, edges, token_vecs
|
|
|
|
if __name__ == "__main__":
|
|
parser = Parser()
|
|
text = "Rolling the rice into a ball about the size of a large tomato, add 2 cups of sweet sugar, and five bottles of water."
|
|
sentence, tag_ls, edges, token_vecs = parser.parse(text)
|
|
|
|
print(token_vecs.shape)
|
|
|
|
# for i in range(len(sentence)):
|
|
# print(sentence[i], tag_ls[i])
|
|
|
|
|