diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 99c729efbc..1e85afe61e 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1531,6 +1531,70 @@ class GPTQCalibConfig(QuantizeAlgorithmConfig): ) +class LAQConfig(QuantizeAlgorithmConfig): + """Config for LAQ (Learnt Amax Quantization) algorithm. + + LAQ uses separate learnable pre-quantization and post-dequantization amax + values. Forward: ``w_q = Q_STE(w / s_pre) * s_post`` where ``s = amax / Q_max``. + + ``learnable_amax`` controls which amax parameters are learnable vs frozen: + - ``["pre", "post"]``: both learnable + - ``"post"`` or ``["post"]``: only post learnable, pre frozen + - ``"pre"`` or ``["pre"]``: only pre learnable, post frozen + - ``[]``: both frozen (static scales) + + ``tied_amax`` makes pre and post share a single tensor (requires both to + have the same learnable state, i.e. ``learnable_amax`` must be + ``["pre", "post"]`` or ``[]``). + """ + + method: Literal["laq"] = ModeloptField("laq") + + learnable_amax: list[Literal["pre", "post"]] | Literal["pre", "post"] = ModeloptField( + default=["post"], + title="Which amax parameters are learnable.", + description=( + "Which amax params are learnable. " + "'pre', 'post', ['pre', 'post'], or []. " + "Defaults to ['post'] (post-only learnable)." + ), + ) + + tied_amax: bool = ModeloptField( + default=False, + title="Tie pre and post amax into a single tensor.", + description=( + "If True, pre and post share one underlying tensor. " + "Requires both to have the same learnable state." + ), + ) + + scale_algorithm: dict | None = ModeloptField( + default=None, + title="Scale calibration algorithm to run first.", + description=( + "Dict with 'method' key: 'mse', 'local_hessian', or 'max'. " + "Optional keys include 'fp8_scale_sweep' for FP4 formats. " + "Defaults to {'method': 'mse'} if None." + ), + ) + + @model_validator(mode="after") + def _validate_tied_amax(self): + """Validate tied_amax is compatible with learnable_amax.""" + learn = self.learnable_amax + if isinstance(learn, str): + learn = [learn] + learn_set = set(learn) + if self.tied_amax: + if learn_set not in (set(), {"pre", "post"}): + raise ValueError( + f"tied_amax=True requires learnable_amax to be [] or ['pre', 'post'], " + f"got {self.learnable_amax}" + ) + return self + + QuantizeQuantCfgType = list[QuantizerCfgEntry] _QuantizeAlgoCfgType = str | dict | QuantizeAlgorithmConfig | None diff --git a/modelopt/torch/quantization/conversion.py b/modelopt/torch/quantization/conversion.py index 55f7fdf6fc..3727c36e2f 100644 --- a/modelopt/torch/quantization/conversion.py +++ b/modelopt/torch/quantization/conversion.py @@ -36,10 +36,10 @@ normalize_quant_cfg_list, ) from .nn import ( - NVFP4StaticQuantizer, QuantModule, QuantModuleRegistry, SequentialQuantizer, + StaticBlockScaleQuantizer, SVDQuantLinear, TensorQuantizer, ) @@ -131,10 +131,11 @@ def restore_quantizer_state(model: nn.Module, config: QuantizeConfig, metadata: name = get_unwrapped_name(name, model) state = quantizer_state_dict[name] # TODO: Add a registry for TensorQuantizers and avoid this manual conversion. - if state.get("_is_nvfp4_static_quantizer") and not isinstance( - module, NVFP4StaticQuantizer - ): - NVFP4StaticQuantizer.from_tensor_quantizer(module) + if ( + state.get("_is_static_block_scale_quantizer") + or state.get("_is_nvfp4_static_quantizer") # legacy checkpoint compat + ) and not isinstance(module, StaticBlockScaleQuantizer): + StaticBlockScaleQuantizer.from_tensor_quantizer(module) module.set_from_modelopt_state(quantizer_state_dict[name]) for name, module in model.named_modules(): diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index c81d5c89c7..ec67a40d3b 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -38,6 +38,7 @@ AWQLiteCalibConfig, CompressConfig, GPTQCalibConfig, + LAQConfig, LocalHessianCalibConfig, MaxCalibConfig, MseCalibConfig, @@ -60,6 +61,7 @@ from .model_calib import ( awq, gptq, + laq, local_hessian_calibrate, max_calibrate, mse_calibrate, @@ -502,3 +504,15 @@ def config_class(self) -> type[QuantizeAlgorithmConfig]: return GPTQCalibConfig _calib_func = gptq + + +@CalibrateModeRegistry.register_mode +class LAQModeDescriptor(BaseCalibrateModeDescriptor): + """Mode for LAQ (Learnt Amax Quantization) algorithm.""" + + @property + def config_class(self) -> type[QuantizeAlgorithmConfig]: + """Specifies the config class for the mode.""" + return LAQConfig + + _calib_func = laq diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 35a0e931c9..fde1cc2133 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -29,13 +29,19 @@ from modelopt.torch.opt.searcher import ForwardLoop from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector -from modelopt.torch.utils import print_rank_0 +from modelopt.torch.utils import print_rank_0, same_device_as from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method from .calib import MseCalibrator, NVFP4MSECalibrator from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context -from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer +from .nn import ( + NVFP4StaticQuantizer, + QuantModule, + SequentialQuantizer, + StaticBlockScaleQuantizer, + TensorQuantizer, +) from .utils import ( disable_calib, enable_fake_quant, @@ -53,6 +59,8 @@ __all__ = [ "awq", + "gptq", + "laq", "local_hessian_calibrate", "max_calibrate", "sequential_calibrate", @@ -1671,3 +1679,149 @@ def gptq( if torch.cuda.is_available(): torch.cuda.empty_cache() print_rank_0(f"GPTQ time: {time.time() - total_start:.2f}s") + + +def _is_quantized_block_scale(quantizer: StaticBlockScaleQuantizer) -> bool: + if quantizer._block_sizes is None: + return False + scale_bits = quantizer._block_sizes.get("scale_bits", None) + if scale_bits is None: + return False + return scale_bits == (4, 3) + + +def _convert_to_static_block_quantizers(model: nn.Module): + """Convert eligible TensorQuantizers to StaticBlockScaleQuantizer.""" + for name, module in model.named_modules(): + if isinstance(module, TensorQuantizer) and not module._disabled: + if not hasattr(module, "_amax") or module._amax is None: + continue + is_static_block_scale = ( + module.is_static_block_quant + and module._block_sizes is not None + and ( + (module._num_bits == (2, 1) and module._block_sizes.get("scale_bits") == (4, 3)) + or isinstance(module._num_bits, int) + ) + ) + if is_static_block_scale: + if _is_quantized_block_scale(module): + global_amax = reduce_amax(module._amax.clone().detach(), axis=None) + else: + global_amax = None + StaticBlockScaleQuantizer.from_tensor_quantizer(module, global_amax=global_amax) + + +def _run_scale_calibration(model, forward_loop, scale_algorithm, caller_name): + """Run calibration and convert to StaticBlockScaleQuantizer if needed.""" + if scale_algorithm is None: + scale_algorithm = {"method": "mse"} + + method = scale_algorithm.get("method") + supported = ("mse", "local_hessian", "max") + assert method in supported, f"{caller_name}: method must be one of {supported}, got '{method}'" + + algo_kwargs = {k: v for k, v in scale_algorithm.items() if k != "method"} + calib_funcs = { + "mse": mse_calibrate, + "local_hessian": local_hessian_calibrate, + "max": max_calibrate, + } + calib_funcs[method](model, forward_loop=forward_loop, **algo_kwargs) + + if method == "max": + _convert_to_static_block_quantizers(model) + + +def _compute_block_scales(quantizer): + """Compute per-block and per-tensor scales from a StaticBlockScaleQuantizer. + + Returns (per_block_scale, per_tensor_scale, quantize_scales). + """ + from .nn.modules.tensor_quantizer import _amax_to_scale + from .tensor_quant import scaled_e4m3 + + amax = quantizer._amax.float() + max_representable = quantizer._quant_max_bound + quantize_scales = _is_quantized_block_scale(quantizer) + per_tensor_scale = None + + with same_device_as(amax): + if quantize_scales: + global_amax = quantizer._global_amax.float() + per_tensor_scale = _amax_to_scale(global_amax, max_representable) + per_block_scale = scaled_e4m3( + _amax_to_scale( + amax, + max_representable, + min_value=0.002 + * per_tensor_scale.view(-1), # 0.002 ≈ smallest positive FP8 E4M3 value + ), + per_tensor_scale, + None, + 4, + 3, + ) + else: + per_block_scale = _amax_to_scale(amax, max_representable) + + return per_block_scale, per_tensor_scale, quantize_scales + + +def _iter_weight_quantizers(model): + """Yield (module, weight_name, quantizer) for each StaticBlockScaleQuantizer with amax.""" + seen_modules = set() + for name, module in model.named_modules(): + if module in seen_modules: + continue + for weight_name in weight_attr_names(module): + wq_name = quantizer_attr_names(weight_name).weight_quantizer + quantizer = getattr(module, wq_name, None) + if isinstance(quantizer, StaticBlockScaleQuantizer) and hasattr(quantizer, "_amax"): + seen_modules.add(module) + yield module, weight_name, quantizer + break + + +def _compute_laq_params(quantizer): + """Compute amax and scale-quantization params for LAQ.""" + per_block_scale, per_tensor_scale, quantize_scales = _compute_block_scales(quantizer) + amax = per_block_scale * quantizer._quant_max_bound + return amax, per_tensor_scale, quantize_scales + + +@torch.no_grad() +def laq( + model: nn.Module, + forward_loop: ForwardLoop | None = None, + scale_algorithm: dict | None = None, + learnable_amax: list | str = ("post",), + tied_amax: bool = False, + **kwargs, +): + """Run scale calibration then convert to LAQ mode. + + Uses separate pre (quant) and post (dequant) amax values. + Forward: ``w_q = Q_STE(w / s_pre) * s_post`` where ``s = amax / Q_max``. + + Args: + model: Quantized model. + forward_loop: Calibration data forward loop. + scale_algorithm: Calibration algorithm config to run first. + Dict with 'method' key: 'mse', 'local_hessian', or 'max'. + Defaults to {'method': 'mse'} if None. + learnable_amax: Which amax params are learnable: 'pre', 'post', + ['pre', 'post'], or []. + tied_amax: If True, pre and post share a single tensor. + """ + _run_scale_calibration(model, forward_loop, scale_algorithm, "laq") + + for module, weight_name, quantizer in _iter_weight_quantizers(model): + amax, per_tensor_scale, quantize_scales = _compute_laq_params(quantizer) + quantizer.enable_laq( + amax, + per_tensor_scale, + quantize_scales, + learnable_amax=learnable_amax, + tied_amax=tied_amax, + ) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 3ff7401ec3..8d58df6896 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -57,6 +57,8 @@ from ...tensor_quant import ( dynamic_block_quant, fake_tensor_quant, + fp4_cast_ste, + int_cast_ste, scaled_e4m3, static_blockwise_fp4_fake_quant, ) @@ -66,6 +68,7 @@ __all__ = [ "NVFP4StaticQuantizer", "SequentialQuantizer", + "StaticBlockScaleQuantizer", "TensorQuantizer", "TensorQuantizerCache", "is_registered_quant_backend", @@ -1297,17 +1300,49 @@ def _set_buffer(self, key, value): self.register_buffer(key, value) -class NVFP4StaticQuantizer(TensorQuantizer): - """TensorQuantizer for NVFP4 static block quantization with two-level scaling. +def _clamp_scale(scale: torch.Tensor, min_value: float | torch.Tensor = 1e-8) -> torch.Tensor: + """Clamp per-block scale to guard against small/zero values.""" + return torch.where(scale <= min_value, min_value, scale) + +def _amax_to_scale( + amax: torch.Tensor, max_bound: float, min_value: float | torch.Tensor = 1e-8 +) -> torch.Tensor: + """Convert amax to per-block scale, guarding against small/zero values.""" + return _clamp_scale(amax.float() / max_bound, min_value) + + +def _to_local(t: torch.Tensor) -> torch.Tensor: + """Convert DTensor to local tensor (no-op for regular tensors). + + Under FSDP2, learnable parameters are DTensors but the quantizer forward + operates on local tensors (see TensorQuantizer.forward DTensor handling). + to_local() preserves autograd so gradients flow back to the DTensor parameter. + """ + if DTensor is not None and isinstance(t, DTensor): + return t.to_local() + return t + + +class StaticBlockScaleQuantizer(TensorQuantizer): + """TensorQuantizer for static block quantization with two-level scaling. + + Supports both FP4 (E2M1) and INT block quantization formats with configurable + block_size and optional FP8 scale quantization. Uses _global_amax and inherited _amax for per-block amax values. """ + _laq: bool = False + _learnable_amax: list = [] + _tied_amax: bool = False + _quant_max_bound: float = 6.0 + _quantize_scales: bool = True + @classmethod def from_tensor_quantizer( cls, tq: TensorQuantizer, global_amax: torch.Tensor | None = None - ) -> "NVFP4StaticQuantizer": - """Convert a TensorQuantizer to NVFP4StaticQuantizer in-place. + ) -> "StaticBlockScaleQuantizer": + """Convert a TensorQuantizer to StaticBlockScaleQuantizer in-place. Args: tq: The TensorQuantizer to convert. @@ -1318,11 +1353,52 @@ def from_tensor_quantizer( tq.global_amax = global_amax return tq tq.__class__ = cls - tq._is_nvfp4_static_quantizer = True + tq._is_static_block_scale_quantizer = True + + tq._quant_max_bound = float(tq.maxbound) + if global_amax is not None: tq.global_amax = global_amax return tq + @property + def amax_pre(self): + """Pre (quantization) amax. Returns _amax_post when tied.""" + if self._tied_amax: + return self._amax_post + return self._amax_pre + + @property + def amax_post(self): + """Post (dequantization) amax.""" + return self._amax_post + + @property + def amax(self): + """Return amax, derived from learnable amax parameters if in LAQ mode.""" + if self._laq and not self._tied_amax: + raise RuntimeError( + "LAQ with untied amaxes has separate pre and post parameters. " + "Access them via amax_pre / amax_post." + ) + if self._laq: + return self._amax_post + if not hasattr(self, "_amax"): + return None + return self._amax + + @amax.setter + def amax(self, value): + assert value is not None, "amax cannot be set to None." + if not isinstance(value, torch.Tensor): + value = torch.tensor(value) + if not hasattr(self, "_amax"): + self.register_buffer("_amax", value.clone().detach()) + else: + if self._amax.shape != value.shape: + raise RuntimeError("Changing shape when setting amax is not allowed.") + self._amax.data.copy_(value.clone().detach().to(self._amax.device)) + @property def global_amax(self): """Return global_amax for quantization.""" @@ -1343,20 +1419,113 @@ def global_amax(self, value): else: self._global_amax.data.copy_(value.clone().detach().to(self._global_amax.device)) + def _short_amax(self, fmt=".4f"): + """Short description of amax, accounting for LAQ mode.""" + if not self._laq: + return super()._short_amax(fmt) + learn = self._learnable_amax + learn_str = "frozen" if not learn else f"learn=[{','.join(learn)}]" + if self._tied_amax: + return f"LAQ(tied={self._short_tensor(self._amax_post.data, fmt)}, {learn_str})" + return ( + f"LAQ(pre={self._short_tensor(self._amax_pre.data, fmt)}, " + f"post={self._short_tensor(self._amax_post.data, fmt)}, {learn_str})" + ) + + def enable_laq( + self, + amax: torch.Tensor, + per_tensor_scale: torch.Tensor = None, + quantize_scales: bool = True, + learnable_amax: list | str = ("post",), + tied_amax: bool = False, + ): + """LAQ mode with configurable learnable/frozen amax tensors. + + Args: + amax: Initial amax values (per-block). + per_tensor_scale: Optional per-tensor scale (frozen buffer). + quantize_scales: Whether to FP8-quantize per-block scales. + learnable_amax: Which amax params are learnable: 'pre', 'post', + ['pre', 'post'], or []. + tied_amax: If True, pre and post share a single tensor. + """ + if hasattr(self, "_amax"): + delattr(self, "_amax") + learn = {learnable_amax} if isinstance(learnable_amax, str) else set(learnable_amax) + + if "post" in learn: + self._amax_post = nn.Parameter(amax.clone().detach().float(), requires_grad=True) + else: + self.register_buffer("_amax_post", amax.clone().detach().float()) + + if not tied_amax: + if "pre" in learn: + self._amax_pre = nn.Parameter(amax.clone().detach().float(), requires_grad=True) + else: + self.register_buffer("_amax_pre", amax.clone().detach().float()) + + if per_tensor_scale is not None: + self.register_buffer("_per_tensor_scale", per_tensor_scale.clone().detach().float()) + self._quantize_scales = quantize_scales + self._laq = True + self._learnable_amax = sorted(learn) + self._tied_amax = tied_amax + + def _cast_ste(self, inputs): + """Cast inputs to quantized representable values (no scaling).""" + if isinstance(self._num_bits, tuple): + return fp4_cast_ste(inputs) + return int_cast_ste(inputs, self._num_bits, self._unsigned, self._narrow_range) + + def _maybe_quantize_scale(self, scale_raw): + """FP8-quantize a per-block scale if ``_quantize_scales`` is enabled, else pass through.""" + if self._quantize_scales: + return scaled_e4m3(scale_raw, self._per_tensor_scale, None, 4, 3) + return scale_raw + def _fake_quantize(self, inputs): """Fake quantization using two-level scaling with _amax and _global_amax.""" - if self.amax is not None: - return static_blockwise_fp4_fake_quant( - inputs, - self.amax, - self.global_amax, # Can be None, will be computed internally - True, # quantize_block_scales - inputs.dtype, - self._pass_through_bwd, + if self._laq: + # 0.002 ≈ smallest positive FP8 E4M3 value; clamps per-block scale floor + _scale_min = 0.002 * self._per_tensor_scale.view(-1) if self._quantize_scales else 1e-8 + + scale_post = self._maybe_quantize_scale( + _amax_to_scale( + _to_local(self.amax_post), + self._quant_max_bound, + min_value=_scale_min, + ) ) + scale_pre = self._maybe_quantize_scale( + _amax_to_scale( + _to_local(self.amax_pre), + self._quant_max_bound, + min_value=_scale_min, + ) + ) + quant_input = inputs.float() / scale_pre.float().view(-1, 1) + w_cast = self._cast_ste(quant_input) + return (w_cast * scale_post.view(-1, 1).to(w_cast.dtype)).to(inputs.dtype) + + if self.amax is not None: + if isinstance(self._num_bits, tuple): + return static_blockwise_fp4_fake_quant( + inputs, + self.amax, + self.global_amax, + True, + inputs.dtype, + self._pass_through_bwd, + ) + else: + return super()._fake_quantize(inputs) return super()._fake_quantize(inputs) +NVFP4StaticQuantizer = StaticBlockScaleQuantizer + + class SequentialQuantizer(nn.Sequential): """A sequential container for :class:`TensorQuantizer` modules. diff --git a/modelopt/torch/quantization/tensor_quant.py b/modelopt/torch/quantization/tensor_quant.py index 16b9d32997..c4b7ab6753 100644 --- a/modelopt/torch/quantization/tensor_quant.py +++ b/modelopt/torch/quantization/tensor_quant.py @@ -642,7 +642,65 @@ def _tensor_quant(inputs, amax, num_bits=8, unsigned=False, narrow_range=True): return outputs +class FP4CastSTEFunction(Function): + """FP4 cast with STE backward -- no scale/descale, just rounding.""" + + @staticmethod + def forward(ctx, x, out_dtype=None, rounding="rne"): + """Forward pass: cast to FP4 using triton kernel. + + Args: + x: Input tensor of shape [NUM_BLOCKS, BLOCK_SIZE]. + out_dtype: Output dtype. Defaults to x.dtype. + rounding: Rounding mode -- ``"rne"`` (round to nearest even, default) + or ``"down"`` (floor toward zero). + """ + if not triton_kernel.IS_AVAILABLE: + raise RuntimeError("FP4CastSTEFunction requires triton.") + ctx.save_for_backward(x) + return triton_kernel.static_blockwise_fp4_cast(x, out_dtype, rounding=rounding) + + @staticmethod + def backward(ctx, grad_outputs): + """Backward pass: STE with clip mask at |x| <= 6.0.""" + (x,) = ctx.saved_tensors + grad = torch.where(x.abs() <= 6.0, grad_outputs, torch.zeros_like(grad_outputs)) + return grad, None, None + + +class IntCastSTEFunction(Function): + """Integer quantization cast with STE backward, analogous to FP4CastSTEFunction.""" + + @staticmethod + def forward(ctx, x, num_bits, unsigned=False, narrow_range=True): + """Forward pass: clamp-round to integer range.""" + max_bound = (2.0 ** (num_bits - 1 + int(unsigned))) - 1.0 + if unsigned: + min_bound = 0 + elif narrow_range: + min_bound = -max_bound + else: + min_bound = -max_bound - 1 + ctx.save_for_backward(x) + ctx.min_bound = min_bound + ctx.max_bound = max_bound + return torch.clamp(x.round(), min_bound, max_bound) + + @staticmethod + def backward(ctx, grad_outputs): + """Backward pass: STE with clip mask.""" + (x,) = ctx.saved_tensors + grad = torch.where( + (x >= ctx.min_bound) & (x <= ctx.max_bound), + grad_outputs, + torch.zeros_like(grad_outputs), + ) + return grad, None, None, None + + fake_tensor_quant = FakeTensorQuantFunction.apply scaled_e4m3 = ScaledE4M3Function.apply dynamic_block_quant = DynamicBlockQuantizationFunction.apply static_blockwise_fp4_fake_quant = StaticBlockwiseFP4FakeQuantFunction.apply +fp4_cast_ste = FP4CastSTEFunction.apply +int_cast_ste = IntCastSTEFunction.apply diff --git a/modelopt/torch/quantization/triton/fp4_kernel.py b/modelopt/torch/quantization/triton/fp4_kernel.py index 63a8b3dcb7..59bff3e031 100644 --- a/modelopt/torch/quantization/triton/fp4_kernel.py +++ b/modelopt/torch/quantization/triton/fp4_kernel.py @@ -24,7 +24,7 @@ import triton import triton.language as tl -__all__ = ["fp4_dequantize", "static_blockwise_fp4_fake_quant"] +__all__ = ["fp4_dequantize", "static_blockwise_fp4_cast", "static_blockwise_fp4_fake_quant"] _TORCH_TO_TL_DTYPE = { @@ -299,3 +299,89 @@ def static_blockwise_fp4_fake_quant( ) return y_flat.view_as(x) + + +@triton.jit +def static_blockwise_fp4_cast_kernel( + x_ptr, # [NUM_ELEMENTS] flattened pre-scaled input + y_ptr, # [NUM_ELEMENTS] flattened output + NUM_ELEMENTS, + TILE_SIZE: tl.constexpr, + OUT_DTYPE: tl.constexpr, +): + """Round pre-scaled values to nearest FP4 representable value (no scale).""" + pid = tl.program_id(axis=0) + offset = pid * TILE_SIZE + tl.arange(0, TILE_SIZE) + mask = offset < NUM_ELEMENTS + + x = tl.load(x_ptr + offset, mask=mask).to(tl.float32) + x_abs = tl.abs(x) + + # FP4 E2M1 representable values: 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0 + q_val = tl.where( + x_abs <= 0.25, + 0.0, + tl.where( + x_abs < 0.75, + 0.5, + tl.where( + x_abs <= 1.25, + 1.0, + tl.where( + x_abs < 1.75, + 1.5, + tl.where( + x_abs <= 2.5, + 2.0, + tl.where( + x_abs < 3.5, + 3.0, + tl.where(x_abs <= 5.0, 4.0, 6.0), + ), + ), + ), + ), + ), + ) + + y = tl.where(x >= 0, q_val, -q_val) + tl.store(y_ptr + offset, y.to(OUT_DTYPE), mask=mask) + + +def static_blockwise_fp4_cast( + x: torch.Tensor, + out_dtype: torch.dtype | None = None, + rounding: str = "rne", +) -> torch.Tensor: + """Round pre-scaled values to nearest FP4 E2M1 representable value. + + Unlike ``static_blockwise_fp4_fake_quant``, this does **not** apply any + scale -- the caller is responsible for pre-dividing by scale_pre and + post-multiplying by scale_post (as in LAQ). + + Args: + x: Input tensor (any shape) on CUDA. + out_dtype: Output dtype. Defaults to x.dtype. + rounding: Rounding mode (only ``"rne"`` supported currently). + """ + if out_dtype is None: + out_dtype = x.dtype + + x_flat = x.contiguous().view(-1) + y_flat = torch.empty_like(x_flat, dtype=out_dtype) + NUM_ELEMENTS = x_flat.numel() + TILE_SIZE = 1024 + + tl_out_dtype = _torch_dtype_to_tl(out_dtype) + grid = ((NUM_ELEMENTS + TILE_SIZE - 1) // TILE_SIZE,) + + with torch.cuda.device(x.device): + static_blockwise_fp4_cast_kernel[grid]( + x_flat, + y_flat, + NUM_ELEMENTS, + TILE_SIZE=TILE_SIZE, + OUT_DTYPE=tl_out_dtype, + ) + + return y_flat.view_as(x) diff --git a/tests/gpu/torch/quantization/test_laq_cuda.py b/tests/gpu/torch/quantization/test_laq_cuda.py new file mode 100644 index 0000000000..256bcc25dc --- /dev/null +++ b/tests/gpu/torch/quantization/test_laq_cuda.py @@ -0,0 +1,174 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU unit tests for the LAQ algorithm using FP4 (NVFP4) quantization.""" + +import pytest +import torch +from torch import nn + +import modelopt.torch.quantization as mtq + +NVFP4_LAQ_POST_MSE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + }, + "algorithm": { + "method": "laq", + "learnable_amax": ["post"], + "scale_algorithm": {"method": "mse", "fp8_scale_sweep": True}, + }, +} + +NVFP4_LAQ_PRE_POST_MSE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + }, + "algorithm": { + "method": "laq", + "learnable_amax": ["pre", "post"], + "scale_algorithm": {"method": "mse", "fp8_scale_sweep": True}, + }, +} + +NVFP4_LAQ_TIED_MSE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + }, + "algorithm": { + "method": "laq", + "learnable_amax": ["pre", "post"], + "tied_amax": True, + "scale_algorithm": {"method": "mse", "fp8_scale_sweep": True}, + }, +} + + +class SimpleModel(nn.Module): + """Minimal model for LAQ testing.""" + + def __init__(self): + super().__init__() + self.linear = nn.Linear(64, 64, bias=False) + + def forward(self, x): + return self.linear(x) + + +def _make_forward_loop(model, device): + x = torch.randn(2, 64, device=device) + + def forward_loop(m): + m(x) + + return forward_loop + + +@pytest.mark.parametrize( + "config", + [NVFP4_LAQ_POST_MSE_CFG, NVFP4_LAQ_PRE_POST_MSE_CFG, NVFP4_LAQ_TIED_MSE_CFG], + ids=["post_only", "pre_and_post", "tied"], +) +def test_laq_quantize_e2e(config): + """End-to-end: quantize a small model with LAQ + NVFP4 on GPU.""" + device = torch.device("cuda") + model = SimpleModel().to(device) + forward_loop = _make_forward_loop(model, device) + + model = mtq.quantize(model, config, forward_loop=forward_loop) + + # Verify the model still produces output of the correct shape + x = torch.randn(2, 64, device=device) + out = model(x) + assert out.shape == (2, 64) + + +def test_laq_fp4_fake_quantize_differentiable(): + """Test that _fake_quantize in FP4 LAQ mode is differentiable.""" + from modelopt.torch.quantization.nn.modules.tensor_quantizer import ( + StaticBlockScaleQuantizer, + TensorQuantizer, + ) + + device = torch.device("cuda") + tq = TensorQuantizer() + tq._num_bits = (2, 1) + tq._unsigned = False + tq._narrow_range = True + tq._disabled = False + tq._block_sizes = {-1: 16, "type": "static", "scale_bits": (4, 3)} + tq._pass_through_bwd = True + tq.register_buffer("_amax", torch.ones(4, device=device)) + tq.to(device) + sbsq = StaticBlockScaleQuantizer.from_tensor_quantizer( + tq, global_amax=torch.tensor(1.0, device=device) + ) + + amax = torch.ones(4, device=device) * 3.0 + per_tensor_scale = torch.tensor(1.0 / 6.0, device=device) + sbsq.enable_laq( + amax, + per_tensor_scale=per_tensor_scale, + quantize_scales=True, + learnable_amax=["post"], + ) + + x = torch.randn(4, 16, device=device) + out = sbsq._fake_quantize(x) + assert out.shape == x.shape + out.sum().backward() + assert sbsq._amax_post.grad is not None + + +def test_laq_fp4_cast_ste(): + """Test fp4_cast_ste on GPU.""" + from modelopt.torch.quantization.tensor_quant import fp4_cast_ste + + device = torch.device("cuda") + x = torch.tensor([[-3.0, 1.5, 0.0, 6.0, -6.0, 0.5, -0.5, 2.0]], device=device) + x.requires_grad_(True) + # fp4_cast_ste expects [NUM_BLOCKS, BLOCK_SIZE] -- pad to block size 16 + x_padded = torch.zeros(1, 16, device=device, requires_grad=True) + with torch.no_grad(): + x_padded[:, : x.shape[1]] = x.detach() + x_padded = x_padded.clone().detach().requires_grad_(True) + y = fp4_cast_ste(x_padded) + assert y.shape == x_padded.shape + y.sum().backward() + assert x_padded.grad is not None diff --git a/tests/unit/recipe/test_laq_recipes.py b/tests/unit/recipe/test_laq_recipes.py new file mode 100644 index 0000000000..e6d4be4b1f --- /dev/null +++ b/tests/unit/recipe/test_laq_recipes.py @@ -0,0 +1,125 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for LAQ PTQ recipe YAML files in configs/quantize/.""" + +from pathlib import Path + +import pytest +import yaml + +CONFIGS_DIR = Path(__file__).resolve().parents[3] / "examples" / "llm_qat" / "configs" / "quantize" + +# (filename, expected learnable_amax, expected tied_amax) +_LAQ_RECIPES = [ + ("nvfp4_laq_post-mse_init-fp8_kv.yml", ["post"], False), + ("nvfp4_laq_pre-mse_init-fp8_kv.yml", ["pre"], False), + ("nvfp4_laq_pre_post-mse_init-fp8_kv.yml", ["pre", "post"], False), + ("nvfp4_laq_pre_post_tied-mse_init-fp8_kv.yml", ["pre", "post"], True), + ("nvfp4_laq_frozen-mse_init-fp8_kv.yml", [], False), +] + + +def _load_yaml(filename): + path = CONFIGS_DIR / filename + with open(path) as f: + return yaml.safe_load(f) + + +def _find_entry(quant_cfg, quantizer_name): + """Find entry by quantizer_name in the quant_cfg list.""" + for entry in quant_cfg: + if entry.get("quantizer_name") == quantizer_name: + return entry + raise KeyError(f"No entry with quantizer_name={quantizer_name!r}") + + +# --------------------------------------------------------------------------- +# Parametrized load & parse test +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize(("filename", "_", "__"), _LAQ_RECIPES, ids=[r[0] for r in _LAQ_RECIPES]) +def test_recipe_loads_and_has_required_sections(filename, _, __): + """Each LAQ recipe YAML is parseable and has metadata + quantize.""" + data = _load_yaml(filename) + assert "metadata" in data + assert data["metadata"]["recipe_type"] == "ptq" + assert "quantize" in data + assert "algorithm" in data["quantize"] + assert "quant_cfg" in data["quantize"] + + +# --------------------------------------------------------------------------- +# Algorithm structure test +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("filename", "expected_learnable", "expected_tied"), + _LAQ_RECIPES, + ids=[r[0] for r in _LAQ_RECIPES], +) +def test_algorithm_has_correct_laq_params(filename, expected_learnable, expected_tied): + """Algorithm section has correct method, learnable_amax, tied_amax, and scale_algorithm.""" + algo = _load_yaml(filename)["quantize"]["algorithm"] + assert algo["method"] == "laq" + assert algo["learnable_amax"] == expected_learnable + assert algo["tied_amax"] is expected_tied + assert algo["scale_algorithm"] == {"method": "mse", "fp8_scale_sweep": True} + + +# --------------------------------------------------------------------------- +# Weight quantizer uses static type +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize(("filename", "_", "__"), _LAQ_RECIPES, ids=[r[0] for r in _LAQ_RECIPES]) +def test_weight_quantizer_is_static(filename, _, __): + """Weight quantizer must use static block type for LAQ learnable scales.""" + qcfg = _load_yaml(filename)["quantize"]["quant_cfg"] + w = _find_entry(qcfg, "*weight_quantizer") + assert w["enable"] is True + assert w["cfg"]["block_sizes"]["type"] == "static" + assert w["cfg"]["num_bits"] == "e2m1" + + +# --------------------------------------------------------------------------- +# Input quantizer uses dynamic type +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize(("filename", "_", "__"), _LAQ_RECIPES, ids=[r[0] for r in _LAQ_RECIPES]) +def test_input_quantizer_is_dynamic(filename, _, __): + """Input/activation quantizer uses dynamic block type.""" + qcfg = _load_yaml(filename)["quantize"]["quant_cfg"] + inp = _find_entry(qcfg, "*input_quantizer") + assert inp["enable"] is True + assert inp["cfg"]["block_sizes"]["type"] == "dynamic" + assert inp["cfg"]["num_bits"] == "e2m1" + + +# --------------------------------------------------------------------------- +# KV cache quantizer enabled +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize(("filename", "_", "__"), _LAQ_RECIPES, ids=[r[0] for r in _LAQ_RECIPES]) +def test_kv_cache_quantizer_enabled(filename, _, __): + """FP8 KV cache quantizer is present and enabled.""" + qcfg = _load_yaml(filename)["quantize"]["quant_cfg"] + kv = _find_entry(qcfg, "*[kv]_bmm_quantizer") + assert kv["enable"] is True + assert kv["cfg"]["num_bits"] == "e4m3" diff --git a/tests/unit/torch/quantization/test_laq.py b/tests/unit/torch/quantization/test_laq.py new file mode 100644 index 0000000000..e1eed2659f --- /dev/null +++ b/tests/unit/torch/quantization/test_laq.py @@ -0,0 +1,205 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU unit tests for the LAQ algorithm using INT4 quantization.""" + +import pytest +import torch +from torch import nn + +from modelopt.torch.quantization.config import LAQConfig +from modelopt.torch.quantization.nn.modules.tensor_quantizer import ( + StaticBlockScaleQuantizer, + TensorQuantizer, +) +from modelopt.torch.quantization.tensor_quant import int_cast_ste + + +class TestLAQConfig: + """Tests for LAQConfig validation.""" + + def test_default_config(self): + cfg = LAQConfig() + assert cfg.method == "laq" + assert cfg.learnable_amax == ["post"] + assert cfg.tied_amax is False + assert cfg.scale_algorithm is None + + @pytest.mark.parametrize( + ("learnable_amax", "tied_amax"), + [ + (["post"], False), + (["pre"], False), + (["pre", "post"], False), + (["pre", "post"], True), + ([], False), + ([], True), + ("post", False), + ("pre", False), + ], + ) + def test_valid_combinations(self, learnable_amax, tied_amax): + cfg = LAQConfig(learnable_amax=learnable_amax, tied_amax=tied_amax) + assert cfg.tied_amax is tied_amax + + @pytest.mark.parametrize( + "learnable_amax", + [["post"], ["pre"], "post", "pre"], + ) + def test_invalid_tied_with_single_learnable(self, learnable_amax): + with pytest.raises(ValueError, match="tied_amax=True requires"): + LAQConfig(learnable_amax=learnable_amax, tied_amax=True) + + +class TestEnableLAQ: + """Tests for StaticBlockScaleQuantizer.enable_laq() with INT4 format.""" + + def _make_quantizer(self): + """Create a StaticBlockScaleQuantizer configured for INT4.""" + tq = TensorQuantizer() + tq._num_bits = 4 + tq._unsigned = False + tq._narrow_range = True + tq._disabled = False + tq._block_sizes = {-1: 16} + tq._pass_through_bwd = True + tq.register_buffer("_amax", torch.ones(8)) + sbsq = StaticBlockScaleQuantizer.from_tensor_quantizer(tq) + assert sbsq._quant_max_bound == 7.0 + return sbsq + + def test_post_only_learnable(self): + q = self._make_quantizer() + amax = torch.ones(8) * 3.0 + q.enable_laq(amax, quantize_scales=False, learnable_amax=["post"], tied_amax=False) + assert q._laq is True + assert isinstance(q._amax_post, nn.Parameter) + assert q._amax_post.requires_grad is True + assert not isinstance(q._amax_pre, nn.Parameter) + assert not q._amax_pre.requires_grad + + def test_pre_only_learnable(self): + q = self._make_quantizer() + amax = torch.ones(8) * 3.0 + q.enable_laq(amax, quantize_scales=False, learnable_amax=["pre"], tied_amax=False) + assert isinstance(q._amax_pre, nn.Parameter) + assert q._amax_pre.requires_grad is True + assert not isinstance(q._amax_post, nn.Parameter) + + def test_both_learnable(self): + q = self._make_quantizer() + amax = torch.ones(8) * 3.0 + q.enable_laq(amax, quantize_scales=False, learnable_amax=["pre", "post"], tied_amax=False) + assert isinstance(q._amax_pre, nn.Parameter) + assert isinstance(q._amax_post, nn.Parameter) + + def test_tied_both_learnable(self): + q = self._make_quantizer() + amax = torch.ones(8) * 3.0 + q.enable_laq(amax, quantize_scales=False, learnable_amax=["pre", "post"], tied_amax=True) + assert q._tied_amax is True + assert isinstance(q._amax_post, nn.Parameter) + assert not hasattr(q, "_amax_pre") + assert q.amax_pre is q._amax_post + + def test_frozen(self): + q = self._make_quantizer() + amax = torch.ones(8) * 3.0 + q.enable_laq(amax, quantize_scales=False, learnable_amax=[], tied_amax=False) + assert not isinstance(q._amax_post, nn.Parameter) + assert not isinstance(q._amax_pre, nn.Parameter) + + def test_old_amax_deleted(self): + q = self._make_quantizer() + assert hasattr(q, "_amax") + q.enable_laq(torch.ones(8), quantize_scales=False) + assert not hasattr(q, "_amax") + + +class TestIntCastSTE: + """Tests for int_cast_ste (INT4 STE function).""" + + def test_round_trip(self): + x = torch.tensor([[-3.2, 1.8, 0.0, 6.5, -7.1]], requires_grad=True) + y = int_cast_ste(x, 4) + assert y.shape == x.shape + max_bound = 7.0 + assert y.min() >= -max_bound + assert y.max() <= max_bound + y.sum().backward() + assert x.grad is not None + + def test_ste_gradient(self): + x = torch.tensor([[2.3, -2.3]], requires_grad=True) + y = int_cast_ste(x, 4) + y.sum().backward() + assert torch.all(x.grad == 1.0) + + +class TestFakeQuantizeLAQ: + """Tests for _fake_quantize() LAQ path with INT4.""" + + def _make_laq_quantizer(self, learnable_amax=("post",), tied_amax=False): + tq = TensorQuantizer() + tq._num_bits = 4 + tq._unsigned = False + tq._narrow_range = True + tq._disabled = False + tq._block_sizes = {-1: 16} + tq._pass_through_bwd = True + tq.register_buffer("_amax", torch.ones(4)) + sbsq = StaticBlockScaleQuantizer.from_tensor_quantizer(tq) + amax = torch.ones(4) * 3.5 + sbsq.enable_laq( + amax, quantize_scales=False, learnable_amax=learnable_amax, tied_amax=tied_amax + ) + return sbsq + + def test_output_shape(self): + q = self._make_laq_quantizer() + x = torch.randn(4, 16) + out = q._fake_quantize(x) + assert out.shape == x.shape + + def test_differentiable_post(self): + q = self._make_laq_quantizer(learnable_amax=["post"]) + x = torch.randn(4, 16) + out = q._fake_quantize(x) + out.sum().backward() + assert q._amax_post.grad is not None + assert q._amax_pre.grad is None + + def test_differentiable_pre(self): + q = self._make_laq_quantizer(learnable_amax=["pre"]) + x = torch.randn(4, 16) + out = q._fake_quantize(x) + out.sum().backward() + assert q._amax_pre.grad is not None + assert q._amax_post.grad is None + + def test_differentiable_both(self): + q = self._make_laq_quantizer(learnable_amax=["pre", "post"]) + x = torch.randn(4, 16) + out = q._fake_quantize(x) + out.sum().backward() + assert q._amax_pre.grad is not None + assert q._amax_post.grad is not None + + def test_tied_shares_tensor(self): + q = self._make_laq_quantizer(learnable_amax=["pre", "post"], tied_amax=True) + x = torch.randn(4, 16) + out = q._fake_quantize(x) + out.sum().backward() + assert q._amax_post.grad is not None