134 lines
3.5 KiB
Python
134 lines
3.5 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Save parameters every a few SGD iterations as fail-safe
|
|
SAVE_PARAMS_EVERY = 5000
|
|
|
|
import pickle
|
|
import glob
|
|
import random
|
|
import numpy as np
|
|
import os.path as op
|
|
|
|
def load_saved_params():
|
|
"""
|
|
A helper function that loads previously saved parameters and resets
|
|
iteration start.
|
|
"""
|
|
st = 0
|
|
for f in glob.glob("saved_params_*.npy"):
|
|
iter = int(op.splitext(op.basename(f))[0].split("_")[2])
|
|
if (iter > st):
|
|
st = iter
|
|
|
|
if st > 0:
|
|
params_file = "saved_params_%d.npy" % st
|
|
state_file = "saved_state_%d.pickle" % st
|
|
params = np.load(params_file)
|
|
with open(state_file, "rb") as f:
|
|
state = pickle.load(f)
|
|
return st, params, state
|
|
else:
|
|
return st, None, None
|
|
|
|
|
|
def save_params(iter, params):
|
|
params_file = "saved_params_%d.npy" % iter
|
|
np.save(params_file, params)
|
|
with open("saved_state_%d.pickle" % iter, "wb") as f:
|
|
pickle.dump(random.getstate(), f)
|
|
|
|
|
|
def sgd(f, x0, step, iterations, postprocessing=None, useSaved=False,
|
|
PRINT_EVERY=10):
|
|
""" Stochastic Gradient Descent
|
|
|
|
Implement the stochastic gradient descent method in this function.
|
|
|
|
Arguments:
|
|
f -- the function to optimize, it should take a single
|
|
argument and yield two outputs, a loss and the gradient
|
|
with respect to the arguments
|
|
x0 -- the initial point to start SGD from
|
|
step -- the step size for SGD
|
|
iterations -- total iterations to run SGD for
|
|
postprocessing -- postprocessing function for the parameters
|
|
if necessary. In the case of word2vec we will need to
|
|
normalize the word vectors to have unit length.
|
|
PRINT_EVERY -- specifies how many iterations to output loss
|
|
|
|
Return:
|
|
x -- the parameter value after SGD finishes
|
|
"""
|
|
|
|
# Anneal learning rate every several iterations
|
|
ANNEAL_EVERY = 20000
|
|
|
|
if useSaved:
|
|
start_iter, oldx, state = load_saved_params()
|
|
if start_iter > 0:
|
|
x0 = oldx
|
|
step *= 0.5 ** (start_iter / ANNEAL_EVERY)
|
|
|
|
if state:
|
|
random.setstate(state)
|
|
else:
|
|
start_iter = 0
|
|
|
|
x = x0
|
|
|
|
if not postprocessing:
|
|
postprocessing = lambda x: x
|
|
|
|
exploss = None
|
|
|
|
for iter in range(start_iter + 1, iterations + 1):
|
|
# You might want to print the progress every few iterations.
|
|
|
|
loss = None
|
|
### YOUR CODE HERE
|
|
loss,gd = f(x)
|
|
x = x - step*gd
|
|
x = postprocessing(x)
|
|
### END YOUR CODE
|
|
|
|
x = postprocessing(x)
|
|
if iter % PRINT_EVERY == 0:
|
|
if not exploss:
|
|
exploss = loss
|
|
else:
|
|
exploss = .95 * exploss + .05 * loss
|
|
print("iter %d: %f" % (iter, exploss))
|
|
|
|
if iter % SAVE_PARAMS_EVERY == 0 and useSaved:
|
|
save_params(iter, x)
|
|
|
|
if iter % ANNEAL_EVERY == 0:
|
|
step *= 0.5
|
|
|
|
return x
|
|
|
|
|
|
def sanity_check():
|
|
quad = lambda x: (np.sum(x ** 2), x * 2)
|
|
|
|
print("Running sanity checks...")
|
|
t1 = sgd(quad, 0.5, 0.01, 1000, PRINT_EVERY=100)
|
|
print("test 1 result:", t1)
|
|
assert abs(t1) <= 1e-6
|
|
|
|
t2 = sgd(quad, 0.0, 0.01, 1000, PRINT_EVERY=100)
|
|
print("test 2 result:", t2)
|
|
assert abs(t2) <= 1e-6
|
|
|
|
t3 = sgd(quad, -1.5, 0.01, 1000, PRINT_EVERY=100)
|
|
print("test 3 result:", t3)
|
|
assert abs(t3) <= 1e-6
|
|
|
|
print("-" * 40)
|
|
print("ALL TESTS PASSED")
|
|
print("-" * 40)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sanity_check()
|