#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ CS224N 2018-19: Homework 3 parser_utils.py: Utilities for training the dependency parser. Sahil Chopra """ import time import os import logging from collections import Counter from . general_utils import get_minibatches from parser_transitions import minibatch_parse from tqdm import tqdm import torch import numpy as np P_PREFIX = '

:' L_PREFIX = ':' UNK = '' NULL = '' ROOT = '' class Config(object): language = 'english' with_punct = True unlabeled = True lowercase = True use_pos = True use_dep = True use_dep = use_dep and (not unlabeled) data_path = './data' train_file = 'train.conll' dev_file = 'dev.conll' test_file = 'test.conll' embedding_file = './data/en-cw.txt' class Parser(object): """Contains everything needed for transition-based dependency parsing except for the model""" def __init__(self, dataset): root_labels = list([l for ex in dataset for (h, l) in zip(ex['head'], ex['label']) if h == 0]) counter = Counter(root_labels) if len(counter) > 1: logging.info('Warning: more than one root label') logging.info(counter) self.root_label = counter.most_common()[0][0] deprel = [self.root_label] + list(set([w for ex in dataset for w in ex['label'] if w != self.root_label])) tok2id = {L_PREFIX + l: i for (i, l) in enumerate(deprel)} tok2id[L_PREFIX + NULL] = self.L_NULL = len(tok2id) config = Config() self.unlabeled = config.unlabeled self.with_punct = config.with_punct self.use_pos = config.use_pos self.use_dep = config.use_dep self.language = config.language if self.unlabeled: trans = ['L', 'R', 'S'] self.n_deprel = 1 else: trans = ['L-' + l for l in deprel] + ['R-' + l for l in deprel] + ['S'] self.n_deprel = len(deprel) self.n_trans = len(trans) self.tran2id = {t: i for (i, t) in enumerate(trans)} self.id2tran = {i: t for (i, t) in enumerate(trans)} # logging.info('Build dictionary for part-of-speech tags.') tok2id.update(build_dict([P_PREFIX + w for ex in dataset for w in ex['pos']], offset=len(tok2id))) tok2id[P_PREFIX + UNK] = self.P_UNK = len(tok2id) tok2id[P_PREFIX + NULL] = self.P_NULL = len(tok2id) tok2id[P_PREFIX + ROOT] = self.P_ROOT = len(tok2id) # logging.info('Build dictionary for words.') tok2id.update(build_dict([w for ex in dataset for w in ex['word']], offset=len(tok2id))) tok2id[UNK] = self.UNK = len(tok2id) tok2id[NULL] = self.NULL = len(tok2id) tok2id[ROOT] = self.ROOT = len(tok2id) self.tok2id = tok2id self.id2tok = {v: k for (k, v) in tok2id.items()} self.n_features = 18 + (18 if config.use_pos else 0) + (12 if config.use_dep else 0) self.n_tokens = len(tok2id) def vectorize(self, examples): vec_examples = [] for ex in examples: word = [self.ROOT] + [self.tok2id[w] if w in self.tok2id else self.UNK for w in ex['word']] pos = [self.P_ROOT] + [self.tok2id[P_PREFIX + w] if P_PREFIX + w in self.tok2id else self.P_UNK for w in ex['pos']] head = [-1] + ex['head'] label = [-1] + [self.tok2id[L_PREFIX + w] if L_PREFIX + w in self.tok2id else -1 for w in ex['label']] vec_examples.append({'word': word, 'pos': pos, 'head': head, 'label': label}) return vec_examples def extract_features(self, stack, buf, arcs, ex): if stack[0] == "ROOT": stack[0] = 0 def get_lc(k): return sorted([arc[1] for arc in arcs if arc[0] == k and arc[1] < k]) def get_rc(k): return sorted([arc[1] for arc in arcs if arc[0] == k and arc[1] > k], reverse=True) p_features = [] l_features = [] features = [self.NULL] * (3 - len(stack)) + [ex['word'][x] for x in stack[-3:]] features += [ex['word'][x] for x in buf[:3]] + [self.NULL] * (3 - len(buf)) if self.use_pos: p_features = [self.P_NULL] * (3 - len(stack)) + [ex['pos'][x] for x in stack[-3:]] p_features += [ex['pos'][x] for x in buf[:3]] + [self.P_NULL] * (3 - len(buf)) for i in range(2): if i < len(stack): k = stack[-i-1] lc = get_lc(k) rc = get_rc(k) llc = get_lc(lc[0]) if len(lc) > 0 else [] rrc = get_rc(rc[0]) if len(rc) > 0 else [] features.append(ex['word'][lc[0]] if len(lc) > 0 else self.NULL) features.append(ex['word'][rc[0]] if len(rc) > 0 else self.NULL) features.append(ex['word'][lc[1]] if len(lc) > 1 else self.NULL) features.append(ex['word'][rc[1]] if len(rc) > 1 else self.NULL) features.append(ex['word'][llc[0]] if len(llc) > 0 else self.NULL) features.append(ex['word'][rrc[0]] if len(rrc) > 0 else self.NULL) if self.use_pos: p_features.append(ex['pos'][lc[0]] if len(lc) > 0 else self.P_NULL) p_features.append(ex['pos'][rc[0]] if len(rc) > 0 else self.P_NULL) p_features.append(ex['pos'][lc[1]] if len(lc) > 1 else self.P_NULL) p_features.append(ex['pos'][rc[1]] if len(rc) > 1 else self.P_NULL) p_features.append(ex['pos'][llc[0]] if len(llc) > 0 else self.P_NULL) p_features.append(ex['pos'][rrc[0]] if len(rrc) > 0 else self.P_NULL) if self.use_dep: l_features.append(ex['label'][lc[0]] if len(lc) > 0 else self.L_NULL) l_features.append(ex['label'][rc[0]] if len(rc) > 0 else self.L_NULL) l_features.append(ex['label'][lc[1]] if len(lc) > 1 else self.L_NULL) l_features.append(ex['label'][rc[1]] if len(rc) > 1 else self.L_NULL) l_features.append(ex['label'][llc[0]] if len(llc) > 0 else self.L_NULL) l_features.append(ex['label'][rrc[0]] if len(rrc) > 0 else self.L_NULL) else: features += [self.NULL] * 6 if self.use_pos: p_features += [self.P_NULL] * 6 if self.use_dep: l_features += [self.L_NULL] * 6 features += p_features + l_features assert len(features) == self.n_features return features def get_oracle(self, stack, buf, ex): if len(stack) < 2: return self.n_trans - 1 i0 = stack[-1] i1 = stack[-2] h0 = ex['head'][i0] h1 = ex['head'][i1] l0 = ex['label'][i0] l1 = ex['label'][i1] if self.unlabeled: if (i1 > 0) and (h1 == i0): return 0 elif (i1 >= 0) and (h0 == i1) and \ (not any([x for x in buf if ex['head'][x] == i0])): return 1 else: return None if len(buf) == 0 else 2 else: if (i1 > 0) and (h1 == i0): return l1 if (l1 >= 0) and (l1 < self.n_deprel) else None elif (i1 >= 0) and (h0 == i1) and \ (not any([x for x in buf if ex['head'][x] == i0])): return l0 + self.n_deprel if (l0 >= 0) and (l0 < self.n_deprel) else None else: return None if len(buf) == 0 else self.n_trans - 1 def create_instances(self, examples): all_instances = [] succ = 0 for id, ex in enumerate(examples): n_words = len(ex['word']) - 1 # arcs = {(h, t, label)} stack = [0] buf = [i + 1 for i in range(n_words)] arcs = [] instances = [] for i in range(n_words * 2): gold_t = self.get_oracle(stack, buf, ex) if gold_t is None: break legal_labels = self.legal_labels(stack, buf) assert legal_labels[gold_t] == 1 instances.append((self.extract_features(stack, buf, arcs, ex), legal_labels, gold_t)) if gold_t == self.n_trans - 1: stack.append(buf[0]) buf = buf[1:] elif gold_t < self.n_deprel: arcs.append((stack[-1], stack[-2], gold_t)) stack = stack[:-2] + [stack[-1]] else: arcs.append((stack[-2], stack[-1], gold_t - self.n_deprel)) stack = stack[:-1] else: succ += 1 all_instances += instances return all_instances def legal_labels(self, stack, buf): labels = ([1] if len(stack) > 2 else [0]) * self.n_deprel labels += ([1] if len(stack) >= 2 else [0]) * self.n_deprel labels += [1] if len(buf) > 0 else [0] return labels def parse(self, dataset, eval_batch_size=5000): sentences = [] sentence_id_to_idx = {} for i, example in enumerate(dataset): n_words = len(example['word']) - 1 sentence = [j + 1 for j in range(n_words)] sentences.append(sentence) sentence_id_to_idx[id(sentence)] = i model = ModelWrapper(self, dataset, sentence_id_to_idx) dependencies = minibatch_parse(sentences, model, eval_batch_size) UAS = all_tokens = 0.0 with tqdm(total=len(dataset)) as prog: for i, ex in enumerate(dataset): head = [-1] * len(ex['word']) for h, t, in dependencies[i]: head[t] = h for pred_h, gold_h, gold_l, pos in \ zip(head[1:], ex['head'][1:], ex['label'][1:], ex['pos'][1:]): assert self.id2tok[pos].startswith(P_PREFIX) pos_str = self.id2tok[pos][len(P_PREFIX):] if (self.with_punct) or (not punct(self.language, pos_str)): UAS += 1 if pred_h == gold_h else 0 all_tokens += 1 prog.update(i + 1) UAS /= all_tokens return UAS, dependencies class ModelWrapper(object): def __init__(self, parser, dataset, sentence_id_to_idx): self.parser = parser self.dataset = dataset self.sentence_id_to_idx = sentence_id_to_idx def predict(self, partial_parses): mb_x = [self.parser.extract_features(p.stack, p.buffer, p.dependencies, self.dataset[self.sentence_id_to_idx[id(p.sentence)]]) for p in partial_parses] mb_x = np.array(mb_x).astype('int32') mb_x = torch.from_numpy(mb_x).long() mb_l = [self.parser.legal_labels(p.stack, p.buffer) for p in partial_parses] pred = self.parser.model(mb_x) pred = pred.detach().numpy() pred = np.argmax(pred + 10000 * np.array(mb_l).astype('float32'), 1) pred = ["S" if p == 2 else ("LA" if p == 0 else "RA") for p in pred] return pred def read_conll(in_file, lowercase=False, max_example=None): examples = [] with open(in_file) as f: word, pos, head, label = [], [], [], [] for line in f.readlines(): sp = line.strip().split('\t') if len(sp) == 10: if '-' not in sp[0]: word.append(sp[1].lower() if lowercase else sp[1]) pos.append(sp[4]) head.append(int(sp[6])) label.append(sp[7]) elif len(word) > 0: examples.append({'word': word, 'pos': pos, 'head': head, 'label': label}) word, pos, head, label = [], [], [], [] if (max_example is not None) and (len(examples) == max_example): break if len(word) > 0: examples.append({'word': word, 'pos': pos, 'head': head, 'label': label}) return examples def build_dict(keys, n_max=None, offset=0): count = Counter() for key in keys: count[key] += 1 ls = count.most_common() if n_max is None \ else count.most_common(n_max) return {w[0]: index + offset for (index, w) in enumerate(ls)} def punct(language, pos): if language == 'english': return pos in ["''", ",", ".", ":", "``", "-LRB-", "-RRB-"] elif language == 'chinese': return pos == 'PU' elif language == 'french': return pos == 'PUNC' elif language == 'german': return pos in ["$.", "$,", "$["] elif language == 'spanish': # http://nlp.stanford.edu/software/spanish-faq.shtml return pos in ["f0", "faa", "fat", "fc", "fd", "fe", "fg", "fh", "fia", "fit", "fp", "fpa", "fpt", "fs", "ft", "fx", "fz"] elif language == 'universal': return pos == 'PUNCT' else: raise ValueError('language: %s is not supported.' % language) def minibatches(data, batch_size): x = np.array([d[0] for d in data]) y = np.array([d[2] for d in data]) one_hot = np.zeros((y.size, 3)) one_hot[np.arange(y.size), y] = 1 return get_minibatches([x, one_hot], batch_size) def load_and_preprocess_data(reduced=True): config = Config() print("Loading data...",) start = time.time() train_set = read_conll(os.path.join(config.data_path, config.train_file), lowercase=config.lowercase) dev_set = read_conll(os.path.join(config.data_path, config.dev_file), lowercase=config.lowercase) test_set = read_conll(os.path.join(config.data_path, config.test_file), lowercase=config.lowercase) if reduced: train_set = train_set[:1000] dev_set = dev_set[:500] test_set = test_set[:500] print("took {:.2f} seconds".format(time.time() - start)) print("Building parser...",) start = time.time() parser = Parser(train_set) print("took {:.2f} seconds".format(time.time() - start)) print("Loading pretrained embeddings...",) start = time.time() word_vectors = {} for line in open(config.embedding_file).readlines(): sp = line.strip().split() word_vectors[sp[0]] = [float(x) for x in sp[1:]] embeddings_matrix = np.asarray(np.random.normal(0, 0.9, (parser.n_tokens, 50)), dtype='float32') for token in parser.tok2id: i = parser.tok2id[token] if token in word_vectors: embeddings_matrix[i] = word_vectors[token] elif token.lower() in word_vectors: embeddings_matrix[i] = word_vectors[token.lower()] print("took {:.2f} seconds".format(time.time() - start)) print("Vectorizing data...",) start = time.time() train_set = parser.vectorize(train_set) dev_set = parser.vectorize(dev_set) test_set = parser.vectorize(test_set) print("took {:.2f} seconds".format(time.time() - start)) print("Preprocessing training data...",) start = time.time() train_examples = parser.create_instances(train_set) print("took {:.2f} seconds".format(time.time() - start)) return parser, embeddings_matrix, train_examples, dev_set, test_set, class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count if __name__ == '__main__': pass