Files
cs224n_2019/rnn-poetry/run.py
chongjiu.jin 704f59eaed add rnn nlg
2019-12-13 14:39:11 +08:00

113 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
import tqdm
import torch
import collections
import pickle
from torch import nn
from model import RNNModel,embed_size,hidden_dims,batch_size
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs=100
lr=0.001
def get_data():
special_character_removal = re.compile(r'[^\w。 ]', re.IGNORECASE)
# 诗集
poetrys = []
peotry_path='data/poetry.txt'
with open(peotry_path,'r',encoding='utf-8') as f:
for content in f:
content=content.strip()
content = '[' + content + ']'
poetrys.append(content)
# poetrys = sorted(poetrys, key=lambda line: len(line))
# 统计每个字出现次数
all_words = []
for poetry in poetrys:
all_words += [word for word in poetry]
counter = collections.Counter(all_words)
count_pairs = sorted(counter.items(), key=lambda x: -x[1])
words, _ = zip(*count_pairs)
# 取前多少个常用字
words = words[:len(words)] + (' ',)
# 每个字映射为一个数字ID
word2ix = dict(zip(words, range(len(words))))
ix2word = {v: k for k, v in word2ix.items()}
data = [[word2ix[c] for c in poetry ] for poetry in poetrys]
# data=numpy.array(data)
return data,word2ix,ix2word
def test(model):
start_idx=[word2ix['[']]
end_word=''
lens=0
hidden = None
ret=''
while end_word!=']' and len(ret)<100:
data_ = torch.tensor([start_idx],device=device).long()
# print("data size",data_.size())
output, hidden = model(data_, hidden)
# print("output size", output.size())
ouput_idx=output.view(-1).argmax().cpu()
# print('ouput_idx',ouput_idx)
# print('ouput_idx', ouput_idx.item())
ouput_idx=ouput_idx.item()
start_idx=[ouput_idx]
end_word=ix2word[ouput_idx]
ret+=end_word
return ret
def train():
# 模型定义
model = RNNModel(len(word2ix), embed_size, hidden_dims)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
model.to(device)
model.train()
for epoch in (range(epochs)):
total_loss=0
count=0
for ii, data_ in tqdm.tqdm(enumerate(data)):
data_=torch.tensor(data_).long()
x = data_.unsqueeze(1).to(device)
optimizer.zero_grad()
y = torch.zeros(x.shape).to(device).long()
y[:-1], y[-1] = x[1:], x[0]
output, _ = model(x)
loss = criterion(output, y.view(-1))
"""
hidden=None
for k in range(2,max_lenth):
data1=data_[:k]
input_, target = data1[:-1, :], data1[1:, :]
output, hidden = model(input_,hidden)
loss = criterion(output, target.view(-1))
optimizer.step()
"""
loss.backward()
optimizer.step()
total_loss+=(loss.item())
count+=1
print(epoch,'loss=',total_loss/count)
torch.save(model.state_dict(), 'model.bin' )
chars=test(model)
print(chars)
if __name__ == "__main__":
# 获取数据
data, word2ix, ix2word = get_data()
with open("word2ix.pkl", 'wb') as outfile:
pickle.dump(word2ix,outfile)
with open("ix2word.pkl", 'wb') as outfile:
pickle.dump(ix2word,outfile)
data=data[:100]
train()