Given two tensors A and B with the same dimension (d>=2) and shapes [A_{1},...,A_{d-2},A_{d-1},A_{d}] and [A_{1},...,A_{d-2},B_{d-1},B_{d}] (shapes of the first d-2 dimensions are identical).
Is there a way to calculate the kronecker product over the last two dimensions?
The shape of my_kron(A,B)should be [A_{1},...,A_{d-2},A_{d-1}*B_{d-1},A_{d}*B_{d}].
For example with d=3,
A.shape=[2,3,3]
B.shape=[2,4,4]
C=my_kron(A,B)
C[0,...] should be the kronecker product of A[0,...] and B[0,...] and C[1,...] the kronecker product of A[1,...] and B[1,...].
For d=2 this is simply what the jnp.kron(or np.kron) function does.
For d=3 this can be achived with jax.vmap.
jax.vmap(lambda x, y: jnp.kron(x[0, :], y[0, :]))(A, B)
But I was not able to find a solution for general (unknown) dimensions. Any suggestions?
In
numpyterms I think this is what you are doing:That treats the initial dimension, the 2, as a
batch. One obvious generalization is to reshape the arrays, reducing the higher dimensions to 1, e.g.reshape(-1,3,3), etc. And then afterwards, reshapeCback to the desired n-dimensions.np.krondoes accept 3d (and higher), but it's doing some sort ofouteron the shared 2 dimension:And visualizing that 4 dimension as (2,2), I can take the
diagonaland get yourC:The full
krondoes more calculations than needed, but is still faster:I'm sure it's possible to do your calculation in a more direct way, but doing that requires a better understanding of how the
kronworks. A quick glance as thenp.kroncode suggest that is does anouter(A,B)which has the same number of elements, but it then
reshapesandconcatenatesto produce thekronlayout.But following a hunch, I found that this is equivalent to what you want:
That is easily generalized to more leading dimensions.