add rnn nlg

This commit is contained in:
chongjiu.jin
2019-12-13 14:39:11 +08:00
parent 8baa4f6bf9
commit 704f59eaed
3 changed files with 113 additions and 48 deletions

11
rnn-poetry/flask_demo.py Normal file
View File

@@ -0,0 +1,11 @@
from flask import Flask
app = Flask(__name__)
@app.route('/')
def hello():
return 'Hello World!'
if __name__ == '__main__':
app.run()

View File

@@ -2,43 +2,28 @@ import re
import tqdm import tqdm
import torch import torch
import collections import collections
import numpy import pickle
from torch import nn from torch import nn
from model import RNNModel from model import RNNModel,embed_size,hidden_dims,batch_size
device= device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size=128
embed_size=256
epochs=100 epochs=100
lr=0.001 lr=0.001
def get_data(): def get_data():
poetry_file = 'data/poetry.txt'
special_character_removal = re.compile(r'[^\w。 ]', re.IGNORECASE) special_character_removal = re.compile(r'[^\w。 ]', re.IGNORECASE)
# 诗集 # 诗集
poetrys = [] poetrys = []
with open(poetry_file, "r", encoding='utf-8', ) as f: peotry_path='data/poetry.txt'
for line in f: with open(peotry_path,'r',encoding='utf-8') as f:
try: for content in f:
title, content = line.strip().split(':') content=content.strip()
content = special_character_removal.sub('', content)
content = content.replace(' ', '')
if len(content) < 5:
continue
if (len(content) > 12 * 6):
content_list = content.split("")
for i in range(0, len(content_list), 2):
content_temp = '[' + content_list[i] + "" + content_list[i + 1] + '。]'
content_temp = content_temp.replace("。。", "")
poetrys.append(content_temp)
else:
content = '[' + content + ']' content = '[' + content + ']'
poetrys.append(content) poetrys.append(content)
except Exception as e:
pass # poetrys = sorted(poetrys, key=lambda line: len(line))
poetrys = sorted(poetrys, key=lambda line: len(line))
# 统计每个字出现次数 # 统计每个字出现次数
all_words = [] all_words = []
for poetry in poetrys: for poetry in poetrys:
@@ -50,41 +35,79 @@ def get_data():
words = words[:len(words)] + (' ',) words = words[:len(words)] + (' ',)
# 每个字映射为一个数字ID # 每个字映射为一个数字ID
word2ix = dict(zip(words, range(len(words)))) word2ix = dict(zip(words, range(len(words))))
ix2word = lambda word: word2ix.get(word, len(words)) ix2word = {v: k for k, v in word2ix.items()}
data = [list(map(ix2word, poetry)) for poetry in poetrys] data = [[word2ix[c] for c in poetry ] for poetry in poetrys]
data=numpy.array(data) # data=numpy.array(data)
return data,word2ix,ix2word return data,word2ix,ix2word
def test(model):
start_idx=[word2ix['[']]
end_word=''
lens=0
hidden = None
ret=''
while end_word!=']' and len(ret)<100:
data_ = torch.tensor([start_idx],device=device).long()
# print("data size",data_.size())
output, hidden = model(data_, hidden)
# print("output size", output.size())
ouput_idx=output.view(-1).argmax().cpu()
# print('ouput_idx',ouput_idx)
# print('ouput_idx', ouput_idx.item())
ouput_idx=ouput_idx.item()
start_idx=[ouput_idx]
end_word=ix2word[ouput_idx]
ret+=end_word
return ret
def train(): def train():
# 获取数据
data, word2ix, ix2word = get_data()
data = torch.from_numpy(data)
dataloader = torch.utils.data.DataLoader(data,
batch_size=batch_size,
shuffle=True,
num_workers=1)
# 模型定义 # 模型定义
model = RNNModel(len(word2ix), batch_size, embed_size) model = RNNModel(len(word2ix), embed_size, hidden_dims)
optimizer = torch.optim.Adam(model.parameters(), lr=lr) optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
model.to(device) model.to(device)
model.train()
for epoch in range(epochs): for epoch in (range(epochs)):
for ii, data_ in tqdm(enumerate(dataloader)): total_loss=0
data_ = data_.long().transpose(1, 0).contiguous() count=0
data_ = data_.to(device) for ii, data_ in tqdm.tqdm(enumerate(data)):
data_=torch.tensor(data_).long()
x = data_.unsqueeze(1).to(device)
optimizer.zero_grad() optimizer.zero_grad()
input_, target = data_[:-1, :], data_[1:, :] y = torch.zeros(x.shape).to(device).long()
output, _ = model(input_) y[:-1], y[-1] = x[1:], x[0]
output, _ = model(x)
loss = criterion(output, y.view(-1))
"""
hidden=None
for k in range(2,max_lenth):
data1=data_[:k]
input_, target = data1[:-1, :], data1[1:, :]
output, hidden = model(input_,hidden)
loss = criterion(output, target.view(-1)) loss = criterion(output, target.view(-1))
optimizer.step()
"""
loss.backward() loss.backward()
optimizer.step() optimizer.step()
total_loss+=(loss.item())
count+=1
print(epoch,'loss=',total_loss/count)
torch.save(model.state_dict(), 'model.bin' ) torch.save(model.state_dict(), 'model.bin' )
if __name__ == "__main__": chars=test(model)
print(chars)
if __name__ == "__main__":
# 获取数据
data, word2ix, ix2word = get_data()
with open("word2ix.pkl", 'wb') as outfile:
pickle.dump(word2ix,outfile)
with open("ix2word.pkl", 'wb') as outfile:
pickle.dump(ix2word,outfile)
data=data[:100]
train() train()

View File

@@ -0,0 +1,31 @@
import pickle
from flask import Flask,request
app = Flask(__name__)
import torch
from model import RNNModel,embed_size,hidden_dims,generate_poetry
with open("word2ix.pkl", 'rb') as outfile:
word2ix=pickle.load(outfile)
with open("ix2word.pkl", 'rb') as outfile:
ix2word=pickle.load(outfile)
device=torch.device('cpu')
model = RNNModel(len(word2ix), embed_size, hidden_dims)
init_checkpoint = 'model.bin'
model.load_state_dict(torch.load(init_checkpoint, map_location='cpu'))
@app.route('/')
def hello():
return 'Hello World!'
@app.route('/peom')
def predict():
begin_word = request.args.get('text', '')
ret=generate_poetry(model,word2ix,ix2word,device,begin_word)
return ret
if __name__ == '__main__':
app.run()