add rnn nlg
This commit is contained in:
11
rnn-poetry/flask_demo.py
Normal file
11
rnn-poetry/flask_demo.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
from flask import Flask
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/')
|
||||||
|
def hello():
|
||||||
|
return 'Hello World!'
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run()
|
||||||
@@ -2,43 +2,28 @@ import re
|
|||||||
import tqdm
|
import tqdm
|
||||||
import torch
|
import torch
|
||||||
import collections
|
import collections
|
||||||
import numpy
|
import pickle
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from model import RNNModel
|
from model import RNNModel,embed_size,hidden_dims,batch_size
|
||||||
|
|
||||||
device=
|
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
batch_size=128
|
|
||||||
embed_size=256
|
|
||||||
|
|
||||||
epochs=100
|
epochs=100
|
||||||
lr=0.001
|
lr=0.001
|
||||||
def get_data():
|
def get_data():
|
||||||
poetry_file = 'data/poetry.txt'
|
|
||||||
special_character_removal = re.compile(r'[^\w。, ]', re.IGNORECASE)
|
special_character_removal = re.compile(r'[^\w。, ]', re.IGNORECASE)
|
||||||
# 诗集
|
# 诗集
|
||||||
poetrys = []
|
poetrys = []
|
||||||
with open(poetry_file, "r", encoding='utf-8', ) as f:
|
peotry_path='data/poetry.txt'
|
||||||
for line in f:
|
with open(peotry_path,'r',encoding='utf-8') as f:
|
||||||
try:
|
for content in f:
|
||||||
title, content = line.strip().split(':')
|
content=content.strip()
|
||||||
content = special_character_removal.sub('', content)
|
|
||||||
content = content.replace(' ', '')
|
|
||||||
if len(content) < 5:
|
|
||||||
continue
|
|
||||||
if (len(content) > 12 * 6):
|
|
||||||
content_list = content.split("。")
|
|
||||||
for i in range(0, len(content_list), 2):
|
|
||||||
content_temp = '[' + content_list[i] + "。" + content_list[i + 1] + '。]'
|
|
||||||
content_temp = content_temp.replace("。。", "。")
|
|
||||||
poetrys.append(content_temp)
|
|
||||||
else:
|
|
||||||
content = '[' + content + ']'
|
content = '[' + content + ']'
|
||||||
poetrys.append(content)
|
poetrys.append(content)
|
||||||
except Exception as e:
|
|
||||||
pass
|
# poetrys = sorted(poetrys, key=lambda line: len(line))
|
||||||
poetrys = sorted(poetrys, key=lambda line: len(line))
|
|
||||||
# 统计每个字出现次数
|
# 统计每个字出现次数
|
||||||
all_words = []
|
all_words = []
|
||||||
for poetry in poetrys:
|
for poetry in poetrys:
|
||||||
@@ -50,41 +35,79 @@ def get_data():
|
|||||||
words = words[:len(words)] + (' ',)
|
words = words[:len(words)] + (' ',)
|
||||||
# 每个字映射为一个数字ID
|
# 每个字映射为一个数字ID
|
||||||
word2ix = dict(zip(words, range(len(words))))
|
word2ix = dict(zip(words, range(len(words))))
|
||||||
ix2word = lambda word: word2ix.get(word, len(words))
|
ix2word = {v: k for k, v in word2ix.items()}
|
||||||
data = [list(map(ix2word, poetry)) for poetry in poetrys]
|
data = [[word2ix[c] for c in poetry ] for poetry in poetrys]
|
||||||
data=numpy.array(data)
|
# data=numpy.array(data)
|
||||||
return data,word2ix,ix2word
|
return data,word2ix,ix2word
|
||||||
|
|
||||||
|
def test(model):
|
||||||
|
start_idx=[word2ix['[']]
|
||||||
|
end_word=''
|
||||||
|
lens=0
|
||||||
|
hidden = None
|
||||||
|
ret=''
|
||||||
|
while end_word!=']' and len(ret)<100:
|
||||||
|
data_ = torch.tensor([start_idx],device=device).long()
|
||||||
|
# print("data size",data_.size())
|
||||||
|
output, hidden = model(data_, hidden)
|
||||||
|
# print("output size", output.size())
|
||||||
|
ouput_idx=output.view(-1).argmax().cpu()
|
||||||
|
# print('ouput_idx',ouput_idx)
|
||||||
|
# print('ouput_idx', ouput_idx.item())
|
||||||
|
ouput_idx=ouput_idx.item()
|
||||||
|
start_idx=[ouput_idx]
|
||||||
|
end_word=ix2word[ouput_idx]
|
||||||
|
ret+=end_word
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
|
|
||||||
# 获取数据
|
|
||||||
data, word2ix, ix2word = get_data()
|
|
||||||
data = torch.from_numpy(data)
|
|
||||||
dataloader = torch.utils.data.DataLoader(data,
|
|
||||||
batch_size=batch_size,
|
|
||||||
shuffle=True,
|
|
||||||
num_workers=1)
|
|
||||||
|
|
||||||
# 模型定义
|
# 模型定义
|
||||||
model = RNNModel(len(word2ix), batch_size, embed_size)
|
model = RNNModel(len(word2ix), embed_size, hidden_dims)
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
model.train()
|
||||||
for epoch in range(epochs):
|
for epoch in (range(epochs)):
|
||||||
for ii, data_ in tqdm(enumerate(dataloader)):
|
total_loss=0
|
||||||
data_ = data_.long().transpose(1, 0).contiguous()
|
count=0
|
||||||
data_ = data_.to(device)
|
for ii, data_ in tqdm.tqdm(enumerate(data)):
|
||||||
|
data_=torch.tensor(data_).long()
|
||||||
|
x = data_.unsqueeze(1).to(device)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
input_, target = data_[:-1, :], data_[1:, :]
|
y = torch.zeros(x.shape).to(device).long()
|
||||||
output, _ = model(input_)
|
y[:-1], y[-1] = x[1:], x[0]
|
||||||
|
output, _ = model(x)
|
||||||
|
loss = criterion(output, y.view(-1))
|
||||||
|
"""
|
||||||
|
hidden=None
|
||||||
|
for k in range(2,max_lenth):
|
||||||
|
data1=data_[:k]
|
||||||
|
input_, target = data1[:-1, :], data1[1:, :]
|
||||||
|
output, hidden = model(input_,hidden)
|
||||||
loss = criterion(output, target.view(-1))
|
loss = criterion(output, target.view(-1))
|
||||||
|
optimizer.step()
|
||||||
|
"""
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
total_loss+=(loss.item())
|
||||||
|
count+=1
|
||||||
|
print(epoch,'loss=',total_loss/count)
|
||||||
torch.save(model.state_dict(), 'model.bin' )
|
torch.save(model.state_dict(), 'model.bin' )
|
||||||
if __name__ == "__main__":
|
chars=test(model)
|
||||||
|
print(chars)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 获取数据
|
||||||
|
data, word2ix, ix2word = get_data()
|
||||||
|
with open("word2ix.pkl", 'wb') as outfile:
|
||||||
|
pickle.dump(word2ix,outfile)
|
||||||
|
with open("ix2word.pkl", 'wb') as outfile:
|
||||||
|
pickle.dump(ix2word,outfile)
|
||||||
|
|
||||||
|
data=data[:100]
|
||||||
train()
|
train()
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
import pickle
|
||||||
|
|
||||||
|
from flask import Flask,request
|
||||||
|
app = Flask(__name__)
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from model import RNNModel,embed_size,hidden_dims,generate_poetry
|
||||||
|
|
||||||
|
with open("word2ix.pkl", 'rb') as outfile:
|
||||||
|
word2ix=pickle.load(outfile)
|
||||||
|
with open("ix2word.pkl", 'rb') as outfile:
|
||||||
|
ix2word=pickle.load(outfile)
|
||||||
|
|
||||||
|
device=torch.device('cpu')
|
||||||
|
model = RNNModel(len(word2ix), embed_size, hidden_dims)
|
||||||
|
init_checkpoint = 'model.bin'
|
||||||
|
model.load_state_dict(torch.load(init_checkpoint, map_location='cpu'))
|
||||||
|
|
||||||
|
@app.route('/')
|
||||||
|
def hello():
|
||||||
|
return 'Hello World!'
|
||||||
|
|
||||||
|
@app.route('/peom')
|
||||||
|
def predict():
|
||||||
|
begin_word = request.args.get('text', '')
|
||||||
|
ret=generate_poetry(model,word2ix,ix2word,device,begin_word)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run()
|
||||||
Reference in New Issue
Block a user