add rnn nlg
This commit is contained in:
58
rnn-poetry/model.py
Normal file
58
rnn-poetry/model.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
batch_size=128
|
||||
embed_size=128
|
||||
hidden_dims=256
|
||||
|
||||
def generate_poetry(model,word2ix,ix2word,device,begin,sent_len=4):
|
||||
start_idx=[word2ix['[']]
|
||||
end_word=''
|
||||
lens=0
|
||||
hidden = None
|
||||
ret=''
|
||||
data_ = torch.tensor([start_idx], device=device).long()
|
||||
output, hidden = model(data_, hidden)
|
||||
start_idx=[word2ix[begin]]
|
||||
ret+=begin
|
||||
while end_word!=']' and len(ret)<100:
|
||||
data_ = torch.tensor([start_idx],device=device).long()
|
||||
# print("data size",data_.size())
|
||||
output, hidden = model(data_, hidden)
|
||||
# print("output size", output.size())
|
||||
ouput_idx=output.view(-1).argmax().cpu()
|
||||
# print('ouput_idx',ouput_idx)
|
||||
# print('ouput_idx', ouput_idx.item())
|
||||
ouput_idx=ouput_idx.item()
|
||||
start_idx=[ouput_idx]
|
||||
end_word=ix2word[ouput_idx]
|
||||
ret+=end_word
|
||||
return ret
|
||||
|
||||
class RNNModel(nn.Module):
|
||||
def __init__(self, vocab_size, embedding_dim, hidden_dim):
|
||||
super(RNNModel, self).__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
|
||||
self.lstm = nn.LSTM(embedding_dim, self.hidden_dim, num_layers=2)
|
||||
self.linear1 = nn.Linear(self.hidden_dim, vocab_size)
|
||||
|
||||
|
||||
|
||||
def forward(self, x, hidden=None):
|
||||
seq_len, batch_size = x.size()
|
||||
|
||||
|
||||
# size: (seq_len,batch_size,embeding_dim)
|
||||
embeds = self.embeddings(x)
|
||||
# output size: (seq_len,batch_size,hidden_dim)
|
||||
if hidden is None:
|
||||
output, hidden = self.lstm(embeds)
|
||||
else:
|
||||
h_0, c_0 = hidden
|
||||
output, hidden = self.lstm(embeds, (h_0, c_0))
|
||||
|
||||
# size: (seq_len*batch_size,vocab_size)
|
||||
output = self.linear1(output.view(seq_len * batch_size, -1))
|
||||
return output, hidden
|
||||
Reference in New Issue
Block a user