Files
cs224n_2019/rnn-poetry/run.py

113 lines
3.3 KiB
Python
Raw Normal View History

2019-12-03 16:41:43 +08:00
import re
import tqdm
import torch
import collections
2019-12-13 14:39:11 +08:00
import pickle
2019-12-03 16:41:43 +08:00
from torch import nn
2019-12-13 14:39:11 +08:00
from model import RNNModel,embed_size,hidden_dims,batch_size
2019-12-03 16:41:43 +08:00
2019-12-13 14:39:11 +08:00
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
2019-12-03 16:41:43 +08:00
epochs=100
lr=0.001
def get_data():
special_character_removal = re.compile(r'[^\w。 ]', re.IGNORECASE)
# 诗集
poetrys = []
2019-12-13 14:39:11 +08:00
peotry_path='data/poetry.txt'
with open(peotry_path,'r',encoding='utf-8') as f:
for content in f:
content=content.strip()
content = '[' + content + ']'
poetrys.append(content)
# poetrys = sorted(poetrys, key=lambda line: len(line))
2019-12-03 16:41:43 +08:00
# 统计每个字出现次数
all_words = []
for poetry in poetrys:
all_words += [word for word in poetry]
counter = collections.Counter(all_words)
count_pairs = sorted(counter.items(), key=lambda x: -x[1])
words, _ = zip(*count_pairs)
# 取前多少个常用字
words = words[:len(words)] + (' ',)
# 每个字映射为一个数字ID
word2ix = dict(zip(words, range(len(words))))
2019-12-13 14:39:11 +08:00
ix2word = {v: k for k, v in word2ix.items()}
data = [[word2ix[c] for c in poetry ] for poetry in poetrys]
# data=numpy.array(data)
2019-12-03 16:41:43 +08:00
return data,word2ix,ix2word
2019-12-13 14:39:11 +08:00
def test(model):
start_idx=[word2ix['[']]
end_word=''
lens=0
hidden = None
ret=''
while end_word!=']' and len(ret)<100:
data_ = torch.tensor([start_idx],device=device).long()
# print("data size",data_.size())
output, hidden = model(data_, hidden)
# print("output size", output.size())
ouput_idx=output.view(-1).argmax().cpu()
# print('ouput_idx',ouput_idx)
# print('ouput_idx', ouput_idx.item())
ouput_idx=ouput_idx.item()
start_idx=[ouput_idx]
end_word=ix2word[ouput_idx]
ret+=end_word
return ret
2019-12-03 16:41:43 +08:00
2019-12-13 14:39:11 +08:00
def train():
2019-12-03 16:41:43 +08:00
# 模型定义
2019-12-13 14:39:11 +08:00
model = RNNModel(len(word2ix), embed_size, hidden_dims)
2019-12-03 16:41:43 +08:00
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
model.to(device)
2019-12-13 14:39:11 +08:00
model.train()
for epoch in (range(epochs)):
total_loss=0
count=0
for ii, data_ in tqdm.tqdm(enumerate(data)):
data_=torch.tensor(data_).long()
x = data_.unsqueeze(1).to(device)
2019-12-03 16:41:43 +08:00
optimizer.zero_grad()
2019-12-13 14:39:11 +08:00
y = torch.zeros(x.shape).to(device).long()
y[:-1], y[-1] = x[1:], x[0]
output, _ = model(x)
loss = criterion(output, y.view(-1))
"""
hidden=None
for k in range(2,max_lenth):
data1=data_[:k]
input_, target = data1[:-1, :], data1[1:, :]
output, hidden = model(input_,hidden)
loss = criterion(output, target.view(-1))
optimizer.step()
"""
2019-12-03 16:41:43 +08:00
loss.backward()
optimizer.step()
2019-12-13 14:39:11 +08:00
total_loss+=(loss.item())
count+=1
print(epoch,'loss=',total_loss/count)
torch.save(model.state_dict(), 'model.bin' )
chars=test(model)
print(chars)
2019-12-03 16:41:43 +08:00
if __name__ == "__main__":
2019-12-13 14:39:11 +08:00
# 获取数据
data, word2ix, ix2word = get_data()
with open("word2ix.pkl", 'wb') as outfile:
pickle.dump(word2ix,outfile)
with open("ix2word.pkl", 'wb') as outfile:
pickle.dump(ix2word,outfile)
2019-12-03 16:41:43 +08:00
2019-12-13 14:39:11 +08:00
data=data[:100]
2019-12-03 16:41:43 +08:00
train()