Showing posts with label GEMM. Show all posts
Showing posts with label GEMM. Show all posts

8/25/2025

CK Tile Tutorial Day 2 (AMD hip programming) - Simple GEMM.

 Concepts Added:

  • 2D grid/block configuration
  • Matrix multiplication basics
  • Each thread computes one output element

Key Pattern:

// Each thread computes C[row][col]
for (int k = 0; k < K; k++) {
    sum += A[row][k] * B[k][col];
}
.
=== Thread Mapping Visualization ===
Each thread computes one C[i][j]:

  Block(0,0)        Block(1,0)
  ┌─────────┐      ┌─────────┐
  │T00 T01..│      │T00 T01..│
  │T10 T11..│      │T10 T11..│
  │... ... ..│      │... ... ..│
  └─────────┘      └─────────┘
       ↓                ↓
  C[0:16,0:16]    C[0:16,16:32]

Each thread's work:
  for k in 0..K:
    sum += A[row][k] * B[k][col]
  C[row][col] = sum

=== Step 2: Simple GEMM ===
Matrix multiply: (64x64) * (64x64) = (64x64)
Launching with grid(4,4), block(16,16)
Result: CORRECT
Time: 0.4232 ms
Performance: 1.23887 GFLOPS

=== Step 2: Simple GEMM ===
Matrix multiply: (128x128) * (128x128) = (128x128)
Launching with grid(8,8), block(16,16)
Result: CORRECT
Time: 0.03824 ms
Performance: 109.684 GFLOPS

Key Concepts Added:
1. 2D grid/block configuration
2. Each thread computes one output element
3. Row-major vs column-major layouts
4. Performance measurement (GFLOPS)
..

code
.
// Step 2: Simple GEMM (Matrix Multiplication)
// Building on Step 1, now each thread computes one output element

#include <hip/hip_runtime.h>
#include <iostream>
#include <vector>

// ============================================
// PART 1: Kernel Arguments
// ============================================
struct SimpleGemmKernelArgs {
const float* a_ptr; // M x K matrix
const float* b_ptr; // K x N matrix
float* c_ptr; // M x N matrix
int M;
int N;
int K;
SimpleGemmKernelArgs(const float* a, const float* b, float* c,
int m, int n, int k)
: a_ptr(a), b_ptr(b), c_ptr(c), M(m), N(n), K(k) {}
};

// ============================================
// PART 2: The Kernel (One thread per output)
// ============================================
struct SimpleGemmKernel {
static dim3 GridSize(const SimpleGemmKernelArgs& args) {
// 16x16 threads per block
int grid_m = (args.M + 15) / 16;
int grid_n = (args.N + 15) / 16;
return dim3(grid_n, grid_m, 1); // Note: x=N, y=M
}
static dim3 BlockSize() {
return dim3(16, 16, 1); // 16x16 = 256 threads
}
__device__ void operator()(const SimpleGemmKernelArgs& args) const {
// Each thread computes one element of C
int col = blockIdx.x * blockDim.x + threadIdx.x; // N dimension
int row = blockIdx.y * blockDim.y + threadIdx.y; // M dimension
// Bounds check
if (row >= args.M || col >= args.N) return;
// Compute dot product for C[row][col]
float sum = 0.0f;
for (int k = 0; k < args.K; k++) {
// A is row-major: A[row][k] = A[row * K + k]
// B is column-major: B[k][col] = B[k + col * K]
float a_val = args.a_ptr[row * args.K + k];
float b_val = args.b_ptr[k + col * args.K];
sum += a_val * b_val;
}
// Store result (C is row-major)
args.c_ptr[row * args.N + col] = sum;
}
};

// ============================================
// PART 3: Host Code
// ============================================
__global__ void simple_gemm_kernel(SimpleGemmKernelArgs args) {
SimpleGemmKernel kernel;
kernel(args);
}

void run_simple_gemm(int M, int N, int K) {
std::cout << "\n=== Step 2: Simple GEMM ===\n";
std::cout << "Matrix multiply: (" << M << "x" << K << ") * ("
<< K << "x" << N << ") = (" << M << "x" << N << ")\n";
// Allocate host memory
std::vector<float> h_a(M * K);
std::vector<float> h_b(K * N);
std::vector<float> h_c(M * N, 0.0f);
// Initialize with simple values
for (int i = 0; i < M * K; i++) h_a[i] = 1.0f;
for (int i = 0; i < K * N; i++) h_b[i] = 2.0f;
// Allocate device memory
float *d_a, *d_b, *d_c;
hipMalloc(&d_a, M * K * sizeof(float));
hipMalloc(&d_b, K * N * sizeof(float));
hipMalloc(&d_c, M * N * sizeof(float));
// Copy to device
hipMemcpy(d_a, h_a.data(), M * K * sizeof(float), hipMemcpyHostToDevice);
hipMemcpy(d_b, h_b.data(), K * N * sizeof(float), hipMemcpyHostToDevice);
// Create kernel arguments
SimpleGemmKernelArgs args(d_a, d_b, d_c, M, N, K);
// Get launch configuration
dim3 grid = SimpleGemmKernel::GridSize(args);
dim3 block = SimpleGemmKernel::BlockSize();
std::cout << "Launching with grid(" << grid.x << "," << grid.y
<< "), block(" << block.x << "," << block.y << ")\n";
// Launch kernel
hipEvent_t start, stop;
hipEventCreate(&start);
hipEventCreate(&stop);
hipEventRecord(start);
simple_gemm_kernel<<<grid, block>>>(args);
hipEventRecord(stop);
hipEventSynchronize(stop);
float milliseconds = 0;
hipEventElapsedTime(&milliseconds, start, stop);
// Copy result back
hipMemcpy(h_c.data(), d_c, M * N * sizeof(float), hipMemcpyDeviceToHost);
// Verify (each element should be K * 1.0 * 2.0 = 2K)
float expected = 2.0f * K;
bool correct = true;
for (int i = 0; i < std::min(10, M*N); i++) {
if (h_c[i] != expected) {
correct = false;
break;
}
}
std::cout << "Result: " << (correct ? "CORRECT" : "WRONG") << "\n";
std::cout << "Time: " << milliseconds << " ms\n";
// Calculate FLOPS
double flops = 2.0 * M * N * K; // 2 ops per multiply-add
double gflops = (flops / milliseconds) / 1e6;
std::cout << "Performance: " << gflops << " GFLOPS\n";
// Cleanup
hipFree(d_a);
hipFree(d_b);
hipFree(d_c);
hipEventDestroy(start);
hipEventDestroy(stop);
}

// ============================================
// VISUALIZATION: How threads map to output
// ============================================
void visualize_thread_mapping() {
std::cout << "\n=== Thread Mapping Visualization ===\n";
std::cout << "Each thread computes one C[i][j]:\n\n";
std::cout << " Block(0,0) Block(1,0)\n";
std::cout << " ┌─────────┐ ┌─────────┐\n";
std::cout << " │T00 T01..│ │T00 T01..│\n";
std::cout << " │T10 T11..│ │T10 T11..│\n";
std::cout << " │... ... ..│ │... ... ..│\n";
std::cout << " └─────────┘ └─────────┘\n";
std::cout << " ↓ ↓\n";
std::cout << " C[0:16,0:16] C[0:16,16:32]\n\n";
std::cout << "Each thread's work:\n";
std::cout << " for k in 0..K:\n";
std::cout << " sum += A[row][k] * B[k][col]\n";
std::cout << " C[row][col] = sum\n";
}

// ============================================
// PART 4: Main
// ============================================
int main() {
std::cout << "MareArts CK Tile Tutorial - Step 2: Simple GEMM\n";
std::cout << "======================================\n";
visualize_thread_mapping();
// Run with different sizes
run_simple_gemm(64, 64, 64);
run_simple_gemm(128, 128, 128);
std::cout << "\nKey Concepts Added:\n";
std::cout << "1. 2D grid/block configuration\n";
std::cout << "2. Each thread computes one output element\n";
std::cout << "3. Row-major vs column-major layouts\n";
std::cout << "4. Performance measurement (GFLOPS)\n";
std::cout << "\nProblem: Each thread reads K elements from A and B\n";
std::cout << " → Poor memory reuse!\n";
std::cout << "Next: Add tiling and shared memory for efficiency\n";
return 0;
}
..

🙇🏻‍♂️
MareArts

1/17/2025

HipBlasLT type definition explanation


1. About Output Types (D):

No, the output D is not limited to fp32/int32. Looking at the table, D can be:

- fp32

- fp16

- bf16

- fp8

- bf8

- int8


2. Input/Output Patterns:

When A is fp16, you have two options:

```

Option 1:

A: fp16 → B: fp16 → C: fp16 → D: fp16 → Compute: fp32


Option 2:

A: fp16 → B: fp16 → C: fp16 → D: fp32 → Compute: fp32

```


The compute/scale is always higher precision (fp32 or int32) to maintain accuracy during calculations, even if inputs/outputs are lower precision.


3. Key Patterns in the Table:

- Inputs A and B must always match in type

- C typically matches A and B, except with fp8/bf8 inputs

- When using fp8/bf8 inputs, C and D can be higher precision (fp32, fp16, or bf16)

- The compute precision is always fp32 for floating point types

- For integer operations (int8), the compute precision is int32


4. Why Different Combinations?

- Performance: Lower precision (fp16, fp8) = faster computation + less memory

- Accuracy: Higher precision (fp32) = better accuracy but slower

- Memory Usage: fp16/fp8 use less memory than fp32

- Mixed Precision: Use lower precision for inputs but higher precision for output to balance speed and accuracy


Example Use Cases:

```

High Accuracy Needs:

A(fp32) → B(fp32) → C(fp32) → D(fp32) → Compute(fp32)


Balanced Performance:

A(fp16) → B(fp16) → C(fp16) → D(fp32) → Compute(fp32)


Maximum Performance:

A(fp8) → B(fp8) → C(fp8) → D(fp8) → Compute(fp32)

```


1/15/2025

GEMM, Triton and hipBLASlt and Transformer engine concept


1. GEMM (General Matrix Multiplication):

- This is the basic operation: C = A × B (matrix multiplication)

- Fundamental operation in deep learning, especially transformers

- Core computation in attention mechanisms, linear layers, etc.


2. Triton:

- A programming language for writing GPU kernels

- Lets you write your own custom GEMM implementation

- You control memory layout, tiling, etc.

- Example use: When you need a very specific matrix operation


3. hipBLASLt:

- A specialized library just for matrix operations

- Pre-built, highly optimized GEMM implementations

- Focuses on performance for common matrix sizes

- Example use: When you need fast, standard matrix multiplication


4. Transformer Engine:

- NVIDIA's specialized library for transformer models

- Automatically handles precision switching (FP8/FP16/FP32)

- Optimizes GEMM operations specifically for transformer architectures

- Includes specialized kernels for attention and linear layers

- Example use: When building large language models


The relationship:

```

Transformer Model

    ↓

Transformer Engine

    ↓

GEMM Operations (can be implemented via:)

    ↓

hipBLASLt / Triton / Other libraries

    ↓

GPU Hardware

```


the same matrix multiplication would be implemented using different approaches:


1. Basic GEMM Operation (what we want to compute):

```python

# C = A × B

# Where A is (M×K) and B is (K×N)

```


2. Using Triton (Custom implementation):

```python

@triton.jit

def matmul_kernel(

    a_ptr, b_ptr, c_ptr,    # Pointers to matrices

    M, N, K,                # Matrix dimensions

    stride_am, stride_ak,   # Memory strides for A

    stride_bk, stride_bn,   # Memory strides for B

    stride_cm, stride_cn,   # Memory strides for C

    BLOCK_SIZE: tl.constexpr,

):

    # Get program ID

    pid = tl.program_id(0)

    # Calculate block indices

    block_i = pid // (N // BLOCK_SIZE)

    block_j = pid % (N // BLOCK_SIZE)

    # Load blocks from A and B

    a = tl.load(a_ptr + ...)  # Load block from A

    b = tl.load(b_ptr + ...)  # Load block from B

    # Compute block multiplication

    c = tl.dot(a, b)          # Matrix multiply

    # Store result

    tl.store(c_ptr + ..., c)

```


3. Using hipBLASLt:

```cpp

// Initialize hipBLASLt

hipblasLtHandle_t handle;

hipblasLtCreate(&handle);


// Define matrix layout

hipblasLtMatrixLayout_t matA, matB, matC;

hipblasLtMatrixLayoutCreate(&matA, HIPBLAS_LT_R_16F, M, K, M);

hipblasLtMatrixLayoutCreate(&matB, HIPBLAS_LT_R_16F, K, N, K);

hipblasLtMatrixLayoutCreate(&matC, HIPBLAS_LT_R_16F, M, N, M);


// Execute GEMM

hipblasLtMatmul(

    handle,

    matmulDesc,

    &alpha,        // Scale factor

    A, matA,       // Input matrix A

    B, matB,       // Input matrix B

    &beta,         // Scale factor

    C, matC,       // Output matrix C

    workspace,     // Temporary workspace

    streams        // CUDA stream

);

```


4. Using Transformer Engine:

```python

import transformer_engine.pytorch as te


# Create TE layers

linear = te.Linear(in_features, out_features)


# Automatic precision handling

with te.fp8_autocast():

    output = linear(input)  # Internally uses optimized GEMM

```


Key differences:

1. Triton: You control everything (memory, blocks, compute)

2. hipBLASLt: Pre-optimized, you just call it

3. Transformer Engine: High-level, handles precision automatically


Performance comparison (general case):

```

Speed: hipBLASLt > Transformer Engine > Custom Triton

Flexibility: Triton > hipBLASLt > Transformer Engine

Ease of use: Transformer Engine > hipBLASLt > Triton

```