1/08/2025

fsdp mixed precision pure vs default

`mixed_precision: PURE` and `mixed_precision: DEFAULT` in FSDP:


`mixed_precision: DEFAULT` (what you saw in logs):

- Parameters are stored in bfloat16

- Gradients are computed and reduced in float32

- Buffers (like batch norm stats) are in bfloat16

- Results in log: "param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.bfloat16"


`mixed_precision: PURE`:

- Parameters are stored in bfloat16

- Gradients are computed and reduced in bfloat16 (this is the key difference)

- Buffers are in bfloat16

- Would show in logs: "param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16"


Performance comparison:

1. Memory Usage:

- PURE uses less memory because gradients are in bfloat16

- DEFAULT uses more memory because gradients are in float32


2. Speed:

- PURE is typically faster because:

  - Less memory bandwidth used during gradient communication

  - Faster gradient reduction operations

  - Particularly beneficial for distributed training

- However, training might be less stable


3. Training Stability:

- DEFAULT is more numerically stable because gradient reduction happens in float32

- PURE might require more careful tuning of learning rate and other hyperparameters


From your logs showing throughput around 191 tokens/sec/device, you might get better performance with PURE mode, possibly 5-15% faster due to reduced communication overhead. However, if you experience training instability (very high loss values or NaNs), you should switch back to DEFAULT.


Recommendation:

1. Start with PURE for better performance

2. Monitor training metrics closely

3. If you see instability, fall back to DEFAULT


No comments:

Post a Comment