From 704f59eaed9e9dac4c8bc13e13060e1c89e180e2 Mon Sep 17 00:00:00 2001 From: "chongjiu.jin" Date: Fri, 13 Dec 2019 14:39:11 +0800 Subject: [PATCH] add rnn nlg --- rnn-poetry/flask_demo.py | 11 ++++ rnn-poetry/run.py | 119 +++++++++++++++++++++++---------------- rnn-poetry/run_server.py | 31 ++++++++++ 3 files changed, 113 insertions(+), 48 deletions(-) create mode 100644 rnn-poetry/flask_demo.py diff --git a/rnn-poetry/flask_demo.py b/rnn-poetry/flask_demo.py new file mode 100644 index 0000000..9005441 --- /dev/null +++ b/rnn-poetry/flask_demo.py @@ -0,0 +1,11 @@ +from flask import Flask +app = Flask(__name__) + + +@app.route('/') +def hello(): + return 'Hello World!' + + +if __name__ == '__main__': + app.run() \ No newline at end of file diff --git a/rnn-poetry/run.py b/rnn-poetry/run.py index a4aaf5b..4dc7319 100644 --- a/rnn-poetry/run.py +++ b/rnn-poetry/run.py @@ -2,43 +2,28 @@ import re import tqdm import torch import collections -import numpy +import pickle from torch import nn -from model import RNNModel +from model import RNNModel,embed_size,hidden_dims,batch_size -device= +device=torch.device('cuda' if torch.cuda.is_available() else 'cpu') -batch_size=128 -embed_size=256 epochs=100 lr=0.001 def get_data(): - poetry_file = 'data/poetry.txt' special_character_removal = re.compile(r'[^\w。, ]', re.IGNORECASE) # 诗集 poetrys = [] - with open(poetry_file, "r", encoding='utf-8', ) as f: - for line in f: - try: - title, content = line.strip().split(':') - content = special_character_removal.sub('', content) - content = content.replace(' ', '') - if len(content) < 5: - continue - if (len(content) > 12 * 6): - content_list = content.split("。") - for i in range(0, len(content_list), 2): - content_temp = '[' + content_list[i] + "。" + content_list[i + 1] + '。]' - content_temp = content_temp.replace("。。", "。") - poetrys.append(content_temp) - else: - content = '[' + content + ']' - poetrys.append(content) - except Exception as e: - pass - poetrys = sorted(poetrys, key=lambda line: len(line)) + peotry_path='data/poetry.txt' + with open(peotry_path,'r',encoding='utf-8') as f: + for content in f: + content=content.strip() + content = '[' + content + ']' + poetrys.append(content) + + # poetrys = sorted(poetrys, key=lambda line: len(line)) # 统计每个字出现次数 all_words = [] for poetry in poetrys: @@ -50,41 +35,79 @@ def get_data(): words = words[:len(words)] + (' ',) # 每个字映射为一个数字ID word2ix = dict(zip(words, range(len(words)))) - ix2word = lambda word: word2ix.get(word, len(words)) - data = [list(map(ix2word, poetry)) for poetry in poetrys] - data=numpy.array(data) + ix2word = {v: k for k, v in word2ix.items()} + data = [[word2ix[c] for c in poetry ] for poetry in poetrys] + # data=numpy.array(data) return data,word2ix,ix2word +def test(model): + start_idx=[word2ix['[']] + end_word='' + lens=0 + hidden = None + ret='' + while end_word!=']' and len(ret)<100: + data_ = torch.tensor([start_idx],device=device).long() + # print("data size",data_.size()) + output, hidden = model(data_, hidden) + # print("output size", output.size()) + ouput_idx=output.view(-1).argmax().cpu() + # print('ouput_idx',ouput_idx) + # print('ouput_idx', ouput_idx.item()) + ouput_idx=ouput_idx.item() + start_idx=[ouput_idx] + end_word=ix2word[ouput_idx] + ret+=end_word + return ret + + + + def train(): - # 获取数据 - data, word2ix, ix2word = get_data() - data = torch.from_numpy(data) - dataloader = torch.utils.data.DataLoader(data, - batch_size=batch_size, - shuffle=True, - num_workers=1) - # 模型定义 - model = RNNModel(len(word2ix), batch_size, embed_size) + model = RNNModel(len(word2ix), embed_size, hidden_dims) optimizer = torch.optim.Adam(model.parameters(), lr=lr) criterion = nn.CrossEntropyLoss() model.to(device) - - for epoch in range(epochs): - for ii, data_ in tqdm(enumerate(dataloader)): - data_ = data_.long().transpose(1, 0).contiguous() - data_ = data_.to(device) + model.train() + for epoch in (range(epochs)): + total_loss=0 + count=0 + for ii, data_ in tqdm.tqdm(enumerate(data)): + data_=torch.tensor(data_).long() + x = data_.unsqueeze(1).to(device) optimizer.zero_grad() - input_, target = data_[:-1, :], data_[1:, :] - output, _ = model(input_) - loss = criterion(output, target.view(-1)) + y = torch.zeros(x.shape).to(device).long() + y[:-1], y[-1] = x[1:], x[0] + output, _ = model(x) + loss = criterion(output, y.view(-1)) + """ + hidden=None + for k in range(2,max_lenth): + data1=data_[:k] + input_, target = data1[:-1, :], data1[1:, :] + output, hidden = model(input_,hidden) + loss = criterion(output, target.view(-1)) + optimizer.step() + """ loss.backward() optimizer.step() + total_loss+=(loss.item()) + count+=1 + print(epoch,'loss=',total_loss/count) + torch.save(model.state_dict(), 'model.bin' ) + chars=test(model) + print(chars) - - torch.save(model.state_dict(), 'model.bin' ) if __name__ == "__main__": + # 获取数据 + data, word2ix, ix2word = get_data() + with open("word2ix.pkl", 'wb') as outfile: + pickle.dump(word2ix,outfile) + with open("ix2word.pkl", 'wb') as outfile: + pickle.dump(ix2word,outfile) + data=data[:100] train() \ No newline at end of file diff --git a/rnn-poetry/run_server.py b/rnn-poetry/run_server.py index e69de29..70be2dd 100644 --- a/rnn-poetry/run_server.py +++ b/rnn-poetry/run_server.py @@ -0,0 +1,31 @@ +import pickle + +from flask import Flask,request +app = Flask(__name__) +import torch + +from model import RNNModel,embed_size,hidden_dims,generate_poetry + +with open("word2ix.pkl", 'rb') as outfile: + word2ix=pickle.load(outfile) +with open("ix2word.pkl", 'rb') as outfile: + ix2word=pickle.load(outfile) + +device=torch.device('cpu') +model = RNNModel(len(word2ix), embed_size, hidden_dims) +init_checkpoint = 'model.bin' +model.load_state_dict(torch.load(init_checkpoint, map_location='cpu')) + +@app.route('/') +def hello(): + return 'Hello World!' + +@app.route('/peom') +def predict(): + begin_word = request.args.get('text', '') + ret=generate_poetry(model,word2ix,ix2word,device,begin_word) + + return ret + +if __name__ == '__main__': + app.run() \ No newline at end of file