Rust get mutable reference to each element of an ndarray in parallel

73 Views Asked by At

I am working on a parallel matrix multiplication code in Rust, where I want to compute every element of the product in parallel. I use ndarrays to store my data. Thus, my code would be something alone the lines

fn mul(lhs: &Array2<f32>, rhs: &Array2<f32>) -> Array2<f32> {
   let N = lhs.raw_size()[0];
   let M = rhs.raw_size()[1];
   let mut result = Array2::zeros((N,M));
   
   range_2d(0..N,0..M).par_iter().map(|(i, j)| {
      // load the result for the (i,j) element into 'result'
   }).count();

   result
}

Is there any way to achieve this?

1

There are 1 best solutions below

2
Chayim Friedman On BEST ANSWER

You can create a parallel iterator this way:

use rayon::prelude::*;

pub fn mul(lhs: &Array2<f32>, rhs: &Array2<f32>) -> Array2<f32> {
    let n = lhs.raw_dim()[0];
    let m = rhs.raw_dim()[1];
    let mut result = Array2::zeros((n, m));

    result
        .axis_iter_mut(Axis(0))
        .into_par_iter()
        .enumerate()
        .flat_map(|(n, axis)| {
            axis.into_slice()
                .unwrap()
                .par_iter_mut()
                .enumerate()
                .map(move |(m, item)| (n, m, item))
        })
        .for_each(|(n, m, item)| {
            // Do the multiplication.
            *item = n as f32 * m as f32;
        });

    result
}