toy model
.
configuration
..
- Input Sharding:
- The input sequence (shape [4 x 12 x 10]) is initially split along the sequence length dimension across 3 GPUs.
- Each GPU receives a [4 x 4 x 10] shard of the input.
- All-Gather Operation:
- An all-gather operation is performed to reconstruct the full input on each GPU.
- After this, each GPU has the full [4 x 12 x 10] input.
- First Layer -
in_proj
(ColwiseParallel):- The weight matrix [10 x 32] is split column-wise across GPUs: [10 x 11], [10 x 11], [10 x 10].
- Each GPU processes the full input [4 x 12 x 10] with its portion of the weight matrix.
- The output on each GPU is [4 x 12 x 11], [4 x 12 x 11], and [4 x 12 x 10] respectively.
- ReLU Activation:
- Applied element-wise to the output of the first layer on each GPU.
- Shapes remain [4 x 12 x 11], [4 x 12 x 11], and [4 x 12 x 10] on the respective GPUs.
- Second Layer -
out_proj
(RowwiseParallel):- The weight matrix [32 x 5] is split row-wise across GPUs: [11 x 5], [11 x 5], [10 x 5].
- Each GPU processes its input ([4 x 12 x 11], [4 x 12 x 11], [4 x 12 x 10]) with its portion of the weight matrix.
- The output on each GPU is [4 x 12 x 5], representing partial sums for the full sequence.
- Reduce-Scatter Operation:
- A reduce-scatter operation is performed to sum the partial results and distribute them across GPUs.
- This results in each GPU having a portion of the final output, sharded along the sequence dimension.
Key Corrections and Clarifications:
- There are indeed two collective operations: an all-gather at the beginning and a reduce-scatter at the end.
- The GPUs do not receive the same amount of tensor in the first layer output due to the uneven split of the weight matrix.
- The sequence dimension (12 in this example) is not sharded during the middle layers but is reconstructed and then re-sharded at the end.
This corrected diagram and explanation more accurately represent the sequence parallelism process as described in the original comment. It shows how the input is gathered, processed in parallel, and then the output is scattered, allowing for efficient parallel processing of the entire sequence across GPUs.