I am struggling to implement this and it feels like it should be more straightforward. I want to do a proper newton update on a small neural network problem. I am exploring a proof of concept that does not necessarily need to scale.
The problematic part of the code is here:
from jax.scipy.linalg import solve
def newton_method(params, input_data, labels, num_iters=100):
# flatten the parameters
params_flat, unravel = tree_flatten(params)
# ravel parameters
params_flat = jnp.concatenate([jnp.ravel(p) for p in params_flat])
def loss(params):
# unflatten the parameters
params_unflattened = tree_unflatten(unravel, params)
return loss_0(params, input_data, labels)
grad_f = jax.grad(loss)
hessian_f = jax.hessian(loss)
for i in range(num_iters):
# Compute gradients and hessian
grad_params = grad_f(params_flat)
hessian_params = hessian_f(params_flat)
# Newton's update step
params_flat -= solve(hessian_params, grad_params)
new_params = tree_unflatten(unravel, params_flat)
return new_params
newton_method(params, data, labels, num_iters=10)
Which illustrates the newton solve idea I am trying to implement. I want to flatten the parameters of the neural network, do whatever taylor approximation solution I want to with the right derivatives and then update the parameters of the model. Below is the extra working code:
import equinox as eqx
from typing import List
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax.tree_util import tree_flatten, tree_unflatten
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
w_key, b_key = random.split(key)
return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))
# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
keys = random.split(key, len(sizes))
return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
layer_sizes = [2, 2, 1]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))
def predict(params, inputs):
# per-example predictions
activations = inputs
for w, b in params[:-1]:
outputs = jnp.dot(w, activations) + b
activations = jax.nn.tanh(outputs)
final_w, final_b = params[-1]
return jnp.dot(final_w, activations) + final_b
# Define model
key = jr.PRNGKey(0)
key, xkey = jr.split(key, 2)
preds = predict(params, jnp.array([10.0, 5.0]))
print(preds.shape)
preds
# Generate data points for x and y
data_points = 100
x = jnp.linspace(-10, 10, data_points)
y = jnp.linspace(-10, 10, data_points)
x, y = jnp.meshgrid(x, y)
x_flat = jnp.expand_dims(x.flatten(), 1)
y_flat = jnp.expand_dims(y.flatten(), 1)
data = jnp.concatenate([x_flat, y_flat], axis=1)
# Generate labels based on the function x^2 + y^2
labels = x_flat ** 2 + y_flat ** 2
def loss_fn(params, input_data, labels):
predictions = predict(params, input_data)
return (predictions - labels) ** 2
def loss_0(params, input_data, labels):
lsn = jax.vmap(loss_fn, in_axes=(None, 0, 0))(params, input_data, labels)
return jnp.mean(lsn)
loss_0 = jax.jit(loss_0)
loss_fn = jax.jit(loss_fn)
from jax.scipy.linalg import solve
def newton_method(params, input_data, labels, num_iters=100):
# flatten the parameters
params_flat, unravel = tree_flatten(params)
# ravel parameters
params_flat = jnp.concatenate([jnp.ravel(p) for p in params_flat])
def loss(params):
# unflatten the parameters
params_unflattened = tree_unflatten(unravel, params)
return loss_0(params, input_data, labels)
grad_f = jax.grad(loss)
hessian_f = jax.hessian(loss)
for i in range(num_iters):
# Compute gradients and hessian
grad_params = grad_f(params_flat)
hessian_params = hessian_f(params_flat)
# Newton's update step
params_flat -= solve(hessian_params, grad_params)
new_params = tree_unflatten(unravel, params_flat)
return new_params
newton_method(params, data, labels, num_iters=10)
I have tried using the standard nested tree structure which does produce the hessians and gradients but that is hard to acess correctly.