This commit is contained in:
chongjiu.jin
2019-11-25 10:52:19 +08:00
parent e6435b5f6c
commit 86eeba3b38
15 changed files with 2312591 additions and 1 deletions

View File

@@ -0,0 +1,63 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CS224N 2018-19: Homework 3
general_utils.py: General purpose utilities.
Sahil Chopra <schopra8@stanford.edu>
"""
import sys
import time
import numpy as np
def get_minibatches(data, minibatch_size, shuffle=True):
"""
Iterates through the provided data one minibatch at at time. You can use this function to
iterate through data in minibatches as follows:
for inputs_minibatch in get_minibatches(inputs, minibatch_size):
...
Or with multiple data sources:
for inputs_minibatch, labels_minibatch in get_minibatches([inputs, labels], minibatch_size):
...
Args:
data: there are two possible values:
- a list or numpy array
- a list where each element is either a list or numpy array
minibatch_size: the maximum number of items in a minibatch
shuffle: whether to randomize the order of returned data
Returns:
minibatches: the return value depends on data:
- If data is a list/array it yields the next minibatch of data.
- If data a list of lists/arrays it returns the next minibatch of each element in the
list. This can be used to iterate through multiple data sources
(e.g., features and labels) at the same time.
"""
list_data = type(data) is list and (type(data[0]) is list or type(data[0]) is np.ndarray)
data_size = len(data[0]) if list_data else len(data)
indices = np.arange(data_size)
if shuffle:
np.random.shuffle(indices)
for minibatch_start in np.arange(0, data_size, minibatch_size):
minibatch_indices = indices[minibatch_start:minibatch_start + minibatch_size]
yield [_minibatch(d, minibatch_indices) for d in data] if list_data \
else _minibatch(data, minibatch_indices)
def _minibatch(data, minibatch_idx):
return data[minibatch_idx] if type(data) is np.ndarray else [data[i] for i in minibatch_idx]
def test_all_close(name, actual, expected):
if actual.shape != expected.shape:
raise ValueError("{:} failed, expected output to have shape {:} but has shape {:}"
.format(name, expected.shape, actual.shape))
if np.amax(np.fabs(actual - expected)) > 1e-6:
raise ValueError("{:} failed, expected {:} but value is {:}".format(name, expected, actual))
else:
print(name, "passed!")

View File

@@ -0,0 +1,422 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CS224N 2018-19: Homework 3
parser_utils.py: Utilities for training the dependency parser.
Sahil Chopra <schopra8@stanford.edu>
"""
import time
import os
import logging
from collections import Counter
from . general_utils import get_minibatches
from parser_transitions import minibatch_parse
from tqdm import tqdm
import torch
import numpy as np
P_PREFIX = '<p>:'
L_PREFIX = '<l>:'
UNK = '<UNK>'
NULL = '<NULL>'
ROOT = '<ROOT>'
class Config(object):
language = 'english'
with_punct = True
unlabeled = True
lowercase = True
use_pos = True
use_dep = True
use_dep = use_dep and (not unlabeled)
data_path = './data'
train_file = 'train.conll'
dev_file = 'dev.conll'
test_file = 'test.conll'
embedding_file = './data/en-cw.txt'
class Parser(object):
"""Contains everything needed for transition-based dependency parsing except for the model"""
def __init__(self, dataset):
root_labels = list([l for ex in dataset
for (h, l) in zip(ex['head'], ex['label']) if h == 0])
counter = Counter(root_labels)
if len(counter) > 1:
logging.info('Warning: more than one root label')
logging.info(counter)
self.root_label = counter.most_common()[0][0]
deprel = [self.root_label] + list(set([w for ex in dataset
for w in ex['label']
if w != self.root_label]))
tok2id = {L_PREFIX + l: i for (i, l) in enumerate(deprel)}
tok2id[L_PREFIX + NULL] = self.L_NULL = len(tok2id)
config = Config()
self.unlabeled = config.unlabeled
self.with_punct = config.with_punct
self.use_pos = config.use_pos
self.use_dep = config.use_dep
self.language = config.language
if self.unlabeled:
trans = ['L', 'R', 'S']
self.n_deprel = 1
else:
trans = ['L-' + l for l in deprel] + ['R-' + l for l in deprel] + ['S']
self.n_deprel = len(deprel)
self.n_trans = len(trans)
self.tran2id = {t: i for (i, t) in enumerate(trans)}
self.id2tran = {i: t for (i, t) in enumerate(trans)}
# logging.info('Build dictionary for part-of-speech tags.')
tok2id.update(build_dict([P_PREFIX + w for ex in dataset for w in ex['pos']],
offset=len(tok2id)))
tok2id[P_PREFIX + UNK] = self.P_UNK = len(tok2id)
tok2id[P_PREFIX + NULL] = self.P_NULL = len(tok2id)
tok2id[P_PREFIX + ROOT] = self.P_ROOT = len(tok2id)
# logging.info('Build dictionary for words.')
tok2id.update(build_dict([w for ex in dataset for w in ex['word']],
offset=len(tok2id)))
tok2id[UNK] = self.UNK = len(tok2id)
tok2id[NULL] = self.NULL = len(tok2id)
tok2id[ROOT] = self.ROOT = len(tok2id)
self.tok2id = tok2id
self.id2tok = {v: k for (k, v) in tok2id.items()}
self.n_features = 18 + (18 if config.use_pos else 0) + (12 if config.use_dep else 0)
self.n_tokens = len(tok2id)
def vectorize(self, examples):
vec_examples = []
for ex in examples:
word = [self.ROOT] + [self.tok2id[w] if w in self.tok2id
else self.UNK for w in ex['word']]
pos = [self.P_ROOT] + [self.tok2id[P_PREFIX + w] if P_PREFIX + w in self.tok2id
else self.P_UNK for w in ex['pos']]
head = [-1] + ex['head']
label = [-1] + [self.tok2id[L_PREFIX + w] if L_PREFIX + w in self.tok2id
else -1 for w in ex['label']]
vec_examples.append({'word': word, 'pos': pos,
'head': head, 'label': label})
return vec_examples
def extract_features(self, stack, buf, arcs, ex):
if stack[0] == "ROOT":
stack[0] = 0
def get_lc(k):
return sorted([arc[1] for arc in arcs if arc[0] == k and arc[1] < k])
def get_rc(k):
return sorted([arc[1] for arc in arcs if arc[0] == k and arc[1] > k],
reverse=True)
p_features = []
l_features = []
features = [self.NULL] * (3 - len(stack)) + [ex['word'][x] for x in stack[-3:]]
features += [ex['word'][x] for x in buf[:3]] + [self.NULL] * (3 - len(buf))
if self.use_pos:
p_features = [self.P_NULL] * (3 - len(stack)) + [ex['pos'][x] for x in stack[-3:]]
p_features += [ex['pos'][x] for x in buf[:3]] + [self.P_NULL] * (3 - len(buf))
for i in range(2):
if i < len(stack):
k = stack[-i-1]
lc = get_lc(k)
rc = get_rc(k)
llc = get_lc(lc[0]) if len(lc) > 0 else []
rrc = get_rc(rc[0]) if len(rc) > 0 else []
features.append(ex['word'][lc[0]] if len(lc) > 0 else self.NULL)
features.append(ex['word'][rc[0]] if len(rc) > 0 else self.NULL)
features.append(ex['word'][lc[1]] if len(lc) > 1 else self.NULL)
features.append(ex['word'][rc[1]] if len(rc) > 1 else self.NULL)
features.append(ex['word'][llc[0]] if len(llc) > 0 else self.NULL)
features.append(ex['word'][rrc[0]] if len(rrc) > 0 else self.NULL)
if self.use_pos:
p_features.append(ex['pos'][lc[0]] if len(lc) > 0 else self.P_NULL)
p_features.append(ex['pos'][rc[0]] if len(rc) > 0 else self.P_NULL)
p_features.append(ex['pos'][lc[1]] if len(lc) > 1 else self.P_NULL)
p_features.append(ex['pos'][rc[1]] if len(rc) > 1 else self.P_NULL)
p_features.append(ex['pos'][llc[0]] if len(llc) > 0 else self.P_NULL)
p_features.append(ex['pos'][rrc[0]] if len(rrc) > 0 else self.P_NULL)
if self.use_dep:
l_features.append(ex['label'][lc[0]] if len(lc) > 0 else self.L_NULL)
l_features.append(ex['label'][rc[0]] if len(rc) > 0 else self.L_NULL)
l_features.append(ex['label'][lc[1]] if len(lc) > 1 else self.L_NULL)
l_features.append(ex['label'][rc[1]] if len(rc) > 1 else self.L_NULL)
l_features.append(ex['label'][llc[0]] if len(llc) > 0 else self.L_NULL)
l_features.append(ex['label'][rrc[0]] if len(rrc) > 0 else self.L_NULL)
else:
features += [self.NULL] * 6
if self.use_pos:
p_features += [self.P_NULL] * 6
if self.use_dep:
l_features += [self.L_NULL] * 6
features += p_features + l_features
assert len(features) == self.n_features
return features
def get_oracle(self, stack, buf, ex):
if len(stack) < 2:
return self.n_trans - 1
i0 = stack[-1]
i1 = stack[-2]
h0 = ex['head'][i0]
h1 = ex['head'][i1]
l0 = ex['label'][i0]
l1 = ex['label'][i1]
if self.unlabeled:
if (i1 > 0) and (h1 == i0):
return 0
elif (i1 >= 0) and (h0 == i1) and \
(not any([x for x in buf if ex['head'][x] == i0])):
return 1
else:
return None if len(buf) == 0 else 2
else:
if (i1 > 0) and (h1 == i0):
return l1 if (l1 >= 0) and (l1 < self.n_deprel) else None
elif (i1 >= 0) and (h0 == i1) and \
(not any([x for x in buf if ex['head'][x] == i0])):
return l0 + self.n_deprel if (l0 >= 0) and (l0 < self.n_deprel) else None
else:
return None if len(buf) == 0 else self.n_trans - 1
def create_instances(self, examples):
all_instances = []
succ = 0
for id, ex in enumerate(examples):
n_words = len(ex['word']) - 1
# arcs = {(h, t, label)}
stack = [0]
buf = [i + 1 for i in range(n_words)]
arcs = []
instances = []
for i in range(n_words * 2):
gold_t = self.get_oracle(stack, buf, ex)
if gold_t is None:
break
legal_labels = self.legal_labels(stack, buf)
assert legal_labels[gold_t] == 1
instances.append((self.extract_features(stack, buf, arcs, ex),
legal_labels, gold_t))
if gold_t == self.n_trans - 1:
stack.append(buf[0])
buf = buf[1:]
elif gold_t < self.n_deprel:
arcs.append((stack[-1], stack[-2], gold_t))
stack = stack[:-2] + [stack[-1]]
else:
arcs.append((stack[-2], stack[-1], gold_t - self.n_deprel))
stack = stack[:-1]
else:
succ += 1
all_instances += instances
return all_instances
def legal_labels(self, stack, buf):
labels = ([1] if len(stack) > 2 else [0]) * self.n_deprel
labels += ([1] if len(stack) >= 2 else [0]) * self.n_deprel
labels += [1] if len(buf) > 0 else [0]
return labels
def parse(self, dataset, eval_batch_size=5000):
sentences = []
sentence_id_to_idx = {}
for i, example in enumerate(dataset):
n_words = len(example['word']) - 1
sentence = [j + 1 for j in range(n_words)]
sentences.append(sentence)
sentence_id_to_idx[id(sentence)] = i
model = ModelWrapper(self, dataset, sentence_id_to_idx)
dependencies = minibatch_parse(sentences, model, eval_batch_size)
UAS = all_tokens = 0.0
with tqdm(total=len(dataset)) as prog:
for i, ex in enumerate(dataset):
head = [-1] * len(ex['word'])
for h, t, in dependencies[i]:
head[t] = h
for pred_h, gold_h, gold_l, pos in \
zip(head[1:], ex['head'][1:], ex['label'][1:], ex['pos'][1:]):
assert self.id2tok[pos].startswith(P_PREFIX)
pos_str = self.id2tok[pos][len(P_PREFIX):]
if (self.with_punct) or (not punct(self.language, pos_str)):
UAS += 1 if pred_h == gold_h else 0
all_tokens += 1
prog.update(i + 1)
UAS /= all_tokens
return UAS, dependencies
class ModelWrapper(object):
def __init__(self, parser, dataset, sentence_id_to_idx):
self.parser = parser
self.dataset = dataset
self.sentence_id_to_idx = sentence_id_to_idx
def predict(self, partial_parses):
mb_x = [self.parser.extract_features(p.stack, p.buffer, p.dependencies,
self.dataset[self.sentence_id_to_idx[id(p.sentence)]])
for p in partial_parses]
mb_x = np.array(mb_x).astype('int32')
mb_x = torch.from_numpy(mb_x).long()
mb_l = [self.parser.legal_labels(p.stack, p.buffer) for p in partial_parses]
pred = self.parser.model(mb_x)
pred = pred.detach().numpy()
pred = np.argmax(pred + 10000 * np.array(mb_l).astype('float32'), 1)
pred = ["S" if p == 2 else ("LA" if p == 0 else "RA") for p in pred]
return pred
def read_conll(in_file, lowercase=False, max_example=None):
examples = []
with open(in_file) as f:
word, pos, head, label = [], [], [], []
for line in f.readlines():
sp = line.strip().split('\t')
if len(sp) == 10:
if '-' not in sp[0]:
word.append(sp[1].lower() if lowercase else sp[1])
pos.append(sp[4])
head.append(int(sp[6]))
label.append(sp[7])
elif len(word) > 0:
examples.append({'word': word, 'pos': pos, 'head': head, 'label': label})
word, pos, head, label = [], [], [], []
if (max_example is not None) and (len(examples) == max_example):
break
if len(word) > 0:
examples.append({'word': word, 'pos': pos, 'head': head, 'label': label})
return examples
def build_dict(keys, n_max=None, offset=0):
count = Counter()
for key in keys:
count[key] += 1
ls = count.most_common() if n_max is None \
else count.most_common(n_max)
return {w[0]: index + offset for (index, w) in enumerate(ls)}
def punct(language, pos):
if language == 'english':
return pos in ["''", ",", ".", ":", "``", "-LRB-", "-RRB-"]
elif language == 'chinese':
return pos == 'PU'
elif language == 'french':
return pos == 'PUNC'
elif language == 'german':
return pos in ["$.", "$,", "$["]
elif language == 'spanish':
# http://nlp.stanford.edu/software/spanish-faq.shtml
return pos in ["f0", "faa", "fat", "fc", "fd", "fe", "fg", "fh",
"fia", "fit", "fp", "fpa", "fpt", "fs", "ft",
"fx", "fz"]
elif language == 'universal':
return pos == 'PUNCT'
else:
raise ValueError('language: %s is not supported.' % language)
def minibatches(data, batch_size):
x = np.array([d[0] for d in data])
y = np.array([d[2] for d in data])
one_hot = np.zeros((y.size, 3))
one_hot[np.arange(y.size), y] = 1
return get_minibatches([x, one_hot], batch_size)
def load_and_preprocess_data(reduced=True):
config = Config()
print("Loading data...",)
start = time.time()
train_set = read_conll(os.path.join(config.data_path, config.train_file),
lowercase=config.lowercase)
dev_set = read_conll(os.path.join(config.data_path, config.dev_file),
lowercase=config.lowercase)
test_set = read_conll(os.path.join(config.data_path, config.test_file),
lowercase=config.lowercase)
if reduced:
train_set = train_set[:1000]
dev_set = dev_set[:500]
test_set = test_set[:500]
print("took {:.2f} seconds".format(time.time() - start))
print("Building parser...",)
start = time.time()
parser = Parser(train_set)
print("took {:.2f} seconds".format(time.time() - start))
print("Loading pretrained embeddings...",)
start = time.time()
word_vectors = {}
for line in open(config.embedding_file).readlines():
sp = line.strip().split()
word_vectors[sp[0]] = [float(x) for x in sp[1:]]
embeddings_matrix = np.asarray(np.random.normal(0, 0.9, (parser.n_tokens, 50)), dtype='float32')
for token in parser.tok2id:
i = parser.tok2id[token]
if token in word_vectors:
embeddings_matrix[i] = word_vectors[token]
elif token.lower() in word_vectors:
embeddings_matrix[i] = word_vectors[token.lower()]
print("took {:.2f} seconds".format(time.time() - start))
print("Vectorizing data...",)
start = time.time()
train_set = parser.vectorize(train_set)
dev_set = parser.vectorize(dev_set)
test_set = parser.vectorize(test_set)
print("took {:.2f} seconds".format(time.time() - start))
print("Preprocessing training data...",)
start = time.time()
train_examples = parser.create_instances(train_set)
print("took {:.2f} seconds".format(time.time() - start))
return parser, embeddings_matrix, train_examples, dev_set, test_set,
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
if __name__ == '__main__':
pass