Skip to content

[ROCm] Enable blocksize 32 4-bit quantization and GEMV kernels on AMD CDNA#1887

Open
sstamenk wants to merge 7 commits intobitsandbytes-foundation:mainfrom
sstamenk:fix/rocm-cdna-blocksize-32-64
Open

[ROCm] Enable blocksize 32 4-bit quantization and GEMV kernels on AMD CDNA#1887
sstamenk wants to merge 7 commits intobitsandbytes-foundation:mainfrom
sstamenk:fix/rocm-cdna-blocksize-32-64

Conversation

@sstamenk
Copy link
Contributor

@sstamenk sstamenk commented Mar 4, 2026

Summary

  • Parameterize kQuantizeBlockwiseSmall on quantization block size (QBLOCK_SIZE), decoupling it from BNB_WARP_SIZE. This allows blocksize=32 and blocksize=64 4-bit quantization to work on both CDNA (warp=64) and RDNA (warp=32) by packing multiple quantization blocks per wavefront.
  • Fix hardcoded warp size of 32 in the GEMV 4-bit inference kernel (kgemm_4bit_inference_naive), replacing literal 32 with BNB_WARP_SIZE for correct warp lane indexing and reduction on CDNA.
  • Remove all ROCM_WARP_SIZE_64 runtime guards from Python blocksize checks and test parameterization, now that the kernels handle both warp sizes correctly.

Problem

On CDNA GPUs (gfx9xx, warp size 64), blocksize=32 was impossible and blocksize=64 required a special code path because kQuantizeBlockwiseSmall hardcoded BLOCK_SIZE = BNB_WARP_SIZE. The GEMV kernel also used literal 32 for warp lane math, producing wrong results on warp-64 hardware. As a result, tests for blocksize=32/64 and GEMV were skipped on CDNA via ROCM_WARP_SIZE_64 guards.

Changes

csrc/kernels.cu / csrc/kernels.cuh - Reworked kQuantizeBlockwiseSmall:

  • Added QBLOCK_SIZE template parameter (was implicitly BNB_WARP_SIZE)
  • Kernel packs BNB_WARP_SIZE / (QBLOCK_SIZE/2) quant blocks per wavefront
  • CDNA blocksize=32: 4 quant blocks per wavefront (64 threads, 16 per block)
  • CDNA blocksize=64: 2 quant blocks per wavefront (64 threads, 32 per block)
  • CUDA/RDNA blocksize=32: 2 quant blocks per warp (32 threads, 16 per block)
  • Fixed hardcoded 32BNB_WARP_SIZE in kgemm_4bit_inference_naive for warp lane/reduction math

csrc/ops.cu - Updated dispatch to use runtime warp size for grid/block calculation for both blocksize=32 and blocksize=64 on HIP.

bitsandbytes/backends/cuda/ops.py - Removed ROCM_WARP_SIZE_64 conditional blocksize checks; blocksize=32 is now valid on all platforms.

bitsandbytes/cextension.py - Removed ROCM_WARP_SIZE_64 and get_rocm_warpsize import (no longer needed).

tests/ - Removed ROCM_WARP_SIZE_64 skip guards from test_functional.py, test_ops.py, test_linear4bit.py, test_linear8bitlt.py, and test_parametrize.py.

Depends on

Test plan

  • RDNA: verify unit test coverage
    16 failed, 2777 passed, 16 skipped, 37 deselected, 32 xfailed, 4095 warnings in 191.94s (0:03:11)

  • CDNA: verify unit test coverage, including newly enabled tests
    4 failed, 2789 passed, 17 skipped, 29 deselected, 32 xfailed, 110 warnings in 216.00s (0:03:35)

    • The following tests failed due to precision issues:
tests/test_functional.py::TestQuantize4BitFunctional::test_gemv_4bit[dim=1024-fp32-fc2-nf4-DQ_True-cuda] FAILED
tests/test_functional.py::TestQuantize4BitFunctional::test_gemv_4bit[dim=1024-fp32-fc2-nf4-DQ_False-cuda] FAILED
tests/test_functional.py::TestQuantize4BitFunctional::test_gemv_4bit[dim=1024-fp32-fc2-fp4-DQ_True-cuda] FAILED
tests/test_functional.py::TestQuantize4BitFunctional::test_gemv_4bit[dim=1024-fp32-fc2-fp4-DQ_False-cuda] FAILED

@github-actions
Copy link

github-actions bot commented Mar 4, 2026

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sstamenk sstamenk changed the title Fix/rocm cdna blocksize 32 64 Fix blocksize-32/64 4-bit quantization and GEMV kernels on ROCm (CDNA + RDNA) Mar 4, 2026
@sstamenk
Copy link
Contributor Author

sstamenk commented Mar 4, 2026

2 additional comments on this:

sstamenk added 5 commits March 4, 2026 18:11
Add bnb_host_warp_size() that queries hipDeviceGetAttribute at runtime
with per-device caching (up to 32 GPUs), replacing the compile-time
BNB_WARP_SIZE macro in host-side dispatch. This fixes incorrect
defaulting to warp size 64 on RDNA and kernel dispatch with
proper parameters.
Remove ROCM_WARP_SIZE_64 guards from all test files now that
blocksize-32/64 quantization and GEMV kernels work on warp-64 hardware.
@sstamenk sstamenk force-pushed the fix/rocm-cdna-blocksize-32-64 branch from a1db771 to 755d1af Compare March 4, 2026 17:15
Copy link
Member

@Abdennacer-Badaoui Abdennacer-Badaoui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes look good, but I’d like to test them on CUDA to confirm everything works as expected.
Could you also provide the failure numbers or details for the GEMV tests?

@@ -34,24 +51,22 @@ void quantizeBlockwise(
else if (blocksize == 128)
kQuantizeBlockwise<T, 128, 2, 0, DATA_TYPE><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
else if (blocksize == 64) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would mean that for (CUDA , bs = 64, 4-bit dtypes), we would use the kQuantizeBlockwiseSmall, while it should use directly kQuantizeBlockwise.
IMO, on CUDA, blocksize=64 should still use the regular kQuantizeBlockwise kernel. Wdyt @matthewdouglas

Copy link
Contributor Author

@sstamenk sstamenk Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I did this intentionally to minimize the kernel selection logic since I believe there shouldn't be much of a performance impact between the 2 but I haven't tested this explicitly. Alternatively, for bs = 64 we can use kQuantizeBlockwiseSmall for CDNA (ws = 64) and the standard kQuantizeBlockwise on CUDA/RDNA (ws = 32).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, me too! Your way of writing it is definitely better for code readiness. But I’m not sure how much impact there’d be going from kQuantizeBlockwise to kQuantizeBlockwiseSmall. I think it’s negligible, but I’m not entirely sure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do some benchmarking/profiling for the overhead between kQuantizeBlockwise and kQuantizeBlockwiseSmall for RDNA and CUDA with batch_size == 64 to confirm.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, can you do that please ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I asked Claude for a benchmarking script, and here are the results on A100

Device: NVIDIA A100-SXM4-80GB (sm_80)
CUDA: 12.6, PyTorch: 2.8.0+cu126
Warmup: 50, Iterations: 200, Blocksize: 64

Capture d’écran 2026-03-05 à 5 14 47 PM

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are quite the same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthewdouglas It's basically does 200 iterations of quantize_4bit and captures the time taken with torch.cuda.Event

"""
Quick A/B benchmark for blocksize=64 quantization kernels.

    # 1. Build baseline branch
    git checkout fix/rocm-runtime-warp-size
    cmake -S . -B build -DCOMPUTE_BACKEND=hip -DBNB_ROCM_ARCH=gfx1201
    cmake --build build -j$(nproc) && pip install -e .

    # 2. Run benchmark, results saved to std_results.json
    python benchmarks/bench_quick.py --save std_results.json

    # 3. Build candidate branch
    git checkout fix/rocm-cdna-blocksize-32-64
    cmake -S . -B build -DCOMPUTE_BACKEND=hip -DBNB_ROCM_ARCH=gfx1201
    cmake --build build -j$(nproc) && pip install -e .

    # 4. Run benchmark, compare against baseline
    python benchmarks/bench_quick.py --save small_results.json --compare std_results.json
"""

import argparse
import json
import sys

import torch
from bitsandbytes.functional import quantize_4bit

SIZES = [1_048_576, 4_194_304, 16_777_216, 67_108_864]
QUANT_TYPES = ["nf4", "fp4"]
DTYPES = [torch.float16, torch.bfloat16, torch.float32]
DTYPE_NAMES = {torch.float16: "f16", torch.bfloat16: "bf16", torch.float32: "f32"}
BLOCKSIZE = 64
WARMUP = 50
ITERS = 200


def bench_one(n, dtype, quant_type, device):
    A = torch.randn(n, dtype=dtype, device=device)

    for _ in range(WARMUP):
        quantize_4bit(A, blocksize=BLOCKSIZE, quant_type=quant_type)
    torch.cuda.synchronize(device)

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record(stream=torch.cuda.current_stream(device))
    for _ in range(ITERS):
        quantize_4bit(A, blocksize=BLOCKSIZE, quant_type=quant_type)
    end.record(stream=torch.cuda.current_stream(device))
    torch.cuda.synchronize(device)

    time_us = start.elapsed_time(end) / ITERS * 1000

    num_absmax = (n + BLOCKSIZE - 1) // BLOCKSIZE
    total_bytes = n * A.element_size() + n // 2 + num_absmax * 4
    bw = (total_bytes / 1e9) / (time_us / 1e6)

    return time_us, bw


def print_header(device):
    props = torch.cuda.get_device_properties(device)
    arch = getattr(props, "gcnArchName", "N/A")
    ws = getattr(props, "warp_size", 32)
    print(f"Device: {props.name} ({arch}, warp_size={ws})")
    print(f"Config: blocksize={BLOCKSIZE}, warmup={WARMUP}, iters={ITERS}")
    print()


def run_all(device):
    results = {}
    total = len(QUANT_TYPES) * len(DTYPES) * len(SIZES)
    done = 0

    for qt in QUANT_TYPES:
        for dt in DTYPES:
            for n in SIZES:
                done += 1
                tag = f"{qt}_{DTYPE_NAMES[dt]}_{n}"
                time_us, bw = bench_one(n, dt, qt, device)
                results[tag] = {"time_us": round(time_us, 2), "bw_gbs": round(bw, 1)}
                print(f"  [{done}/{total}] {qt} {DTYPE_NAMES[dt]:>3s} N={n:>12,d}  {time_us:8.1f} us  {bw:7.1f} GB/s")

    return results


def print_comparison(base, cand):
    print()
    hdr = f"{'quant':>5s} {'dtype':>5s} {'N':>12s} | {'Base (us)':>10s} {'Cand (us)':>10s} {'Speedup':>8s} | {'Base GB/s':>10s} {'Cand GB/s':>10s}"
    print(hdr)
    print("-" * len(hdr))

    for qt in QUANT_TYPES:
        for dt in DTYPES:
            for n in SIZES:
                tag = f"{qt}_{DTYPE_NAMES[dt]}_{n}"
                b = base.get(tag)
                c = cand.get(tag)
                if not b or not c:
                    continue
                speedup = b["time_us"] / c["time_us"]
                marker = "**" if speedup >= 1.10 else ""
                print(
                    f"{qt:>5s} {DTYPE_NAMES[dt]:>5s} {n:>12,d} | "
                    f"{b['time_us']:>10.1f} {c['time_us']:>10.1f} {speedup:>7.2f}x{marker} | "
                    f"{b['bw_gbs']:>10.1f} {c['bw_gbs']:>10.1f}"
                )
        print()


def main():
    parser = argparse.ArgumentParser(description="Quick A/B benchmark for blocksize=64 quantization")
    parser.add_argument("--device", type=int, default=0, help="GPU device index (default 0; use HIP_VISIBLE_DEVICES to isolate)")
    parser.add_argument("--save", type=str, help="Save results to JSON file")
    parser.add_argument("--compare", type=str, help="Compare against a previously saved JSON (baseline)")
    args = parser.parse_args()

    device = torch.device(f"cuda:{args.device}")
    torch.cuda.set_device(device)
    print_header(device)

    print("Running benchmarks...")
    results = run_all(device)

    if args.save:
        with open(args.save, "w") as f:
            json.dump(results, f, indent=2)
        print(f"\nResults saved to {args.save}")

    if args.compare:
        with open(args.compare) as f:
            baseline = json.load(f)
        print_comparison(baseline, results)


if __name__ == "__main__":
    main()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Abdennacer-Badaoui Same here, on an RTX 5090 the results are within margin of error:

quant dtype            N |  Base (us)  Cand (us)  Speedup |  Base GB/s  Cand GB/s
---------------------------------------------------------------------------------
  nf4   f16    1,048,576 |       32.5       32.4    1.00x |       82.8       83.0
  nf4   f16    4,194,304 |       55.0       54.9    1.00x |      195.6      195.7
  nf4   f16   16,777,216 |      148.7      148.5    1.00x |      289.1      289.5
  nf4   f16   67,108,864 |      522.4      522.0    1.00x |      329.2      329.4
  nf4  bf16    1,048,576 |       32.3       32.1    1.01x |       83.1       83.6
  nf4  bf16    4,194,304 |       55.0       54.8    1.00x |      195.3      196.1
  nf4  bf16   16,777,216 |      149.0      148.5    1.00x |      288.5      289.4
  nf4  bf16   67,108,864 |      522.9      521.5    1.00x |      328.8      329.7
  nf4   f32    1,048,576 |       32.3       32.2    1.00x |      148.2      148.4
  nf4   f32    4,194,304 |       54.9       54.8    1.00x |      348.7      348.9
  nf4   f32   16,777,216 |      149.0      148.9    1.00x |      513.9      514.1
  nf4   f32   67,108,864 |      523.6      523.1    1.00x |      584.8      585.3

  fp4   f16    1,048,576 |       32.3       32.3    1.00x |       83.2       83.3
  fp4   f16    4,194,304 |       54.9       54.8    1.00x |      195.7      196.0
  fp4   f16   16,777,216 |      148.9      148.8    1.00x |      288.8      288.8
  fp4   f16   67,108,864 |      522.9      522.8    1.00x |      328.9      328.9
  fp4  bf16    1,048,576 |       32.1       32.1    1.00x |       83.6       83.8
  fp4  bf16    4,194,304 |       55.0       54.7    1.01x |      195.4      196.4
  fp4  bf16   16,777,216 |      148.8      149.0    1.00x |      288.9      288.5
  fp4  bf16   67,108,864 |      522.8      522.6    1.00x |      328.9      329.1
  fp4   f32    1,048,576 |       32.2       32.1    1.00x |      148.6      149.1
  fp4   f32    4,194,304 |       54.9       56.4    0.97x |      348.7      339.3
  fp4   f32   16,777,216 |      149.0      148.9    1.00x |      513.9      514.1
  fp4   f32   67,108,864 |      523.1      524.2    1.00x |      585.3      584.1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthewdouglas Data with larger N:

Quantize

quant dtype N Time Base (us) Time Cand (us) BW Base (GB/s) BW Cand (GB/s) Speedup
nf4 f16 268,435,456 3034.3 2280.2 226.7 301.7 1.33x
nf4 f16 536,870,912 6015.5 4511.3 228.7 305.0 1.33x
nf4 f16 1,073,741,824 12055.1 8995.6 228.2 305.9 1.34x
nf4 bf16 268,435,456 3007.6 2290.6 228.7 300.3 1.31x
nf4 bf16 536,870,912 5999.8 4534.6 229.3 303.4 1.32x
nf4 bf16 1,073,741,824 11999.4 9037.4 229.3 304.5 1.33x
nf4 f32 268,435,456 3155.2 2410.3 388.2 508.1 1.31x
nf4 f32 536,870,912 6273.9 4788.1 390.4 511.6 1.31x
nf4 f32 1,073,741,824 12502.6 9478.2 391.8 516.9 1.32x
fp4 f16 268,435,456 2855.7 2135.2 240.9 322.2 1.34x
fp4 f16 536,870,912 5664.7 4230.4 242.9 325.2 1.34x
fp4 f16 1,073,741,824 11328.2 8413.6 242.9 327.0 1.35x
fp4 bf16 268,435,456 2833.5 2137.4 242.8 321.8 1.33x
fp4 bf16 536,870,912 5622.8 4237.9 244.7 324.6 1.33x
fp4 bf16 1,073,741,824 11252.3 8431.7 244.5 326.3 1.33x
fp4 f32 268,435,456 2978.5 2279.3 411.2 537.3 1.31x
fp4 f32 536,870,912 5934.0 4515.9 412.8 542.4 1.31x
fp4 f32 1,073,741,824 11817.6 8945.3 414.5 547.7 1.32x

Dequantize (same kernel on both branches)

quant dtype N Time Base (us) Time Cand (us) BW Base (GB/s) BW Cand (GB/s) Speedup
nf4 f16 268,435,456 1146.1 1145.7 600.2 600.4 1.00x
nf4 f16 536,870,912 2301.8 2300.3 597.7 598.1 1.00x
nf4 f16 1,073,741,824 4597.3 4598.4 598.5 598.4 1.00x
nf4 bf16 268,435,456 1147.4 1147.6 599.5 599.4 1.00x
nf4 bf16 536,870,912 2283.5 2283.9 602.5 602.3 1.00x
nf4 bf16 1,073,741,824 4551.5 4552.6 604.5 604.4 1.00x
nf4 f32 268,435,456 2070.8 2070.9 591.4 591.4 1.00x
nf4 f32 536,870,912 4134.7 4134.0 592.4 592.5 1.00x
nf4 f32 1,073,741,824 8265.9 8265.8 592.7 592.7 1.00x
fp4 f16 268,435,456 1146.4 1146.6 600.0 599.9 1.00x
fp4 f16 536,870,912 2298.0 2295.0 598.7 599.4 1.00x
fp4 f16 1,073,741,824 4588.7 4587.6 599.6 599.8 1.00x
fp4 bf16 268,435,456 1151.1 1151.9 597.6 597.2 1.00x
fp4 bf16 536,870,912 2290.7 2292.0 600.6 600.2 1.00x
fp4 bf16 1,073,741,824 4567.6 4569.5 602.4 602.1 1.00x
fp4 f32 268,435,456 2070.5 2070.6 591.5 591.5 1.00x
fp4 f32 536,870,912 4134.4 4137.3 592.5 592.0 1.00x
fp4 f32 1,073,741,824 8267.7 8265.2 592.5 592.7 1.00x

@sstamenk
Copy link
Contributor Author

sstamenk commented Mar 4, 2026

The changes look good, but I’d like to test them on CUDA to confirm everything works as expected. Could you also provide the failure numbers or details for the GEMV tests?

Here is the full log of the run cdna_bnb_unit_tests.log

int grid = (num_blocks + num_qb - 1) / num_qb;
kQuantizeBlockwiseSmall<T, 64, DATA_TYPE><<<grid, ws>>>(code, A, absmax, out, rand, rand_offset, n);
} else {
kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Observation here that kQuantizeBlockwise in combination with General8bit data type only utilizes half of the warp since it launches 32 threads on CDNA which has 64.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can discuss this further in another PR. For now, let’s focus on merging this one.

@Abdennacer-Badaoui
Copy link
Member

I tested your PR on a cuda device (A100), all tests are passing, including the GEMV tests. They are also all passing on MI325X. On which CDNA GPU did you encounter the issue?
The precision issues seem to be specific to certain AMD hardware. Could you investigate this on your end?
Otherwise, for the problematic devices, we could either relax the thresholds or skip these tests. What do you think, @matthewdouglas @sstamenk ?

@sstamenk
Copy link
Contributor Author

sstamenk commented Mar 5, 2026

@Abdennacer-Badaoui I tested this on MI308X (gfx942) with TheRock nightly ROCm 7.12 build. If it's device dependent, I would leave it as is for now and then address it in a follow up when I get the chance to test it on more devices.

@sstamenk sstamenk changed the title Fix blocksize-32/64 4-bit quantization and GEMV kernels on ROCm (CDNA + RDNA) [ROCm] Enable blocksize 32 4-bit quantization and GEMV kernels on AMD CDNA Mar 5, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants