add rnn nlg

This commit is contained in:
chongjiu.jin
2019-12-13 14:39:11 +08:00
parent 8baa4f6bf9
commit 704f59eaed
3 changed files with 113 additions and 48 deletions

11
rnn-poetry/flask_demo.py Normal file
View File

@@ -0,0 +1,11 @@
from flask import Flask
app = Flask(__name__)
@app.route('/')
def hello():
return 'Hello World!'
if __name__ == '__main__':
app.run()

View File

@@ -2,43 +2,28 @@ import re
import tqdm
import torch
import collections
import numpy
import pickle
from torch import nn
from model import RNNModel
from model import RNNModel,embed_size,hidden_dims,batch_size
device=
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size=128
embed_size=256
epochs=100
lr=0.001
def get_data():
poetry_file = 'data/poetry.txt'
special_character_removal = re.compile(r'[^\w。 ]', re.IGNORECASE)
# 诗集
poetrys = []
with open(poetry_file, "r", encoding='utf-8', ) as f:
for line in f:
try:
title, content = line.strip().split(':')
content = special_character_removal.sub('', content)
content = content.replace(' ', '')
if len(content) < 5:
continue
if (len(content) > 12 * 6):
content_list = content.split("")
for i in range(0, len(content_list), 2):
content_temp = '[' + content_list[i] + "" + content_list[i + 1] + '。]'
content_temp = content_temp.replace("。。", "")
poetrys.append(content_temp)
else:
content = '[' + content + ']'
poetrys.append(content)
except Exception as e:
pass
poetrys = sorted(poetrys, key=lambda line: len(line))
peotry_path='data/poetry.txt'
with open(peotry_path,'r',encoding='utf-8') as f:
for content in f:
content=content.strip()
content = '[' + content + ']'
poetrys.append(content)
# poetrys = sorted(poetrys, key=lambda line: len(line))
# 统计每个字出现次数
all_words = []
for poetry in poetrys:
@@ -50,41 +35,79 @@ def get_data():
words = words[:len(words)] + (' ',)
# 每个字映射为一个数字ID
word2ix = dict(zip(words, range(len(words))))
ix2word = lambda word: word2ix.get(word, len(words))
data = [list(map(ix2word, poetry)) for poetry in poetrys]
data=numpy.array(data)
ix2word = {v: k for k, v in word2ix.items()}
data = [[word2ix[c] for c in poetry ] for poetry in poetrys]
# data=numpy.array(data)
return data,word2ix,ix2word
def test(model):
start_idx=[word2ix['[']]
end_word=''
lens=0
hidden = None
ret=''
while end_word!=']' and len(ret)<100:
data_ = torch.tensor([start_idx],device=device).long()
# print("data size",data_.size())
output, hidden = model(data_, hidden)
# print("output size", output.size())
ouput_idx=output.view(-1).argmax().cpu()
# print('ouput_idx',ouput_idx)
# print('ouput_idx', ouput_idx.item())
ouput_idx=ouput_idx.item()
start_idx=[ouput_idx]
end_word=ix2word[ouput_idx]
ret+=end_word
return ret
def train():
# 获取数据
data, word2ix, ix2word = get_data()
data = torch.from_numpy(data)
dataloader = torch.utils.data.DataLoader(data,
batch_size=batch_size,
shuffle=True,
num_workers=1)
# 模型定义
model = RNNModel(len(word2ix), batch_size, embed_size)
model = RNNModel(len(word2ix), embed_size, hidden_dims)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
model.to(device)
for epoch in range(epochs):
for ii, data_ in tqdm(enumerate(dataloader)):
data_ = data_.long().transpose(1, 0).contiguous()
data_ = data_.to(device)
model.train()
for epoch in (range(epochs)):
total_loss=0
count=0
for ii, data_ in tqdm.tqdm(enumerate(data)):
data_=torch.tensor(data_).long()
x = data_.unsqueeze(1).to(device)
optimizer.zero_grad()
input_, target = data_[:-1, :], data_[1:, :]
output, _ = model(input_)
loss = criterion(output, target.view(-1))
y = torch.zeros(x.shape).to(device).long()
y[:-1], y[-1] = x[1:], x[0]
output, _ = model(x)
loss = criterion(output, y.view(-1))
"""
hidden=None
for k in range(2,max_lenth):
data1=data_[:k]
input_, target = data1[:-1, :], data1[1:, :]
output, hidden = model(input_,hidden)
loss = criterion(output, target.view(-1))
optimizer.step()
"""
loss.backward()
optimizer.step()
total_loss+=(loss.item())
count+=1
print(epoch,'loss=',total_loss/count)
torch.save(model.state_dict(), 'model.bin' )
chars=test(model)
print(chars)
torch.save(model.state_dict(), 'model.bin' )
if __name__ == "__main__":
# 获取数据
data, word2ix, ix2word = get_data()
with open("word2ix.pkl", 'wb') as outfile:
pickle.dump(word2ix,outfile)
with open("ix2word.pkl", 'wb') as outfile:
pickle.dump(ix2word,outfile)
data=data[:100]
train()

View File

@@ -0,0 +1,31 @@
import pickle
from flask import Flask,request
app = Flask(__name__)
import torch
from model import RNNModel,embed_size,hidden_dims,generate_poetry
with open("word2ix.pkl", 'rb') as outfile:
word2ix=pickle.load(outfile)
with open("ix2word.pkl", 'rb') as outfile:
ix2word=pickle.load(outfile)
device=torch.device('cpu')
model = RNNModel(len(word2ix), embed_size, hidden_dims)
init_checkpoint = 'model.bin'
model.load_state_dict(torch.load(init_checkpoint, map_location='cpu'))
@app.route('/')
def hello():
return 'Hello World!'
@app.route('/peom')
def predict():
begin_word = request.args.get('text', '')
ret=generate_poetry(model,word2ix,ix2word,device,begin_word)
return ret
if __name__ == '__main__':
app.run()