I am trying to train a model that has two outputs with gradient descent. My cost function therefore returns two errors. What is the typical way to deal with this problem?
I've seen mentions here and there of this problem, but I haven't come up with a satisfactory solution.
This is a toy example that reproduces my problem:
from jax import jit, random, grad
import optax
@jit
def my_model(forz, params):
a, b = params
a_vect = a + forz**b
b_vect = b + forz**a
return a_vect, b_vect*50.
@jit
def rmse(predictions, targets):
rmse = jnp.sqrt(jnp.mean((predictions - targets) ** 2))
return rmse
@jit
def my_loss(forz, params, true_a, true_b):
sim_a, sim_b = my_model(forz, params)
loss_a = rmse(sim_a, true_a)
loss_b = rmse(sim_b, true_b)
return loss_a, loss_b
grad_myloss = jit(grad(my_loss, argnums=1))
# synthetic true data
key = random.PRNGKey(758493)
forz = random.uniform(key, shape=(1000,))
true_params = [8.9, 6.6]
true_a, true_b = my_model(forz, true_params)
# Train
model_params = random.uniform(key, shape=(2,))
optimizer = optax.adabelief(1e-1)
opt_state = optimizer.init(model_params)
for i in range(1000):
grads = grad_myloss(forz, model_params, true_a, true_b) # this fails
updates, opt_state = optimizer.update(grads, opt_state)
model_params = optax.apply_updates(model_params, updates)
I understand that either the two errors has to be somehow aggregated to a single one implementing some kind of normalization to the losses (my output vectors have non-comparable units),
@jit
def normalized_rmse(predictions, targets):
std_dev_targets = jnp.std(targets)
rmse = jnp.sqrt(jnp.mean((predictions - targets) ** 2))
return rmse/std_dev_targets
@jit
def my_loss_single(forz, params, true_a, true_b):
sim_a, sim_b = my_model(forz, params)
loss_a = normalized_rmse(sim_a, true_a)
loss_b = normalized_rmse(sim_b, true_b)
return jnp.sqrt((loss_a ** 2) + (loss_b * 2))
or I should use the Jacobian matrix (jacrev) somehow?
optax, like most optimization frameworks, is only able to optimize a single-valued loss function. You should decide what single-valued loss makes sense for your particular problem. A good option given the RMS form of your individual losses might be the square sum:With this change, your code executes without an error.