106 lines
2.9 KiB
Python
106 lines
2.9 KiB
Python
import random
|
|
import re
|
|
|
|
import torch
|
|
import torch.optim as optim
|
|
from tqdm import tqdm
|
|
from pytorch_word2vec_model import SkipGram
|
|
|
|
epochs = 50
|
|
negative_sampling = 4
|
|
window = 2
|
|
vocab_size = 1
|
|
embd_size = 300
|
|
|
|
|
|
def batch_data(x, batch_size=128):
|
|
in_w = []
|
|
out_w = []
|
|
target = []
|
|
for text in x:
|
|
for i in range(window, len(text) - window):
|
|
word_set = set()
|
|
in_w.append(text[i])
|
|
in_w.append(text[i])
|
|
in_w.append(text[i])
|
|
in_w.append(text[i])
|
|
|
|
out_w.append(text[i - 2])
|
|
out_w.append(text[i - 1])
|
|
out_w.append(text[i + 1])
|
|
out_w.append(text[i + 2])
|
|
|
|
target.append(1)
|
|
target.append(1)
|
|
target.append(1)
|
|
target.append(1)
|
|
# negative sampling
|
|
count = 0
|
|
while count < negative_sampling:
|
|
rand_id = random.randint(0, vocab_size-1)
|
|
if not rand_id in word_set:
|
|
in_w.append(text[i])
|
|
out_w.append(rand_id)
|
|
target.append(0)
|
|
count += 1
|
|
|
|
if len(out_w) >= batch_size:
|
|
yield [in_w, out_w, target]
|
|
in_w = []
|
|
out_w = []
|
|
target = []
|
|
if out_w:
|
|
yield [in_w, out_w, target]
|
|
|
|
|
|
def train(train_text_id, model,opt):
|
|
model.train() # 启用dropout和batch normalization
|
|
ave_loss = 0
|
|
pbar = tqdm()
|
|
cnt=0
|
|
for x_batch in batch_data(train_text_id):
|
|
in_w, out_w, target = x_batch
|
|
in_w_var = torch.tensor(in_w)
|
|
out_w_var = torch.tensor(out_w)
|
|
target_var = torch.tensor(target,dtype=torch.float)
|
|
|
|
model.zero_grad()
|
|
log_probs = model(in_w_var, out_w_var)
|
|
loss = model.loss(log_probs, target_var)
|
|
loss.backward()
|
|
opt.step()
|
|
ave_loss += loss.item()
|
|
pbar.update(1)
|
|
cnt += 1
|
|
pbar.set_description('< loss: %.5f >' % (ave_loss / cnt))
|
|
pbar.close()
|
|
text_id = []
|
|
vocab_dict = {}
|
|
|
|
with open(
|
|
'D:\\project\\ml\\github\\cs224n-natural-language-processing-winter2019\\a1_intro_word_vectors\\a1\\corpus\\corpus.txt',
|
|
encoding='utf-8') as fp:
|
|
for line in fp:
|
|
lines = re.sub("[^A-Za-z0-9']+", ' ', line).lower().split()
|
|
line_id = []
|
|
for s in lines:
|
|
if not s:
|
|
continue
|
|
if s not in vocab_dict:
|
|
vocab_dict[s] = len(vocab_dict)
|
|
id = vocab_dict[s]
|
|
line_id.append(id)
|
|
if id==11500:
|
|
print(id,s)
|
|
text_id.append(line_id)
|
|
vocab_size = len(vocab_dict)
|
|
print('vocab_size', vocab_size)
|
|
model = SkipGram(vocab_size, embd_size)
|
|
|
|
for epoch in range(epochs):
|
|
print('epoch', epoch)
|
|
opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
|
|
lr=0.001, weight_decay=0)
|
|
train(text_id, model,opt)
|
|
|