64 lines
2.4 KiB
Python
64 lines
2.4 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
CS224N 2018-19: Homework 3
|
|
general_utils.py: General purpose utilities.
|
|
Sahil Chopra <schopra8@stanford.edu>
|
|
"""
|
|
|
|
import sys
|
|
import time
|
|
import numpy as np
|
|
|
|
|
|
def get_minibatches(data, minibatch_size, shuffle=True):
|
|
"""
|
|
Iterates through the provided data one minibatch at at time. You can use this function to
|
|
iterate through data in minibatches as follows:
|
|
|
|
for inputs_minibatch in get_minibatches(inputs, minibatch_size):
|
|
...
|
|
|
|
Or with multiple data sources:
|
|
|
|
for inputs_minibatch, labels_minibatch in get_minibatches([inputs, labels], minibatch_size):
|
|
...
|
|
|
|
Args:
|
|
data: there are two possible values:
|
|
- a list or numpy array
|
|
- a list where each element is either a list or numpy array
|
|
minibatch_size: the maximum number of items in a minibatch
|
|
shuffle: whether to randomize the order of returned data
|
|
Returns:
|
|
minibatches: the return value depends on data:
|
|
- If data is a list/array it yields the next minibatch of data.
|
|
- If data a list of lists/arrays it returns the next minibatch of each element in the
|
|
list. This can be used to iterate through multiple data sources
|
|
(e.g., features and labels) at the same time.
|
|
|
|
"""
|
|
list_data = type(data) is list and (type(data[0]) is list or type(data[0]) is np.ndarray)
|
|
data_size = len(data[0]) if list_data else len(data)
|
|
indices = np.arange(data_size)
|
|
if shuffle:
|
|
np.random.shuffle(indices)
|
|
for minibatch_start in np.arange(0, data_size, minibatch_size):
|
|
minibatch_indices = indices[minibatch_start:minibatch_start + minibatch_size]
|
|
yield [_minibatch(d, minibatch_indices) for d in data] if list_data \
|
|
else _minibatch(data, minibatch_indices)
|
|
|
|
|
|
def _minibatch(data, minibatch_idx):
|
|
return data[minibatch_idx] if type(data) is np.ndarray else [data[i] for i in minibatch_idx]
|
|
|
|
|
|
def test_all_close(name, actual, expected):
|
|
if actual.shape != expected.shape:
|
|
raise ValueError("{:} failed, expected output to have shape {:} but has shape {:}"
|
|
.format(name, expected.shape, actual.shape))
|
|
if np.amax(np.fabs(actual - expected)) > 1e-6:
|
|
raise ValueError("{:} failed, expected {:} but value is {:}".format(name, expected, actual))
|
|
else:
|
|
print(name, "passed!")
|