Can you tell me how to improve the inference speed in this case?
W = torch.rand(768, 768) X = torch.rand(128, 128, 768)
I want to use only the 2nd and 4th quadrants of W for matrix multiplication. (1st and 3rd quadrants of W consist of 0)
For example, i calculate this example. X[:, :, :384] = X[:, :, :384].matmul(W[0:384, 0:384]) X[:, :, 384:768] = X[:, :, 384:768].matmul(W[384:768, 384:768])
By using this code, I tried to improve general XW matrix multiplication. But if i use this code, inference speed is degraded.. How to implement this situation efficiently?
Additionally, Is there any "Divide-and-Conquer" library in Pytorch?
Help me please.