modify a5 origin
This commit is contained in:
@@ -27,11 +27,7 @@ class CharDecoder(nn.Module):
|
||||
### Hint: - Use target_vocab.char2id to access the character vocabulary for the target language.
|
||||
### - Set the padding_idx argument of the embedding matrix.
|
||||
### - Create a new Embedding layer. Do not reuse embeddings created in Part 1 of this assignment.
|
||||
super(CharDecoder, self).__init__()
|
||||
self.charDecoder = nn.LSTM(char_embedding_size,hidden_size,batch_first=True) #bias = True
|
||||
self.char_output_projection = nn.Linear(hidden_size,len(target_vocab.char2id))
|
||||
self.decoderCharEmb = nn.Embedding(len(target_vocab.char2id),char_embedding_size,padding_idx=target_vocab.char2id['<pad>'])
|
||||
self.target_vocab = target_vocab
|
||||
|
||||
|
||||
### END YOUR CODE
|
||||
|
||||
@@ -48,20 +44,7 @@ class CharDecoder(nn.Module):
|
||||
"""
|
||||
### YOUR CODE HERE for part 2b
|
||||
### TODO - Implement the forward pass of the character decoder.
|
||||
#print('size of input is',input.size())
|
||||
input = input.permute(1,0).contiguous()
|
||||
ip_embedding=self.decoderCharEmb(input)# F.embedding(source_padded, self.model_embeddings.source.weight)
|
||||
#X = nn.utils.rnn.pack_padded_sequence(src_padded_embedding,source_lengths)
|
||||
|
||||
#ip_embedding = ip_embedding.permute(1,0,2).contiguous()
|
||||
|
||||
output,(h_n,c_n) = self.charDecoder(ip_embedding,dec_hidden)
|
||||
#print('shape of hidden is',h_n.size())
|
||||
s_t = self.char_output_projection(output)
|
||||
#print('shape of logits is',s_t.size())
|
||||
s_t = s_t.permute(1,0,2).contiguous()
|
||||
|
||||
return s_t,(h_n,c_n)
|
||||
|
||||
### END YOUR CODE
|
||||
|
||||
|
||||
@@ -79,22 +62,6 @@ class CharDecoder(nn.Module):
|
||||
### Hint: - Make sure padding characters do not contribute to the cross-entropy loss.
|
||||
### - char_sequence corresponds to the sequence x_1 ... x_{n+1} from the handout (e.g., <START>,m,u,s,i,c,<END>).
|
||||
|
||||
input = char_sequence[:-1,:]
|
||||
output = char_sequence[1:,:]
|
||||
#print(input)
|
||||
#print(output)
|
||||
target = output.reshape(-1)
|
||||
#print('shape of target',target.shape)
|
||||
s_t,(h_n,c_n) = self.forward(input,dec_hidden)
|
||||
#print('shape of s_t',s_t.shape)
|
||||
s_t_shape = s_t.shape
|
||||
s_t_re = s_t.reshape(-1,s_t.shape[2])
|
||||
|
||||
|
||||
#print('shape of s_t_re',s_t_re.shape)
|
||||
loss = nn.CrossEntropyLoss(ignore_index=self.target_vocab.char2id['<pad>'],reduction='sum')
|
||||
|
||||
return loss(s_t_re,target)
|
||||
### END YOUR CODE
|
||||
|
||||
def decode_greedy(self, initialStates, device, max_length=21):
|
||||
@@ -114,43 +81,6 @@ class CharDecoder(nn.Module):
|
||||
### - Use torch.tensor(..., device=device) to turn a list of character indices into a tensor.
|
||||
### - We use curly brackets as start-of-word and end-of-word characters. That is, use the character '{' for <START> and '}' for <END>.
|
||||
### Their indices are self.target_vocab.start_of_word and self.target_vocab.end_of_word, respectively.
|
||||
decodedWords = []
|
||||
current_char = self.target_vocab.start_of_word
|
||||
start_tensor = torch.tensor([current_char],device=device)
|
||||
#print('size of start_tensor is',start_tensor.shape)
|
||||
batch_size = initialStates[0].shape[1]
|
||||
start_batch = start_tensor.repeat(batch_size,1)
|
||||
#print('size of start_batch is',start_batch.shape)
|
||||
embed_current_char = self.decoderCharEmb(start_batch)
|
||||
#print('size of embed_current_char is',embed_current_char.shape)
|
||||
h_n,c_n = initialStates
|
||||
output_word = torch.zeros((batch_size,1),dtype=torch.long,device=device)
|
||||
for t in range(0,max_length):
|
||||
#h_n,c_n = self.charDecoder(embed_current_char,(h_n,c_n))
|
||||
# s_t,(h_n,c_n) = self.forward(embed_current_char,(h_n,c_n))
|
||||
#print('shape of embed_current_char is',embed_current_char.shape)
|
||||
output,(h_n,c_n) = self.charDecoder(embed_current_char,(h_n,c_n))
|
||||
s_t = self.char_output_projection(output)
|
||||
#print(s_t.shape)
|
||||
st_smax = nn.Softmax(dim=2)(s_t)
|
||||
p_next = st_smax.argmax(2)
|
||||
current_char = p_next
|
||||
embed_current_char = self.decoderCharEmb(current_char)
|
||||
#decodedWords.append(self.target_vocab.id2char[current_char])
|
||||
#print('*** size of current_char is',current_char.size())
|
||||
output_word = torch.cat((output_word,current_char),1)
|
||||
#Convert output_word tensor to list and each element to char and put together in decodedWords
|
||||
out_list = output_word.tolist()
|
||||
out_list = [[self.target_vocab.id2char[x] for x in ilist[1:]] for ilist in out_list]
|
||||
decodedWords = []
|
||||
for string_list in out_list:
|
||||
stringer = ''
|
||||
for char in string_list:
|
||||
if char!='}':
|
||||
stringer = stringer+char
|
||||
else:
|
||||
break
|
||||
decodedWords.append(stringer)
|
||||
return decodedWords
|
||||
|
||||
### END YOUR CODE
|
||||
|
||||
|
||||
90
rnn-poetry/run.py
Normal file
90
rnn-poetry/run.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import re
|
||||
import tqdm
|
||||
import torch
|
||||
import collections
|
||||
import numpy
|
||||
from torch import nn
|
||||
|
||||
from model import RNNModel
|
||||
|
||||
device=
|
||||
|
||||
batch_size=128
|
||||
embed_size=256
|
||||
|
||||
epochs=100
|
||||
lr=0.001
|
||||
def get_data():
|
||||
poetry_file = 'data/poetry.txt'
|
||||
special_character_removal = re.compile(r'[^\w。, ]', re.IGNORECASE)
|
||||
# 诗集
|
||||
poetrys = []
|
||||
with open(poetry_file, "r", encoding='utf-8', ) as f:
|
||||
for line in f:
|
||||
try:
|
||||
title, content = line.strip().split(':')
|
||||
content = special_character_removal.sub('', content)
|
||||
content = content.replace(' ', '')
|
||||
if len(content) < 5:
|
||||
continue
|
||||
if (len(content) > 12 * 6):
|
||||
content_list = content.split("。")
|
||||
for i in range(0, len(content_list), 2):
|
||||
content_temp = '[' + content_list[i] + "。" + content_list[i + 1] + '。]'
|
||||
content_temp = content_temp.replace("。。", "。")
|
||||
poetrys.append(content_temp)
|
||||
else:
|
||||
content = '[' + content + ']'
|
||||
poetrys.append(content)
|
||||
except Exception as e:
|
||||
pass
|
||||
poetrys = sorted(poetrys, key=lambda line: len(line))
|
||||
# 统计每个字出现次数
|
||||
all_words = []
|
||||
for poetry in poetrys:
|
||||
all_words += [word for word in poetry]
|
||||
counter = collections.Counter(all_words)
|
||||
count_pairs = sorted(counter.items(), key=lambda x: -x[1])
|
||||
words, _ = zip(*count_pairs)
|
||||
# 取前多少个常用字
|
||||
words = words[:len(words)] + (' ',)
|
||||
# 每个字映射为一个数字ID
|
||||
word2ix = dict(zip(words, range(len(words))))
|
||||
ix2word = lambda word: word2ix.get(word, len(words))
|
||||
data = [list(map(ix2word, poetry)) for poetry in poetrys]
|
||||
data=numpy.array(data)
|
||||
return data,word2ix,ix2word
|
||||
|
||||
def train():
|
||||
|
||||
# 获取数据
|
||||
data, word2ix, ix2word = get_data()
|
||||
data = torch.from_numpy(data)
|
||||
dataloader = torch.utils.data.DataLoader(data,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=1)
|
||||
|
||||
# 模型定义
|
||||
model = RNNModel(len(word2ix), batch_size, embed_size)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
model.to(device)
|
||||
|
||||
for epoch in range(epochs):
|
||||
for ii, data_ in tqdm(enumerate(dataloader)):
|
||||
data_ = data_.long().transpose(1, 0).contiguous()
|
||||
data_ = data_.to(device)
|
||||
optimizer.zero_grad()
|
||||
input_, target = data_[:-1, :], data_[1:, :]
|
||||
output, _ = model(input_)
|
||||
loss = criterion(output, target.view(-1))
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
|
||||
|
||||
torch.save(model.state_dict(), 'model.bin' )
|
||||
if __name__ == "__main__":
|
||||
|
||||
train()
|
||||
0
rnn-poetry/run_server.py
Normal file
0
rnn-poetry/run_server.py
Normal file
Reference in New Issue
Block a user