#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ CS224N 2018-19: Homework 4 nmt.py: NMT Model Pencheng Yin Sahil Chopra """ import math from typing import List import numpy as np import torch import torch.nn as nn import torch.nn.functional as F def pad_sents(sents, pad_token): """ Pad list of sentences according to the longest sentence in the batch. @param sents (list[list[str]]): list of sentences, where each sentence is represented as a list of words @param pad_token (str): padding token @returns sents_padded (list[list[str]]): list of sentences where sentences shorter than the max length sentence are padded out with the pad_token, such that each sentences in the batch now has equal length. """ sents_padded = [] ### YOUR CODE HERE (~6 Lines) lengths = [ len(sent) for sent in sents] max_length = max(lengths) for i in range(len(lengths)): num_append = max_length - lengths[i] for j in range(num_append): sents[i].append(pad_token) sents_padded = sents ### END YOUR CODE return sents_padded def read_corpus(file_path, source): """ Read file, where each sentence is dilineated by a `\n`. @param file_path (str): path to file containing corpus @param source (str): "tgt" or "src" indicating whether text is of the source language or target language """ data = [] for line in open(file_path): sent = line.strip().split(' ') # only append and to the target sentence if source == 'tgt': sent = [''] + sent + [''] data.append(sent) return data def batch_iter(data, batch_size, shuffle=False): """ Yield batches of source and target sentences reverse sorted by length (largest to smallest). @param data (list of (src_sent, tgt_sent)): list of tuples containing source and target sentence @param batch_size (int): batch size @param shuffle (boolean): whether to randomly shuffle the dataset """ batch_num = math.ceil(len(data) / batch_size) index_array = list(range(len(data))) if shuffle: np.random.shuffle(index_array) for i in range(batch_num): indices = index_array[i * batch_size: (i + 1) * batch_size] examples = [data[idx] for idx in indices] examples = sorted(examples, key=lambda e: len(e[0]), reverse=True) src_sents = [e[0] for e in examples] tgt_sents = [e[1] for e in examples] yield src_sents, tgt_sents