59 lines
1.8 KiB
Python
59 lines
1.8 KiB
Python
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
import torch.nn.functional as F
|
||
|
|
|
||
|
|
batch_size=128
|
||
|
|
embed_size=128
|
||
|
|
hidden_dims=256
|
||
|
|
|
||
|
|
def generate_poetry(model,word2ix,ix2word,device,begin,sent_len=4):
|
||
|
|
start_idx=[word2ix['[']]
|
||
|
|
end_word=''
|
||
|
|
lens=0
|
||
|
|
hidden = None
|
||
|
|
ret=''
|
||
|
|
data_ = torch.tensor([start_idx], device=device).long()
|
||
|
|
output, hidden = model(data_, hidden)
|
||
|
|
start_idx=[word2ix[begin]]
|
||
|
|
ret+=begin
|
||
|
|
while end_word!=']' and len(ret)<100:
|
||
|
|
data_ = torch.tensor([start_idx],device=device).long()
|
||
|
|
# print("data size",data_.size())
|
||
|
|
output, hidden = model(data_, hidden)
|
||
|
|
# print("output size", output.size())
|
||
|
|
ouput_idx=output.view(-1).argmax().cpu()
|
||
|
|
# print('ouput_idx',ouput_idx)
|
||
|
|
# print('ouput_idx', ouput_idx.item())
|
||
|
|
ouput_idx=ouput_idx.item()
|
||
|
|
start_idx=[ouput_idx]
|
||
|
|
end_word=ix2word[ouput_idx]
|
||
|
|
ret+=end_word
|
||
|
|
return ret
|
||
|
|
|
||
|
|
class RNNModel(nn.Module):
|
||
|
|
def __init__(self, vocab_size, embedding_dim, hidden_dim):
|
||
|
|
super(RNNModel, self).__init__()
|
||
|
|
self.hidden_dim = hidden_dim
|
||
|
|
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
|
||
|
|
self.lstm = nn.LSTM(embedding_dim, self.hidden_dim, num_layers=2)
|
||
|
|
self.linear1 = nn.Linear(self.hidden_dim, vocab_size)
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
def forward(self, x, hidden=None):
|
||
|
|
seq_len, batch_size = x.size()
|
||
|
|
|
||
|
|
|
||
|
|
# size: (seq_len,batch_size,embeding_dim)
|
||
|
|
embeds = self.embeddings(x)
|
||
|
|
# output size: (seq_len,batch_size,hidden_dim)
|
||
|
|
if hidden is None:
|
||
|
|
output, hidden = self.lstm(embeds)
|
||
|
|
else:
|
||
|
|
h_0, c_0 = hidden
|
||
|
|
output, hidden = self.lstm(embeds, (h_0, c_0))
|
||
|
|
|
||
|
|
# size: (seq_len*batch_size,vocab_size)
|
||
|
|
output = self.linear1(output.view(seq_len * batch_size, -1))
|
||
|
|
return output, hidden
|