Files
cs224n_2019/[finished]Assignment_3_neural_dependency_parsing/utils/general_utils.py
chongjiu.jin 86eeba3b38 add a3
2019-11-25 10:52:19 +08:00

64 lines
2.4 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CS224N 2018-19: Homework 3
general_utils.py: General purpose utilities.
Sahil Chopra <schopra8@stanford.edu>
"""
import sys
import time
import numpy as np
def get_minibatches(data, minibatch_size, shuffle=True):
"""
Iterates through the provided data one minibatch at at time. You can use this function to
iterate through data in minibatches as follows:
for inputs_minibatch in get_minibatches(inputs, minibatch_size):
...
Or with multiple data sources:
for inputs_minibatch, labels_minibatch in get_minibatches([inputs, labels], minibatch_size):
...
Args:
data: there are two possible values:
- a list or numpy array
- a list where each element is either a list or numpy array
minibatch_size: the maximum number of items in a minibatch
shuffle: whether to randomize the order of returned data
Returns:
minibatches: the return value depends on data:
- If data is a list/array it yields the next minibatch of data.
- If data a list of lists/arrays it returns the next minibatch of each element in the
list. This can be used to iterate through multiple data sources
(e.g., features and labels) at the same time.
"""
list_data = type(data) is list and (type(data[0]) is list or type(data[0]) is np.ndarray)
data_size = len(data[0]) if list_data else len(data)
indices = np.arange(data_size)
if shuffle:
np.random.shuffle(indices)
for minibatch_start in np.arange(0, data_size, minibatch_size):
minibatch_indices = indices[minibatch_start:minibatch_start + minibatch_size]
yield [_minibatch(d, minibatch_indices) for d in data] if list_data \
else _minibatch(data, minibatch_indices)
def _minibatch(data, minibatch_idx):
return data[minibatch_idx] if type(data) is np.ndarray else [data[i] for i in minibatch_idx]
def test_all_close(name, actual, expected):
if actual.shape != expected.shape:
raise ValueError("{:} failed, expected output to have shape {:} but has shape {:}"
.format(name, expected.shape, actual.shape))
if np.amax(np.fabs(actual - expected)) > 1e-6:
raise ValueError("{:} failed, expected {:} but value is {:}".format(name, expected, actual))
else:
print(name, "passed!")