I want to use multiple GPUs to do matrix multiplication, like torch.mm(a, b), to reduce memory usage on a single GPU.
Here is the code working on a single GPU:
import torch
a = torch.randn(30000, 30000).cuda(1)
b = torch.randn(30000, 30000).cuda(1)
c = torch.mm(a, b)
# during this process, the maximum memory usage is 10491 MB.
Here is the code working on two GPUs:
import torch
# assuming `a1` and `a2` are parts of a big matrix
a1 = torch.randn(15000, 30000).cuda(0)
a2 = torch.randn(15000, 30000).cuda(1)
b1 = torch.randn(30000, 30000).cuda(0)
b2 = b1.cuda(1)
c1 = torch.mm(a1,b1)
c2 = torch.mm(a2,b2).to(0)
# for now, the result `c1` and `c2` is on GPU 0
# the maximun memory usage on GPU 1 is 7059 MB
# the maximum memory usage on GPU 0 is 8777 MB, bigger than 1 because the result is on it
c = torch.concat([c1, c2], dim=0)
# OOM because concat is not in-place
Therefore, if we can make the concat operation in-place, seems it would work as expected? Or should I move c1 and c2 to CPU memory first and then cat them, then move the cated result to GPU?
I have also tried tensor parallelism provided by PyTorch 2.2:
import torch
import torch.distributed as distributed
import os
from torch.distributed._tensor import init_device_mesh, Shard, distribute_tensor
from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
from visualize_sharding import visualize_sharding
mesh = init_device_mesh("cuda", (2,))
rank = distributed.get_rank()
big_tensor_1 = torch.randn(3, 2)
big_tensor_2 = torch.randn(2, 6)
print("big_tensor_1", big_tensor_1)
my_dtensor_1 = distribute_tensor(big_tensor_1, mesh, [Shard(dim=0)])
my_dtensor_2 = distribute_tensor(big_tensor_2, mesh, [Shard(dim=1)])
# visualize_sharding(my_dtensor_1, header="my_dtensor_1")
c = torch.mm(my_dtensor_1, my_dtensor_2)
print("c: ", c)
But everything would run twice because the command was python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 tmp.py, so there would be two big_tensor_1 randomly generated, how can I modify the code to make it run once with two processes?
Everthing I tried is listed in the problem details.
I tried the following approach which can solve the first problem in the problem detail to some extent, but does not completely resolve it.
UPDATE: this code can reduce the maximum mem usage on a single GPU when using multiple GPUs (here 2 GPUs used):