1/19/2025

FP32 vs FP8 with tiny NN model.

I'll create a simple example of a tiny neural network to demonstrate fp8 vs fp32 memory usage. Let's make a small model with these layers:


1. Input: 784 features (like MNIST image 28x28)

2. Hidden layer 1: 512 neurons

3. Hidden layer 2: 256 neurons

4. Output: 10 neurons (for 10 digit classes)


Let's calculate the memory needed for weights:


1. First Layer Weights:

```

784 × 512 = 401,408 weights

+ 512 biases

= 401,920 parameters

```


2. Second Layer Weights:

```

512 × 256 = 131,072 weights

+ 256 biases

= 131,328 parameters

```


3. Output Layer Weights:

```

256 × 10 = 2,560 weights

+ 10 biases

= 2,570 parameters

```


Total Parameters: 535,818


Memory Usage:

```

FP32: 535,818 × 4 bytes = 2,143,272 bytes ≈ 2.14 MB

FP8:  535,818 × 1 byte  =   535,818 bytes ≈ 0.54 MB

```


Let's demonstrate this with some actual matrix multiplication:


```python

# Example of one batch of inference

Input size: 32 images (batch) × 784 features

32 × 784 = 25,088 numbers


For first layer multiplication:

(32 × 784) × (784 × 512) → (32 × 512)

```


During computation:

1. With fp32:

```

Weights in memory: 401,920 × 4 = 1,607,680 bytes

Input in memory: 25,088 × 4 = 100,352 bytes

Output in memory: 16,384 × 4 = 65,536 bytes

Total: ≈ 1.77 MB

```


2. With fp8:

```

Weights in memory: 401,920 × 1 = 401,920 bytes

Input in memory: 25,088 × 1 = 25,088 bytes

Output in memory: 16,384 × 1 = 16,384 bytes

Total: ≈ 0.44 MB

```


During actual computation:

```

1. Load a tile/block of the weight matrix (let's say 128×128)

   fp8: 128×128 = 16,384 bytes

2. Convert this block to fp32: 16,384 × 4 = 65,536 bytes

3. Perform multiplication in fp32

4. Convert result back to fp8

5. Move to next block

```


This shows how even though we compute in fp32, keeping the model in fp8:

1. Uses 1/4 the memory for storage

2. Only needs small blocks in fp32 temporarily

3. Can process larger batches or models with same memory

No comments:

Post a Comment