a2
This commit is contained in:
105
[finished]Assignment_2_word2vec/others/pytorch_train.py
Normal file
105
[finished]Assignment_2_word2vec/others/pytorch_train.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import random
|
||||
import re
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from tqdm import tqdm
|
||||
from pytorch_word2vec_model import SkipGram
|
||||
|
||||
epochs = 50
|
||||
negative_sampling = 4
|
||||
window = 2
|
||||
vocab_size = 1
|
||||
embd_size = 300
|
||||
|
||||
|
||||
def batch_data(x, batch_size=128):
|
||||
in_w = []
|
||||
out_w = []
|
||||
target = []
|
||||
for text in x:
|
||||
for i in range(window, len(text) - window):
|
||||
word_set = set()
|
||||
in_w.append(text[i])
|
||||
in_w.append(text[i])
|
||||
in_w.append(text[i])
|
||||
in_w.append(text[i])
|
||||
|
||||
out_w.append(text[i - 2])
|
||||
out_w.append(text[i - 1])
|
||||
out_w.append(text[i + 1])
|
||||
out_w.append(text[i + 2])
|
||||
|
||||
target.append(1)
|
||||
target.append(1)
|
||||
target.append(1)
|
||||
target.append(1)
|
||||
# negative sampling
|
||||
count = 0
|
||||
while count < negative_sampling:
|
||||
rand_id = random.randint(0, vocab_size-1)
|
||||
if not rand_id in word_set:
|
||||
in_w.append(text[i])
|
||||
out_w.append(rand_id)
|
||||
target.append(0)
|
||||
count += 1
|
||||
|
||||
if len(out_w) >= batch_size:
|
||||
yield [in_w, out_w, target]
|
||||
in_w = []
|
||||
out_w = []
|
||||
target = []
|
||||
if out_w:
|
||||
yield [in_w, out_w, target]
|
||||
|
||||
|
||||
def train(train_text_id, model,opt):
|
||||
model.train() # 启用dropout和batch normalization
|
||||
ave_loss = 0
|
||||
pbar = tqdm()
|
||||
cnt=0
|
||||
for x_batch in batch_data(train_text_id):
|
||||
in_w, out_w, target = x_batch
|
||||
in_w_var = torch.tensor(in_w)
|
||||
out_w_var = torch.tensor(out_w)
|
||||
target_var = torch.tensor(target,dtype=torch.float)
|
||||
|
||||
model.zero_grad()
|
||||
log_probs = model(in_w_var, out_w_var)
|
||||
loss = model.loss(log_probs, target_var)
|
||||
loss.backward()
|
||||
opt.step()
|
||||
ave_loss += loss.item()
|
||||
pbar.update(1)
|
||||
cnt += 1
|
||||
pbar.set_description('< loss: %.5f >' % (ave_loss / cnt))
|
||||
pbar.close()
|
||||
text_id = []
|
||||
vocab_dict = {}
|
||||
|
||||
with open(
|
||||
'D:\\project\\ml\\github\\cs224n-natural-language-processing-winter2019\\a1_intro_word_vectors\\a1\\corpus\\corpus.txt',
|
||||
encoding='utf-8') as fp:
|
||||
for line in fp:
|
||||
lines = re.sub("[^A-Za-z0-9']+", ' ', line).lower().split()
|
||||
line_id = []
|
||||
for s in lines:
|
||||
if not s:
|
||||
continue
|
||||
if s not in vocab_dict:
|
||||
vocab_dict[s] = len(vocab_dict)
|
||||
id = vocab_dict[s]
|
||||
line_id.append(id)
|
||||
if id==11500:
|
||||
print(id,s)
|
||||
text_id.append(line_id)
|
||||
vocab_size = len(vocab_dict)
|
||||
print('vocab_size', vocab_size)
|
||||
model = SkipGram(vocab_size, embd_size)
|
||||
|
||||
for epoch in range(epochs):
|
||||
print('epoch', epoch)
|
||||
opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
|
||||
lr=0.001, weight_decay=0)
|
||||
train(text_id, model,opt)
|
||||
|
||||
Reference in New Issue
Block a user