add a3
This commit is contained in:
@@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
CS224N 2018-19: Homework 3
|
||||
general_utils.py: General purpose utilities.
|
||||
Sahil Chopra <schopra8@stanford.edu>
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_minibatches(data, minibatch_size, shuffle=True):
|
||||
"""
|
||||
Iterates through the provided data one minibatch at at time. You can use this function to
|
||||
iterate through data in minibatches as follows:
|
||||
|
||||
for inputs_minibatch in get_minibatches(inputs, minibatch_size):
|
||||
...
|
||||
|
||||
Or with multiple data sources:
|
||||
|
||||
for inputs_minibatch, labels_minibatch in get_minibatches([inputs, labels], minibatch_size):
|
||||
...
|
||||
|
||||
Args:
|
||||
data: there are two possible values:
|
||||
- a list or numpy array
|
||||
- a list where each element is either a list or numpy array
|
||||
minibatch_size: the maximum number of items in a minibatch
|
||||
shuffle: whether to randomize the order of returned data
|
||||
Returns:
|
||||
minibatches: the return value depends on data:
|
||||
- If data is a list/array it yields the next minibatch of data.
|
||||
- If data a list of lists/arrays it returns the next minibatch of each element in the
|
||||
list. This can be used to iterate through multiple data sources
|
||||
(e.g., features and labels) at the same time.
|
||||
|
||||
"""
|
||||
list_data = type(data) is list and (type(data[0]) is list or type(data[0]) is np.ndarray)
|
||||
data_size = len(data[0]) if list_data else len(data)
|
||||
indices = np.arange(data_size)
|
||||
if shuffle:
|
||||
np.random.shuffle(indices)
|
||||
for minibatch_start in np.arange(0, data_size, minibatch_size):
|
||||
minibatch_indices = indices[minibatch_start:minibatch_start + minibatch_size]
|
||||
yield [_minibatch(d, minibatch_indices) for d in data] if list_data \
|
||||
else _minibatch(data, minibatch_indices)
|
||||
|
||||
|
||||
def _minibatch(data, minibatch_idx):
|
||||
return data[minibatch_idx] if type(data) is np.ndarray else [data[i] for i in minibatch_idx]
|
||||
|
||||
|
||||
def test_all_close(name, actual, expected):
|
||||
if actual.shape != expected.shape:
|
||||
raise ValueError("{:} failed, expected output to have shape {:} but has shape {:}"
|
||||
.format(name, expected.shape, actual.shape))
|
||||
if np.amax(np.fabs(actual - expected)) > 1e-6:
|
||||
raise ValueError("{:} failed, expected {:} but value is {:}".format(name, expected, actual))
|
||||
else:
|
||||
print(name, "passed!")
|
||||
Reference in New Issue
Block a user