[ROCm] Enable blocksize 32 4-bit quantization and GEMV kernels on AMD CDNA#1887
[ROCm] Enable blocksize 32 4-bit quantization and GEMV kernels on AMD CDNA#1887sstamenk wants to merge 7 commits intobitsandbytes-foundation:mainfrom
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
2 additional comments on this:
|
Add bnb_host_warp_size() that queries hipDeviceGetAttribute at runtime with per-device caching (up to 32 GPUs), replacing the compile-time BNB_WARP_SIZE macro in host-side dispatch. This fixes incorrect defaulting to warp size 64 on RDNA and kernel dispatch with proper parameters.
Remove ROCM_WARP_SIZE_64 guards from all test files now that blocksize-32/64 quantization and GEMV kernels work on warp-64 hardware.
a1db771 to
755d1af
Compare
Abdennacer-Badaoui
left a comment
There was a problem hiding this comment.
The changes look good, but I’d like to test them on CUDA to confirm everything works as expected.
Could you also provide the failure numbers or details for the GEMV tests?
| @@ -34,24 +51,22 @@ void quantizeBlockwise( | |||
| else if (blocksize == 128) | |||
| kQuantizeBlockwise<T, 128, 2, 0, DATA_TYPE><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n); | |||
| else if (blocksize == 64) { | |||
There was a problem hiding this comment.
This would mean that for (CUDA , bs = 64, 4-bit dtypes), we would use the kQuantizeBlockwiseSmall, while it should use directly kQuantizeBlockwise.
IMO, on CUDA, blocksize=64 should still use the regular kQuantizeBlockwise kernel. Wdyt @matthewdouglas
There was a problem hiding this comment.
Yes, I did this intentionally to minimize the kernel selection logic since I believe there shouldn't be much of a performance impact between the 2 but I haven't tested this explicitly. Alternatively, for bs = 64 we can use kQuantizeBlockwiseSmall for CDNA (ws = 64) and the standard kQuantizeBlockwise on CUDA/RDNA (ws = 32).
There was a problem hiding this comment.
Yeah, me too! Your way of writing it is definitely better for code readiness. But I’m not sure how much impact there’d be going from kQuantizeBlockwise to kQuantizeBlockwiseSmall. I think it’s negligible, but I’m not entirely sure.
There was a problem hiding this comment.
We can do some benchmarking/profiling for the overhead between kQuantizeBlockwise and kQuantizeBlockwiseSmall for RDNA and CUDA with batch_size == 64 to confirm.
There was a problem hiding this comment.
Yes, can you do that please ?
There was a problem hiding this comment.
They are quite the same here
There was a problem hiding this comment.
@matthewdouglas It's basically does 200 iterations of quantize_4bit and captures the time taken with torch.cuda.Event
"""
Quick A/B benchmark for blocksize=64 quantization kernels.
# 1. Build baseline branch
git checkout fix/rocm-runtime-warp-size
cmake -S . -B build -DCOMPUTE_BACKEND=hip -DBNB_ROCM_ARCH=gfx1201
cmake --build build -j$(nproc) && pip install -e .
# 2. Run benchmark, results saved to std_results.json
python benchmarks/bench_quick.py --save std_results.json
# 3. Build candidate branch
git checkout fix/rocm-cdna-blocksize-32-64
cmake -S . -B build -DCOMPUTE_BACKEND=hip -DBNB_ROCM_ARCH=gfx1201
cmake --build build -j$(nproc) && pip install -e .
# 4. Run benchmark, compare against baseline
python benchmarks/bench_quick.py --save small_results.json --compare std_results.json
"""
import argparse
import json
import sys
import torch
from bitsandbytes.functional import quantize_4bit
SIZES = [1_048_576, 4_194_304, 16_777_216, 67_108_864]
QUANT_TYPES = ["nf4", "fp4"]
DTYPES = [torch.float16, torch.bfloat16, torch.float32]
DTYPE_NAMES = {torch.float16: "f16", torch.bfloat16: "bf16", torch.float32: "f32"}
BLOCKSIZE = 64
WARMUP = 50
ITERS = 200
def bench_one(n, dtype, quant_type, device):
A = torch.randn(n, dtype=dtype, device=device)
for _ in range(WARMUP):
quantize_4bit(A, blocksize=BLOCKSIZE, quant_type=quant_type)
torch.cuda.synchronize(device)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record(stream=torch.cuda.current_stream(device))
for _ in range(ITERS):
quantize_4bit(A, blocksize=BLOCKSIZE, quant_type=quant_type)
end.record(stream=torch.cuda.current_stream(device))
torch.cuda.synchronize(device)
time_us = start.elapsed_time(end) / ITERS * 1000
num_absmax = (n + BLOCKSIZE - 1) // BLOCKSIZE
total_bytes = n * A.element_size() + n // 2 + num_absmax * 4
bw = (total_bytes / 1e9) / (time_us / 1e6)
return time_us, bw
def print_header(device):
props = torch.cuda.get_device_properties(device)
arch = getattr(props, "gcnArchName", "N/A")
ws = getattr(props, "warp_size", 32)
print(f"Device: {props.name} ({arch}, warp_size={ws})")
print(f"Config: blocksize={BLOCKSIZE}, warmup={WARMUP}, iters={ITERS}")
print()
def run_all(device):
results = {}
total = len(QUANT_TYPES) * len(DTYPES) * len(SIZES)
done = 0
for qt in QUANT_TYPES:
for dt in DTYPES:
for n in SIZES:
done += 1
tag = f"{qt}_{DTYPE_NAMES[dt]}_{n}"
time_us, bw = bench_one(n, dt, qt, device)
results[tag] = {"time_us": round(time_us, 2), "bw_gbs": round(bw, 1)}
print(f" [{done}/{total}] {qt} {DTYPE_NAMES[dt]:>3s} N={n:>12,d} {time_us:8.1f} us {bw:7.1f} GB/s")
return results
def print_comparison(base, cand):
print()
hdr = f"{'quant':>5s} {'dtype':>5s} {'N':>12s} | {'Base (us)':>10s} {'Cand (us)':>10s} {'Speedup':>8s} | {'Base GB/s':>10s} {'Cand GB/s':>10s}"
print(hdr)
print("-" * len(hdr))
for qt in QUANT_TYPES:
for dt in DTYPES:
for n in SIZES:
tag = f"{qt}_{DTYPE_NAMES[dt]}_{n}"
b = base.get(tag)
c = cand.get(tag)
if not b or not c:
continue
speedup = b["time_us"] / c["time_us"]
marker = "**" if speedup >= 1.10 else ""
print(
f"{qt:>5s} {DTYPE_NAMES[dt]:>5s} {n:>12,d} | "
f"{b['time_us']:>10.1f} {c['time_us']:>10.1f} {speedup:>7.2f}x{marker} | "
f"{b['bw_gbs']:>10.1f} {c['bw_gbs']:>10.1f}"
)
print()
def main():
parser = argparse.ArgumentParser(description="Quick A/B benchmark for blocksize=64 quantization")
parser.add_argument("--device", type=int, default=0, help="GPU device index (default 0; use HIP_VISIBLE_DEVICES to isolate)")
parser.add_argument("--save", type=str, help="Save results to JSON file")
parser.add_argument("--compare", type=str, help="Compare against a previously saved JSON (baseline)")
args = parser.parse_args()
device = torch.device(f"cuda:{args.device}")
torch.cuda.set_device(device)
print_header(device)
print("Running benchmarks...")
results = run_all(device)
if args.save:
with open(args.save, "w") as f:
json.dump(results, f, indent=2)
print(f"\nResults saved to {args.save}")
if args.compare:
with open(args.compare) as f:
baseline = json.load(f)
print_comparison(baseline, results)
if __name__ == "__main__":
main()
There was a problem hiding this comment.
@Abdennacer-Badaoui Same here, on an RTX 5090 the results are within margin of error:
quant dtype N | Base (us) Cand (us) Speedup | Base GB/s Cand GB/s
---------------------------------------------------------------------------------
nf4 f16 1,048,576 | 32.5 32.4 1.00x | 82.8 83.0
nf4 f16 4,194,304 | 55.0 54.9 1.00x | 195.6 195.7
nf4 f16 16,777,216 | 148.7 148.5 1.00x | 289.1 289.5
nf4 f16 67,108,864 | 522.4 522.0 1.00x | 329.2 329.4
nf4 bf16 1,048,576 | 32.3 32.1 1.01x | 83.1 83.6
nf4 bf16 4,194,304 | 55.0 54.8 1.00x | 195.3 196.1
nf4 bf16 16,777,216 | 149.0 148.5 1.00x | 288.5 289.4
nf4 bf16 67,108,864 | 522.9 521.5 1.00x | 328.8 329.7
nf4 f32 1,048,576 | 32.3 32.2 1.00x | 148.2 148.4
nf4 f32 4,194,304 | 54.9 54.8 1.00x | 348.7 348.9
nf4 f32 16,777,216 | 149.0 148.9 1.00x | 513.9 514.1
nf4 f32 67,108,864 | 523.6 523.1 1.00x | 584.8 585.3
fp4 f16 1,048,576 | 32.3 32.3 1.00x | 83.2 83.3
fp4 f16 4,194,304 | 54.9 54.8 1.00x | 195.7 196.0
fp4 f16 16,777,216 | 148.9 148.8 1.00x | 288.8 288.8
fp4 f16 67,108,864 | 522.9 522.8 1.00x | 328.9 328.9
fp4 bf16 1,048,576 | 32.1 32.1 1.00x | 83.6 83.8
fp4 bf16 4,194,304 | 55.0 54.7 1.01x | 195.4 196.4
fp4 bf16 16,777,216 | 148.8 149.0 1.00x | 288.9 288.5
fp4 bf16 67,108,864 | 522.8 522.6 1.00x | 328.9 329.1
fp4 f32 1,048,576 | 32.2 32.1 1.00x | 148.6 149.1
fp4 f32 4,194,304 | 54.9 56.4 0.97x | 348.7 339.3
fp4 f32 16,777,216 | 149.0 148.9 1.00x | 513.9 514.1
fp4 f32 67,108,864 | 523.1 524.2 1.00x | 585.3 584.1
There was a problem hiding this comment.
@matthewdouglas Data with larger N:
Quantize
| quant | dtype | N | Time Base (us) | Time Cand (us) | BW Base (GB/s) | BW Cand (GB/s) | Speedup |
|---|---|---|---|---|---|---|---|
| nf4 | f16 | 268,435,456 | 3034.3 | 2280.2 | 226.7 | 301.7 | 1.33x |
| nf4 | f16 | 536,870,912 | 6015.5 | 4511.3 | 228.7 | 305.0 | 1.33x |
| nf4 | f16 | 1,073,741,824 | 12055.1 | 8995.6 | 228.2 | 305.9 | 1.34x |
| nf4 | bf16 | 268,435,456 | 3007.6 | 2290.6 | 228.7 | 300.3 | 1.31x |
| nf4 | bf16 | 536,870,912 | 5999.8 | 4534.6 | 229.3 | 303.4 | 1.32x |
| nf4 | bf16 | 1,073,741,824 | 11999.4 | 9037.4 | 229.3 | 304.5 | 1.33x |
| nf4 | f32 | 268,435,456 | 3155.2 | 2410.3 | 388.2 | 508.1 | 1.31x |
| nf4 | f32 | 536,870,912 | 6273.9 | 4788.1 | 390.4 | 511.6 | 1.31x |
| nf4 | f32 | 1,073,741,824 | 12502.6 | 9478.2 | 391.8 | 516.9 | 1.32x |
| fp4 | f16 | 268,435,456 | 2855.7 | 2135.2 | 240.9 | 322.2 | 1.34x |
| fp4 | f16 | 536,870,912 | 5664.7 | 4230.4 | 242.9 | 325.2 | 1.34x |
| fp4 | f16 | 1,073,741,824 | 11328.2 | 8413.6 | 242.9 | 327.0 | 1.35x |
| fp4 | bf16 | 268,435,456 | 2833.5 | 2137.4 | 242.8 | 321.8 | 1.33x |
| fp4 | bf16 | 536,870,912 | 5622.8 | 4237.9 | 244.7 | 324.6 | 1.33x |
| fp4 | bf16 | 1,073,741,824 | 11252.3 | 8431.7 | 244.5 | 326.3 | 1.33x |
| fp4 | f32 | 268,435,456 | 2978.5 | 2279.3 | 411.2 | 537.3 | 1.31x |
| fp4 | f32 | 536,870,912 | 5934.0 | 4515.9 | 412.8 | 542.4 | 1.31x |
| fp4 | f32 | 1,073,741,824 | 11817.6 | 8945.3 | 414.5 | 547.7 | 1.32x |
Dequantize (same kernel on both branches)
| quant | dtype | N | Time Base (us) | Time Cand (us) | BW Base (GB/s) | BW Cand (GB/s) | Speedup |
|---|---|---|---|---|---|---|---|
| nf4 | f16 | 268,435,456 | 1146.1 | 1145.7 | 600.2 | 600.4 | 1.00x |
| nf4 | f16 | 536,870,912 | 2301.8 | 2300.3 | 597.7 | 598.1 | 1.00x |
| nf4 | f16 | 1,073,741,824 | 4597.3 | 4598.4 | 598.5 | 598.4 | 1.00x |
| nf4 | bf16 | 268,435,456 | 1147.4 | 1147.6 | 599.5 | 599.4 | 1.00x |
| nf4 | bf16 | 536,870,912 | 2283.5 | 2283.9 | 602.5 | 602.3 | 1.00x |
| nf4 | bf16 | 1,073,741,824 | 4551.5 | 4552.6 | 604.5 | 604.4 | 1.00x |
| nf4 | f32 | 268,435,456 | 2070.8 | 2070.9 | 591.4 | 591.4 | 1.00x |
| nf4 | f32 | 536,870,912 | 4134.7 | 4134.0 | 592.4 | 592.5 | 1.00x |
| nf4 | f32 | 1,073,741,824 | 8265.9 | 8265.8 | 592.7 | 592.7 | 1.00x |
| fp4 | f16 | 268,435,456 | 1146.4 | 1146.6 | 600.0 | 599.9 | 1.00x |
| fp4 | f16 | 536,870,912 | 2298.0 | 2295.0 | 598.7 | 599.4 | 1.00x |
| fp4 | f16 | 1,073,741,824 | 4588.7 | 4587.6 | 599.6 | 599.8 | 1.00x |
| fp4 | bf16 | 268,435,456 | 1151.1 | 1151.9 | 597.6 | 597.2 | 1.00x |
| fp4 | bf16 | 536,870,912 | 2290.7 | 2292.0 | 600.6 | 600.2 | 1.00x |
| fp4 | bf16 | 1,073,741,824 | 4567.6 | 4569.5 | 602.4 | 602.1 | 1.00x |
| fp4 | f32 | 268,435,456 | 2070.5 | 2070.6 | 591.5 | 591.5 | 1.00x |
| fp4 | f32 | 536,870,912 | 4134.4 | 4137.3 | 592.5 | 592.0 | 1.00x |
| fp4 | f32 | 1,073,741,824 | 8267.7 | 8265.2 | 592.5 | 592.7 | 1.00x |
Here is the full log of the run cdna_bnb_unit_tests.log |
| int grid = (num_blocks + num_qb - 1) / num_qb; | ||
| kQuantizeBlockwiseSmall<T, 64, DATA_TYPE><<<grid, ws>>>(code, A, absmax, out, rand, rand_offset, n); | ||
| } else { | ||
| kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n); |
There was a problem hiding this comment.
Observation here that kQuantizeBlockwise in combination with General8bit data type only utilizes half of the warp since it launches 32 threads on CDNA which has 64.
There was a problem hiding this comment.
Yes, we can discuss this further in another PR. For now, let’s focus on merging this one.
|
I tested your PR on a cuda device (A100), all tests are passing, including the GEMV tests. They are also all passing on MI325X. On which CDNA GPU did you encounter the issue? |
|
@Abdennacer-Badaoui I tested this on MI308X (gfx942) with TheRock nightly ROCm 7.12 build. If it's device dependent, I would leave it as is for now and then address it in a follow up when I get the chance to test it on more devices. |

Summary
kQuantizeBlockwiseSmallon quantization block size (QBLOCK_SIZE), decoupling it fromBNB_WARP_SIZE. This allows blocksize=32 and blocksize=64 4-bit quantization to work on both CDNA (warp=64) and RDNA (warp=32) by packing multiple quantization blocks per wavefront.kgemm_4bit_inference_naive), replacing literal32withBNB_WARP_SIZEfor correct warp lane indexing and reduction on CDNA.ROCM_WARP_SIZE_64runtime guards from Python blocksize checks and test parameterization, now that the kernels handle both warp sizes correctly.Problem
On CDNA GPUs (gfx9xx, warp size 64), blocksize=32 was impossible and blocksize=64 required a special code path because
kQuantizeBlockwiseSmallhardcodedBLOCK_SIZE = BNB_WARP_SIZE. The GEMV kernel also used literal32for warp lane math, producing wrong results on warp-64 hardware. As a result, tests for blocksize=32/64 and GEMV were skipped on CDNA viaROCM_WARP_SIZE_64guards.Changes
csrc/kernels.cu/csrc/kernels.cuh- ReworkedkQuantizeBlockwiseSmall:QBLOCK_SIZEtemplate parameter (was implicitlyBNB_WARP_SIZE)BNB_WARP_SIZE / (QBLOCK_SIZE/2)quant blocks per wavefront32→BNB_WARP_SIZEinkgemm_4bit_inference_naivefor warp lane/reduction mathcsrc/ops.cu- Updated dispatch to use runtime warp size for grid/block calculation for both blocksize=32 and blocksize=64 on HIP.bitsandbytes/backends/cuda/ops.py- RemovedROCM_WARP_SIZE_64conditional blocksize checks; blocksize=32 is now valid on all platforms.bitsandbytes/cextension.py- RemovedROCM_WARP_SIZE_64andget_rocm_warpsizeimport (no longer needed).tests/- RemovedROCM_WARP_SIZE_64skip guards fromtest_functional.py,test_ops.py,test_linear4bit.py,test_linear8bitlt.py, andtest_parametrize.py.Depends on
bnb_host_warp_size()andcommon.cuhfallback chainTest plan
RDNA: verify unit test coverage
16 failed, 2777 passed, 16 skipped, 37 deselected, 32 xfailed, 4095 warnings in 191.94s (0:03:11)CDNA: verify unit test coverage, including newly enabled tests
4 failed, 2789 passed, 17 skipped, 29 deselected, 32 xfailed, 110 warnings in 216.00s (0:03:35)