I'm mostly an outsider trying to understand if Rust is appropriate for my projects.
There are frameworks that do automatic differentiation in Rust. Specifically, candle, and some other projects, I think, somehow do it in a way that's similar to PyTorch, according to their description.
However, I know that Rust does not allow multiple mutable references. And it seems like that is what's needed for PyTorch-like automatic differentiation:
x = torch.rand(10) # an array of 10 elements
x.requires_grad = True
y = x.sin()
z = x**2
Both y and z must keep mutable references to x, because you might want to backpropagate them, which will modify x.grad. For example:
(y.dot(z)).backwards()
print(x.grad) # .backwards() adds a new field (an array) to x, without modifying it otherwise
So how can similar behavior be implemented in Rust, given that it does not allow multiple mutable references?
You are correct that the rust compiler enforces that there can only be one mutable reference to a value at a time, but there is an escape hatch: the interior mutability pattern.
This pattern allows programmers to construct data structures for which the rules are checked at run time instead of compile time.
The standard library provides a number of containers that implement interior mutability, with different usage patterns suitable for different scenarios. Key examples are:
RefCell<T>, which allows run time borrow checking for single threaded usageRwLock<T>, which allows run time borrow checking for mutliple threaded usageMutex<T>, which only allows one reference at a time to its contentsThere are others - see the module level documentation for
cellandsync.How does this apply to candle? Let's take a peek under the hood:
The contents of the storage that backs the tensor are protected by an
RwLock. In fact there are some comments in the code immediately above this which describes the reason for the choice of this particular solution - worth a read.Not only that, but this is in turn wrapped in an
Arc<T>- which means that it is in fact a heap allocated, reference counted value. There can be multiple 'owners' of this value, and it will only be deallocated when the last owner goes out of scope.How is this used in the case of backpropagation? Well, the
backward()method ofTensordoes not directly modify the tensor, rather it returns aGradStorecontaining the computed gradients. AGradStoremay in turn be consumed by anOptimizer.Optimizeris a trait, with a couple of different implementations, so let's take a look at theSGDoptimizer:OK, so the gradients are here being applied to some
Varinstances - what are these (defined here)?Ok, a wrapper around a
Tensor. And how does thesetmethod do it's job? This line is key:That gives us a mutable variable that seems to represent the destination for the set operation. What does this
storage_mut_and_layout()method do?Ahah! It calls the
write()method on theRwLockwe saw above, inside which the storage lives. The documentation for this method says:So in summary:
backward()method itself does not seem to modify the inputTensor, but it returns a data structure containing the gradientsTensorusing anOptimizer.Optimizeruses thesetmethod to alter theTensor, which under the hood gets mutable access to theTensor's data storage using thewrite()method on theRwLockthat is protecting it.