From a587191b3365eb6a233720cefcf98fe239ee6f09 Mon Sep 17 00:00:00 2001 From: "Arty S." <248714260+arty-kk@users.noreply.github.com> Date: Mon, 30 Mar 2026 13:35:09 +0300 Subject: [PATCH] Avoid blocking drain when oversize chunk is terminal --- hyperquant/api/app.py | 104 ++++++++++++++++++++++++++++-------- hyperquant/api/models.py | 103 +++++++++++++++++++++-------------- hyperquant/cli.py | 62 +++++++++++++-------- hyperquant/context_codec.py | 38 ++++++++----- hyperquant/defaults.py | 40 ++++++++++++++ hyperquant/resident_tier.py | 63 ++++++++++++++-------- hyperquant/vector_codec.py | 17 ++++-- tests/array_builders.py | 56 +++++++++++++++++++ tests/test_api.py | 93 ++++++++++++++++++-------------- tests/test_cli.py | 33 ++++++++++++ tests/test_context_codec.py | 44 +-------------- tests/test_resident_tier.py | 42 +++++++++++++++ 12 files changed, 492 insertions(+), 203 deletions(-) create mode 100644 hyperquant/defaults.py create mode 100644 tests/array_builders.py create mode 100644 tests/test_cli.py diff --git a/hyperquant/api/app.py b/hyperquant/api/app.py index 2bd53ba..4cbcb91 100644 --- a/hyperquant/api/app.py +++ b/hyperquant/api/app.py @@ -54,6 +54,78 @@ DEFAULT_MAX_HTTP_BODY_OVERHEAD_BYTES = 1024 * 1024 +class _BodySizeLimitMiddleware: + def __init__(self, app, *, max_http_body_bytes: int, metrics: HyperQuantMetrics) -> None: + self.app = app + self.max_http_body_bytes = int(max_http_body_bytes) + self.metrics = metrics + + async def __call__(self, scope, receive, send): + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + body_too_large_detail = f"request body exceeds max_http_body_bytes={self.max_http_body_bytes}" + for key, value in scope.get("headers", []): + if key.lower() != b"content-length": + continue + try: + declared = int(value.decode("latin1")) + except ValueError: + declared = None + if declared is not None and declared > self.max_http_body_bytes: + self.metrics.observe_error("http", "request_too_large") + await JSONResponse(status_code=413, content={"detail": body_too_large_detail})(scope, receive, send) + return + + seen = 0 + too_large = False + sent_too_large_response = False + request_stream_ended = False + + async def drain_remaining_body() -> None: + nonlocal request_stream_ended + while True: + message = await receive() + if message.get("type") != "http.request": + return + request_stream_ended = not message.get("more_body", False) + if not message.get("more_body", False): + return + + async def guarded_receive(): + nonlocal seen, too_large, request_stream_ended + message = await receive() + if too_large: + return {"type": "http.request", "body": b"", "more_body": False} + if message.get("type") == "http.request": + request_stream_ended = not message.get("more_body", False) + seen += len(message.get("body", b"")) + if seen > self.max_http_body_bytes: + too_large = True + return {"type": "http.request", "body": b"", "more_body": False} + return message + + async def guarded_send(message): + nonlocal sent_too_large_response + if too_large: + if not sent_too_large_response: + sent_too_large_response = True + if not request_stream_ended: + await drain_remaining_body() + self.metrics.observe_error("http", "request_too_large") + await JSONResponse(status_code=413, content={"detail": body_too_large_detail})(scope, receive, send) + return + await send(message) + + await self.app(scope, guarded_receive, guarded_send) + if too_large and not sent_too_large_response: + if not request_stream_ended: + await drain_remaining_body() + self.metrics.observe_error("http", "request_too_large") + await JSONResponse(status_code=413, content={"detail": body_too_large_detail})(scope, receive, send) + + def _pydantic_model_to_dict(model) -> dict: if hasattr(model, "model_dump"): return model.model_dump() @@ -148,22 +220,10 @@ async def run_bound(fn): app.state.max_request_bytes = max_request_bytes app.state.max_http_body_bytes = max_http_body_bytes app.state.max_concurrency = resolved_max_concurrency + app.add_middleware(_BodySizeLimitMiddleware, max_http_body_bytes=max_http_body_bytes, metrics=metrics) - @app.middleware("http") - async def enforce_content_length(request, call_next): - content_length = request.headers.get("content-length") - if content_length is not None: - try: - size = int(content_length) - except ValueError: - size = None - if size is not None and size > max_http_body_bytes: - metrics.observe_error("http", "request_too_large") - return JSONResponse( - status_code=413, - content={"detail": f"request body exceeds max_http_body_bytes={max_http_body_bytes}"}, - ) - return await call_next(request) + def internal_server_error(_exc: Exception) -> HTTPException: + return HTTPException(status_code=500, detail="internal server error") @app.get("/healthz", response_model=HealthResponse) async def healthz() -> HealthResponse: @@ -205,7 +265,7 @@ def do_compress(): raise HTTPException(status_code=400, detail=str(exc)) from exc except Exception as exc: # pragma: no cover - FastAPI behavior tested through endpoint metrics.observe_error("compress", "internal_error") - raise HTTPException(status_code=400, detail=str(exc)) from exc + raise internal_server_error(exc) from exc metrics.observe_compress(stats, latency_seconds=time.perf_counter() - started) return CodebookCompressResponse( envelope_b64=envelope.to_base64(), @@ -227,7 +287,7 @@ def do_decompress(): raise HTTPException(status_code=400, detail=str(exc)) from exc except Exception as exc: # pragma: no cover metrics.observe_error("decompress", "internal_error") - raise HTTPException(status_code=400, detail=str(exc)) from exc + raise internal_server_error(exc) from exc metrics.observe_decompress(latency_seconds=time.perf_counter() - started) return DecompressResponse(array_b64=ndarray_to_b64(restored)) @@ -247,7 +307,7 @@ def do_vector_compress(): raise HTTPException(status_code=400, detail=str(exc)) from exc except Exception as exc: # pragma: no cover metrics.observe_error("vector_compress", "internal_error") - raise HTTPException(status_code=400, detail=str(exc)) from exc + raise internal_server_error(exc) from exc metrics.observe_vector_compress(stats, latency_seconds=time.perf_counter() - started) return VectorCompressResponse( envelope_b64=envelope.to_base64(), @@ -270,7 +330,7 @@ def do_vector_decompress(): raise HTTPException(status_code=400, detail=str(exc)) from exc except Exception as exc: # pragma: no cover metrics.observe_error("vector_decompress", "internal_error") - raise HTTPException(status_code=400, detail=str(exc)) from exc + raise internal_server_error(exc) from exc metrics.observe_vector_decompress(latency_seconds=time.perf_counter() - started) return DecompressResponse(array_b64=ndarray_to_b64(restored)) @@ -316,7 +376,7 @@ def do_resident_plan(): raise HTTPException(status_code=400, detail=str(exc)) from exc except Exception as exc: # pragma: no cover metrics.observe_error("resident_plan", "internal_error") - raise HTTPException(status_code=400, detail=str(exc)) from exc + raise internal_server_error(exc) from exc metrics.observe_resident_plan(plan, latency_seconds=time.perf_counter() - started) return ResidentPlanResponse(plan=plan.to_dict()) @@ -370,7 +430,7 @@ def do_context_compress(): raise HTTPException(status_code=400, detail=str(exc)) from exc except Exception as exc: # pragma: no cover metrics.observe_error("context_compress", "internal_error") - raise HTTPException(status_code=400, detail=str(exc)) from exc + raise internal_server_error(exc) from exc metrics.observe_context_compress(stats, latency_seconds=time.perf_counter() - started) return ContextCompressResponse( envelope_b64=envelope.to_base64(), @@ -395,7 +455,7 @@ def do_context_decompress(): raise HTTPException(status_code=400, detail=str(exc)) from exc except Exception as exc: # pragma: no cover metrics.observe_error("context_decompress", "internal_error") - raise HTTPException(status_code=400, detail=str(exc)) from exc + raise internal_server_error(exc) from exc metrics.observe_context_decompress(latency_seconds=time.perf_counter() - started) return DecompressResponse(array_b64=ndarray_to_b64(restored)) diff --git a/hyperquant/api/models.py b/hyperquant/api/models.py index 2afc992..1a91f61 100644 --- a/hyperquant/api/models.py +++ b/hyperquant/api/models.py @@ -18,6 +18,31 @@ from pydantic import BaseModel, Field +from ..defaults import ( + CONTEXT_ENABLE_INT8_FALLBACK_DEFAULT, + CONTEXT_ENABLE_PAGE_REF_DEFAULT, + CONTEXT_INT8_MAX_ABS_THRESHOLD_DEFAULT, + CONTEXT_INT8_REL_RMS_THRESHOLD_DEFAULT, + CONTEXT_LOW_RANK_ERROR_THRESHOLD_DEFAULT, + CONTEXT_PAGE_REF_REL_RMS_THRESHOLD_DEFAULT, + CONTEXT_PAGE_SIZE_DEFAULT, + CONTEXT_PREFIX_KEEP_VECTORS_DEFAULT, + CONTEXT_RANK_DEFAULT, + CONTEXT_REF_ROUND_DECIMALS_DEFAULT, + CONTEXT_SUFFIX_KEEP_VECTORS_DEFAULT, + CONTEXT_TRY_INT8_FOR_PROTECTED_DEFAULT, + RESIDENT_ACTIVE_WINDOW_TOKENS_DEFAULT, + RESIDENT_ALLOW_VECTOR_FOR_PROTECTED_DEFAULT, + RESIDENT_CONCURRENT_SESSIONS_DEFAULT, + RESIDENT_HOT_PAGES_DEFAULT, + RESIDENT_RUNTIME_VALUE_BYTES_DEFAULT, + VECTOR_BITS_DEFAULT, + VECTOR_GROUP_SIZE_DEFAULT, + VECTOR_PREFER_NATIVE_FWHT_DEFAULT, + VECTOR_RESIDUAL_TOPK_DEFAULT, + VECTOR_ROTATION_SEED_DEFAULT, +) + class CodebookCompressRequest(BaseModel): array_b64: str = Field(..., description="Base64-encoded .npy payload.") @@ -30,11 +55,11 @@ class DecompressRequest(BaseModel): class VectorCompressRequest(BaseModel): array_b64: str = Field(..., description="Base64-encoded .npy payload.") - bits: int = Field(default=3, ge=2, le=4) - group_size: int = Field(default=128, gt=0) - rotation_seed: int = Field(default=17) - residual_topk: int = Field(default=1, ge=0) - prefer_native_fwht: bool = True + bits: int = Field(default=VECTOR_BITS_DEFAULT, ge=2, le=4) + group_size: int = Field(default=VECTOR_GROUP_SIZE_DEFAULT, gt=0) + rotation_seed: int = Field(default=VECTOR_ROTATION_SEED_DEFAULT) + residual_topk: int = Field(default=VECTOR_RESIDUAL_TOPK_DEFAULT, ge=0) + prefer_native_fwht: bool = VECTOR_PREFER_NATIVE_FWHT_DEFAULT class ContextGuaranteeModel(BaseModel): @@ -47,18 +72,18 @@ class ContextGuaranteeModel(BaseModel): class ContextCompressRequest(BaseModel): array_b64: str = Field(..., description="Base64-encoded .npy payload.") protected_vector_indices: List[int] = Field(default_factory=list) - page_size: int = Field(default=64, gt=0) - rank: int = Field(default=1, gt=0) - prefix_keep_vectors: int = Field(default=32, ge=0) - suffix_keep_vectors: int = Field(default=64, ge=0) - low_rank_error_threshold: float = Field(default=0.03, ge=0.0) - ref_round_decimals: int = Field(default=3, ge=0) - enable_page_ref: bool = True - page_ref_rel_rms_threshold: float = Field(default=0.005, ge=0.0) - enable_int8_fallback: bool = True - try_int8_for_protected: bool = True - int8_rel_rms_threshold: float = Field(default=0.01, ge=0.0) - int8_max_abs_threshold: float = Field(default=0.05, ge=0.0) + page_size: int = Field(default=CONTEXT_PAGE_SIZE_DEFAULT, gt=0) + rank: int = Field(default=CONTEXT_RANK_DEFAULT, gt=0) + prefix_keep_vectors: int = Field(default=CONTEXT_PREFIX_KEEP_VECTORS_DEFAULT, ge=0) + suffix_keep_vectors: int = Field(default=CONTEXT_SUFFIX_KEEP_VECTORS_DEFAULT, ge=0) + low_rank_error_threshold: float = Field(default=CONTEXT_LOW_RANK_ERROR_THRESHOLD_DEFAULT, ge=0.0) + ref_round_decimals: int = Field(default=CONTEXT_REF_ROUND_DECIMALS_DEFAULT, ge=0) + enable_page_ref: bool = CONTEXT_ENABLE_PAGE_REF_DEFAULT + page_ref_rel_rms_threshold: float = Field(default=CONTEXT_PAGE_REF_REL_RMS_THRESHOLD_DEFAULT, ge=0.0) + enable_int8_fallback: bool = CONTEXT_ENABLE_INT8_FALLBACK_DEFAULT + try_int8_for_protected: bool = CONTEXT_TRY_INT8_FOR_PROTECTED_DEFAULT + int8_rel_rms_threshold: float = Field(default=CONTEXT_INT8_REL_RMS_THRESHOLD_DEFAULT, ge=0.0) + int8_max_abs_threshold: float = Field(default=CONTEXT_INT8_MAX_ABS_THRESHOLD_DEFAULT, ge=0.0) fail_closed: bool = True guarantee: ContextGuaranteeModel | None = None @@ -120,29 +145,29 @@ class ContextCompressionStatsModel(BaseModel): class ResidentPlanRequest(BaseModel): array_b64: str = Field(..., description="Base64-encoded .npy payload.") - concurrent_sessions: int = Field(default=8, gt=0) - active_window_tokens: int = Field(default=256, gt=0) - runtime_value_bytes: int = Field(default=2, gt=0) + concurrent_sessions: int = Field(default=RESIDENT_CONCURRENT_SESSIONS_DEFAULT, gt=0) + active_window_tokens: int = Field(default=RESIDENT_ACTIVE_WINDOW_TOKENS_DEFAULT, gt=0) + runtime_value_bytes: int = Field(default=RESIDENT_RUNTIME_VALUE_BYTES_DEFAULT, gt=0) budget_bytes: int | None = Field(default=None, gt=0) - page_size: int = Field(default=64, gt=0) - rank: int = Field(default=1, gt=0) - bits: int = Field(default=3, ge=2, le=4) - group_size: int = Field(default=128, gt=0) - hot_pages: int = Field(default=8, gt=0) - rotation_seed: int = Field(default=17) - residual_topk: int = Field(default=1, ge=0) - prefix_keep_vectors: int = Field(default=32, ge=0) - suffix_keep_vectors: int = Field(default=64, ge=0) - low_rank_error_threshold: float = Field(default=0.03, ge=0.0) - ref_round_decimals: int = Field(default=3, ge=0) - enable_page_ref: bool = True - page_ref_rel_rms_threshold: float = Field(default=0.005, ge=0.0) - enable_int8_fallback: bool = True - try_int8_for_protected: bool = True - int8_rel_rms_threshold: float = Field(default=0.01, ge=0.0) - int8_max_abs_threshold: float = Field(default=0.05, ge=0.0) - prefer_native_fwht: bool = True - allow_vector_for_protected: bool = False + page_size: int = Field(default=CONTEXT_PAGE_SIZE_DEFAULT, gt=0) + rank: int = Field(default=CONTEXT_RANK_DEFAULT, gt=0) + bits: int = Field(default=VECTOR_BITS_DEFAULT, ge=2, le=4) + group_size: int = Field(default=VECTOR_GROUP_SIZE_DEFAULT, gt=0) + hot_pages: int = Field(default=RESIDENT_HOT_PAGES_DEFAULT, gt=0) + rotation_seed: int = Field(default=VECTOR_ROTATION_SEED_DEFAULT) + residual_topk: int = Field(default=VECTOR_RESIDUAL_TOPK_DEFAULT, ge=0) + prefix_keep_vectors: int = Field(default=CONTEXT_PREFIX_KEEP_VECTORS_DEFAULT, ge=0) + suffix_keep_vectors: int = Field(default=CONTEXT_SUFFIX_KEEP_VECTORS_DEFAULT, ge=0) + low_rank_error_threshold: float = Field(default=CONTEXT_LOW_RANK_ERROR_THRESHOLD_DEFAULT, ge=0.0) + ref_round_decimals: int = Field(default=CONTEXT_REF_ROUND_DECIMALS_DEFAULT, ge=0) + enable_page_ref: bool = CONTEXT_ENABLE_PAGE_REF_DEFAULT + page_ref_rel_rms_threshold: float = Field(default=CONTEXT_PAGE_REF_REL_RMS_THRESHOLD_DEFAULT, ge=0.0) + enable_int8_fallback: bool = CONTEXT_ENABLE_INT8_FALLBACK_DEFAULT + try_int8_for_protected: bool = CONTEXT_TRY_INT8_FOR_PROTECTED_DEFAULT + int8_rel_rms_threshold: float = Field(default=CONTEXT_INT8_REL_RMS_THRESHOLD_DEFAULT, ge=0.0) + int8_max_abs_threshold: float = Field(default=CONTEXT_INT8_MAX_ABS_THRESHOLD_DEFAULT, ge=0.0) + prefer_native_fwht: bool = VECTOR_PREFER_NATIVE_FWHT_DEFAULT + allow_vector_for_protected: bool = RESIDENT_ALLOW_VECTOR_FOR_PROTECTED_DEFAULT class ResidentPlanResponse(BaseModel): diff --git a/hyperquant/cli.py b/hyperquant/cli.py index 94682da..382be0c 100644 --- a/hyperquant/cli.py +++ b/hyperquant/cli.py @@ -48,6 +48,25 @@ ContextCodecConfig, ContextCodec, ) +from .defaults import ( + CONTEXT_INT8_MAX_ABS_THRESHOLD_DEFAULT, + CONTEXT_INT8_REL_RMS_THRESHOLD_DEFAULT, + CONTEXT_LOW_RANK_ERROR_THRESHOLD_DEFAULT, + CONTEXT_PAGE_REF_REL_RMS_THRESHOLD_DEFAULT, + CONTEXT_PAGE_SIZE_DEFAULT, + CONTEXT_PREFIX_KEEP_VECTORS_DEFAULT, + CONTEXT_RANK_DEFAULT, + CONTEXT_REF_ROUND_DECIMALS_DEFAULT, + CONTEXT_SUFFIX_KEEP_VECTORS_DEFAULT, + RESIDENT_ACTIVE_WINDOW_TOKENS_DEFAULT, + RESIDENT_CONCURRENT_SESSIONS_DEFAULT, + RESIDENT_HOT_PAGES_DEFAULT, + RESIDENT_RUNTIME_VALUE_BYTES_DEFAULT, + VECTOR_BITS_DEFAULT, + VECTOR_GROUP_SIZE_DEFAULT, + VECTOR_RESIDUAL_TOPK_DEFAULT, + VECTOR_ROTATION_SEED_DEFAULT, +) from .resident_tier import ( ResidentTierConfig, ResidentPlanner, @@ -206,7 +225,7 @@ def cmd_context_compress_file(args: argparse.Namespace) -> int: def cmd_context_decompress_file(args: argparse.Namespace) -> int: envelope = ContextEnvelope.from_bytes(Path(args.input).read_bytes()) - compressor = ContextCodec(_build_context_config(args)) + compressor = ContextCodec(ContextCodecConfig(page_size=envelope.page_size, rank=envelope.rank)) restored = compressor.decompress(envelope) np.save(args.output, restored, allow_pickle=False) print(f"saved reconstructed array to {args.output}") @@ -474,18 +493,18 @@ def cmd_serve(args: argparse.Namespace) -> int: def _add_context_args(parser: argparse.ArgumentParser) -> None: - parser.add_argument("--page-size", type=int, default=64) - parser.add_argument("--rank", type=int, default=1) - parser.add_argument("--prefix-keep-vectors", type=int, default=32) - parser.add_argument("--suffix-keep-vectors", type=int, default=64) - parser.add_argument("--low-rank-error-threshold", type=float, default=0.03) - parser.add_argument("--ref-round-decimals", type=int, default=3) - parser.add_argument("--page-ref-rel-rms-threshold", type=float, default=0.005) + parser.add_argument("--page-size", type=int, default=CONTEXT_PAGE_SIZE_DEFAULT) + parser.add_argument("--rank", type=int, default=CONTEXT_RANK_DEFAULT) + parser.add_argument("--prefix-keep-vectors", type=int, default=CONTEXT_PREFIX_KEEP_VECTORS_DEFAULT) + parser.add_argument("--suffix-keep-vectors", type=int, default=CONTEXT_SUFFIX_KEEP_VECTORS_DEFAULT) + parser.add_argument("--low-rank-error-threshold", type=float, default=CONTEXT_LOW_RANK_ERROR_THRESHOLD_DEFAULT) + parser.add_argument("--ref-round-decimals", type=int, default=CONTEXT_REF_ROUND_DECIMALS_DEFAULT) + parser.add_argument("--page-ref-rel-rms-threshold", type=float, default=CONTEXT_PAGE_REF_REL_RMS_THRESHOLD_DEFAULT) parser.add_argument("--disable-page-ref", action="store_true") parser.add_argument("--disable-int8-fallback", action="store_true") parser.add_argument("--disable-int8-for-protected", action="store_true") - parser.add_argument("--int8-rel-rms-threshold", type=float, default=0.01) - parser.add_argument("--int8-max-abs-threshold", type=float, default=0.05) + parser.add_argument("--int8-rel-rms-threshold", type=float, default=CONTEXT_INT8_REL_RMS_THRESHOLD_DEFAULT) + parser.add_argument("--int8-max-abs-threshold", type=float, default=CONTEXT_INT8_MAX_ABS_THRESHOLD_DEFAULT) @@ -506,10 +525,10 @@ def _add_timing_args(parser: argparse.ArgumentParser) -> None: def _add_vector_args(parser: argparse.ArgumentParser) -> None: - parser.add_argument("--bits", type=int, default=3) - parser.add_argument("--group-size", type=int, default=128) - parser.add_argument("--rotation-seed", type=int, default=17) - parser.add_argument("--residual-topk", type=int, default=1, help="Number of rotated coefficients per group stored exactly as a residual rescue side-channel.") + parser.add_argument("--bits", type=int, default=VECTOR_BITS_DEFAULT) + parser.add_argument("--group-size", type=int, default=VECTOR_GROUP_SIZE_DEFAULT) + parser.add_argument("--rotation-seed", type=int, default=VECTOR_ROTATION_SEED_DEFAULT) + parser.add_argument("--residual-topk", type=int, default=VECTOR_RESIDUAL_TOPK_DEFAULT, help="Number of rotated coefficients per group stored exactly as a residual rescue side-channel.") parser.add_argument("--disable-native-fwht", action="store_true") @@ -517,7 +536,7 @@ def _add_vector_args(parser: argparse.ArgumentParser) -> None: def _add_resident_tier_args(parser: argparse.ArgumentParser) -> None: _add_context_args(parser) _add_vector_args(parser) - parser.add_argument("--hot-pages", type=int, default=8) + parser.add_argument("--hot-pages", type=int, default=RESIDENT_HOT_PAGES_DEFAULT) parser.add_argument("--allow-vector-for-protected", action="store_true") @@ -611,7 +630,6 @@ def build_parser() -> argparse.ArgumentParser: context_decompress = sub.add_parser("context-decompress-file") context_decompress.add_argument("--input", required=True) context_decompress.add_argument("--output", required=True) - _add_context_args(context_decompress) context_decompress.set_defaults(func=cmd_context_decompress_file) context_benchmark = sub.add_parser("context-benchmark") @@ -643,9 +661,9 @@ def build_parser() -> argparse.ArgumentParser: resident_plan = sub.add_parser("resident-plan") resident_plan.add_argument("--input", required=True) resident_plan.add_argument("--output") - resident_plan.add_argument("--concurrent-sessions", type=int, default=8) - resident_plan.add_argument("--active-window-tokens", type=int, default=256) - resident_plan.add_argument("--runtime-value-bytes", type=int, default=2) + resident_plan.add_argument("--concurrent-sessions", type=int, default=RESIDENT_CONCURRENT_SESSIONS_DEFAULT) + resident_plan.add_argument("--active-window-tokens", type=int, default=RESIDENT_ACTIVE_WINDOW_TOKENS_DEFAULT) + resident_plan.add_argument("--runtime-value-bytes", type=int, default=RESIDENT_RUNTIME_VALUE_BYTES_DEFAULT) resident_plan.add_argument("--budget-bytes", type=int) _add_resident_tier_args(resident_plan) resident_plan.set_defaults(func=cmd_resident_plan) @@ -673,9 +691,9 @@ def build_parser() -> argparse.ArgumentParser: resident_benchmark.add_argument("--structured-tokens", type=int, default=16384) resident_benchmark.add_argument("--mixed-tokens", type=int, default=32768) resident_benchmark.add_argument("--dim", type=int, default=128) - resident_benchmark.add_argument("--concurrent-sessions", type=int, default=8) - resident_benchmark.add_argument("--active-window-tokens", type=int, default=256) - resident_benchmark.add_argument("--runtime-value-bytes", type=int, default=2) + resident_benchmark.add_argument("--concurrent-sessions", type=int, default=RESIDENT_CONCURRENT_SESSIONS_DEFAULT) + resident_benchmark.add_argument("--active-window-tokens", type=int, default=RESIDENT_ACTIVE_WINDOW_TOKENS_DEFAULT) + resident_benchmark.add_argument("--runtime-value-bytes", type=int, default=RESIDENT_RUNTIME_VALUE_BYTES_DEFAULT) resident_benchmark.add_argument("--slice-iterations", type=int, default=5) resident_benchmark.add_argument("--seed", type=int, default=20260329) resident_benchmark.add_argument("--json-output") diff --git a/hyperquant/context_codec.py b/hyperquant/context_codec.py index eeb25ab..ca31097 100644 --- a/hyperquant/context_codec.py +++ b/hyperquant/context_codec.py @@ -24,6 +24,20 @@ import numpy as np from .contour import ContourAnalysis, ContourThresholds, ProductContour, analyze_context_contour +from .defaults import ( + CONTEXT_ENABLE_INT8_FALLBACK_DEFAULT, + CONTEXT_ENABLE_PAGE_REF_DEFAULT, + CONTEXT_INT8_MAX_ABS_THRESHOLD_DEFAULT, + CONTEXT_INT8_REL_RMS_THRESHOLD_DEFAULT, + CONTEXT_LOW_RANK_ERROR_THRESHOLD_DEFAULT, + CONTEXT_PAGE_REF_REL_RMS_THRESHOLD_DEFAULT, + CONTEXT_PAGE_SIZE_DEFAULT, + CONTEXT_PREFIX_KEEP_VECTORS_DEFAULT, + CONTEXT_RANK_DEFAULT, + CONTEXT_REF_ROUND_DECIMALS_DEFAULT, + CONTEXT_SUFFIX_KEEP_VECTORS_DEFAULT, + CONTEXT_TRY_INT8_FOR_PROTECTED_DEFAULT, +) from .guarantee import ( ContourViolation, GuaranteeMode, @@ -44,18 +58,18 @@ class ContextPageMode(IntEnum): @dataclass(frozen=True) class ContextCodecConfig: - page_size: int = 64 - rank: int = 1 - prefix_keep_vectors: int = 32 - suffix_keep_vectors: int = 64 - low_rank_error_threshold: float = 0.03 - ref_round_decimals: int = 3 - enable_page_ref: bool = True - page_ref_rel_rms_threshold: float = 0.005 - enable_int8_fallback: bool = True - try_int8_for_protected: bool = True - int8_rel_rms_threshold: float = 0.01 - int8_max_abs_threshold: float = 0.05 + page_size: int = CONTEXT_PAGE_SIZE_DEFAULT + rank: int = CONTEXT_RANK_DEFAULT + prefix_keep_vectors: int = CONTEXT_PREFIX_KEEP_VECTORS_DEFAULT + suffix_keep_vectors: int = CONTEXT_SUFFIX_KEEP_VECTORS_DEFAULT + low_rank_error_threshold: float = CONTEXT_LOW_RANK_ERROR_THRESHOLD_DEFAULT + ref_round_decimals: int = CONTEXT_REF_ROUND_DECIMALS_DEFAULT + enable_page_ref: bool = CONTEXT_ENABLE_PAGE_REF_DEFAULT + page_ref_rel_rms_threshold: float = CONTEXT_PAGE_REF_REL_RMS_THRESHOLD_DEFAULT + enable_int8_fallback: bool = CONTEXT_ENABLE_INT8_FALLBACK_DEFAULT + try_int8_for_protected: bool = CONTEXT_TRY_INT8_FOR_PROTECTED_DEFAULT + int8_rel_rms_threshold: float = CONTEXT_INT8_REL_RMS_THRESHOLD_DEFAULT + int8_max_abs_threshold: float = CONTEXT_INT8_MAX_ABS_THRESHOLD_DEFAULT def validate(self) -> None: if self.page_size <= 0: diff --git a/hyperquant/defaults.py b/hyperquant/defaults.py new file mode 100644 index 0000000..cd9e34b --- /dev/null +++ b/hyperquant/defaults.py @@ -0,0 +1,40 @@ +# Copyright 2026 Сацук Артём Венедиктович (Satsuk Artem) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +VECTOR_BITS_DEFAULT = 3 +VECTOR_GROUP_SIZE_DEFAULT = 128 +VECTOR_ROTATION_SEED_DEFAULT = 17 +VECTOR_RESIDUAL_TOPK_DEFAULT = 1 +VECTOR_PREFER_NATIVE_FWHT_DEFAULT = True + +CONTEXT_PAGE_SIZE_DEFAULT = 64 +CONTEXT_RANK_DEFAULT = 1 +CONTEXT_PREFIX_KEEP_VECTORS_DEFAULT = 32 +CONTEXT_SUFFIX_KEEP_VECTORS_DEFAULT = 64 +CONTEXT_LOW_RANK_ERROR_THRESHOLD_DEFAULT = 0.03 +CONTEXT_REF_ROUND_DECIMALS_DEFAULT = 3 +CONTEXT_ENABLE_PAGE_REF_DEFAULT = True +CONTEXT_PAGE_REF_REL_RMS_THRESHOLD_DEFAULT = 0.005 +CONTEXT_ENABLE_INT8_FALLBACK_DEFAULT = True +CONTEXT_TRY_INT8_FOR_PROTECTED_DEFAULT = True +CONTEXT_INT8_REL_RMS_THRESHOLD_DEFAULT = 0.01 +CONTEXT_INT8_MAX_ABS_THRESHOLD_DEFAULT = 0.05 + +RESIDENT_HOT_PAGES_DEFAULT = 8 +RESIDENT_ALLOW_VECTOR_FOR_PROTECTED_DEFAULT = False +RESIDENT_CONCURRENT_SESSIONS_DEFAULT = 8 +RESIDENT_ACTIVE_WINDOW_TOKENS_DEFAULT = 256 +RESIDENT_RUNTIME_VALUE_BYTES_DEFAULT = 2 diff --git a/hyperquant/resident_tier.py b/hyperquant/resident_tier.py index 66583a5..a9a9021 100644 --- a/hyperquant/resident_tier.py +++ b/hyperquant/resident_tier.py @@ -26,7 +26,28 @@ import numpy as np -from .guarantee import GuaranteeMode +from .defaults import ( + CONTEXT_ENABLE_INT8_FALLBACK_DEFAULT, + CONTEXT_ENABLE_PAGE_REF_DEFAULT, + CONTEXT_INT8_MAX_ABS_THRESHOLD_DEFAULT, + CONTEXT_INT8_REL_RMS_THRESHOLD_DEFAULT, + CONTEXT_LOW_RANK_ERROR_THRESHOLD_DEFAULT, + CONTEXT_PAGE_REF_REL_RMS_THRESHOLD_DEFAULT, + CONTEXT_PAGE_SIZE_DEFAULT, + CONTEXT_PREFIX_KEEP_VECTORS_DEFAULT, + CONTEXT_RANK_DEFAULT, + CONTEXT_REF_ROUND_DECIMALS_DEFAULT, + CONTEXT_SUFFIX_KEEP_VECTORS_DEFAULT, + CONTEXT_TRY_INT8_FOR_PROTECTED_DEFAULT, + RESIDENT_ALLOW_VECTOR_FOR_PROTECTED_DEFAULT, + RESIDENT_HOT_PAGES_DEFAULT, + VECTOR_BITS_DEFAULT, + VECTOR_GROUP_SIZE_DEFAULT, + VECTOR_PREFER_NATIVE_FWHT_DEFAULT, + VECTOR_RESIDUAL_TOPK_DEFAULT, + VECTOR_ROTATION_SEED_DEFAULT, +) +from .guarantee import ContourViolation, GuaranteeMode, GuaranteeViolation from .context_codec import ContextCodecConfig, ContextCodec from .vector_codec import VectorCodec from .utils import sha256_hex @@ -43,25 +64,25 @@ class ResidentPageMode(StrEnum): @dataclass(frozen=True) class ResidentTierConfig: - page_size: int = 64 - rank: int = 1 - bits: int = 3 - group_size: int = 128 - rotation_seed: int = 17 - residual_topk: int = 1 - hot_pages: int = 8 - prefix_keep_vectors: int = 32 - suffix_keep_vectors: int = 64 - low_rank_error_threshold: float = 0.03 - ref_round_decimals: int = 3 - enable_page_ref: bool = True - page_ref_rel_rms_threshold: float = 0.005 - enable_int8_fallback: bool = True - try_int8_for_protected: bool = True - int8_rel_rms_threshold: float = 0.01 - int8_max_abs_threshold: float = 0.05 - prefer_native_fwht: bool = True - allow_vector_for_protected: bool = False + page_size: int = CONTEXT_PAGE_SIZE_DEFAULT + rank: int = CONTEXT_RANK_DEFAULT + bits: int = VECTOR_BITS_DEFAULT + group_size: int = VECTOR_GROUP_SIZE_DEFAULT + rotation_seed: int = VECTOR_ROTATION_SEED_DEFAULT + residual_topk: int = VECTOR_RESIDUAL_TOPK_DEFAULT + hot_pages: int = RESIDENT_HOT_PAGES_DEFAULT + prefix_keep_vectors: int = CONTEXT_PREFIX_KEEP_VECTORS_DEFAULT + suffix_keep_vectors: int = CONTEXT_SUFFIX_KEEP_VECTORS_DEFAULT + low_rank_error_threshold: float = CONTEXT_LOW_RANK_ERROR_THRESHOLD_DEFAULT + ref_round_decimals: int = CONTEXT_REF_ROUND_DECIMALS_DEFAULT + enable_page_ref: bool = CONTEXT_ENABLE_PAGE_REF_DEFAULT + page_ref_rel_rms_threshold: float = CONTEXT_PAGE_REF_REL_RMS_THRESHOLD_DEFAULT + enable_int8_fallback: bool = CONTEXT_ENABLE_INT8_FALLBACK_DEFAULT + try_int8_for_protected: bool = CONTEXT_TRY_INT8_FOR_PROTECTED_DEFAULT + int8_rel_rms_threshold: float = CONTEXT_INT8_REL_RMS_THRESHOLD_DEFAULT + int8_max_abs_threshold: float = CONTEXT_INT8_MAX_ABS_THRESHOLD_DEFAULT + prefer_native_fwht: bool = VECTOR_PREFER_NATIVE_FWHT_DEFAULT + allow_vector_for_protected: bool = RESIDENT_ALLOW_VECTOR_FOR_PROTECTED_DEFAULT def validate(self) -> None: if self.page_size <= 0: @@ -882,7 +903,7 @@ def plan( context_error = None try: _, context_stats = self._context.compress(validated, guarantee_mode=GuaranteeMode.ALLOW_BEST_EFFORT) - except Exception as exc: + except (ValueError, ContourViolation, GuaranteeViolation) as exc: context_error = str(exc) candidates: dict[str, dict[str, object]] = { diff --git a/hyperquant/vector_codec.py b/hyperquant/vector_codec.py index 8bdad69..10cfbbe 100644 --- a/hyperquant/vector_codec.py +++ b/hyperquant/vector_codec.py @@ -21,6 +21,13 @@ import numpy as np from .compat import StrEnum +from .defaults import ( + VECTOR_BITS_DEFAULT, + VECTOR_GROUP_SIZE_DEFAULT, + VECTOR_PREFER_NATIVE_FWHT_DEFAULT, + VECTOR_RESIDUAL_TOPK_DEFAULT, + VECTOR_ROTATION_SEED_DEFAULT, +) from .native_core import fwht_rows, native_fwht_status from .utils import EPS, bytes_from_b64, bytes_to_b64, sha256_hex from .validation import ShapeLimits, validate_float_dtype, validate_numeric_finite_array, validate_shape @@ -49,14 +56,14 @@ class RotationKind(StrEnum): @dataclass(frozen=True) class RotatedScalarConfig: - bits: int = 3 - group_size: int = 128 + bits: int = VECTOR_BITS_DEFAULT + group_size: int = VECTOR_GROUP_SIZE_DEFAULT rotation_kind: RotationKind = RotationKind.STRUCTURED_FWHT - rotation_seed: int = 17 + rotation_seed: int = VECTOR_ROTATION_SEED_DEFAULT normalize: bool = True - prefer_native_fwht: bool = True + prefer_native_fwht: bool = VECTOR_PREFER_NATIVE_FWHT_DEFAULT profile_name: str = "vector_codec" - residual_topk: int = 1 + residual_topk: int = VECTOR_RESIDUAL_TOPK_DEFAULT def validate(self) -> None: if self.bits not in _NORMAL_LLOYD_MAX: diff --git a/tests/array_builders.py b/tests/array_builders.py new file mode 100644 index 0000000..38d24a0 --- /dev/null +++ b/tests/array_builders.py @@ -0,0 +1,56 @@ +# Copyright 2026 Сацук Артём Венедиктович (Satsuk Artem) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import numpy as np + + +def build_context_like_array(n_tokens: int = 4096, dim: int = 128, page_size: int = 64) -> np.ndarray: + rng = np.random.default_rng(123) + n_pages = n_tokens // page_size + pages: list[np.ndarray] = [] + templates: list[np.ndarray] = [] + + for _ in range(4): + topic = rng.standard_normal((1, dim)).astype(np.float32) + coeff = rng.standard_normal((page_size, 1)).astype(np.float32) * 0.7 + basis = rng.standard_normal((1, dim)).astype(np.float32) + page = topic + coeff @ basis + 0.002 * rng.standard_normal((page_size, dim)).astype(np.float32) + templates.append(page.astype(np.float32)) + + for page_idx in range(n_pages): + if page_idx < 4: + pages.append(templates[page_idx]) + elif page_idx < 16: + pages.append(templates[page_idx % 4].copy()) + elif page_idx >= n_pages - 1: + topic = rng.standard_normal((1, dim)).astype(np.float32) + coeff = rng.standard_normal((page_size, 3)).astype(np.float32) + basis = rng.standard_normal((3, dim)).astype(np.float32) + page = topic + coeff @ basis + 0.03 * rng.standard_normal((page_size, dim)).astype(np.float32) + pages.append(page.astype(np.float32)) + else: + topic = rng.standard_normal((1, dim)).astype(np.float32) + coeff = rng.standard_normal((page_size, 1)).astype(np.float32) + basis = rng.standard_normal((1, dim)).astype(np.float32) + page = topic + coeff @ basis + 0.008 * rng.standard_normal((page_size, dim)).astype(np.float32) + pages.append(page.astype(np.float32)) + + return np.concatenate(pages, axis=0).astype(np.float32) + + +def build_random_array(n_tokens: int = 4096, dim: int = 128) -> np.ndarray: + rng = np.random.default_rng(999) + return rng.standard_normal((n_tokens, dim)).astype(np.float32) diff --git a/tests/test_api.py b/tests/test_api.py index ad517a1..579d001 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -19,13 +19,16 @@ import httpx import numpy as np +import pytest +from hyperquant.api import app as api_app from hyperquant.api.app import create_app from hyperquant.bundle import CodebookBundle from hyperquant.codebook import MiniBatchKMeansTrainer from hyperquant.config import CodebookConfig from hyperquant.context_codec import ContextEnvelope from hyperquant.utils import bytes_to_b64, ndarray_from_b64, ndarray_to_b64 +from tests.array_builders import build_context_like_array, build_random_array def build_demo_array() -> np.ndarray: @@ -33,45 +36,6 @@ def build_demo_array() -> np.ndarray: return rng.standard_normal((64, 128)).astype(np.float32) -def build_context_like_array(n_tokens: int = 4096, dim: int = 128, page_size: int = 64) -> np.ndarray: - rng = np.random.default_rng(123) - n_pages = n_tokens // page_size - pages: list[np.ndarray] = [] - templates: list[np.ndarray] = [] - - for _ in range(4): - topic = rng.standard_normal((1, dim)).astype(np.float32) - coeff = rng.standard_normal((page_size, 1)).astype(np.float32) * 0.7 - basis = rng.standard_normal((1, dim)).astype(np.float32) - page = topic + coeff @ basis + 0.002 * rng.standard_normal((page_size, dim)).astype(np.float32) - templates.append(page.astype(np.float32)) - - for page_idx in range(n_pages): - if page_idx < 4: - pages.append(templates[page_idx]) - elif page_idx < 16: - pages.append(templates[page_idx % 4].copy()) - elif page_idx >= n_pages - 1: - topic = rng.standard_normal((1, dim)).astype(np.float32) - coeff = rng.standard_normal((page_size, 3)).astype(np.float32) - basis = rng.standard_normal((3, dim)).astype(np.float32) - page = topic + coeff @ basis + 0.03 * rng.standard_normal((page_size, dim)).astype(np.float32) - pages.append(page.astype(np.float32)) - else: - topic = rng.standard_normal((1, dim)).astype(np.float32) - coeff = rng.standard_normal((page_size, 1)).astype(np.float32) - basis = rng.standard_normal((1, dim)).astype(np.float32) - page = topic + coeff @ basis + 0.008 * rng.standard_normal((page_size, dim)).astype(np.float32) - pages.append(page.astype(np.float32)) - - return np.concatenate(pages, axis=0).astype(np.float32) - - -def build_random_array(n_tokens: int = 4096, dim: int = 128) -> np.ndarray: - rng = np.random.default_rng(999) - return rng.standard_normal((n_tokens, dim)).astype(np.float32) - - async def _scenario_client(app, fn): transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: @@ -297,3 +261,54 @@ async def scenario(client: httpx.AsyncClient) -> None: assert "hyperquant_last_projected_resident_ratio" in metrics.text asyncio.run(_scenario_client(app, scenario)) + + +def test_api_returns_500_for_internal_worker_error(tmp_path, monkeypatch) -> None: + array = build_demo_array() + trainer = MiniBatchKMeansTrainer(CodebookConfig(chunk_size=32, codebook_size=32, sample_size=1024, training_iterations=4)) + bundle = trainer.train(array) + bundle_path = tmp_path / "bundle.npz" + bundle.save(bundle_path) + + def raising_ndarray_from_b64(*args, **kwargs): + raise RuntimeError("boom") + + monkeypatch.setattr(api_app, "ndarray_from_b64", raising_ndarray_from_b64) + app = create_app(bundle_path, max_request_bytes=4 * 1024 * 1024) + + async def scenario(client: httpx.AsyncClient) -> None: + response = await client.post( + "/v1/codebook/compress", + json={"array_b64": ndarray_to_b64(array)}, + ) + assert response.status_code == 500 + assert response.json()["detail"] == "internal server error" + metrics = await client.get("/metrics") + assert 'endpoint="compress",reason="internal_error"' in metrics.text + + asyncio.run(_scenario_client(app, scenario)) + + +@pytest.mark.parametrize("headers", [{}, {"content-length": "invalid"}, {"content-length": "10"}]) +def test_api_rejects_oversized_body_without_reliable_content_length(tmp_path, headers) -> None: + array = build_demo_array() + trainer = MiniBatchKMeansTrainer(CodebookConfig(chunk_size=32, codebook_size=32, sample_size=1024, training_iterations=4)) + bundle = trainer.train(array) + bundle_path = tmp_path / "bundle.npz" + bundle.save(bundle_path) + app = create_app(bundle_path, max_request_bytes=1024) + + oversized_payload = b'{"array_b64":"' + (b"A" * (2 * 1024 * 1024)) + b'"}' + + async def scenario(client: httpx.AsyncClient) -> None: + async def stream(): + midpoint = len(oversized_payload) // 2 + yield oversized_payload[:midpoint] + yield oversized_payload[midpoint:] + + request_headers = {"content-type": "application/json", **headers} + response = await client.post("/v1/codebook/compress", content=stream(), headers=request_headers) + assert response.status_code == 413 + assert "max_http_body_bytes" in response.json()["detail"] + + asyncio.run(_scenario_client(app, scenario)) diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..91a4de8 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,33 @@ +# Copyright 2026 Сацук Артём Венедиктович (Satsuk Artem) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from hyperquant.cli import build_parser + + +def _subcommand_parser(name: str): + parser = build_parser() + subparsers_action = next(action for action in parser._actions if action.dest == "command") # noqa: SLF001 + return subparsers_action.choices[name] + + +def test_context_decompress_file_help_has_only_decode_inputs() -> None: + parser = _subcommand_parser("context-decompress-file") + help_text = parser.format_help() + assert "--input" in help_text + assert "--output" in help_text + assert "--page-size" not in help_text + assert "--rank" not in help_text + assert "--low-rank-error-threshold" not in help_text diff --git a/tests/test_context_codec.py b/tests/test_context_codec.py index 8dffaa9..1c1312f 100644 --- a/tests/test_context_codec.py +++ b/tests/test_context_codec.py @@ -26,49 +26,7 @@ ContextCodec, ) from hyperquant.utils import bytes_to_b64 - - -def build_context_like_array( - n_tokens: int = 4096, - dim: int = 128, - page_size: int = 64, -) -> np.ndarray: - rng = np.random.default_rng(123) - n_pages = n_tokens // page_size - pages: list[np.ndarray] = [] - templates: list[np.ndarray] = [] - - for _ in range(4): - topic = rng.standard_normal((1, dim)).astype(np.float32) - coeff = rng.standard_normal((page_size, 1)).astype(np.float32) * 0.7 - basis = rng.standard_normal((1, dim)).astype(np.float32) - page = topic + coeff @ basis + 0.002 * rng.standard_normal((page_size, dim)).astype(np.float32) - templates.append(page.astype(np.float32)) - - for page_idx in range(n_pages): - if page_idx < 4: - pages.append(templates[page_idx]) - elif page_idx < 16: - pages.append(templates[page_idx % 4].copy()) - elif page_idx >= n_pages - 1: - topic = rng.standard_normal((1, dim)).astype(np.float32) - coeff = rng.standard_normal((page_size, 3)).astype(np.float32) - basis = rng.standard_normal((3, dim)).astype(np.float32) - page = topic + coeff @ basis + 0.03 * rng.standard_normal((page_size, dim)).astype(np.float32) - pages.append(page.astype(np.float32)) - else: - topic = rng.standard_normal((1, dim)).astype(np.float32) - coeff = rng.standard_normal((page_size, 1)).astype(np.float32) - basis = rng.standard_normal((1, dim)).astype(np.float32) - page = topic + coeff @ basis + 0.008 * rng.standard_normal((page_size, dim)).astype(np.float32) - pages.append(page.astype(np.float32)) - - return np.concatenate(pages, axis=0).astype(np.float32) - - -def build_random_array(n_tokens: int = 4096, dim: int = 128) -> np.ndarray: - rng = np.random.default_rng(999) - return rng.standard_normal((n_tokens, dim)).astype(np.float32) +from tests.array_builders import build_context_like_array, build_random_array def test_context_roundtrip_and_ratio() -> None: diff --git a/tests/test_resident_tier.py b/tests/test_resident_tier.py index ad9cb50..dd64450 100644 --- a/tests/test_resident_tier.py +++ b/tests/test_resident_tier.py @@ -19,6 +19,8 @@ from hyperquant.live_data import generate_mixed_long_context, generate_online_vector_stream from hyperquant.resident_tier import ResidentTierConfig, ResidentPlanner, ResidentPageMode, ResidentTierStore +from hyperquant.context_codec import ContextCodec +from hyperquant.guarantee import ContourViolation def test_tiered_store_build_open_and_slice(tmp_path) -> None: @@ -91,3 +93,43 @@ def test_tiered_store_detects_payload_tampering(tmp_path) -> None: reopened = ResidentTierStore.open(tmp_path / "store") with pytest.raises(ValueError, match="sha256 mismatch"): reopened.get_page(page.page_index) + + +def test_resident_planner_keeps_expected_context_qualification_errors(monkeypatch) -> None: + array = generate_online_vector_stream(n_vectors=1024, dim=64, seed=20260329) + planner = ResidentPlanner(ResidentTierConfig(page_size=32, group_size=64, hot_pages=4, residual_topk=1)) + + def raise_value_error(self, array, **kwargs): # noqa: ARG001 + raise ValueError("not context-like") + + monkeypatch.setattr(ContextCodec, "compress", raise_value_error) + plan = planner.plan(array, concurrent_sessions=2, active_window_tokens=64, runtime_value_bytes=2) + candidate = plan.candidates["context_codec_full_envelope"] + assert candidate["resident_bytes_per_session"] is None + assert candidate["error"] == "not context-like" + + +def test_resident_planner_propagates_unexpected_context_errors(monkeypatch) -> None: + array = generate_online_vector_stream(n_vectors=1024, dim=64, seed=20260329) + planner = ResidentPlanner(ResidentTierConfig(page_size=32, group_size=64, hot_pages=4, residual_topk=1)) + + def raise_runtime_error(self, array, **kwargs): # noqa: ARG001 + raise RuntimeError("unexpected") + + monkeypatch.setattr(ContextCodec, "compress", raise_runtime_error) + with pytest.raises(RuntimeError, match="unexpected"): + planner.plan(array, concurrent_sessions=2, active_window_tokens=64, runtime_value_bytes=2) + + +def test_resident_planner_keeps_contour_qualification_errors(monkeypatch) -> None: + array = generate_online_vector_stream(n_vectors=1024, dim=64, seed=20260329) + planner = ResidentPlanner(ResidentTierConfig(page_size=32, group_size=64, hot_pages=4, residual_topk=1)) + + def raise_contour_violation(self, array, **kwargs): # noqa: ARG001 + raise ContourViolation(["contour reject"]) + + monkeypatch.setattr(ContextCodec, "compress", raise_contour_violation) + plan = planner.plan(array, concurrent_sessions=2, active_window_tokens=64, runtime_value_bytes=2) + candidate = plan.candidates["context_codec_full_envelope"] + assert candidate["resident_bytes_per_session"] is None + assert "contour reject" in candidate["error"]