Concepts Added:
- 2D grid/block configuration
- Matrix multiplication basics
- Each thread computes one output element
Key Pattern:
// Each thread computes C[row][col]
for (int k = 0; k < K; k++) {
sum += A[row][k] * B[k][col];
}
.Concepts Added:
Key Pattern:
// Each thread computes C[row][col]
for (int k = 0; k < K; k++) {
sum += A[row][k] * B[k][col];
}
.1. About Output Types (D):
No, the output D is not limited to fp32/int32. Looking at the table, D can be:
- fp32
- fp16
- bf16
- fp8
- bf8
- int8
2. Input/Output Patterns:
When A is fp16, you have two options:
```
Option 1:
A: fp16 → B: fp16 → C: fp16 → D: fp16 → Compute: fp32
Option 2:
A: fp16 → B: fp16 → C: fp16 → D: fp32 → Compute: fp32
```
The compute/scale is always higher precision (fp32 or int32) to maintain accuracy during calculations, even if inputs/outputs are lower precision.
3. Key Patterns in the Table:
- Inputs A and B must always match in type
- C typically matches A and B, except with fp8/bf8 inputs
- When using fp8/bf8 inputs, C and D can be higher precision (fp32, fp16, or bf16)
- The compute precision is always fp32 for floating point types
- For integer operations (int8), the compute precision is int32
4. Why Different Combinations?
- Performance: Lower precision (fp16, fp8) = faster computation + less memory
- Accuracy: Higher precision (fp32) = better accuracy but slower
- Memory Usage: fp16/fp8 use less memory than fp32
- Mixed Precision: Use lower precision for inputs but higher precision for output to balance speed and accuracy
Example Use Cases:
```
High Accuracy Needs:
A(fp32) → B(fp32) → C(fp32) → D(fp32) → Compute(fp32)
Balanced Performance:
A(fp16) → B(fp16) → C(fp16) → D(fp32) → Compute(fp32)
Maximum Performance:
A(fp8) → B(fp8) → C(fp8) → D(fp8) → Compute(fp32)
```
1. GEMM (General Matrix Multiplication):
- This is the basic operation: C = A × B (matrix multiplication)
- Fundamental operation in deep learning, especially transformers
- Core computation in attention mechanisms, linear layers, etc.
2. Triton:
- A programming language for writing GPU kernels
- Lets you write your own custom GEMM implementation
- You control memory layout, tiling, etc.
- Example use: When you need a very specific matrix operation
3. hipBLASLt:
- A specialized library just for matrix operations
- Pre-built, highly optimized GEMM implementations
- Focuses on performance for common matrix sizes
- Example use: When you need fast, standard matrix multiplication
4. Transformer Engine:
- NVIDIA's specialized library for transformer models
- Automatically handles precision switching (FP8/FP16/FP32)
- Optimizes GEMM operations specifically for transformer architectures
- Includes specialized kernels for attention and linear layers
- Example use: When building large language models
The relationship:
```
Transformer Model
↓
Transformer Engine
↓
GEMM Operations (can be implemented via:)
↓
hipBLASLt / Triton / Other libraries
↓
GPU Hardware
```
the same matrix multiplication would be implemented using different approaches:
1. Basic GEMM Operation (what we want to compute):
```python
# C = A × B
# Where A is (M×K) and B is (K×N)
```
2. Using Triton (Custom implementation):
```python
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr, # Pointers to matrices
M, N, K, # Matrix dimensions
stride_am, stride_ak, # Memory strides for A
stride_bk, stride_bn, # Memory strides for B
stride_cm, stride_cn, # Memory strides for C
BLOCK_SIZE: tl.constexpr,
):
# Get program ID
pid = tl.program_id(0)
# Calculate block indices
block_i = pid // (N // BLOCK_SIZE)
block_j = pid % (N // BLOCK_SIZE)
# Load blocks from A and B
a = tl.load(a_ptr + ...) # Load block from A
b = tl.load(b_ptr + ...) # Load block from B
# Compute block multiplication
c = tl.dot(a, b) # Matrix multiply
# Store result
tl.store(c_ptr + ..., c)
```
3. Using hipBLASLt:
```cpp
// Initialize hipBLASLt
hipblasLtHandle_t handle;
hipblasLtCreate(&handle);
// Define matrix layout
hipblasLtMatrixLayout_t matA, matB, matC;
hipblasLtMatrixLayoutCreate(&matA, HIPBLAS_LT_R_16F, M, K, M);
hipblasLtMatrixLayoutCreate(&matB, HIPBLAS_LT_R_16F, K, N, K);
hipblasLtMatrixLayoutCreate(&matC, HIPBLAS_LT_R_16F, M, N, M);
// Execute GEMM
hipblasLtMatmul(
handle,
matmulDesc,
&alpha, // Scale factor
A, matA, // Input matrix A
B, matB, // Input matrix B
&beta, // Scale factor
C, matC, // Output matrix C
workspace, // Temporary workspace
streams // CUDA stream
);
```
4. Using Transformer Engine:
```python
import transformer_engine.pytorch as te
# Create TE layers
linear = te.Linear(in_features, out_features)
# Automatic precision handling
with te.fp8_autocast():
output = linear(input) # Internally uses optimized GEMM
```
Key differences:
1. Triton: You control everything (memory, blocks, compute)
2. hipBLASLt: Pre-optimized, you just call it
3. Transformer Engine: High-level, handles precision automatically
Performance comparison (general case):
```
Speed: hipBLASLt > Transformer Engine > Custom Triton
Flexibility: Triton > hipBLASLt > Transformer Engine
Ease of use: Transformer Engine > hipBLASLt > Triton
```