I'm mostly an outsider trying to understand if Rust is appropriate for my projects.

There are frameworks that do automatic differentiation in Rust. Specifically, candle, and some other projects, I think, somehow do it in a way that's similar to PyTorch, according to their description.

However, I know that Rust does not allow multiple mutable references. And it seems like that is what's needed for PyTorch-like automatic differentiation:

x = torch.rand(10) # an array of 10 elements
x.requires_grad = True

y = x.sin()
z = x**2

Both y and z must keep mutable references to x, because you might want to backpropagate them, which will modify x.grad. For example:

(y.dot(z)).backwards()
print(x.grad) # .backwards() adds a new field (an array) to x, without modifying it otherwise

So how can similar behavior be implemented in Rust, given that it does not allow multiple mutable references?

2

There are 2 best solutions below

0
harmic On BEST ANSWER

You are correct that the rust compiler enforces that there can only be one mutable reference to a value at a time, but there is an escape hatch: the interior mutability pattern.

This pattern allows programmers to construct data structures for which the rules are checked at run time instead of compile time.

The standard library provides a number of containers that implement interior mutability, with different usage patterns suitable for different scenarios. Key examples are:

  • RefCell<T>, which allows run time borrow checking for single threaded usage

  • RwLock<T>, which allows run time borrow checking for mutliple threaded usage

  • Mutex<T>, which only allows one reference at a time to its contents

There are others - see the module level documentation for cell and sync.

How does this apply to candle? Let's take a peek under the hood:

pub struct Tensor_ {
    ...
    storage: Arc<RwLock<Storage>>,
    ...

The contents of the storage that backs the tensor are protected by an RwLock. In fact there are some comments in the code immediately above this which describes the reason for the choice of this particular solution - worth a read.

Not only that, but this is in turn wrapped in an Arc<T> - which means that it is in fact a heap allocated, reference counted value. There can be multiple 'owners' of this value, and it will only be deallocated when the last owner goes out of scope.

How is this used in the case of backpropagation? Well, the backward() method of Tensor does not directly modify the tensor, rather it returns a GradStore containing the computed gradients. A GradStore may in turn be consumed by an Optimizer. Optimizer is a trait, with a couple of different implementations, so let's take a look at the SGD optimizer:

    fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
        for var in self.vars.iter() {
            if let Some(grad) = grads.get(var) {
                var.set(&var.sub(&(grad * self.learning_rate)?)?)?;
            }
        }
        Ok(())
    }

OK, so the gradients are here being applied to some Var instances - what are these (defined here)?

pub struct Var(Tensor);

Ok, a wrapper around a Tensor. And how does the set method do it's job? This line is key:

let (mut dst, layout) = self.storage_mut_and_layout();

That gives us a mutable variable that seems to represent the destination for the set operation. What does this storage_mut_and_layout() method do?

let storage = self.storage.write().unwrap();

Ahah! It calls the write() method on the RwLock we saw above, inside which the storage lives. The documentation for this method says:

Locks this RwLock with exclusive write access, blocking the current thread until it can be acquired.

This function will not return while other writers or other readers currently have access to the lock.

So in summary:

  • The backward() method itself does not seem to modify the input Tensor, but it returns a data structure containing the gradients
  • The gradients get applied to the Tensor using an Optimizer.
  • The Optimizer uses the set method to alter the Tensor, which under the hood gets mutable access to the Tensor's data storage using the write() method on the RwLock that is protecting it.
1
kmdreko On

The way to provide seemingly multiple mutable references in Rust is via interior mutability which allows for mutation through shared references. There are still requirements imposed by Rust that disallow mutation at the same time but there are a few ways this can be ensured and thus there are a few types that provide interior mutability generically: Cell, RefCell, Mutex, RwLock. They build off of UnsafeCell as the core primitive that tells the compiler that & doesn't necessarily mean immutable for the contained value.

If we look at candle's source, the fundamental Tensor contains an Arc which allows multiple tensor "handles" to refer to the same value - Arc provides shared ownership (source):

pub struct Tensor(Arc<Tensor_>);

The hidden inner Tensor_ type looks like this (source):

pub struct Tensor_ {
    id: TensorId,
    // As we provide inner mutability on the tensor content, the alternatives are:
    // - Using a mutex, this would have the highest cost when retrieving the storage but would
    //   prevent errors when concurrent access takes place. Mutex would also be subject to
    //   deadlocks for example using the current code if the same tensor is used twice by a single
    //   binary op.
    // - Using a refcell unsafe cell would have some intermediary cost, borrow checking would be
    //   verified dynamically, but the resulting tensors would not be send or sync.
    // - Using an unsafe cell would have the lowest cost but undefined behavior on concurrent
    //   accesses.
    // Ideally, we would use Arc<Storage> for tensors on which we don't plan on modifying the data
    // and Arc<Mutex<Storage>> for tensors where the data could be modified, e.g. variables but
    // that's tricky to encode in the current setup.
    storage: Arc<RwLock<Storage>>,
    layout: Layout,
    op: BackpropOp,
    is_variable: bool,
    dtype: DType,
    device: Device,
}

Which conveniently has a comment weighs the options for interior mutability of storage for us. The RwLock is the peice that allows multiple handles to access and/or mutate the tensor's contents.

So when the backpropagation happens, it accesses the relevant tensors' storage by acquiring a RwLockReadGuard in order to access the data to do the operation and then releases those guards before doing anything with the resulting tensor to avoid deadlocking (since the RwLock would block if mutation was attempted when an existing guard is held).

It appears the library does not make use of this interior mutability, since it prefers to create new tensors instead of mutating existing ones, unless it is a variable that should be updated to reflect the new data. In this case it acquires a RwLockWriteGuard to swap out the data with the new value and again quickly releases the guard.

Its hard to give exact lines for this in candle's source since there are many layers for the backpropagation, storage operations, and keeping track of the results. I also can't give a concrete demonstration with formulas since I'm not well versed in the subject. But I hope this is clear nonetheless and can help you with your own ventures.