`mixed_precision: PURE` and `mixed_precision: DEFAULT` in FSDP:
`mixed_precision: DEFAULT` (what you saw in logs):
- Parameters are stored in bfloat16
- Gradients are computed and reduced in float32
- Buffers (like batch norm stats) are in bfloat16
- Results in log: "param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.bfloat16"
`mixed_precision: PURE`:
- Parameters are stored in bfloat16
- Gradients are computed and reduced in bfloat16 (this is the key difference)
- Buffers are in bfloat16
- Would show in logs: "param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16"
Performance comparison:
1. Memory Usage:
- PURE uses less memory because gradients are in bfloat16
- DEFAULT uses more memory because gradients are in float32
2. Speed:
- PURE is typically faster because:
- Less memory bandwidth used during gradient communication
- Faster gradient reduction operations
- Particularly beneficial for distributed training
- However, training might be less stable
3. Training Stability:
- DEFAULT is more numerically stable because gradient reduction happens in float32
- PURE might require more careful tuning of learning rate and other hyperparameters
From your logs showing throughput around 191 tokens/sec/device, you might get better performance with PURE mode, possibly 5-15% faster due to reduced communication overhead. However, if you experience training instability (very high loss values or NaNs), you should switch back to DEFAULT.
Recommendation:
1. Start with PURE for better performance
2. Monitor training metrics closely
3. If you see instability, fall back to DEFAULT
No comments:
Post a Comment