I'll create a simple example of a tiny neural network to demonstrate fp8 vs fp32 memory usage. Let's make a small model with these layers:
1. Input: 784 features (like MNIST image 28x28)
2. Hidden layer 1: 512 neurons
3. Hidden layer 2: 256 neurons
4. Output: 10 neurons (for 10 digit classes)
Let's calculate the memory needed for weights:
1. First Layer Weights:
```
784 × 512 = 401,408 weights
+ 512 biases
= 401,920 parameters
```
2. Second Layer Weights:
```
512 × 256 = 131,072 weights
+ 256 biases
= 131,328 parameters
```
3. Output Layer Weights:
```
256 × 10 = 2,560 weights
+ 10 biases
= 2,570 parameters
```
Total Parameters: 535,818
Memory Usage:
```
FP32: 535,818 × 4 bytes = 2,143,272 bytes ≈ 2.14 MB
FP8: 535,818 × 1 byte = 535,818 bytes ≈ 0.54 MB
```
Let's demonstrate this with some actual matrix multiplication:
```python
# Example of one batch of inference
Input size: 32 images (batch) × 784 features
32 × 784 = 25,088 numbers
For first layer multiplication:
(32 × 784) × (784 × 512) → (32 × 512)
```
During computation:
1. With fp32:
```
Weights in memory: 401,920 × 4 = 1,607,680 bytes
Input in memory: 25,088 × 4 = 100,352 bytes
Output in memory: 16,384 × 4 = 65,536 bytes
Total: ≈ 1.77 MB
```
2. With fp8:
```
Weights in memory: 401,920 × 1 = 401,920 bytes
Input in memory: 25,088 × 1 = 25,088 bytes
Output in memory: 16,384 × 1 = 16,384 bytes
Total: ≈ 0.44 MB
```
During actual computation:
```
1. Load a tile/block of the weight matrix (let's say 128×128)
fp8: 128×128 = 16,384 bytes
2. Convert this block to fp32: 16,384 × 4 = 65,536 bytes
3. Perform multiplication in fp32
4. Convert result back to fp8
5. Move to next block
```
This shows how even though we compute in fp32, keeping the model in fp8:
1. Uses 1/4 the memory for storage
2. Only needs small blocks in fp32 temporarily
3. Can process larger batches or models with same memory
No comments:
Post a Comment