PyTorch) How to improve the inference speed in this case?

65 Views Asked by At

Can you tell me how to improve the inference speed in this case?

W = torch.rand(768, 768) X = torch.rand(128, 128, 768)

I want to use only the 2nd and 4th quadrants of W for matrix multiplication. (1st and 3rd quadrants of W consist of 0)

For example, i calculate this example. X[:, :, :384] = X[:, :, :384].matmul(W[0:384, 0:384]) X[:, :, 384:768] = X[:, :, 384:768].matmul(W[384:768, 384:768])

By using this code, I tried to improve general XW matrix multiplication. But if i use this code, inference speed is degraded.. How to implement this situation efficiently?

Additionally, Is there any "Divide-and-Conquer" library in Pytorch?

Help me please.

0

There are 0 best solutions below