Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,6 +1531,70 @@ class GPTQCalibConfig(QuantizeAlgorithmConfig):
)


class LAQConfig(QuantizeAlgorithmConfig):
"""Config for LAQ (Learnt Amax Quantization) algorithm.

LAQ uses separate learnable pre-quantization and post-dequantization amax
values. Forward: ``w_q = Q_STE(w / s_pre) * s_post`` where ``s = amax / Q_max``.

``learnable_amax`` controls which amax parameters are learnable vs frozen:
- ``["pre", "post"]``: both learnable
- ``"post"`` or ``["post"]``: only post learnable, pre frozen
- ``"pre"`` or ``["pre"]``: only pre learnable, post frozen
- ``[]``: both frozen (static scales)

``tied_amax`` makes pre and post share a single tensor (requires both to
have the same learnable state, i.e. ``learnable_amax`` must be
``["pre", "post"]`` or ``[]``).
"""

method: Literal["laq"] = ModeloptField("laq")

learnable_amax: list[Literal["pre", "post"]] | Literal["pre", "post"] = ModeloptField(
default=["post"],
title="Which amax parameters are learnable.",
description=(
"Which amax params are learnable. "
"'pre', 'post', ['pre', 'post'], or []. "
"Defaults to ['post'] (post-only learnable)."
),
)

tied_amax: bool = ModeloptField(
default=False,
title="Tie pre and post amax into a single tensor.",
description=(
"If True, pre and post share one underlying tensor. "
"Requires both to have the same learnable state."
),
)

scale_algorithm: dict | None = ModeloptField(
default=None,
title="Scale calibration algorithm to run first.",
description=(
"Dict with 'method' key: 'mse', 'local_hessian', or 'max'. "
"Optional keys include 'fp8_scale_sweep' for FP4 formats. "
"Defaults to {'method': 'mse'} if None."
),
)

@model_validator(mode="after")
def _validate_tied_amax(self):
"""Validate tied_amax is compatible with learnable_amax."""
learn = self.learnable_amax
if isinstance(learn, str):
learn = [learn]
learn_set = set(learn)
if self.tied_amax:
if learn_set not in (set(), {"pre", "post"}):
raise ValueError(
f"tied_amax=True requires learnable_amax to be [] or ['pre', 'post'], "
f"got {self.learnable_amax}"
)
return self


QuantizeQuantCfgType = list[QuantizerCfgEntry]

_QuantizeAlgoCfgType = str | dict | QuantizeAlgorithmConfig | None
Expand Down
11 changes: 6 additions & 5 deletions modelopt/torch/quantization/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@
normalize_quant_cfg_list,
)
from .nn import (
NVFP4StaticQuantizer,
QuantModule,
QuantModuleRegistry,
SequentialQuantizer,
StaticBlockScaleQuantizer,
SVDQuantLinear,
TensorQuantizer,
)
Expand Down Expand Up @@ -131,10 +131,11 @@ def restore_quantizer_state(model: nn.Module, config: QuantizeConfig, metadata:
name = get_unwrapped_name(name, model)
state = quantizer_state_dict[name]
# TODO: Add a registry for TensorQuantizers and avoid this manual conversion.
if state.get("_is_nvfp4_static_quantizer") and not isinstance(
module, NVFP4StaticQuantizer
):
NVFP4StaticQuantizer.from_tensor_quantizer(module)
if (
state.get("_is_static_block_scale_quantizer")
or state.get("_is_nvfp4_static_quantizer") # legacy checkpoint compat
) and not isinstance(module, StaticBlockScaleQuantizer):
StaticBlockScaleQuantizer.from_tensor_quantizer(module)
module.set_from_modelopt_state(quantizer_state_dict[name])

for name, module in model.named_modules():
Expand Down
14 changes: 14 additions & 0 deletions modelopt/torch/quantization/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
AWQLiteCalibConfig,
CompressConfig,
GPTQCalibConfig,
LAQConfig,
LocalHessianCalibConfig,
MaxCalibConfig,
MseCalibConfig,
Expand All @@ -60,6 +61,7 @@
from .model_calib import (
awq,
gptq,
laq,
local_hessian_calibrate,
max_calibrate,
mse_calibrate,
Expand Down Expand Up @@ -502,3 +504,15 @@ def config_class(self) -> type[QuantizeAlgorithmConfig]:
return GPTQCalibConfig

_calib_func = gptq


@CalibrateModeRegistry.register_mode
class LAQModeDescriptor(BaseCalibrateModeDescriptor):
"""Mode for LAQ (Learnt Amax Quantization) algorithm."""

@property
def config_class(self) -> type[QuantizeAlgorithmConfig]:
"""Specifies the config class for the mode."""
return LAQConfig

_calib_func = laq
158 changes: 156 additions & 2 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,19 @@

from modelopt.torch.opt.searcher import ForwardLoop
from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector
from modelopt.torch.utils import print_rank_0
from modelopt.torch.utils import print_rank_0, same_device_as
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method

from .calib import MseCalibrator, NVFP4MSECalibrator
from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context
from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer
from .nn import (
NVFP4StaticQuantizer,
QuantModule,
SequentialQuantizer,
StaticBlockScaleQuantizer,
TensorQuantizer,
)
from .utils import (
disable_calib,
enable_fake_quant,
Expand All @@ -53,6 +59,8 @@

__all__ = [
"awq",
"gptq",
"laq",
"local_hessian_calibrate",
"max_calibrate",
"sequential_calibrate",
Expand Down Expand Up @@ -1671,3 +1679,149 @@ def gptq(
if torch.cuda.is_available():
torch.cuda.empty_cache()
print_rank_0(f"GPTQ time: {time.time() - total_start:.2f}s")


def _is_quantized_block_scale(quantizer: StaticBlockScaleQuantizer) -> bool:
if quantizer._block_sizes is None:
return False
scale_bits = quantizer._block_sizes.get("scale_bits", None)
if scale_bits is None:
return False
return scale_bits == (4, 3)


def _convert_to_static_block_quantizers(model: nn.Module):
"""Convert eligible TensorQuantizers to StaticBlockScaleQuantizer."""
for name, module in model.named_modules():
if isinstance(module, TensorQuantizer) and not module._disabled:
if not hasattr(module, "_amax") or module._amax is None:
continue
is_static_block_scale = (
module.is_static_block_quant
and module._block_sizes is not None
and (
(module._num_bits == (2, 1) and module._block_sizes.get("scale_bits") == (4, 3))
or isinstance(module._num_bits, int)
)
)
if is_static_block_scale:
if _is_quantized_block_scale(module):
global_amax = reduce_amax(module._amax.clone().detach(), axis=None)
else:
global_amax = None
StaticBlockScaleQuantizer.from_tensor_quantizer(module, global_amax=global_amax)


def _run_scale_calibration(model, forward_loop, scale_algorithm, caller_name):
"""Run calibration and convert to StaticBlockScaleQuantizer if needed."""
if scale_algorithm is None:
scale_algorithm = {"method": "mse"}

method = scale_algorithm.get("method")
supported = ("mse", "local_hessian", "max")
assert method in supported, f"{caller_name}: method must be one of {supported}, got '{method}'"

algo_kwargs = {k: v for k, v in scale_algorithm.items() if k != "method"}
calib_funcs = {
"mse": mse_calibrate,
"local_hessian": local_hessian_calibrate,
"max": max_calibrate,
}
calib_funcs[method](model, forward_loop=forward_loop, **algo_kwargs)

if method == "max":
_convert_to_static_block_quantizers(model)


def _compute_block_scales(quantizer):
"""Compute per-block and per-tensor scales from a StaticBlockScaleQuantizer.

Returns (per_block_scale, per_tensor_scale, quantize_scales).
"""
from .nn.modules.tensor_quantizer import _amax_to_scale
from .tensor_quant import scaled_e4m3

amax = quantizer._amax.float()
max_representable = quantizer._quant_max_bound
quantize_scales = _is_quantized_block_scale(quantizer)
per_tensor_scale = None

with same_device_as(amax):
if quantize_scales:
global_amax = quantizer._global_amax.float()
per_tensor_scale = _amax_to_scale(global_amax, max_representable)
per_block_scale = scaled_e4m3(
_amax_to_scale(
amax,
max_representable,
min_value=0.002
* per_tensor_scale.view(-1), # 0.002 ≈ smallest positive FP8 E4M3 value
),
per_tensor_scale,
None,
4,
3,
)
else:
per_block_scale = _amax_to_scale(amax, max_representable)

return per_block_scale, per_tensor_scale, quantize_scales


def _iter_weight_quantizers(model):
"""Yield (module, weight_name, quantizer) for each StaticBlockScaleQuantizer with amax."""
seen_modules = set()
for name, module in model.named_modules():
if module in seen_modules:
continue
for weight_name in weight_attr_names(module):
wq_name = quantizer_attr_names(weight_name).weight_quantizer
quantizer = getattr(module, wq_name, None)
if isinstance(quantizer, StaticBlockScaleQuantizer) and hasattr(quantizer, "_amax"):
seen_modules.add(module)
yield module, weight_name, quantizer
break


def _compute_laq_params(quantizer):
"""Compute amax and scale-quantization params for LAQ."""
per_block_scale, per_tensor_scale, quantize_scales = _compute_block_scales(quantizer)
amax = per_block_scale * quantizer._quant_max_bound
return amax, per_tensor_scale, quantize_scales


@torch.no_grad()
def laq(
model: nn.Module,
forward_loop: ForwardLoop | None = None,
scale_algorithm: dict | None = None,
learnable_amax: list | str = ("post",),
tied_amax: bool = False,
**kwargs,
):
"""Run scale calibration then convert to LAQ mode.

Uses separate pre (quant) and post (dequant) amax values.
Forward: ``w_q = Q_STE(w / s_pre) * s_post`` where ``s = amax / Q_max``.

Args:
model: Quantized model.
forward_loop: Calibration data forward loop.
scale_algorithm: Calibration algorithm config to run first.
Dict with 'method' key: 'mse', 'local_hessian', or 'max'.
Defaults to {'method': 'mse'} if None.
learnable_amax: Which amax params are learnable: 'pre', 'post',
['pre', 'post'], or [].
tied_amax: If True, pre and post share a single tensor.
"""
_run_scale_calibration(model, forward_loop, scale_algorithm, "laq")

for module, weight_name, quantizer in _iter_weight_quantizers(model):
amax, per_tensor_scale, quantize_scales = _compute_laq_params(quantizer)
quantizer.enable_laq(
amax,
per_tensor_scale,
quantize_scales,
learnable_amax=learnable_amax,
tied_amax=tied_amax,
)
Loading
Loading