modify test
This commit is contained in:
@@ -6,6 +6,23 @@ CS224N 2018-19: Homework 5
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
### YOUR CODE HERE for part 1i
|
### YOUR CODE HERE for part 1i
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
class CNN(nn.Module):
|
||||||
|
def __init__(self, in_ch, out_ch,k=5):
|
||||||
|
"""
|
||||||
|
Apply the output of the convolution later (x_conv) through a highway network
|
||||||
|
@param D_in (int): Size of input layer
|
||||||
|
@param H (int): Size of Hidden layer
|
||||||
|
@param D_out (int): Size of output layer
|
||||||
|
@param prob (float): Probability of dropout
|
||||||
|
"""
|
||||||
|
super(CNN, self).__init__()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
pass
|
||||||
### END YOUR CODE
|
### END YOUR CODE
|
||||||
|
|
||||||
|
|||||||
@@ -6,8 +6,28 @@ CS224N 2018-19: Homework 5
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
### YOUR CODE HERE for part 1h
|
### YOUR CODE HERE for part 1h
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
class Highway(torch.nn.Module):
|
||||||
|
def __init__(self, D_in, H, D_out,prob):
|
||||||
|
"""
|
||||||
|
Apply the output of the convolution later (x_conv) through a highway network
|
||||||
|
@param D_in (int): Size of input layer
|
||||||
|
@param H (int): Size of Hidden layer
|
||||||
|
@param D_out (int): Size of output layer
|
||||||
|
@param prob (float): Probability of dropout
|
||||||
|
"""
|
||||||
|
super(Highway, self).__init__()
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Apply the output of the convolution later (x_conv) through a highway network
|
||||||
|
@param x (Tensor): Input x_cov gets applied to Highway network - shape of input tensor [batch_size,1,e_word]
|
||||||
|
@returns x_pred (Tensor): Size of Hidden layer -- NOTE: check the shapes
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
### END YOUR CODE
|
### END YOUR CODE
|
||||||
|
|
||||||
|
|||||||
@@ -294,7 +294,7 @@ def main():
|
|||||||
|
|
||||||
# Check Python & PyTorch Versions
|
# Check Python & PyTorch Versions
|
||||||
assert (sys.version_info >= (3, 5)), "Please update your installation of Python to version >= 3.5"
|
assert (sys.version_info >= (3, 5)), "Please update your installation of Python to version >= 3.5"
|
||||||
assert(torch.__version__ == "1.0.0"), "Please update your installation of PyTorch. You have {} and you should have version 1.0.0".format(torch.__version__)
|
assert(torch.__version__ >= "1.0.0"), "Please update your installation of PyTorch. You have {} and you should have version 1.0.0".format(torch.__version__)
|
||||||
|
|
||||||
# Seed the Random Number Generators
|
# Seed the Random Number Generators
|
||||||
seed = 1234
|
seed = 1234
|
||||||
|
|||||||
Reference in New Issue
Block a user