Showing posts with label fp32. Show all posts
Showing posts with label fp32. Show all posts

1/19/2025

FP32 vs FP8 with tiny NN model.

I'll create a simple example of a tiny neural network to demonstrate fp8 vs fp32 memory usage. Let's make a small model with these layers:


1. Input: 784 features (like MNIST image 28x28)

2. Hidden layer 1: 512 neurons

3. Hidden layer 2: 256 neurons

4. Output: 10 neurons (for 10 digit classes)


Let's calculate the memory needed for weights:


1. First Layer Weights:

```

784 × 512 = 401,408 weights

+ 512 biases

= 401,920 parameters

```


2. Second Layer Weights:

```

512 × 256 = 131,072 weights

+ 256 biases

= 131,328 parameters

```


3. Output Layer Weights:

```

256 × 10 = 2,560 weights

+ 10 biases

= 2,570 parameters

```


Total Parameters: 535,818


Memory Usage:

```

FP32: 535,818 × 4 bytes = 2,143,272 bytes ≈ 2.14 MB

FP8:  535,818 × 1 byte  =   535,818 bytes ≈ 0.54 MB

```


Let's demonstrate this with some actual matrix multiplication:


```python

# Example of one batch of inference

Input size: 32 images (batch) × 784 features

32 × 784 = 25,088 numbers


For first layer multiplication:

(32 × 784) × (784 × 512) → (32 × 512)

```


During computation:

1. With fp32:

```

Weights in memory: 401,920 × 4 = 1,607,680 bytes

Input in memory: 25,088 × 4 = 100,352 bytes

Output in memory: 16,384 × 4 = 65,536 bytes

Total: ≈ 1.77 MB

```


2. With fp8:

```

Weights in memory: 401,920 × 1 = 401,920 bytes

Input in memory: 25,088 × 1 = 25,088 bytes

Output in memory: 16,384 × 1 = 16,384 bytes

Total: ≈ 0.44 MB

```


During actual computation:

```

1. Load a tile/block of the weight matrix (let's say 128×128)

   fp8: 128×128 = 16,384 bytes

2. Convert this block to fp32: 16,384 × 4 = 65,536 bytes

3. Perform multiplication in fp32

4. Convert result back to fp8

5. Move to next block

```


This shows how even though we compute in fp32, keeping the model in fp8:

1. Uses 1/4 the memory for storage

2. Only needs small blocks in fp32 temporarily

3. Can process larger batches or models with same memory

1/07/2025

FP32, TF32, FP16, BFLOAT16, FP8

 


A floating-point number consists of three parts:

1. Sign bit (determines if number is positive or negative)

2. Exponent (controls how far to move the decimal point)

3. Mantissa/Fraction (the actual digits of the number)


Basic Formula:

```

Number = (-1)^sign × (1 + mantissa) × 2^(exponent - bias)

```


Let's break down the number 42.5 into FP32 format:

1. First, convert 42.5 to binary:

   - 42 = 101010 (in binary)

   - 0.5 = 0.1 (in binary)

   - So 42.5 = 101010.1 (binary)


2. Normalize the binary (move decimal until only one 1 is before decimal):

   - 101010.1 = 1.010101 × 2^5

   - Mantissa becomes: 010101

   - Exponent becomes: 5


3. For FP32:

   - Sign bit: 0 (positive number)

   - Exponent: 5 + 127 (bias) = 132 = 10000100

   - Mantissa: 01010100000000000000000


Example in different formats:


1. FP32 (32-bit):

```

Sign    Exponent     Mantissa

0       10000100    01010100000000000000000

```


2. FP16 (16-bit):

```

Sign    Exponent  Mantissa

0       10100     0101010000

```


3. FP8 (8-bit):

```

Sign    Exponent  Mantissa

0       1010      010

```


Real-world example:

```python

# Breaking down 42.5 in FP32

sign = 0  # positive

exponent = 5 + 127  # actual exponent + bias

mantissa = 0.328125  # binary 010101 converted to decimal


# Calculation

value = (-1)**sign * (1 + mantissa) * (2**(exponent - 127))

# = 1 * (1 + 0.328125) * (2**5)

# = 1.328125 * 32

# = 42.5

```


The tradeoffs:

- More exponent bits = larger range of numbers (very big/small)

- More mantissa bits = more precision (decimal places)

- FP8 sacrifices both for memory efficiency

- BFLOAT16 keeps exponent bits (range) but reduces precision


This is why different formats are used for different parts of ML models:

- Weights might use FP16/BF16 for good balance

- Activations might use FP8 for efficiency

- Final results might use FP32 for accuracy