Product of arrays (einsum) of 3D arrays containing only -1 or +1

77 Views Asked by At

Let X be an array of shape (M, k, g) and Q be an array of shape (m, k, g), where m, M, k, and g can be "very large". Suppose the entries of X and Y are either -1 or plus +1. I'm interested in the array Z of shape (m, M, k) defined by Z[a,b,i] = Q[a,i] @ X[b,i]. It is clear that this can be accomplished in Numpy (or Jax, for GPU exploitation) like so

Z = einsum("mkg,Mkg->mMk", Q, X)

Question. Can Z be computed more efficiently by using the information that the entries of X and Y are -1 or +1 ?

2

There are 2 best solutions below

4
Learning is a mess On BEST ANSWER

As I mentioned in my comment, casting to int8 seems to give a speed up. If you are able and willing to use torch, I see a speedup when using their einsum (cannot elaborate on which optimizations they are using), seems slower if casting to their int8 type though:

enter image description here


0
dohmatob On

Here are some more benchmarks on GPU (T4) to complement the accepted answer. I compared Numpy vs JAX vs PyTorch

CPU

enter image description here

GPU

enter image description here

The conclusion (based on this tiny experiment) is:

  • Only Numpy knows how to take full advantage int8. The others are probably doing some weird casting under the hood.
  • GPU helps JAX.
  • PyTorch is a beast on GPU.