How to properly satisfy trait bounds in Rust while using ndarray structs

67 Views Asked by At

I am trying to implement a Tensor struct that will hold an Array from ndarray crate, where T is the element type, and I for the dimensionality. The core idea of the struct is to hold a two dimensional or a three dimensional array.

use std::fmt::Debug;
use ndarray::ArrayBase;
use ndarray::prelude::*;
use ndarray::OwnedRepr;
use ndarray::{Array};
use ndarray::Array3;
use ndarray::Dimension;

struct Tensor<T,I> 
where T: Clone + Sync + Send,
I: PartialEq + Debug 
{
   data: ArrayBase<OwnedRepr<T>,Dim<I>>
}


impl<T,I> Tensor<T,I>
where T: Clone + Sync + Send,
I: PartialEq + Debug 
{    
     pub fn new(arr:ArrayBase<OwnedRepr<T>,Dim<I>>) -> Self {
       Self {data: arr }
     }

     pub fn shape(&self) -> &[usize] {
       self.data.shape()
     }
}


fn main() {
    let mut temperature = Array3::<f32>::zeros((3, 4, 5));
    let shape = temperature.shape();
    let t3D:Tensor<f32,[usize;3]> = Tensor::new(temperature);
}

I wanted to create a shape() method, associated to the Tensor struct, that returns the shape of the ndarray residing in the data field.

But rustc throws an error saying that:

error[E0599]: the method `shape` exists for struct `ArrayBase<OwnedRepr<T>, Dim<I>>`, but its
trait bounds were not satisfied

**error[E0599]: the method `shape` exists for struct `ArrayBase<OwnedRepr<T>, Dim<I>>`, 
but its     trait bounds were not satisfied
--> src/main.rs:36:20
   |
36 |          self.data.shape()
   |                    ^^^^^ method cannot be called on `ArrayBase<OwnedRepr<T>, Dim<I>>` due to unsatisfied trait bounds**


doesn't satisfy `<_ as DimAdd<Dim<IxDynImpl>>>::Output = Dim<IxDynImpl>`
| doesn't satisfy `<_ as DimAdd<Dim<[usize; 0]>>>::Output = Dim<I>`
| doesn't satisfy `<_ as DimAdd<Dim<[usize; 1]>>>::Output = <Dim<I> as Dimension>::Larger`
| doesn't satisfy `<_ as DimMax<<Dim<I> as Dimension>::Larger>>::Output = <Dim<I> as Dimension>::Larger`
| doesn't satisfy `<_ as DimMax<<Dim<I> as Dimension>::Smaller>>::Output = Dim<I>`
| doesn't satisfy `<_ as DimMax<Dim<I>>>::Output = Dim<I>`
| doesn't satisfy `<_ as DimMax<Dim<IxDynImpl>>>::Output = Dim<IxDynImpl>`
| doesn't satisfy `<_ as DimMax<Dim<[usize; 0]>>>::Output = Dim<I>`
| doesn't satisfy `<_ as Index<usize>>::Output = usize`
| doesn't satisfy `<_ as Mul<usize>>::Output = Dim<I>`
| doesn't satisfy `<ndarray::Dim<I> as Add>::Output = ndarray::Dim<I>`
| doesn't satisfy `<ndarray::Dim<I> as Mul>::Output = ndarray::Dim<I>`
| doesn't satisfy `<ndarray::Dim<I> as Sub>::Output = ndarray::Dim<I>`
| doesn't satisfy `_: DimAdd<<Dim<I> as Dimension>::Larger>`
| doesn't satisfy `_: DimAdd<<Dim<I> as Dimension>::Smaller>`
| doesn't satisfy `_: DimMax<<Dim<I> as Dimension>::Larger>`
| doesn't satisfy `_: DimMax<<Dim<I> as Dimension>::Smaller>`
| doesn't satisfy `ndarray::Dim<I>: AddAssign<&'x ndarray::Dim<I>>`
| doesn't satisfy `ndarray::Dim<I>: AddAssign`
| doesn't satisfy `ndarray::Dim<I>: Add`
| doesn't satisfy `ndarray::Dim<I>: Clone`
| doesn't satisfy `ndarray::Dim<I>: Default`
| doesn't satisfy `ndarray::Dim<I>: DimAdd<ndarray::Dim<I>>`
| doesn't satisfy `ndarray::Dim<I>: DimAdd<ndarray::Dim<IxDynImpl>>`
| doesn't satisfy `ndarray::Dim<I>: DimAdd<ndarray::Dim<[usize; 0]>>`
| doesn't satisfy `ndarray::Dim<I>: DimAdd<ndarray::Dim<[usize; 1]>>`
| doesn't satisfy `ndarray::Dim<I>: DimMax<ndarray::Dim<I>>`
| doesn't satisfy `ndarray::Dim<I>: DimMax<ndarray::Dim<IxDynImpl>>`
| doesn't satisfy `ndarray::Dim<I>: DimMax<ndarray::Dim<[usize; 0]>>`
| doesn't satisfy `ndarray::Dim<I>: Dimension`
| doesn't satisfy `ndarray::Dim<I>: Eq`
| doesn't satisfy `ndarray::Dim<I>: IndexMut<usize>`
| doesn't satisfy `ndarray::Dim<I>: Mul<usize>`
| doesn't satisfy `ndarray::Dim<I>: MulAssign<&'x ndarray::Dim<I>>`
| doesn't satisfy `ndarray::Dim<I>: MulAssign<usize>`
| doesn't satisfy `ndarray::Dim<I>: MulAssign`
| doesn't satisfy `ndarray::Dim<I>: Mul`
| doesn't satisfy `ndarray::Dim<I>: Send`
| doesn't satisfy `ndarray::Dim<I>: SubAssign<&'x ndarray::Dim<I>>`
| doesn't satisfy `ndarray::Dim<I>: SubAssign`
| doesn't satisfy `ndarray::Dim<I>: Sub`
| doesn't satisfy `ndarray::Dim<I>: Sync`
| doesn't satisfy `ndarray::Dim<I>: std::ops::Index<usize>`

I tried to apply trait bounds to the struct and also the impl but the compiler keeps throwing the error

So I have a couple of questions:

  1. How to add the trait bounds properly to the above code?
  2. How to Identify the trait bound a particular method needs?
2

There are 2 best solutions below

0
Chayim Friedman On

You can constrain where Dim<I>: Dimension.

0
Jmb On

Easiest is to simply add a bound of Dim<I>: Dimension on your impl block (note BTW that you can remove the bounds on the struct itself):

use ndarray::prelude::*;
use ndarray::Array;
use ndarray::Array3;
use ndarray::ArrayBase;
use ndarray::Dimension;
use ndarray::OwnedRepr;
use std::fmt::Debug;

struct Tensor<T, I> {
    data: ArrayBase<OwnedRepr<T>, Dim<I>>,
}

impl<T, I> Tensor<T, I>
where
    T: Clone + Sync + Send,
    I: PartialEq + Debug,
    Dim<I>: Dimension,
{
    pub fn new(arr: ArrayBase<OwnedRepr<T>, Dim<I>>) -> Self {
        Self { data: arr }
    }

    pub fn shape(&self) -> &[usize] {
        self.data.shape()
    }
}

fn main() {
    let mut temperature = Array3::<f32>::zeros((3, 4, 5));
    let shape = temperature.shape();
    let t3D: Tensor<f32, [usize; 3]> = Tensor::new(temperature);
}

Playground

This can be deduced from the docs because the impl ArrayBase block that defines the shape method is:

impl<A, S, D> ArrayBase<S, D> where
    S: RawData<Elem = A>,
    D: Dimension, 

In your case D is Dim<I>, so you want Dim<I>: Dimension.