How to train a model using gradient descent with multioutput (vector-valued) loss function in JAX?

73 Views Asked by At

I am trying to train a model that has two outputs with gradient descent. My cost function therefore returns two errors. What is the typical way to deal with this problem?

I've seen mentions here and there of this problem, but I haven't come up with a satisfactory solution.

This is a toy example that reproduces my problem:

from jax import jit, random, grad
import optax


@jit
def my_model(forz, params):
    a, b = params

    a_vect = a + forz**b
    b_vect = b + forz**a

    return a_vect, b_vect*50.


@jit
def rmse(predictions, targets):

    rmse = jnp.sqrt(jnp.mean((predictions - targets) ** 2))
    return rmse


@jit
def my_loss(forz, params, true_a, true_b):

    sim_a, sim_b = my_model(forz, params)

    loss_a = rmse(sim_a, true_a)
    loss_b = rmse(sim_b, true_b)

    return loss_a, loss_b


grad_myloss = jit(grad(my_loss, argnums=1))

# synthetic true data
key = random.PRNGKey(758493)
forz = random.uniform(key, shape=(1000,))

true_params = [8.9, 6.6]
true_a, true_b = my_model(forz, true_params)

# Train
model_params = random.uniform(key, shape=(2,))
optimizer = optax.adabelief(1e-1)
opt_state = optimizer.init(model_params)

for i in range(1000):

    grads = grad_myloss(forz, model_params, true_a, true_b)  # this fails
    updates, opt_state = optimizer.update(grads, opt_state)
    model_params = optax.apply_updates(model_params, updates)

I understand that either the two errors has to be somehow aggregated to a single one implementing some kind of normalization to the losses (my output vectors have non-comparable units),

@jit
def normalized_rmse(predictions, targets):
   std_dev_targets = jnp.std(targets)
   rmse = jnp.sqrt(jnp.mean((predictions - targets) ** 2))
   return rmse/std_dev_targets


@jit
def my_loss_single(forz, params, true_a, true_b):

   sim_a, sim_b = my_model(forz, params)

   loss_a = normalized_rmse(sim_a, true_a)
   loss_b = normalized_rmse(sim_b, true_b)

   return jnp.sqrt((loss_a ** 2) + (loss_b * 2)) 

or I should use the Jacobian matrix (jacrev) somehow?

1

There are 1 best solutions below

2
jakevdp On BEST ANSWER

optax, like most optimization frameworks, is only able to optimize a single-valued loss function. You should decide what single-valued loss makes sense for your particular problem. A good option given the RMS form of your individual losses might be the square sum:

@jit
def my_loss(forz, params, true_a, true_b):

    sim_a, sim_b = my_model(forz, params)

    loss_a = rmse(sim_a, true_a)
    loss_b = rmse(sim_b, true_b)

    return loss_a ** 2 + loss_b ** 2

With this change, your code executes without an error.