From 2b6bdc4718ff50d696bc0b13907f6bc7477281e3 Mon Sep 17 00:00:00 2001 From: "chongjiu.jin" Date: Tue, 3 Dec 2019 16:41:43 +0800 Subject: [PATCH] modify a5 origin --- .../Assignment 5/char_decoder.py | 76 +--------------- rnn-poetry/run.py | 90 +++++++++++++++++++ rnn-poetry/run_server.py | 0 3 files changed, 93 insertions(+), 73 deletions(-) create mode 100644 rnn-poetry/run.py create mode 100644 rnn-poetry/run_server.py diff --git a/Assignment_origin/Assignment 5/char_decoder.py b/Assignment_origin/Assignment 5/char_decoder.py index 482cb32..8eaf45c 100644 --- a/Assignment_origin/Assignment 5/char_decoder.py +++ b/Assignment_origin/Assignment 5/char_decoder.py @@ -27,11 +27,7 @@ class CharDecoder(nn.Module): ### Hint: - Use target_vocab.char2id to access the character vocabulary for the target language. ### - Set the padding_idx argument of the embedding matrix. ### - Create a new Embedding layer. Do not reuse embeddings created in Part 1 of this assignment. - super(CharDecoder, self).__init__() - self.charDecoder = nn.LSTM(char_embedding_size,hidden_size,batch_first=True) #bias = True - self.char_output_projection = nn.Linear(hidden_size,len(target_vocab.char2id)) - self.decoderCharEmb = nn.Embedding(len(target_vocab.char2id),char_embedding_size,padding_idx=target_vocab.char2id['']) - self.target_vocab = target_vocab + ### END YOUR CODE @@ -48,20 +44,7 @@ class CharDecoder(nn.Module): """ ### YOUR CODE HERE for part 2b ### TODO - Implement the forward pass of the character decoder. - #print('size of input is',input.size()) - input = input.permute(1,0).contiguous() - ip_embedding=self.decoderCharEmb(input)# F.embedding(source_padded, self.model_embeddings.source.weight) - #X = nn.utils.rnn.pack_padded_sequence(src_padded_embedding,source_lengths) - - #ip_embedding = ip_embedding.permute(1,0,2).contiguous() - - output,(h_n,c_n) = self.charDecoder(ip_embedding,dec_hidden) - #print('shape of hidden is',h_n.size()) - s_t = self.char_output_projection(output) - #print('shape of logits is',s_t.size()) - s_t = s_t.permute(1,0,2).contiguous() - - return s_t,(h_n,c_n) + ### END YOUR CODE @@ -79,22 +62,6 @@ class CharDecoder(nn.Module): ### Hint: - Make sure padding characters do not contribute to the cross-entropy loss. ### - char_sequence corresponds to the sequence x_1 ... x_{n+1} from the handout (e.g., ,m,u,s,i,c,). - input = char_sequence[:-1,:] - output = char_sequence[1:,:] - #print(input) - #print(output) - target = output.reshape(-1) - #print('shape of target',target.shape) - s_t,(h_n,c_n) = self.forward(input,dec_hidden) - #print('shape of s_t',s_t.shape) - s_t_shape = s_t.shape - s_t_re = s_t.reshape(-1,s_t.shape[2]) - - - #print('shape of s_t_re',s_t_re.shape) - loss = nn.CrossEntropyLoss(ignore_index=self.target_vocab.char2id[''],reduction='sum') - - return loss(s_t_re,target) ### END YOUR CODE def decode_greedy(self, initialStates, device, max_length=21): @@ -114,43 +81,6 @@ class CharDecoder(nn.Module): ### - Use torch.tensor(..., device=device) to turn a list of character indices into a tensor. ### - We use curly brackets as start-of-word and end-of-word characters. That is, use the character '{' for and '}' for . ### Their indices are self.target_vocab.start_of_word and self.target_vocab.end_of_word, respectively. - decodedWords = [] - current_char = self.target_vocab.start_of_word - start_tensor = torch.tensor([current_char],device=device) - #print('size of start_tensor is',start_tensor.shape) - batch_size = initialStates[0].shape[1] - start_batch = start_tensor.repeat(batch_size,1) - #print('size of start_batch is',start_batch.shape) - embed_current_char = self.decoderCharEmb(start_batch) - #print('size of embed_current_char is',embed_current_char.shape) - h_n,c_n = initialStates - output_word = torch.zeros((batch_size,1),dtype=torch.long,device=device) - for t in range(0,max_length): - #h_n,c_n = self.charDecoder(embed_current_char,(h_n,c_n)) - # s_t,(h_n,c_n) = self.forward(embed_current_char,(h_n,c_n)) - #print('shape of embed_current_char is',embed_current_char.shape) - output,(h_n,c_n) = self.charDecoder(embed_current_char,(h_n,c_n)) - s_t = self.char_output_projection(output) - #print(s_t.shape) - st_smax = nn.Softmax(dim=2)(s_t) - p_next = st_smax.argmax(2) - current_char = p_next - embed_current_char = self.decoderCharEmb(current_char) - #decodedWords.append(self.target_vocab.id2char[current_char]) - #print('*** size of current_char is',current_char.size()) - output_word = torch.cat((output_word,current_char),1) - #Convert output_word tensor to list and each element to char and put together in decodedWords - out_list = output_word.tolist() - out_list = [[self.target_vocab.id2char[x] for x in ilist[1:]] for ilist in out_list] - decodedWords = [] - for string_list in out_list: - stringer = '' - for char in string_list: - if char!='}': - stringer = stringer+char - else: - break - decodedWords.append(stringer) - return decodedWords + ### END YOUR CODE diff --git a/rnn-poetry/run.py b/rnn-poetry/run.py new file mode 100644 index 0000000..a4aaf5b --- /dev/null +++ b/rnn-poetry/run.py @@ -0,0 +1,90 @@ +import re +import tqdm +import torch +import collections +import numpy +from torch import nn + +from model import RNNModel + +device= + +batch_size=128 +embed_size=256 + +epochs=100 +lr=0.001 +def get_data(): + poetry_file = 'data/poetry.txt' + special_character_removal = re.compile(r'[^\w。, ]', re.IGNORECASE) + # 诗集 + poetrys = [] + with open(poetry_file, "r", encoding='utf-8', ) as f: + for line in f: + try: + title, content = line.strip().split(':') + content = special_character_removal.sub('', content) + content = content.replace(' ', '') + if len(content) < 5: + continue + if (len(content) > 12 * 6): + content_list = content.split("。") + for i in range(0, len(content_list), 2): + content_temp = '[' + content_list[i] + "。" + content_list[i + 1] + '。]' + content_temp = content_temp.replace("。。", "。") + poetrys.append(content_temp) + else: + content = '[' + content + ']' + poetrys.append(content) + except Exception as e: + pass + poetrys = sorted(poetrys, key=lambda line: len(line)) + # 统计每个字出现次数 + all_words = [] + for poetry in poetrys: + all_words += [word for word in poetry] + counter = collections.Counter(all_words) + count_pairs = sorted(counter.items(), key=lambda x: -x[1]) + words, _ = zip(*count_pairs) + # 取前多少个常用字 + words = words[:len(words)] + (' ',) + # 每个字映射为一个数字ID + word2ix = dict(zip(words, range(len(words)))) + ix2word = lambda word: word2ix.get(word, len(words)) + data = [list(map(ix2word, poetry)) for poetry in poetrys] + data=numpy.array(data) + return data,word2ix,ix2word + +def train(): + + # 获取数据 + data, word2ix, ix2word = get_data() + data = torch.from_numpy(data) + dataloader = torch.utils.data.DataLoader(data, + batch_size=batch_size, + shuffle=True, + num_workers=1) + + # 模型定义 + model = RNNModel(len(word2ix), batch_size, embed_size) + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + criterion = nn.CrossEntropyLoss() + model.to(device) + + for epoch in range(epochs): + for ii, data_ in tqdm(enumerate(dataloader)): + data_ = data_.long().transpose(1, 0).contiguous() + data_ = data_.to(device) + optimizer.zero_grad() + input_, target = data_[:-1, :], data_[1:, :] + output, _ = model(input_) + loss = criterion(output, target.view(-1)) + loss.backward() + optimizer.step() + + + + torch.save(model.state_dict(), 'model.bin' ) +if __name__ == "__main__": + + train() \ No newline at end of file diff --git a/rnn-poetry/run_server.py b/rnn-poetry/run_server.py new file mode 100644 index 0000000..e69de29