toy model
class ToyModel(nn.Module):
"""MLP based model"""
def __init__(self):
super().__init__()
self.in_proj = nn.Linear(10, 32)
self.relu = nn.ReLU()
self.out_proj = nn.Linear(32, 5)
def forward(self, x):
return self.out_proj(self.relu(self.in_proj(x)))
.
configuration
sp_model = parallelize_module(
module=model,
device_mesh=device_mesh,
parallelize_plan={
"in_proj": ColwiseParallel(input_layouts=Shard(0)),
"out_proj": RowwiseParallel(output_layouts=Shard(0)),
},
)
..
- Input Sharding:
- The input sequence (shape [4 x 12 x 10]) is initially split along the sequence length dimension across 3 GPUs.
- Each GPU receives a [4 x 4 x 10] shard of the input.
- All-Gather Operation:
- An all-gather operation is performed to reconstruct the full input on each GPU.
- After this, each GPU has the full [4 x 12 x 10] input.
- First Layer -
in_proj
(ColwiseParallel):- The weight matrix [10 x 32] is split column-wise across GPUs: [10 x 11], [10 x 11], [10 x 10].
- Each GPU processes the full input [4 x 12 x 10] with its portion of the weight matrix.
- The output on each GPU is [4 x 12 x 11], [4 x 12 x 11], and [4 x 12 x 10] respectively.
- ReLU Activation:
- Applied element-wise to the output of the first layer on each GPU.
- Shapes remain [4 x 12 x 11], [4 x 12 x 11], and [4 x 12 x 10] on the respective GPUs.
- Second Layer -
out_proj
(RowwiseParallel):- The weight matrix [32 x 5] is split row-wise across GPUs: [11 x 5], [11 x 5], [10 x 5].
- Each GPU processes its input ([4 x 12 x 11], [4 x 12 x 11], [4 x 12 x 10]) with its portion of the weight matrix.
- The output on each GPU is [4 x 12 x 5], representing partial sums for the full sequence.
- Reduce-Scatter Operation:
- A reduce-scatter operation is performed to sum the partial results and distribute them across GPUs.
- This results in each GPU having a portion of the final output, sharded along the sequence dimension.
Key Corrections and Clarifications:
- There are indeed two collective operations: an all-gather at the beginning and a reduce-scatter at the end.
- The GPUs do not receive the same amount of tensor in the first layer output due to the uneven split of the weight matrix.
- The sequence dimension (12 in this example) is not sharded during the middle layers but is reconstructed and then re-sharded at the end.
This corrected diagram and explanation more accurately represent the sequence parallelism process as described in the original comment. It shows how the input is gathered, processed in parallel, and then the output is scattered, allowing for efficient parallel processing of the entire sequence across GPUs.
.
import os
import sys
import torch
import torch.nn as nn
from torch.distributed._tensor import Shard
from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
RowwiseParallel,
)
from log_utils import rank_log, get_logger, verify_min_gpu_count
import torch.profiler
# ---- GPU check ------------
_min_gpu_count = 2
if not verify_min_gpu_count(min_gpus=_min_gpu_count):
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
sys.exit()
# ---------------------------
from torch.distributed._tensor.device_mesh import init_device_mesh
"""
This is the script to test Sequence Parallel(SP) on a toy model in a
Megetron-LM SPMD style. We show an E2E working flow from forward,
backward and optimization.
We use the example of two `nn.Linear` layers with an element-wise `nn.RELU`
in between to show an example of sequence parallel, which was proposed in paper:
https://arxiv.org/pdf/2205.05198.pdf.
Like tensor parallel, we parallelize the first linear layer by column
and also parallelize the second linear layer by row. But the input in each rank
now is different so that we need one all-gather for input and one reduce-scatter
in the end of the second linear layer.
"""
class ToyModel(nn.Module):
"""MLP based model"""
def __init__(self):
super().__init__()
self.in_proj = nn.Linear(10, 32)
self.relu = nn.ReLU()
self.out_proj = nn.Linear(32, 5)
def forward(self, x):
return self.out_proj(self.relu(self.in_proj(x)))
def main():
logger = get_logger()
# create a device mesh based on the given world_size.
device_mesh = init_device_mesh(
device_type="cuda", mesh_shape=(int(os.environ["WORLD_SIZE"]),)
)
_rank = device_mesh.get_rank()
print(f"Starting PyTorch Sequence Parallel example on rank {_rank}.")
rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}")
# create model and move it to GPU. Init_device_mesh has already assigned gpu ids...
model = ToyModel().to("cuda")
# Custom parallelization plan for the model
sp_model = parallelize_module(
module=model,
device_mesh=device_mesh,
parallelize_plan={
"in_proj": ColwiseParallel(input_layouts=Shard(0)),
"out_proj": RowwiseParallel(output_layouts=Shard(0)),
},
)
# Create a optimizer for the parallelized module.
lr = 0.25
optimizer = torch.optim.AdamW(sp_model.parameters(), lr=lr, foreach=True)
# Perform a num of iterations of forward/backward
# and optimizations for the sharded module.
num_iters = 10
rank_log(_rank, logger, "Sequence Parallel training starting...")
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
on_trace_ready=torch.profiler.tensorboard_trace_handler(f'./log/tensorboard/rank_{_rank}'),
record_shapes=True,
profile_memory=True,
with_stack=True
) as prof:
for i in range(num_iters):
# For SP, input can be different across all ranks.
inp = torch.rand(20, 10, device="cuda")
output = sp_model(inp)
output.sum().backward()
optimizer.step()
rank_log(_rank, logger, f"Sequence Parallel iter {i} completed")
prof.step()
rank_log(_rank, logger, "Sequence Parallel training completed!")
# Print profiler results
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
if __name__ == "__main__":
main()
..
Thank you!
No comments:
Post a Comment