diff --git a/.claude/rules/rust.md b/.claude/rules/rust.md index 3d940a23..4584b330 100644 --- a/.claude/rules/rust.md +++ b/.claude/rules/rust.md @@ -21,3 +21,16 @@ paths: - Never use Result<> as a function argument. - Never forward Result in enums if you can instead create a targeted error enum. It is always better to signal the specific issue, so it can be handled downstream. - Always destructure structs in arguments if possible. + +# Code Style + +Imports/uses must not be mixed with other kinds of rust syntax. + +Each file needs to follow this order: +1. `pub mod`/`mod` exports +2. vendor crate `use` +2. project crate `use` +3. local crate `use` +4. private function helpers +5. private struct helpers +6. single public export diff --git a/.claude/rules/subproject-llama-cpp-log-decoder.md b/.claude/rules/subproject-llama-cpp-log-decoder.md new file mode 100644 index 00000000..32ca56ad --- /dev/null +++ b/.claude/rules/subproject-llama-cpp-log-decoder.md @@ -0,0 +1,11 @@ +--- +paths: + - "llama-cpp-log-decoder/**" +--- + +# `llama-cpp-log-decoder` Standards + +- The logging subsystem MUST NEVER panic, crash, or otherwise interrupt the program. +- Logs report issues; they must not cause them. +- No .unwrap(), .expect(), panic!(), or panic-prone indexing. +- No panic-prone slicing. diff --git a/.claude/skills/run-all-tests/SKILL.md b/.claude/skills/run-all-tests/SKILL.md new file mode 100644 index 00000000..cbf1584e --- /dev/null +++ b/.claude/skills/run-all-tests/SKILL.md @@ -0,0 +1,70 @@ +--- +name: run-all-tests +description: Runs every test suite in the workspace on the fastest available device. Use when the user asks to run the tests, run all the tests, run the full test suite, or check that everything still passes. +--- + +# Running all tests + +Run every test suite in the workspace, picking the fastest compiled device backend for the host. + +## Step 1: detect the device + +Run this once at the start and echo the chosen device: + +```bash +if [[ "$OSTYPE" == "darwin"* ]]; then + DEVICE=metal +elif command -v nvidia-smi >/dev/null 2>&1 && nvidia-smi >/dev/null 2>&1; then + DEVICE=cuda +else + DEVICE=cpu +fi +echo "Device: $DEVICE" +``` + +`$DEVICE` selects the backend feature for every suite in Step 2, including `test.unit`. Passing the same device through every target keeps the cmake hash stable, so llama.cpp is compiled once and reused across all suites. + +## Step 2: run the suites + +Sequentially, from the workspace root. + +Copy this checklist and tick each item as the suite completes: + +``` +Test progress: +- [ ] make test.unit +- [ ] make test.qwen3.5_0.8B +- [ ] make test.qwen3.6_35b_a3b +- [ ] make test.glm4_7_flash +- [ ] make test.deepseek_r1_distill_llama_8b +``` + +Translate `$DEVICE` into the value the Makefile expects. `TEST_DEVICE` holds **only** the backend name (`cuda` / `metal` / `vulkan` / `rocm`), or empty for CPU since there is no `cpu` feature: + +```bash +[ "$DEVICE" = "cpu" ] && FEAT= || FEAT="$DEVICE" +``` + +Then run exactly: + +```bash +make test.unit TEST_DEVICE="$FEAT" +make test.qwen3.5_0.8B TEST_DEVICE="$FEAT" +make test.qwen3.6_35b_a3b TEST_DEVICE="$FEAT" +make test.glm4_7_flash TEST_DEVICE="$FEAT" +make test.deepseek_r1_distill_llama_8b TEST_DEVICE="$FEAT" +``` + +The Makefile's `$(if $(TEST_DEVICE),--features $(TEST_DEVICE),)` already skips the `--features` flag when `$FEAT` is empty, so the CPU path needs no further special-casing. + +Do not run `make test.llms` or `make test`. Those bundle every LLM suite into one cargo invocation, which loses per-suite failure attribution and breaks the checklist above. + +## Step 3: rules during the run + +- **Serialize GPU suites.** When `$DEVICE` is `cuda` or `metal`, run test suites sequentially to avoid device contention. +- **Per-test 30 s budget.** Flag any individual test that exceeds 30 s wall-clock. That is a real bug — production or test — not flakiness. + +## Step 4: report + +After all suites finish, sum up the results in an actionable report. + diff --git a/Cargo.lock b/Cargo.lock index d621b601..d5d0c96f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -983,12 +983,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "lazy_static" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" - [[package]] name = "leb128fmt" version = "0.1.0" @@ -1046,15 +1040,14 @@ dependencies = [ "enumflags2", "llama-cpp-bindings-sys", "llama-cpp-bindings-types", + "llama-cpp-log-decoder", "llguidance", + "log", "nom 8.0.0", "serde_json", "serial_test", "thiserror", "toktrie", - "tracing", - "tracing-core", - "tracing-subscriber", ] [[package]] @@ -1088,8 +1081,6 @@ dependencies = [ "llama-cpp-bindings-sys", "serde_json", "serial_test", - "tracing", - "tracing-subscriber", ] [[package]] @@ -1101,6 +1092,10 @@ dependencies = [ "thiserror", ] +[[package]] +name = "llama-cpp-log-decoder" +version = "0.6.0" + [[package]] name = "llguidance" version = "1.7.0" @@ -1208,15 +1203,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "nu-ansi-term" -version = "0.50.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" -dependencies = [ - "windows-sys 0.61.2", -] - [[package]] name = "num-conv" version = "0.2.1" @@ -1814,15 +1800,6 @@ dependencies = [ "syn", ] -[[package]] -name = "sharded-slab" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" -dependencies = [ - "lazy_static", -] - [[package]] name = "shlex" version = "1.3.0" @@ -1986,15 +1963,6 @@ dependencies = [ "syn", ] -[[package]] -name = "thread_local" -version = "1.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" -dependencies = [ - "cfg-if", -] - [[package]] name = "time" version = "0.3.47" @@ -2160,21 +2128,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ "pin-project-lite", - "tracing-attributes", "tracing-core", ] -[[package]] -name = "tracing-attributes" -version = "0.1.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "tracing-core" version = "0.1.36" @@ -2182,45 +2138,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" dependencies = [ "once_cell", - "valuable", -] - -[[package]] -name = "tracing-log" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" -dependencies = [ - "log", - "once_cell", - "tracing-core", -] - -[[package]] -name = "tracing-serde" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "704b1aeb7be0d0a84fc9828cae51dab5970fee5088f83d1dd7ee6f6246fc6ff1" -dependencies = [ - "serde", - "tracing-core", -] - -[[package]] -name = "tracing-subscriber" -version = "0.3.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319" -dependencies = [ - "nu-ansi-term", - "serde", - "serde_json", - "sharded-slab", - "smallvec", - "thread_local", - "tracing-core", - "tracing-log", - "tracing-serde", ] [[package]] @@ -2319,12 +2236,6 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" -[[package]] -name = "valuable" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" - [[package]] name = "vcpkg" version = "0.2.15" diff --git a/Cargo.toml b/Cargo.toml index 477b4224..05cc20f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ members = [ "llama-cpp-bindings-types", "llama-cpp-bindings", "llama-cpp-bindings-tests", + "llama-cpp-log-decoder", ] [workspace.package] @@ -28,14 +29,13 @@ llama-cpp-bindings = { path = "llama-cpp-bindings", version = "=0.6.0" } llama-cpp-bindings-build = { path = "llama-cpp-bindings-build", version = "=0.6.0" } llama-cpp-bindings-sys = { path = "llama-cpp-bindings-sys", version = "=0.6.0" } llama-cpp-bindings-types = { path = "llama-cpp-bindings-types", version = "=0.6.0" } +llama-cpp-log-decoder = { path = "llama-cpp-log-decoder", version = "=0.6.0" } llguidance = "=1.7.0" +log = "=0.4.29" nom = "=8.0.0" serde = { version = "=1.0.228", features = ["derive"] } serde_json = "=1.0.149" serial_test = "=3.4.0" thiserror = "=2.0.18" toktrie = "=1.7.0" -tracing = "=0.1.44" -tracing-core = "=0.1.36" -tracing-subscriber = { version = "=0.3.23", features = ["json"] } walkdir = "=2.5.0" diff --git a/Makefile b/Makefile index bae3fbf6..35a09b34 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,12 @@ -FEATURES = sampler -TEST_FEATURES = +TEST_DEVICE ?= QWEN_CAPABLE_FEATURES = multimodal_capable,mrope_model -CARGO_TEST_LLM_FLAGS = --no-fail-fast -p llama-cpp-bindings-tests $(if $(TEST_FEATURES),--features $(TEST_FEATURES),) -- --test-threads=1 -CARGO_TEST_LLM_FLAGS_QWEN_CAPABLE = --no-fail-fast -p llama-cpp-bindings-tests $(if $(TEST_FEATURES),--features $(TEST_FEATURES),) --features $(QWEN_CAPABLE_FEATURES) -- --test-threads=1 + +DEVICE_FEATURE = $(if $(TEST_DEVICE),--features $(TEST_DEVICE),) +LLM_BASE_FEATURE_FLAGS = $(DEVICE_FEATURE) +LLM_QWEN_CAPABLE_FEATURE_FLAGS = $(DEVICE_FEATURE) --features $(QWEN_CAPABLE_FEATURES) + +CARGO_TEST_LLM_FLAGS = --no-fail-fast -p llama-cpp-bindings-tests $(LLM_BASE_FEATURE_FLAGS) -- --test-threads=1 +CARGO_TEST_LLM_FLAGS_QWEN_CAPABLE = --no-fail-fast -p llama-cpp-bindings-tests $(LLM_QWEN_CAPABLE_FEATURE_FLAGS) -- --test-threads=1 QWEN3_5_0_8B_ENV = \ LLAMA_TEST_HF_REPO=unsloth/Qwen3.5-0.8B-GGUF \ @@ -38,26 +42,41 @@ DEEPSEEK_R1_DISTILL_LLAMA_8B_ENV = \ LLAMA_TEST_HF_ENCODER_REPO=Xiaojian9992024/t5-small-GGUF \ LLAMA_TEST_HF_ENCODER_MODEL=t5-small.bf16.gguf -.PHONY: test.unit -test.unit: clippy - cargo test -p llama-cpp-bindings --features $(FEATURES) +.PHONY: clean.cmake +clean.cmake: + rm -rf target/llama-cpp-cmake-build + +.PHONY: clippy +clippy: clippy.core clippy.tests.base clippy.tests.qwen_capable + +.PHONY: clippy.core +clippy.core: + cargo clippy --all-targets -p llama-cpp-log-decoder -- -D warnings + cargo clippy --all-targets -p llama-cpp-bindings $(DEVICE_FEATURE) -- -D warnings + +.PHONY: clippy.tests.base +clippy.tests.base: + cargo clippy --all-targets -p llama-cpp-bindings-tests $(LLM_BASE_FEATURE_FLAGS) -- -D warnings + +.PHONY: clippy.tests.qwen_capable +clippy.tests.qwen_capable: + cargo clippy --all-targets -p llama-cpp-bindings-tests $(LLM_QWEN_CAPABLE_FEATURE_FLAGS) -- -D warnings + +.PHONY: fmt +fmt: + cargo fmt --all --check + +.PHONY: test +test: test.unit test.llms .PHONY: test.deepseek_r1_distill_llama_8b -test.deepseek_r1_distill_llama_8b: clippy +test.deepseek_r1_distill_llama_8b: clippy.core clippy.tests.base $(DEEPSEEK_R1_DISTILL_LLAMA_8B_ENV) cargo test $(CARGO_TEST_LLM_FLAGS) .PHONY: test.glm4_7_flash -test.glm4_7_flash: clippy +test.glm4_7_flash: clippy.core clippy.tests.base $(GLM4_7_FLASH_ENV) cargo test $(CARGO_TEST_LLM_FLAGS) -.PHONY: test.qwen3.5_0.8B -test.qwen3.5_0.8B: clippy - $(QWEN3_5_0_8B_ENV) cargo test $(CARGO_TEST_LLM_FLAGS_QWEN_CAPABLE) - -.PHONY: test.qwen3.6_35b_a3b -test.qwen3.6_35b_a3b: clippy - $(QWEN3_6_35B_A3B_ENV) cargo test $(CARGO_TEST_LLM_FLAGS_QWEN_CAPABLE) - .PHONY: test.llms test.llms: \ test.deepseek_r1_distill_llama_8b \ @@ -65,18 +84,15 @@ test.llms: \ test.qwen3.5_0.8B \ test.qwen3.6_35b_a3b -.PHONY: test -test: test.unit test.llms - -.PHONY: fmt -fmt: - cargo fmt --all --check +.PHONY: test.qwen3.5_0.8B +test.qwen3.5_0.8B: clippy.core clippy.tests.qwen_capable + $(QWEN3_5_0_8B_ENV) cargo test $(CARGO_TEST_LLM_FLAGS_QWEN_CAPABLE) -.PHONY: clippy -clippy: - cargo clippy --all-targets -p llama-cpp-bindings --features $(FEATURES) -- -D warnings - cargo clippy --all-targets -p llama-cpp-bindings-tests $(if $(TEST_FEATURES),--features $(TEST_FEATURES),) -- -D warnings +.PHONY: test.qwen3.6_35b_a3b +test.qwen3.6_35b_a3b: clippy.core clippy.tests.qwen_capable + $(QWEN3_6_35B_A3B_ENV) cargo test $(CARGO_TEST_LLM_FLAGS_QWEN_CAPABLE) -.PHONY: clean.cmake -clean.cmake: - rm -rf target/llama-cpp-cmake-build +.PHONY: test.unit +test.unit: clippy.core + cargo test -p llama-cpp-log-decoder + cargo test -p llama-cpp-bindings $(DEVICE_FEATURE) diff --git a/llama-cpp-bindings-tests/Cargo.toml b/llama-cpp-bindings-tests/Cargo.toml index e7ce3723..c19700da 100644 --- a/llama-cpp-bindings-tests/Cargo.toml +++ b/llama-cpp-bindings-tests/Cargo.toml @@ -10,12 +10,10 @@ publish = false anyhow = { workspace = true } encoding_rs = { workspace = true } hf-hub = { workspace = true } -llama-cpp-bindings = { workspace = true, features = ["sampler"] } +llama-cpp-bindings = { workspace = true } llama-cpp-bindings-sys = { workspace = true } serde_json = { workspace = true } serial_test = { workspace = true } -tracing = { workspace = true } -tracing-subscriber = { workspace = true } [features] cuda = ["llama-cpp-bindings/cuda"] diff --git a/llama-cpp-bindings-tests/src/classify_sample_loop.rs b/llama-cpp-bindings-tests/src/classify_sample_loop.rs index 03ad1551..d5b070c4 100644 --- a/llama-cpp-bindings-tests/src/classify_sample_loop.rs +++ b/llama-cpp-bindings-tests/src/classify_sample_loop.rs @@ -1,9 +1,9 @@ use anyhow::Result; use llama_cpp_bindings::context::LlamaContext; +use llama_cpp_bindings::ingest_outcome::IngestOutcome; use llama_cpp_bindings::llama_batch::LlamaBatch; use llama_cpp_bindings::model::LlamaModel; use llama_cpp_bindings::sampled_token::SampledToken; -use llama_cpp_bindings::sampled_token_classifier::IngestOutcome; use llama_cpp_bindings::sampled_token_classifier::SampledTokenClassifier; use llama_cpp_bindings::sampling::LlamaSampler; diff --git a/llama-cpp-bindings-tests/src/gpu_backend.rs b/llama-cpp-bindings-tests/src/gpu_backend.rs index bd9b5f8e..16b6f03c 100644 --- a/llama-cpp-bindings-tests/src/gpu_backend.rs +++ b/llama-cpp-bindings-tests/src/gpu_backend.rs @@ -96,7 +96,7 @@ fn require_backend( #[cfg(test)] mod tests { use llama_cpp_bindings::llama_backend_device::LlamaBackendDevice; - use llama_cpp_bindings::llama_backend_device::LlamaBackendDeviceType; + use llama_cpp_bindings::llama_backend_device_type::LlamaBackendDeviceType; use super::require_backend; diff --git a/llama-cpp-bindings-tests/tests/sampled_token_classifier_markers.rs b/llama-cpp-bindings-tests/tests/sampled_token_classifier_markers.rs index ee747c61..e1c4fef3 100644 --- a/llama-cpp-bindings-tests/tests/sampled_token_classifier_markers.rs +++ b/llama-cpp-bindings-tests/tests/sampled_token_classifier_markers.rs @@ -2,8 +2,8 @@ use anyhow::Result; use llama_cpp_bindings::SampledToken; use llama_cpp_bindings::llama_batch::LlamaBatch; use llama_cpp_bindings::sampled_token_classifier::SampledTokenClassifier; -use llama_cpp_bindings::sampled_token_classifier::SampledTokenSection; -use llama_cpp_bindings::sampled_token_classifier::StreamingMarkers; +use llama_cpp_bindings::sampled_token_section::SampledTokenSection; +use llama_cpp_bindings::streaming_markers::StreamingMarkers; use llama_cpp_bindings_tests::FixtureSession; #[test] diff --git a/llama-cpp-bindings/Cargo.toml b/llama-cpp-bindings/Cargo.toml index 2d407bb1..1500583c 100644 --- a/llama-cpp-bindings/Cargo.toml +++ b/llama-cpp-bindings/Cargo.toml @@ -11,17 +11,16 @@ encoding_rs = { workspace = true } enumflags2 = { workspace = true } llama-cpp-bindings-sys = { workspace = true } llama-cpp-bindings-types = { workspace = true } +llama-cpp-log-decoder = { workspace = true } llguidance = { workspace = true } +log = { workspace = true } nom = { workspace = true } serde_json = { workspace = true } thiserror = { workspace = true } toktrie = { workspace = true } -tracing = { workspace = true } -tracing-core = { workspace = true } [dev-dependencies] serial_test = { workspace = true } -tracing-subscriber = { workspace = true } [features] default = ["openmp", "android-shared-stdcxx"] @@ -33,7 +32,6 @@ dynamic-backends = ["dynamic-link", "llama-cpp-bindings-sys/dynamic-backends"] vulkan = ["llama-cpp-bindings-sys/vulkan"] openmp = ["llama-cpp-bindings-sys/openmp"] rocm = ["llama-cpp-bindings-sys/rocm"] -sampler = [] # Only has an impact on Android. android-shared-stdcxx = ["llama-cpp-bindings-sys/shared-stdcxx"] android-static-stdcxx = ["llama-cpp-bindings-sys/static-stdcxx"] @@ -55,6 +53,3 @@ module_name_repetitions = "allow" # Generated FFI bindings use these patterns used_underscore_binding = "allow" - -[package.metadata.docs.rs] -features = ["sampler"] diff --git a/llama-cpp-bindings/src/context.rs b/llama-cpp-bindings/src/context.rs index 09d6560d..410ade82 100644 --- a/llama-cpp-bindings/src/context.rs +++ b/llama-cpp-bindings/src/context.rs @@ -39,10 +39,14 @@ const fn check_lora_remove_result(err_code: i32) -> Result<(), LlamaLoraAdapterR } pub mod kv_cache; +pub mod kv_cache_type; +pub mod llama_attention_type; +pub mod llama_pooling_type; pub mod llama_state_seq_flags; pub mod load_seq_state_error; pub mod load_session_error; pub mod params; +pub mod rope_scaling_type; pub mod save_seq_state_error; pub mod save_session_error; pub mod session; @@ -465,7 +469,7 @@ impl<'model> LlamaContext<'model> { }; check_lora_set_result(err_code)?; - tracing::debug!("Set lora adapter"); + log::debug!("Set lora adapter"); Ok(()) } @@ -491,7 +495,7 @@ impl<'model> LlamaContext<'model> { }; check_lora_remove_result(err_code)?; - tracing::debug!("Remove lora adapter"); + log::debug!("Remove lora adapter"); Ok(()) } } diff --git a/llama-cpp-bindings/src/context/kv_cache_type.rs b/llama-cpp-bindings/src/context/kv_cache_type.rs new file mode 100644 index 00000000..661a59e1 --- /dev/null +++ b/llama-cpp-bindings/src/context/kv_cache_type.rs @@ -0,0 +1,194 @@ +/// A rusty wrapper around `ggml_type` for KV cache types. +#[expect( + non_camel_case_types, + reason = "variant names mirror llama.cpp's `enum ggml_type` symbol names verbatim so they can \ + be matched 1:1 against the C ABI without a translation table" +)] +#[expect( + missing_docs, + reason = "each variant denotes a quantisation flavour whose semantics are defined upstream in \ + ggml; restating the upstream spec inline would risk drifting from the source of truth" +)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum KvCacheType { + /// Represents an unknown or not-yet-mapped `ggml_type` and carries the raw value. + /// When passed through FFI, the raw value is used as-is (if llama.cpp supports it, + /// the runtime will operate with that type). + /// This variant preserves API compatibility when new `ggml_type` values are + /// introduced in the future. + Unknown(llama_cpp_bindings_sys::ggml_type), + F32, + F16, + Q4_0, + Q4_1, + Q5_0, + Q5_1, + Q8_0, + Q8_1, + Q2_K, + Q3_K, + Q4_K, + Q5_K, + Q6_K, + Q8_K, + IQ2_XXS, + IQ2_XS, + IQ3_XXS, + IQ1_S, + IQ4_NL, + IQ3_S, + IQ2_S, + IQ4_XS, + I8, + I16, + I32, + I64, + F64, + IQ1_M, + BF16, + TQ1_0, + TQ2_0, + MXFP4, +} + +impl From for llama_cpp_bindings_sys::ggml_type { + fn from(value: KvCacheType) -> Self { + match value { + KvCacheType::Unknown(raw) => raw, + KvCacheType::F32 => llama_cpp_bindings_sys::GGML_TYPE_F32, + KvCacheType::F16 => llama_cpp_bindings_sys::GGML_TYPE_F16, + KvCacheType::Q4_0 => llama_cpp_bindings_sys::GGML_TYPE_Q4_0, + KvCacheType::Q4_1 => llama_cpp_bindings_sys::GGML_TYPE_Q4_1, + KvCacheType::Q5_0 => llama_cpp_bindings_sys::GGML_TYPE_Q5_0, + KvCacheType::Q5_1 => llama_cpp_bindings_sys::GGML_TYPE_Q5_1, + KvCacheType::Q8_0 => llama_cpp_bindings_sys::GGML_TYPE_Q8_0, + KvCacheType::Q8_1 => llama_cpp_bindings_sys::GGML_TYPE_Q8_1, + KvCacheType::Q2_K => llama_cpp_bindings_sys::GGML_TYPE_Q2_K, + KvCacheType::Q3_K => llama_cpp_bindings_sys::GGML_TYPE_Q3_K, + KvCacheType::Q4_K => llama_cpp_bindings_sys::GGML_TYPE_Q4_K, + KvCacheType::Q5_K => llama_cpp_bindings_sys::GGML_TYPE_Q5_K, + KvCacheType::Q6_K => llama_cpp_bindings_sys::GGML_TYPE_Q6_K, + KvCacheType::Q8_K => llama_cpp_bindings_sys::GGML_TYPE_Q8_K, + KvCacheType::IQ2_XXS => llama_cpp_bindings_sys::GGML_TYPE_IQ2_XXS, + KvCacheType::IQ2_XS => llama_cpp_bindings_sys::GGML_TYPE_IQ2_XS, + KvCacheType::IQ3_XXS => llama_cpp_bindings_sys::GGML_TYPE_IQ3_XXS, + KvCacheType::IQ1_S => llama_cpp_bindings_sys::GGML_TYPE_IQ1_S, + KvCacheType::IQ4_NL => llama_cpp_bindings_sys::GGML_TYPE_IQ4_NL, + KvCacheType::IQ3_S => llama_cpp_bindings_sys::GGML_TYPE_IQ3_S, + KvCacheType::IQ2_S => llama_cpp_bindings_sys::GGML_TYPE_IQ2_S, + KvCacheType::IQ4_XS => llama_cpp_bindings_sys::GGML_TYPE_IQ4_XS, + KvCacheType::I8 => llama_cpp_bindings_sys::GGML_TYPE_I8, + KvCacheType::I16 => llama_cpp_bindings_sys::GGML_TYPE_I16, + KvCacheType::I32 => llama_cpp_bindings_sys::GGML_TYPE_I32, + KvCacheType::I64 => llama_cpp_bindings_sys::GGML_TYPE_I64, + KvCacheType::F64 => llama_cpp_bindings_sys::GGML_TYPE_F64, + KvCacheType::IQ1_M => llama_cpp_bindings_sys::GGML_TYPE_IQ1_M, + KvCacheType::BF16 => llama_cpp_bindings_sys::GGML_TYPE_BF16, + KvCacheType::TQ1_0 => llama_cpp_bindings_sys::GGML_TYPE_TQ1_0, + KvCacheType::TQ2_0 => llama_cpp_bindings_sys::GGML_TYPE_TQ2_0, + KvCacheType::MXFP4 => llama_cpp_bindings_sys::GGML_TYPE_MXFP4, + } + } +} + +impl From for KvCacheType { + fn from(value: llama_cpp_bindings_sys::ggml_type) -> Self { + match value { + x if x == llama_cpp_bindings_sys::GGML_TYPE_F32 => Self::F32, + x if x == llama_cpp_bindings_sys::GGML_TYPE_F16 => Self::F16, + x if x == llama_cpp_bindings_sys::GGML_TYPE_Q4_0 => Self::Q4_0, + x if x == llama_cpp_bindings_sys::GGML_TYPE_Q4_1 => Self::Q4_1, + x if x == llama_cpp_bindings_sys::GGML_TYPE_Q5_0 => Self::Q5_0, + x if x == llama_cpp_bindings_sys::GGML_TYPE_Q5_1 => Self::Q5_1, + x if x == llama_cpp_bindings_sys::GGML_TYPE_Q8_0 => Self::Q8_0, + x if x == llama_cpp_bindings_sys::GGML_TYPE_Q8_1 => Self::Q8_1, + x if x == llama_cpp_bindings_sys::GGML_TYPE_Q2_K => Self::Q2_K, + x if x == llama_cpp_bindings_sys::GGML_TYPE_Q3_K => Self::Q3_K, + x if x == llama_cpp_bindings_sys::GGML_TYPE_Q4_K => Self::Q4_K, + x if x == llama_cpp_bindings_sys::GGML_TYPE_Q5_K => Self::Q5_K, + x if x == llama_cpp_bindings_sys::GGML_TYPE_Q6_K => Self::Q6_K, + x if x == llama_cpp_bindings_sys::GGML_TYPE_Q8_K => Self::Q8_K, + x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ2_XXS => Self::IQ2_XXS, + x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ2_XS => Self::IQ2_XS, + x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ3_XXS => Self::IQ3_XXS, + x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ1_S => Self::IQ1_S, + x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ4_NL => Self::IQ4_NL, + x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ3_S => Self::IQ3_S, + x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ2_S => Self::IQ2_S, + x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ4_XS => Self::IQ4_XS, + x if x == llama_cpp_bindings_sys::GGML_TYPE_I8 => Self::I8, + x if x == llama_cpp_bindings_sys::GGML_TYPE_I16 => Self::I16, + x if x == llama_cpp_bindings_sys::GGML_TYPE_I32 => Self::I32, + x if x == llama_cpp_bindings_sys::GGML_TYPE_I64 => Self::I64, + x if x == llama_cpp_bindings_sys::GGML_TYPE_F64 => Self::F64, + x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ1_M => Self::IQ1_M, + x if x == llama_cpp_bindings_sys::GGML_TYPE_BF16 => Self::BF16, + x if x == llama_cpp_bindings_sys::GGML_TYPE_TQ1_0 => Self::TQ1_0, + x if x == llama_cpp_bindings_sys::GGML_TYPE_TQ2_0 => Self::TQ2_0, + x if x == llama_cpp_bindings_sys::GGML_TYPE_MXFP4 => Self::MXFP4, + _ => Self::Unknown(value), + } + } +} + +#[cfg(test)] +mod tests { + use super::KvCacheType; + + #[test] + fn kv_cache_type_unknown_preserves_raw_value() { + let unknown_raw: llama_cpp_bindings_sys::ggml_type = 99999; + let cache_type = KvCacheType::from(unknown_raw); + + assert_eq!(cache_type, KvCacheType::Unknown(99999)); + + let back: llama_cpp_bindings_sys::ggml_type = cache_type.into(); + + assert_eq!(back, 99999); + } + + #[test] + fn kv_cache_type_all_known_variants_roundtrip() { + let all_variants = [ + KvCacheType::F32, + KvCacheType::F16, + KvCacheType::Q4_0, + KvCacheType::Q4_1, + KvCacheType::Q5_0, + KvCacheType::Q5_1, + KvCacheType::Q8_0, + KvCacheType::Q8_1, + KvCacheType::Q2_K, + KvCacheType::Q3_K, + KvCacheType::Q4_K, + KvCacheType::Q5_K, + KvCacheType::Q6_K, + KvCacheType::Q8_K, + KvCacheType::IQ2_XXS, + KvCacheType::IQ2_XS, + KvCacheType::IQ3_XXS, + KvCacheType::IQ1_S, + KvCacheType::IQ4_NL, + KvCacheType::IQ3_S, + KvCacheType::IQ2_S, + KvCacheType::IQ4_XS, + KvCacheType::I8, + KvCacheType::I16, + KvCacheType::I32, + KvCacheType::I64, + KvCacheType::F64, + KvCacheType::IQ1_M, + KvCacheType::BF16, + KvCacheType::TQ1_0, + KvCacheType::TQ2_0, + KvCacheType::MXFP4, + ]; + + for variant in all_variants { + let ggml_type: llama_cpp_bindings_sys::ggml_type = variant.into(); + let back = KvCacheType::from(ggml_type); + + assert_eq!(back, variant); + } + } +} diff --git a/llama-cpp-bindings/src/context/llama_attention_type.rs b/llama-cpp-bindings/src/context/llama_attention_type.rs new file mode 100644 index 00000000..b785ffb0 --- /dev/null +++ b/llama-cpp-bindings/src/context/llama_attention_type.rs @@ -0,0 +1,63 @@ +/// A rusty wrapper around `LLAMA_ATTENTION_TYPE`. +#[repr(i8)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum LlamaAttentionType { + /// The attention type is unspecified + Unspecified = -1, + /// Causal attention + Causal = 0, + /// Non-causal attention + NonCausal = 1, +} + +impl From for LlamaAttentionType { + fn from(value: i32) -> Self { + match value { + 0 => Self::Causal, + 1 => Self::NonCausal, + _ => Self::Unspecified, + } + } +} + +impl From for i32 { + fn from(value: LlamaAttentionType) -> Self { + match value { + LlamaAttentionType::Causal => 0, + LlamaAttentionType::NonCausal => 1, + LlamaAttentionType::Unspecified => -1, + } + } +} + +#[cfg(test)] +mod tests { + use super::LlamaAttentionType; + + #[test] + fn attention_type_unknown_defaults_to_unspecified() { + assert_eq!( + LlamaAttentionType::from(99), + LlamaAttentionType::Unspecified + ); + assert_eq!( + LlamaAttentionType::from(-50), + LlamaAttentionType::Unspecified + ); + } + + #[test] + fn attention_type_roundtrip_all_variants() { + for (raw, expected) in [ + (-1, LlamaAttentionType::Unspecified), + (0, LlamaAttentionType::Causal), + (1, LlamaAttentionType::NonCausal), + ] { + let from_raw = LlamaAttentionType::from(raw); + assert_eq!(from_raw, expected); + + let back_to_raw: i32 = from_raw.into(); + assert_eq!(back_to_raw, raw); + } + } +} diff --git a/llama-cpp-bindings/src/context/llama_pooling_type.rs b/llama-cpp-bindings/src/context/llama_pooling_type.rs new file mode 100644 index 00000000..f0d4486b --- /dev/null +++ b/llama-cpp-bindings/src/context/llama_pooling_type.rs @@ -0,0 +1,75 @@ +/// A rusty wrapper around `LLAMA_POOLING_TYPE`. +#[repr(i8)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum LlamaPoolingType { + /// The pooling type is unspecified + Unspecified = -1, + /// No pooling + None = 0, + /// Mean pooling + Mean = 1, + /// CLS pooling + Cls = 2, + /// Last pooling + Last = 3, + /// Rank pooling + Rank = 4, +} + +/// Create a `LlamaPoolingType` from a `c_int` - returns `LlamaPoolingType::Unspecified` if +/// the value is not recognized. +impl From for LlamaPoolingType { + fn from(value: i32) -> Self { + match value { + 0 => Self::None, + 1 => Self::Mean, + 2 => Self::Cls, + 3 => Self::Last, + 4 => Self::Rank, + _ => Self::Unspecified, + } + } +} + +/// Create a `c_int` from a `LlamaPoolingType`. +impl From for i32 { + fn from(value: LlamaPoolingType) -> Self { + match value { + LlamaPoolingType::None => 0, + LlamaPoolingType::Mean => 1, + LlamaPoolingType::Cls => 2, + LlamaPoolingType::Last => 3, + LlamaPoolingType::Rank => 4, + LlamaPoolingType::Unspecified => -1, + } + } +} + +#[cfg(test)] +mod tests { + use super::LlamaPoolingType; + + #[test] + fn pooling_type_unknown_defaults_to_unspecified() { + assert_eq!(LlamaPoolingType::from(99), LlamaPoolingType::Unspecified); + assert_eq!(LlamaPoolingType::from(-50), LlamaPoolingType::Unspecified); + } + + #[test] + fn pooling_type_roundtrip_all_variants() { + for (raw, expected) in [ + (-1, LlamaPoolingType::Unspecified), + (0, LlamaPoolingType::None), + (1, LlamaPoolingType::Mean), + (2, LlamaPoolingType::Cls), + (3, LlamaPoolingType::Last), + (4, LlamaPoolingType::Rank), + ] { + let from_raw = LlamaPoolingType::from(raw); + assert_eq!(from_raw, expected); + + let back_to_raw: i32 = from_raw.into(); + assert_eq!(back_to_raw, raw); + } + } +} diff --git a/llama-cpp-bindings/src/context/params.rs b/llama-cpp-bindings/src/context/params.rs index bcea5898..13935e21 100644 --- a/llama-cpp-bindings/src/context/params.rs +++ b/llama-cpp-bindings/src/context/params.rs @@ -2,256 +2,10 @@ use std::fmt::Debug; use std::num::NonZeroU32; -/// A rusty wrapper around `rope_scaling_type`. -#[repr(i8)] -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum RopeScalingType { - /// The scaling type is unspecified - Unspecified = -1, - /// No scaling - None = 0, - /// Linear scaling - Linear = 1, - /// Yarn scaling - Yarn = 2, -} - -/// Create a `RopeScalingType` from a `c_int` - returns `RopeScalingType::ScalingUnspecified` if -/// the value is not recognized. -impl From for RopeScalingType { - fn from(value: i32) -> Self { - match value { - 0 => Self::None, - 1 => Self::Linear, - 2 => Self::Yarn, - _ => Self::Unspecified, - } - } -} - -/// Create a `c_int` from a `RopeScalingType`. -impl From for i32 { - fn from(value: RopeScalingType) -> Self { - match value { - RopeScalingType::None => 0, - RopeScalingType::Linear => 1, - RopeScalingType::Yarn => 2, - RopeScalingType::Unspecified => -1, - } - } -} - -/// A rusty wrapper around `LLAMA_POOLING_TYPE`. -#[repr(i8)] -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum LlamaPoolingType { - /// The pooling type is unspecified - Unspecified = -1, - /// No pooling - None = 0, - /// Mean pooling - Mean = 1, - /// CLS pooling - Cls = 2, - /// Last pooling - Last = 3, - /// Rank pooling - Rank = 4, -} - -/// Create a `LlamaPoolingType` from a `c_int` - returns `LlamaPoolingType::Unspecified` if -/// the value is not recognized. -impl From for LlamaPoolingType { - fn from(value: i32) -> Self { - match value { - 0 => Self::None, - 1 => Self::Mean, - 2 => Self::Cls, - 3 => Self::Last, - 4 => Self::Rank, - _ => Self::Unspecified, - } - } -} - -/// Create a `c_int` from a `LlamaPoolingType`. -impl From for i32 { - fn from(value: LlamaPoolingType) -> Self { - match value { - LlamaPoolingType::None => 0, - LlamaPoolingType::Mean => 1, - LlamaPoolingType::Cls => 2, - LlamaPoolingType::Last => 3, - LlamaPoolingType::Rank => 4, - LlamaPoolingType::Unspecified => -1, - } - } -} - -/// A rusty wrapper around `LLAMA_ATTENTION_TYPE`. -#[repr(i8)] -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum LlamaAttentionType { - /// The attention type is unspecified - Unspecified = -1, - /// Causal attention - Causal = 0, - /// Non-causal attention - NonCausal = 1, -} - -impl From for LlamaAttentionType { - fn from(value: i32) -> Self { - match value { - 0 => Self::Causal, - 1 => Self::NonCausal, - _ => Self::Unspecified, - } - } -} - -impl From for i32 { - fn from(value: LlamaAttentionType) -> Self { - match value { - LlamaAttentionType::Causal => 0, - LlamaAttentionType::NonCausal => 1, - LlamaAttentionType::Unspecified => -1, - } - } -} - -/// A rusty wrapper around `ggml_type` for KV cache types. -#[expect( - non_camel_case_types, - reason = "variant names mirror llama.cpp's `enum ggml_type` symbol names verbatim so they can \ - be matched 1:1 against the C ABI without a translation table" -)] -#[expect( - missing_docs, - reason = "each variant denotes a quantisation flavour whose semantics are defined upstream in \ - ggml; restating the upstream spec inline would risk drifting from the source of truth" -)] -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum KvCacheType { - /// Represents an unknown or not-yet-mapped `ggml_type` and carries the raw value. - /// When passed through FFI, the raw value is used as-is (if llama.cpp supports it, - /// the runtime will operate with that type). - /// This variant preserves API compatibility when new `ggml_type` values are - /// introduced in the future. - Unknown(llama_cpp_bindings_sys::ggml_type), - F32, - F16, - Q4_0, - Q4_1, - Q5_0, - Q5_1, - Q8_0, - Q8_1, - Q2_K, - Q3_K, - Q4_K, - Q5_K, - Q6_K, - Q8_K, - IQ2_XXS, - IQ2_XS, - IQ3_XXS, - IQ1_S, - IQ4_NL, - IQ3_S, - IQ2_S, - IQ4_XS, - I8, - I16, - I32, - I64, - F64, - IQ1_M, - BF16, - TQ1_0, - TQ2_0, - MXFP4, -} - -impl From for llama_cpp_bindings_sys::ggml_type { - fn from(value: KvCacheType) -> Self { - match value { - KvCacheType::Unknown(raw) => raw, - KvCacheType::F32 => llama_cpp_bindings_sys::GGML_TYPE_F32, - KvCacheType::F16 => llama_cpp_bindings_sys::GGML_TYPE_F16, - KvCacheType::Q4_0 => llama_cpp_bindings_sys::GGML_TYPE_Q4_0, - KvCacheType::Q4_1 => llama_cpp_bindings_sys::GGML_TYPE_Q4_1, - KvCacheType::Q5_0 => llama_cpp_bindings_sys::GGML_TYPE_Q5_0, - KvCacheType::Q5_1 => llama_cpp_bindings_sys::GGML_TYPE_Q5_1, - KvCacheType::Q8_0 => llama_cpp_bindings_sys::GGML_TYPE_Q8_0, - KvCacheType::Q8_1 => llama_cpp_bindings_sys::GGML_TYPE_Q8_1, - KvCacheType::Q2_K => llama_cpp_bindings_sys::GGML_TYPE_Q2_K, - KvCacheType::Q3_K => llama_cpp_bindings_sys::GGML_TYPE_Q3_K, - KvCacheType::Q4_K => llama_cpp_bindings_sys::GGML_TYPE_Q4_K, - KvCacheType::Q5_K => llama_cpp_bindings_sys::GGML_TYPE_Q5_K, - KvCacheType::Q6_K => llama_cpp_bindings_sys::GGML_TYPE_Q6_K, - KvCacheType::Q8_K => llama_cpp_bindings_sys::GGML_TYPE_Q8_K, - KvCacheType::IQ2_XXS => llama_cpp_bindings_sys::GGML_TYPE_IQ2_XXS, - KvCacheType::IQ2_XS => llama_cpp_bindings_sys::GGML_TYPE_IQ2_XS, - KvCacheType::IQ3_XXS => llama_cpp_bindings_sys::GGML_TYPE_IQ3_XXS, - KvCacheType::IQ1_S => llama_cpp_bindings_sys::GGML_TYPE_IQ1_S, - KvCacheType::IQ4_NL => llama_cpp_bindings_sys::GGML_TYPE_IQ4_NL, - KvCacheType::IQ3_S => llama_cpp_bindings_sys::GGML_TYPE_IQ3_S, - KvCacheType::IQ2_S => llama_cpp_bindings_sys::GGML_TYPE_IQ2_S, - KvCacheType::IQ4_XS => llama_cpp_bindings_sys::GGML_TYPE_IQ4_XS, - KvCacheType::I8 => llama_cpp_bindings_sys::GGML_TYPE_I8, - KvCacheType::I16 => llama_cpp_bindings_sys::GGML_TYPE_I16, - KvCacheType::I32 => llama_cpp_bindings_sys::GGML_TYPE_I32, - KvCacheType::I64 => llama_cpp_bindings_sys::GGML_TYPE_I64, - KvCacheType::F64 => llama_cpp_bindings_sys::GGML_TYPE_F64, - KvCacheType::IQ1_M => llama_cpp_bindings_sys::GGML_TYPE_IQ1_M, - KvCacheType::BF16 => llama_cpp_bindings_sys::GGML_TYPE_BF16, - KvCacheType::TQ1_0 => llama_cpp_bindings_sys::GGML_TYPE_TQ1_0, - KvCacheType::TQ2_0 => llama_cpp_bindings_sys::GGML_TYPE_TQ2_0, - KvCacheType::MXFP4 => llama_cpp_bindings_sys::GGML_TYPE_MXFP4, - } - } -} - -impl From for KvCacheType { - fn from(value: llama_cpp_bindings_sys::ggml_type) -> Self { - match value { - x if x == llama_cpp_bindings_sys::GGML_TYPE_F32 => Self::F32, - x if x == llama_cpp_bindings_sys::GGML_TYPE_F16 => Self::F16, - x if x == llama_cpp_bindings_sys::GGML_TYPE_Q4_0 => Self::Q4_0, - x if x == llama_cpp_bindings_sys::GGML_TYPE_Q4_1 => Self::Q4_1, - x if x == llama_cpp_bindings_sys::GGML_TYPE_Q5_0 => Self::Q5_0, - x if x == llama_cpp_bindings_sys::GGML_TYPE_Q5_1 => Self::Q5_1, - x if x == llama_cpp_bindings_sys::GGML_TYPE_Q8_0 => Self::Q8_0, - x if x == llama_cpp_bindings_sys::GGML_TYPE_Q8_1 => Self::Q8_1, - x if x == llama_cpp_bindings_sys::GGML_TYPE_Q2_K => Self::Q2_K, - x if x == llama_cpp_bindings_sys::GGML_TYPE_Q3_K => Self::Q3_K, - x if x == llama_cpp_bindings_sys::GGML_TYPE_Q4_K => Self::Q4_K, - x if x == llama_cpp_bindings_sys::GGML_TYPE_Q5_K => Self::Q5_K, - x if x == llama_cpp_bindings_sys::GGML_TYPE_Q6_K => Self::Q6_K, - x if x == llama_cpp_bindings_sys::GGML_TYPE_Q8_K => Self::Q8_K, - x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ2_XXS => Self::IQ2_XXS, - x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ2_XS => Self::IQ2_XS, - x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ3_XXS => Self::IQ3_XXS, - x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ1_S => Self::IQ1_S, - x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ4_NL => Self::IQ4_NL, - x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ3_S => Self::IQ3_S, - x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ2_S => Self::IQ2_S, - x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ4_XS => Self::IQ4_XS, - x if x == llama_cpp_bindings_sys::GGML_TYPE_I8 => Self::I8, - x if x == llama_cpp_bindings_sys::GGML_TYPE_I16 => Self::I16, - x if x == llama_cpp_bindings_sys::GGML_TYPE_I32 => Self::I32, - x if x == llama_cpp_bindings_sys::GGML_TYPE_I64 => Self::I64, - x if x == llama_cpp_bindings_sys::GGML_TYPE_F64 => Self::F64, - x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ1_M => Self::IQ1_M, - x if x == llama_cpp_bindings_sys::GGML_TYPE_BF16 => Self::BF16, - x if x == llama_cpp_bindings_sys::GGML_TYPE_TQ1_0 => Self::TQ1_0, - x if x == llama_cpp_bindings_sys::GGML_TYPE_TQ2_0 => Self::TQ2_0, - x if x == llama_cpp_bindings_sys::GGML_TYPE_MXFP4 => Self::MXFP4, - _ => Self::Unknown(value), - } - } -} +pub use crate::context::kv_cache_type::KvCacheType; +pub use crate::context::llama_attention_type::LlamaAttentionType; +pub use crate::context::llama_pooling_type::LlamaPoolingType; +pub use crate::context::rope_scaling_type::RopeScalingType; /// A safe wrapper around `llama_context_params`. /// @@ -1033,30 +787,6 @@ impl Default for LlamaContextParams { mod tests { use super::{KvCacheType, LlamaAttentionType, LlamaPoolingType, RopeScalingType}; - #[test] - fn rope_scaling_type_unknown_defaults_to_unspecified() { - assert_eq!(RopeScalingType::from(99), RopeScalingType::Unspecified); - assert_eq!(RopeScalingType::from(-100), RopeScalingType::Unspecified); - } - - #[test] - fn pooling_type_unknown_defaults_to_unspecified() { - assert_eq!(LlamaPoolingType::from(99), LlamaPoolingType::Unspecified); - assert_eq!(LlamaPoolingType::from(-50), LlamaPoolingType::Unspecified); - } - - #[test] - fn kv_cache_type_unknown_preserves_raw_value() { - let unknown_raw: llama_cpp_bindings_sys::ggml_type = 99999; - let cache_type = KvCacheType::from(unknown_raw); - - assert_eq!(cache_type, KvCacheType::Unknown(99999)); - - let back: llama_cpp_bindings_sys::ggml_type = cache_type.into(); - - assert_eq!(back, 99999); - } - #[test] fn default_params_have_expected_values() { let params = super::LlamaContextParams::default(); @@ -1284,85 +1014,6 @@ mod tests { assert!((params.rope_freq_scale() - 0.25).abs() < f32::EPSILON); } - #[test] - fn rope_scaling_type_roundtrip_all_variants() { - for (raw, expected) in [ - (-1, RopeScalingType::Unspecified), - (0, RopeScalingType::None), - (1, RopeScalingType::Linear), - (2, RopeScalingType::Yarn), - ] { - let from_raw = RopeScalingType::from(raw); - assert_eq!(from_raw, expected); - - let back_to_raw: i32 = from_raw.into(); - assert_eq!(back_to_raw, raw); - } - } - - #[test] - fn pooling_type_roundtrip_all_variants() { - for (raw, expected) in [ - (-1, LlamaPoolingType::Unspecified), - (0, LlamaPoolingType::None), - (1, LlamaPoolingType::Mean), - (2, LlamaPoolingType::Cls), - (3, LlamaPoolingType::Last), - (4, LlamaPoolingType::Rank), - ] { - let from_raw = LlamaPoolingType::from(raw); - assert_eq!(from_raw, expected); - - let back_to_raw: i32 = from_raw.into(); - assert_eq!(back_to_raw, raw); - } - } - - #[test] - fn kv_cache_type_all_known_variants_roundtrip() { - let all_variants = [ - KvCacheType::F32, - KvCacheType::F16, - KvCacheType::Q4_0, - KvCacheType::Q4_1, - KvCacheType::Q5_0, - KvCacheType::Q5_1, - KvCacheType::Q8_0, - KvCacheType::Q8_1, - KvCacheType::Q2_K, - KvCacheType::Q3_K, - KvCacheType::Q4_K, - KvCacheType::Q5_K, - KvCacheType::Q6_K, - KvCacheType::Q8_K, - KvCacheType::IQ2_XXS, - KvCacheType::IQ2_XS, - KvCacheType::IQ3_XXS, - KvCacheType::IQ1_S, - KvCacheType::IQ4_NL, - KvCacheType::IQ3_S, - KvCacheType::IQ2_S, - KvCacheType::IQ4_XS, - KvCacheType::I8, - KvCacheType::I16, - KvCacheType::I32, - KvCacheType::I64, - KvCacheType::F64, - KvCacheType::IQ1_M, - KvCacheType::BF16, - KvCacheType::TQ1_0, - KvCacheType::TQ2_0, - KvCacheType::MXFP4, - ]; - - for variant in all_variants { - let ggml_type: llama_cpp_bindings_sys::ggml_type = variant.into(); - let back = KvCacheType::from(ggml_type); - - assert_eq!(back, variant); - } - } - #[test] fn with_cb_eval_sets_callback() { extern "C" fn test_cb_eval( @@ -1402,33 +1053,6 @@ mod tests { ); } - #[test] - fn attention_type_unknown_defaults_to_unspecified() { - assert_eq!( - LlamaAttentionType::from(99), - LlamaAttentionType::Unspecified - ); - assert_eq!( - LlamaAttentionType::from(-50), - LlamaAttentionType::Unspecified - ); - } - - #[test] - fn attention_type_roundtrip_all_variants() { - for (raw, expected) in [ - (-1, LlamaAttentionType::Unspecified), - (0, LlamaAttentionType::Causal), - (1, LlamaAttentionType::NonCausal), - ] { - let from_raw = LlamaAttentionType::from(raw); - assert_eq!(from_raw, expected); - - let back_to_raw: i32 = from_raw.into(); - assert_eq!(back_to_raw, raw); - } - } - #[test] fn with_attention_type_causal() { let params = diff --git a/llama-cpp-bindings/src/context/rope_scaling_type.rs b/llama-cpp-bindings/src/context/rope_scaling_type.rs new file mode 100644 index 00000000..0bbfa831 --- /dev/null +++ b/llama-cpp-bindings/src/context/rope_scaling_type.rs @@ -0,0 +1,65 @@ +/// A rusty wrapper around `rope_scaling_type`. +#[repr(i8)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum RopeScalingType { + /// The scaling type is unspecified + Unspecified = -1, + /// No scaling + None = 0, + /// Linear scaling + Linear = 1, + /// Yarn scaling + Yarn = 2, +} + +/// Create a `RopeScalingType` from a `c_int` - returns `RopeScalingType::ScalingUnspecified` if +/// the value is not recognized. +impl From for RopeScalingType { + fn from(value: i32) -> Self { + match value { + 0 => Self::None, + 1 => Self::Linear, + 2 => Self::Yarn, + _ => Self::Unspecified, + } + } +} + +/// Create a `c_int` from a `RopeScalingType`. +impl From for i32 { + fn from(value: RopeScalingType) -> Self { + match value { + RopeScalingType::None => 0, + RopeScalingType::Linear => 1, + RopeScalingType::Yarn => 2, + RopeScalingType::Unspecified => -1, + } + } +} + +#[cfg(test)] +mod tests { + use super::RopeScalingType; + + #[test] + fn rope_scaling_type_unknown_defaults_to_unspecified() { + assert_eq!(RopeScalingType::from(99), RopeScalingType::Unspecified); + assert_eq!(RopeScalingType::from(-100), RopeScalingType::Unspecified); + } + + #[test] + fn rope_scaling_type_roundtrip_all_variants() { + for (raw, expected) in [ + (-1, RopeScalingType::Unspecified), + (0, RopeScalingType::None), + (1, RopeScalingType::Linear), + (2, RopeScalingType::Yarn), + ] { + let from_raw = RopeScalingType::from(raw); + assert_eq!(from_raw, expected); + + let back_to_raw: i32 = from_raw.into(); + assert_eq!(back_to_raw, raw); + } + } +} diff --git a/llama-cpp-bindings/src/error.rs b/llama-cpp-bindings/src/error.rs index d48e2596..2314452f 100644 --- a/llama-cpp-bindings/src/error.rs +++ b/llama-cpp-bindings/src/error.rs @@ -1,614 +1,68 @@ -use std::ffi::NulError; -use std::num::NonZeroI32; -use std::os::raw::c_int; -use std::path::PathBuf; -use std::string::FromUtf8Error; - -use crate::batch_add_error::BatchAddError; -use crate::mtmd::MtmdEvalError; -use crate::mtmd::mtmd_input_chunk_type::MtmdInputChunkTypeError; +pub mod apply_chat_template_error; +pub mod bracketed_args_failure; +pub mod chat_template_error; +pub mod decode_error; +pub mod embeddings_error; +pub mod encode_error; +pub mod eval_multimodal_chunks_error; +pub mod fit_error; +pub mod grammar_error; +pub mod json_object_failure; +pub mod key_value_xml_tags_failure; +pub mod llama_context_load_error; +pub mod llama_cpp_error; +pub mod llama_lora_adapter_init_error; +pub mod llama_lora_adapter_remove_error; +pub mod llama_lora_adapter_set_error; +pub mod llama_model_load_error; +pub mod logits_error; +pub mod marker_detection_error; +pub mod meta_val_error; +pub mod model_params_error; +pub mod new_llama_chat_message_error; +pub mod paired_quote_failure; +pub mod parse_chat_message_error; +pub mod sample_error; +pub mod sampler_accept_error; +pub mod sampling_error; +pub mod string_to_token_error; +pub mod token_sampling_error; +pub mod token_to_string_error; +pub mod tool_call_format_failure; +pub mod xml_function_tags_failure; + +pub use apply_chat_template_error::ApplyChatTemplateError; +pub use bracketed_args_failure::BracketedArgsFailure; +pub use chat_template_error::ChatTemplateError; +pub use decode_error::DecodeError; +pub use embeddings_error::EmbeddingsError; +pub use encode_error::EncodeError; +pub use eval_multimodal_chunks_error::EvalMultimodalChunksError; +pub use fit_error::FitError; +pub use grammar_error::GrammarError; +pub use json_object_failure::JsonObjectFailure; +pub use key_value_xml_tags_failure::KeyValueXmlTagsFailure; +pub use llama_context_load_error::LlamaContextLoadError; +pub use llama_cpp_error::LlamaCppError; +pub use llama_lora_adapter_init_error::LlamaLoraAdapterInitError; +pub use llama_lora_adapter_remove_error::LlamaLoraAdapterRemoveError; +pub use llama_lora_adapter_set_error::LlamaLoraAdapterSetError; +pub use llama_model_load_error::LlamaModelLoadError; +pub use logits_error::LogitsError; +pub use marker_detection_error::MarkerDetectionError; +pub use meta_val_error::MetaValError; +pub use model_params_error::ModelParamsError; +pub use new_llama_chat_message_error::NewLlamaChatMessageError; +pub use paired_quote_failure::PairedQuoteFailure; +pub use parse_chat_message_error::ParseChatMessageError; +pub use sample_error::SampleError; +pub use sampler_accept_error::SamplerAcceptError; +pub use sampling_error::SamplingError; +pub use string_to_token_error::StringToTokenError; +pub use token_sampling_error::TokenSamplingError; +pub use token_to_string_error::TokenToStringError; +pub use tool_call_format_failure::ToolCallFormatFailure; +pub use xml_function_tags_failure::XmlFunctionTagsFailure; /// A failable result from a llama.cpp function. pub type Result = std::result::Result; - -/// All errors that can occur in the llama-cpp crate. -#[derive(Debug, Eq, PartialEq, thiserror::Error)] -pub enum LlamaCppError { - /// The backend was already initialized. This can generally be ignored as initializing the backend - /// is idempotent. - #[error("BackendAlreadyInitialized")] - BackendAlreadyInitialized, - /// There was an error while get the chat template from model. - #[error("{0}")] - ChatTemplateError(#[from] ChatTemplateError), - /// There was an error while decoding a batch. - #[error("{0}")] - DecodeError(#[from] DecodeError), - /// There was an error while encoding a batch. - #[error("{0}")] - EncodeError(#[from] EncodeError), - /// There was an error loading a model. - #[error("{0}")] - LlamaModelLoadError(#[from] LlamaModelLoadError), - /// There was an error creating a new model context. - #[error("{0}")] - LlamaContextLoadError(#[from] LlamaContextLoadError), - /// There was an error adding a token to a batch. - #[error["{0}"]] - BatchAddError(#[from] BatchAddError), - /// see [`EmbeddingsError`] - #[error(transparent)] - EmbeddingError(#[from] EmbeddingsError), - /// Backend device not found - #[error("Backend device {0} not found")] - BackendDeviceNotFound(usize), - /// Max devices exceeded - #[error("Max devices exceeded. Max devices is {0}")] - MaxDevicesExceeded(usize), - /// Failed to convert JSON schema to grammar. - #[error("JsonSchemaToGrammarError: {0}")] - JsonSchemaToGrammarError(String), - /// see [`FitError`] - #[error(transparent)] - FitError(#[from] FitError), -} - -/// There was an error while getting the chat template from a model. -#[derive(Debug, Eq, PartialEq, thiserror::Error)] -pub enum ChatTemplateError { - /// gguf has no chat template (by that name) - #[error("chat template not found - returned null pointer")] - MissingTemplate, - - /// chat template contained a null byte - #[error("null byte in string {0}")] - NullError(#[from] NulError), - - /// The chat template was not valid utf8. - #[error(transparent)] - Utf8Error(#[from] std::str::Utf8Error), -} - -/// Failed fetching metadata value -#[derive(Debug, Eq, PartialEq, thiserror::Error)] -pub enum MetaValError { - /// The provided string contains an unexpected null-byte - #[error("null byte in string {0}")] - NullError(#[from] NulError), - - /// The returned data contains invalid UTF8 data - #[error("FromUtf8Error {0}")] - FromUtf8Error(#[from] FromUtf8Error), - - /// Got negative return value. This happens if the key or index queried does not exist. - #[error("Negative return value. Likely due to a missing index or key. Got return value: {0}")] - NegativeReturn(i32), -} - -/// Failed to Load context -#[derive(Debug, Eq, PartialEq, thiserror::Error)] -pub enum LlamaContextLoadError { - /// llama.cpp returned null - #[error("null reference from llama.cpp")] - NullReturn, -} - -/// Failed to decode a batch. -#[derive(Debug, Eq, PartialEq, thiserror::Error)] -pub enum DecodeError { - /// No kv cache slot was available. - #[error("Decode Error 1: NoKvCacheSlot")] - NoKvCacheSlot, - /// The computation was aborted by the abort callback. - #[error("Decode Error 2: Aborted")] - Aborted, - /// The number of tokens in the batch was 0. - #[error("Decode Error -1: n_tokens == 0")] - NTokensZero, - /// An unknown error occurred. - #[error("Decode Error {0}: unknown")] - Unknown(c_int), -} - -/// Failed to decode a batch. -#[derive(Debug, Eq, PartialEq, thiserror::Error)] -pub enum EncodeError { - /// No kv cache slot was available. - #[error("Encode Error 1: NoKvCacheSlot")] - NoKvCacheSlot, - /// The number of tokens in the batch was 0. - #[error("Encode Error -1: n_tokens == 0")] - NTokensZero, - /// An unknown error occurred. - #[error("Encode Error {0}: unknown")] - Unknown(c_int), -} - -/// When embedding related functions fail -#[derive(Debug, Eq, PartialEq, thiserror::Error)] -pub enum EmbeddingsError { - /// Embeddings weren't enabled in the context options - #[error("Embeddings weren't enabled in the context options")] - NotEnabled, - /// Logits weren't enabled for the given token - #[error("Logits were not enabled for the given token")] - LogitsNotEnabled, - /// The given sequence index exceeds the max sequence id - #[error("Can't use sequence embeddings with a model supporting only LLAMA_POOLING_TYPE_NONE")] - NonePoolType, - /// The embedding dimension does not fit into a usize. - #[error("Invalid embedding dimension: {0}")] - InvalidEmbeddingDimension(#[source] std::num::TryFromIntError), -} - -/// When logits-related functions fail -#[derive(Debug, Eq, PartialEq, thiserror::Error)] -pub enum LogitsError { - /// The logits data pointer is null. - #[error("logits data pointer is null")] - NullLogits, - /// The requested token index has not been initialized for logits. - #[error("logit for token index {0} is not initialized")] - TokenNotInitialized(i32), - /// The token index exceeds the context size. - #[error("token index {token_index} exceeds context size {context_size}")] - TokenIndexExceedsContext { - /// The token index that was requested. - token_index: u32, - /// The context size. - context_size: u32, - }, - /// The vocabulary size does not fit into a usize. - #[error("n_vocab does not fit into usize: {0}")] - VocabSizeOverflow(#[source] std::num::TryFromIntError), - /// The token index does not fit into a u32. - #[error("token_index does not fit into u32: {0}")] - TokenIndexOverflow(#[source] std::num::TryFromIntError), -} - -/// Errors that can occur when initializing a grammar sampler -#[derive(Debug, Eq, PartialEq, thiserror::Error)] -pub enum GrammarError { - /// The grammar root was not found in the grammar string - #[error("Grammar root not found in grammar string")] - RootNotFound, - /// The trigger word contains null bytes - #[error("Trigger word contains null bytes: {0}")] - TriggerWordNullBytes(NulError), - /// The grammar string or root contains null bytes - #[error("Grammar string or root contains null bytes: {0}")] - GrammarNullBytes(NulError), - /// A string contains null bytes - #[error("String contains null bytes: {0}")] - NulError(#[from] NulError), - /// The grammar call returned null - #[error("Grammar initialization failed: {0}")] - NullGrammar(String), - /// An integer value exceeded the allowed range - #[error("Integer overflow: {0}")] - IntegerOverflow(String), - /// An error from the llguidance library - #[error("llguidance error: {0}")] - LlguidanceError(String), -} - -/// Errors that can occur when creating a sampling configuration. -#[derive(Debug, Eq, PartialEq, thiserror::Error)] -pub enum SamplingError { - /// An integer value exceeded the allowed range - #[error("Integer overflow: {0}")] - IntegerOverflow(String), -} - -/// Errors that can occur when sampling a token. -#[derive(Debug, Eq, PartialEq, thiserror::Error)] -pub enum SampleError { - /// A C++ exception was thrown during sampling - #[error("C++ exception during sampling: {0}")] - CppException(String), - - /// An invalid argument was passed to the sampler - #[error("Invalid argument passed to sampler")] - InvalidArgument, -} - -/// Decode a error from llama.cpp into a [`DecodeError`]. -impl From for DecodeError { - fn from(value: NonZeroI32) -> Self { - match value.get() { - 1 => Self::NoKvCacheSlot, - 2 => Self::Aborted, - -1 => Self::NTokensZero, - error_code => Self::Unknown(error_code), - } - } -} - -/// Encode a error from llama.cpp into a [`EncodeError`]. -impl From for EncodeError { - fn from(value: NonZeroI32) -> Self { - match value.get() { - 1 => Self::NoKvCacheSlot, - -1 => Self::NTokensZero, - error_code => Self::Unknown(error_code), - } - } -} - -/// An error that can occur when loading a model. -#[derive(Debug, Eq, PartialEq, thiserror::Error)] -pub enum LlamaModelLoadError { - /// There was a null byte in a provided string and thus it could not be converted to a C string. - #[error("null byte in string {0}")] - NullError(#[from] NulError), - /// llama.cpp returned a nullptr - this could be many different causes. - #[error("null result from llama cpp")] - NullResult, - /// Failed to convert the path to a rust str. This means the path was not valid unicode - #[error("failed to convert path {0} to str")] - PathToStrError(PathBuf), - /// The model file does not exist at the given path. - #[error("model file not found: {0}")] - FileNotFound(PathBuf), -} - -/// An error that can occur when loading a model. -#[derive(Debug, Eq, PartialEq, thiserror::Error)] -pub enum LlamaLoraAdapterInitError { - /// There was a null byte in a provided string and thus it could not be converted to a C string. - #[error("null byte in string {0}")] - NullError(#[from] NulError), - /// llama.cpp returned a nullptr - this could be many different causes. - #[error("null result from llama cpp")] - NullResult, - /// Failed to convert the path to a rust str. This means the path was not valid unicode - #[error("failed to convert path {0} to str")] - PathToStrError(PathBuf), - /// The adapter file does not exist at the given path. - #[error("adapter file not found: {0}")] - FileNotFound(PathBuf), -} - -/// An error that can occur when loading a model. -#[derive(Debug, Eq, PartialEq, thiserror::Error)] -pub enum LlamaLoraAdapterSetError { - /// llama.cpp returned a non-zero error code. - #[error("error code from llama cpp")] - ErrorResult(i32), -} - -/// An error that can occur when loading a model. -#[derive(Debug, Eq, PartialEq, thiserror::Error)] -pub enum LlamaLoraAdapterRemoveError { - /// llama.cpp returned a non-zero error code. - #[error("error code from llama cpp")] - ErrorResult(i32), -} - -/// An error that can occur when converting a token to a string. -#[derive(Debug, thiserror::Error, Clone)] -#[non_exhaustive] -pub enum TokenToStringError { - /// the token type was unknown - #[error("Unknown Token Type")] - UnknownTokenType, - /// There was insufficient buffer space to convert the token to a string. - #[error("Insufficient Buffer Space {0}")] - InsufficientBufferSpace(c_int), - /// The token was not valid utf8. - #[error("FromUtf8Error {0}")] - FromUtf8Error(#[from] FromUtf8Error), - /// An integer conversion failed. - #[error("Integer conversion error: {0}")] - IntConversionError(#[from] std::num::TryFromIntError), -} - -/// Failed to convert a string to a token sequence. -#[derive(Debug, thiserror::Error)] -pub enum StringToTokenError { - /// the string contained a null byte and thus could not be converted to a c string. - #[error("{0}")] - NulError(#[from] NulError), - #[error("{0}")] - /// Failed to convert a provided integer to a [`c_int`]. - CIntConversionError(#[from] std::num::TryFromIntError), -} - -/// Failed to apply model chat template. -#[derive(Debug, thiserror::Error)] -pub enum NewLlamaChatMessageError { - /// the string contained a null byte and thus could not be converted to a c string. - #[error("{0}")] - NulError(#[from] NulError), -} - -/// Failed to apply model chat template. -#[derive(Debug, thiserror::Error)] -pub enum ApplyChatTemplateError { - /// the string could not be converted to utf8. - #[error("{0}")] - FromUtf8Error(#[from] FromUtf8Error), - /// An integer conversion failed. - #[error("Integer conversion error: {0}")] - IntConversionError(#[from] std::num::TryFromIntError), -} - -/// Failed to detect tool-call diagnostic markers for a model. -#[derive(Debug, thiserror::Error)] -pub enum MarkerDetectionError { - /// llama.cpp returned an error code from the marker detection FFI call. - #[error("ffi error {0}")] - FfiError(i32), - /// The C++ side threw an exception during template analysis. - #[error("c++ exception during template analysis: {0}")] - AnalyzeException(String), - /// llama.cpp returned a marker string but its bytes were not valid UTF-8. - #[error("ffi returned non-utf8 marker bytes: {0}")] - MarkerUtf8Error(#[from] FromUtf8Error), -} - -/// Failed to parse a chat message via [`crate::Model::parse_chat_message`]. -#[derive(Debug, thiserror::Error)] -pub enum ParseChatMessageError { - /// llama.cpp returned an error code from the parse FFI call. - #[error("ffi error {0}")] - FfiError(i32), - /// The C++ side threw an exception while parsing. - #[error("c++ exception during chat parse: {0}")] - ParseException(String), - /// An accessor returned bytes that were not valid UTF-8. - #[error("ffi returned non-utf8 string: {0}")] - StringUtf8Error(#[from] FromUtf8Error), - /// The caller passed a `tools_json` argument that is not valid JSON. - #[error("tools_json is not valid JSON: {0}")] - ToolsJsonInvalid(#[source] serde_json::Error), - /// The caller passed a `tools_json` argument that parses as JSON but is not an array. - #[error("tools_json must be a JSON array")] - ToolsJsonNotArray, - /// Failed to serialize the tools array for the FFI call. - #[error("could not serialize tools to JSON: {0}")] - ToolsSerialization(String), - /// The model has no usable chat template, so the parser cannot be built. - #[error("model has no chat template")] - NoChatTemplate, - /// The wrapper-side fallback parser detected a structural issue while parsing the body. - #[error("template-override fallback parser failed: {0}")] - TemplateOverrideFailed(#[from] ToolCallFormatFailure), -} - -/// Top-level failure for the wrapper-side template-override parsers (one variant per supported shape). -#[derive(Debug, thiserror::Error)] -pub enum ToolCallFormatFailure { - #[error("bracketed-args fallback parser: {0}")] - BracketedArgs(#[from] BracketedArgsFailure), - #[error("json-object fallback parser: {0}")] - JsonObject(#[from] JsonObjectFailure), - #[error("key-value-xml-tags fallback parser: {0}")] - KeyValueXmlTags(#[from] KeyValueXmlTagsFailure), - #[error("paired-quote fallback parser: {0}")] - PairedQuote(#[from] PairedQuoteFailure), - #[error("xml-function-tags fallback parser: {0}")] - XmlFunctionTags(#[from] XmlFunctionTagsFailure), -} - -/// Failures specific to the JSON-object args parser (Qwen 3 `{"name":..., "arguments":...}`). -#[derive(Debug, thiserror::Error)] -pub enum JsonObjectFailure { - #[error("tool call body has malformed JSON: {message}")] - InvalidJson { message: String }, -} - -/// Failures specific to the bracketed-JSON args parser (Mistral 3 `[TOOL_CALLS]name[ARGS]{...}`). -#[derive(Debug, thiserror::Error)] -pub enum BracketedArgsFailure { - #[error("tool call '{tool_name}' arguments are not valid JSON: {message}")] - InvalidJsonArguments { tool_name: String, message: String }, - #[error("tool call '{tool_name}' arguments truncated before JSON value completed")] - UnterminatedArguments { tool_name: String }, -} - -/// Failures specific to the paired-quote args parser (Gemma 4 `<|tool_call>call:name{key:<|"|>val<|"|>}`). -#[derive(Debug, thiserror::Error)] -pub enum PairedQuoteFailure { - #[error("empty key in tool call '{tool_name}' arguments")] - EmptyKey { tool_name: String }, - #[error("tool call '{tool_name}' translated arguments are not valid JSON: {message}")] - InvalidJsonArguments { tool_name: String, message: String }, - #[error("tool call '{tool_name}' has unclosed quoted value for key '{key}'")] - UnclosedQuotedValue { tool_name: String, key: String }, - #[error("tool call '{tool_name}' arguments ended without close marker (state: {state})")] - UnclosedArgumentBlock { - tool_name: String, - state: &'static str, - }, - #[error( - "tool call '{tool_name}' has unexpected character '{character}' after value for key '{key}'" - )] - UnexpectedCharAfterValue { - tool_name: String, - key: String, - character: char, - }, -} - -/// Failures specific to the key-value XML-tags parser (GLM-4.7 `{name}{k}{v}...`). -#[derive(Debug, thiserror::Error)] -pub enum KeyValueXmlTagsFailure { - #[error("tool call function tag has empty name")] - EmptyFunctionName, - #[error("tool call function block is missing close tag '{expected_close}'")] - UnclosedFunctionBlock { expected_close: String }, - #[error("tool call function '{function_name}' has key tag with empty content")] - EmptyKey { function_name: String }, - #[error("tool call function '{function_name}' is missing key close tag '{expected_close}'")] - UnclosedKeyTag { - function_name: String, - expected_close: String, - }, - #[error( - "tool call function '{function_name}' key '{key}' is missing value open tag '{expected_open}'" - )] - MissingValueTag { - function_name: String, - key: String, - expected_open: String, - }, - #[error( - "tool call function '{function_name}' key '{key}' is missing value close tag '{expected_close}'" - )] - UnclosedValueTag { - function_name: String, - key: String, - expected_close: String, - }, -} - -/// Failures specific to the XML function-tags parser (Qwen 3.5+ `val`). -#[derive(Debug, thiserror::Error)] -pub enum XmlFunctionTagsFailure { - #[error("tool call function tag has empty name")] - EmptyFunctionName, - #[error("tool call function '{function_name}' is missing close tag '{expected_close}'")] - UnclosedFunctionBlock { - function_name: String, - expected_close: String, - }, - #[error("tool call function '{function_name}' has parameter with empty name")] - EmptyParameterName { function_name: String }, - #[error( - "tool call function '{function_name}' parameter '{parameter_name}' is missing close tag '{expected_close}'" - )] - UnclosedParameterBlock { - function_name: String, - parameter_name: String, - expected_close: String, - }, -} - -/// Failed to evaluate multimodal chunks through the request classifier. -#[derive(Debug, thiserror::Error)] -pub enum EvalMultimodalChunksError { - /// `MtmdInputChunks::eval_chunks` returned an error. - #[error("{0}")] - EvalFailed(#[from] MtmdEvalError), - /// A chunk reported a type that is not known to this binding. - #[error("{0}")] - UnknownChunkType(#[from] MtmdInputChunkTypeError), - /// A chunk index that was within `chunks.len()` returned `None` from `chunks.get(index)`. - #[error("chunk index {0} out of bounds during post-eval walk")] - ChunkOutOfBounds(usize), -} - -/// Failed to accept a token in a sampler. -#[derive(Debug, thiserror::Error)] -pub enum SamplerAcceptError { - /// A C++ exception was thrown during accept - #[error("C++ exception during sampler accept: {0}")] - CppException(String), - - /// An invalid argument was passed (null sampler or null error pointer) - #[error("Invalid argument passed to sampler accept")] - InvalidArgument, -} - -/// Errors that can occur when modifying model parameters. -#[derive(Debug, Eq, PartialEq, thiserror::Error)] -pub enum ModelParamsError { - /// The internal override vector has no available slot. - #[error("No available slot in override vector")] - NoAvailableSlot, - /// The first override slot is not empty. - #[error("Override slot is not empty")] - SlotNotEmpty, - /// A character in the key is not a valid C char. - #[error("Invalid character in key: byte {byte}, {reason}")] - InvalidCharacterInKey { - /// The byte value that failed conversion. - byte: u8, - /// The reason the conversion failed. - reason: String, - }, -} - -/// Failed to sample a token from the data array. -#[derive(Debug, Eq, PartialEq, thiserror::Error)] -pub enum TokenSamplingError { - /// The sampler did not select any token. - #[error("No token was selected by the sampler")] - NoTokenSelected, -} - -/// Returned by [`crate::model::params::LlamaModelParams::fit_params`]. -#[derive(Debug, Clone, Copy, Eq, PartialEq, thiserror::Error)] -pub enum FitError { - /// Could not find allocations that fit available memory. - #[error("could not find allocations that fit available memory")] - Failure, - /// A hard error occurred during fitting (e.g. model not found at the specified path, - /// or the C++ wrapper threw an exception). - #[error("hard error during parameter fitting")] - Error, -} - -#[cfg(test)] -mod tests { - use std::num::NonZeroI32; - - use super::{DecodeError, EncodeError}; - - #[test] - fn decode_error_no_kv_cache_slot() { - let error = DecodeError::from(NonZeroI32::new(1).expect("1 is non-zero")); - - assert_eq!(error, DecodeError::NoKvCacheSlot); - assert_eq!(error.to_string(), "Decode Error 1: NoKvCacheSlot"); - } - - #[test] - fn decode_error_n_tokens_zero() { - let error = DecodeError::from(NonZeroI32::new(-1).expect("-1 is non-zero")); - - assert_eq!(error, DecodeError::NTokensZero); - assert_eq!(error.to_string(), "Decode Error -1: n_tokens == 0"); - } - - #[test] - fn decode_error_aborted() { - let error = DecodeError::from(NonZeroI32::new(2).expect("2 is non-zero")); - - assert_eq!(error, DecodeError::Aborted); - assert_eq!(error.to_string(), "Decode Error 2: Aborted"); - } - - #[test] - fn decode_error_unknown() { - let error = DecodeError::from(NonZeroI32::new(42).expect("42 is non-zero")); - - assert_eq!(error, DecodeError::Unknown(42)); - assert_eq!(error.to_string(), "Decode Error 42: unknown"); - } - - #[test] - fn encode_error_no_kv_cache_slot() { - let error = EncodeError::from(NonZeroI32::new(1).expect("1 is non-zero")); - - assert_eq!(error, EncodeError::NoKvCacheSlot); - assert_eq!(error.to_string(), "Encode Error 1: NoKvCacheSlot"); - } - - #[test] - fn encode_error_n_tokens_zero() { - let error = EncodeError::from(NonZeroI32::new(-1).expect("-1 is non-zero")); - - assert_eq!(error, EncodeError::NTokensZero); - assert_eq!(error.to_string(), "Encode Error -1: n_tokens == 0"); - } - - #[test] - fn encode_error_unknown() { - let error = EncodeError::from(NonZeroI32::new(99).expect("99 is non-zero")); - - assert_eq!(error, EncodeError::Unknown(99)); - assert_eq!(error.to_string(), "Encode Error 99: unknown"); - } -} diff --git a/llama-cpp-bindings/src/error/apply_chat_template_error.rs b/llama-cpp-bindings/src/error/apply_chat_template_error.rs new file mode 100644 index 00000000..251dda35 --- /dev/null +++ b/llama-cpp-bindings/src/error/apply_chat_template_error.rs @@ -0,0 +1,12 @@ +use std::string::FromUtf8Error; + +/// Failed to apply model chat template. +#[derive(Debug, thiserror::Error)] +pub enum ApplyChatTemplateError { + /// the string could not be converted to utf8. + #[error("{0}")] + FromUtf8Error(#[from] FromUtf8Error), + /// An integer conversion failed. + #[error("Integer conversion error: {0}")] + IntConversionError(#[from] std::num::TryFromIntError), +} diff --git a/llama-cpp-bindings/src/error/bracketed_args_failure.rs b/llama-cpp-bindings/src/error/bracketed_args_failure.rs new file mode 100644 index 00000000..8750a9be --- /dev/null +++ b/llama-cpp-bindings/src/error/bracketed_args_failure.rs @@ -0,0 +1,8 @@ +/// Failures specific to the bracketed-JSON args parser (Mistral 3 `[TOOL_CALLS]name[ARGS]{...}`). +#[derive(Debug, thiserror::Error)] +pub enum BracketedArgsFailure { + #[error("tool call '{tool_name}' arguments are not valid JSON: {message}")] + InvalidJsonArguments { tool_name: String, message: String }, + #[error("tool call '{tool_name}' arguments truncated before JSON value completed")] + UnterminatedArguments { tool_name: String }, +} diff --git a/llama-cpp-bindings/src/error/chat_template_error.rs b/llama-cpp-bindings/src/error/chat_template_error.rs new file mode 100644 index 00000000..190b96fa --- /dev/null +++ b/llama-cpp-bindings/src/error/chat_template_error.rs @@ -0,0 +1,17 @@ +use std::ffi::NulError; + +/// There was an error while getting the chat template from a model. +#[derive(Debug, Eq, PartialEq, thiserror::Error)] +pub enum ChatTemplateError { + /// gguf has no chat template (by that name) + #[error("chat template not found - returned null pointer")] + MissingTemplate, + + /// chat template contained a null byte + #[error("null byte in string {0}")] + NullError(#[from] NulError), + + /// The chat template was not valid utf8. + #[error(transparent)] + Utf8Error(#[from] std::str::Utf8Error), +} diff --git a/llama-cpp-bindings/src/error/decode_error.rs b/llama-cpp-bindings/src/error/decode_error.rs new file mode 100644 index 00000000..1a404605 --- /dev/null +++ b/llama-cpp-bindings/src/error/decode_error.rs @@ -0,0 +1,70 @@ +use std::num::NonZeroI32; +use std::os::raw::c_int; + +/// Failed to decode a batch. +#[derive(Debug, Eq, PartialEq, thiserror::Error)] +pub enum DecodeError { + /// No kv cache slot was available. + #[error("Decode Error 1: NoKvCacheSlot")] + NoKvCacheSlot, + /// The computation was aborted by the abort callback. + #[error("Decode Error 2: Aborted")] + Aborted, + /// The number of tokens in the batch was 0. + #[error("Decode Error -1: n_tokens == 0")] + NTokensZero, + /// An unknown error occurred. + #[error("Decode Error {0}: unknown")] + Unknown(c_int), +} + +/// Decode a error from llama.cpp into a [`DecodeError`]. +impl From for DecodeError { + fn from(value: NonZeroI32) -> Self { + match value.get() { + 1 => Self::NoKvCacheSlot, + 2 => Self::Aborted, + -1 => Self::NTokensZero, + error_code => Self::Unknown(error_code), + } + } +} + +#[cfg(test)] +mod tests { + use std::num::NonZeroI32; + + use super::DecodeError; + + #[test] + fn decode_error_no_kv_cache_slot() { + let error = DecodeError::from(NonZeroI32::new(1).expect("1 is non-zero")); + + assert_eq!(error, DecodeError::NoKvCacheSlot); + assert_eq!(error.to_string(), "Decode Error 1: NoKvCacheSlot"); + } + + #[test] + fn decode_error_n_tokens_zero() { + let error = DecodeError::from(NonZeroI32::new(-1).expect("-1 is non-zero")); + + assert_eq!(error, DecodeError::NTokensZero); + assert_eq!(error.to_string(), "Decode Error -1: n_tokens == 0"); + } + + #[test] + fn decode_error_aborted() { + let error = DecodeError::from(NonZeroI32::new(2).expect("2 is non-zero")); + + assert_eq!(error, DecodeError::Aborted); + assert_eq!(error.to_string(), "Decode Error 2: Aborted"); + } + + #[test] + fn decode_error_unknown() { + let error = DecodeError::from(NonZeroI32::new(42).expect("42 is non-zero")); + + assert_eq!(error, DecodeError::Unknown(42)); + assert_eq!(error.to_string(), "Decode Error 42: unknown"); + } +} diff --git a/llama-cpp-bindings/src/error/embeddings_error.rs b/llama-cpp-bindings/src/error/embeddings_error.rs new file mode 100644 index 00000000..a01bb428 --- /dev/null +++ b/llama-cpp-bindings/src/error/embeddings_error.rs @@ -0,0 +1,16 @@ +/// When embedding related functions fail +#[derive(Debug, Eq, PartialEq, thiserror::Error)] +pub enum EmbeddingsError { + /// Embeddings weren't enabled in the context options + #[error("Embeddings weren't enabled in the context options")] + NotEnabled, + /// Logits weren't enabled for the given token + #[error("Logits were not enabled for the given token")] + LogitsNotEnabled, + /// The given sequence index exceeds the max sequence id + #[error("Can't use sequence embeddings with a model supporting only LLAMA_POOLING_TYPE_NONE")] + NonePoolType, + /// The embedding dimension does not fit into a usize. + #[error("Invalid embedding dimension: {0}")] + InvalidEmbeddingDimension(#[source] std::num::TryFromIntError), +} diff --git a/llama-cpp-bindings/src/error/encode_error.rs b/llama-cpp-bindings/src/error/encode_error.rs new file mode 100644 index 00000000..33999d61 --- /dev/null +++ b/llama-cpp-bindings/src/error/encode_error.rs @@ -0,0 +1,58 @@ +use std::num::NonZeroI32; +use std::os::raw::c_int; + +/// Failed to decode a batch. +#[derive(Debug, Eq, PartialEq, thiserror::Error)] +pub enum EncodeError { + /// No kv cache slot was available. + #[error("Encode Error 1: NoKvCacheSlot")] + NoKvCacheSlot, + /// The number of tokens in the batch was 0. + #[error("Encode Error -1: n_tokens == 0")] + NTokensZero, + /// An unknown error occurred. + #[error("Encode Error {0}: unknown")] + Unknown(c_int), +} + +/// Encode a error from llama.cpp into a [`EncodeError`]. +impl From for EncodeError { + fn from(value: NonZeroI32) -> Self { + match value.get() { + 1 => Self::NoKvCacheSlot, + -1 => Self::NTokensZero, + error_code => Self::Unknown(error_code), + } + } +} + +#[cfg(test)] +mod tests { + use std::num::NonZeroI32; + + use super::EncodeError; + + #[test] + fn encode_error_no_kv_cache_slot() { + let error = EncodeError::from(NonZeroI32::new(1).expect("1 is non-zero")); + + assert_eq!(error, EncodeError::NoKvCacheSlot); + assert_eq!(error.to_string(), "Encode Error 1: NoKvCacheSlot"); + } + + #[test] + fn encode_error_n_tokens_zero() { + let error = EncodeError::from(NonZeroI32::new(-1).expect("-1 is non-zero")); + + assert_eq!(error, EncodeError::NTokensZero); + assert_eq!(error.to_string(), "Encode Error -1: n_tokens == 0"); + } + + #[test] + fn encode_error_unknown() { + let error = EncodeError::from(NonZeroI32::new(99).expect("99 is non-zero")); + + assert_eq!(error, EncodeError::Unknown(99)); + assert_eq!(error.to_string(), "Encode Error 99: unknown"); + } +} diff --git a/llama-cpp-bindings/src/error/eval_multimodal_chunks_error.rs b/llama-cpp-bindings/src/error/eval_multimodal_chunks_error.rs new file mode 100644 index 00000000..146bcedb --- /dev/null +++ b/llama-cpp-bindings/src/error/eval_multimodal_chunks_error.rs @@ -0,0 +1,16 @@ +use crate::mtmd::MtmdEvalError; +use crate::mtmd::mtmd_input_chunk_type_error::MtmdInputChunkTypeError; + +/// Failed to evaluate multimodal chunks through the request classifier. +#[derive(Debug, thiserror::Error)] +pub enum EvalMultimodalChunksError { + /// `MtmdInputChunks::eval_chunks` returned an error. + #[error("{0}")] + EvalFailed(#[from] MtmdEvalError), + /// A chunk reported a type that is not known to this binding. + #[error("{0}")] + UnknownChunkType(#[from] MtmdInputChunkTypeError), + /// A chunk index that was within `chunks.len()` returned `None` from `chunks.get(index)`. + #[error("chunk index {0} out of bounds during post-eval walk")] + ChunkOutOfBounds(usize), +} diff --git a/llama-cpp-bindings/src/error/fit_error.rs b/llama-cpp-bindings/src/error/fit_error.rs new file mode 100644 index 00000000..7585530d --- /dev/null +++ b/llama-cpp-bindings/src/error/fit_error.rs @@ -0,0 +1,11 @@ +/// Returned by [`crate::model::params::LlamaModelParams::fit_params`]. +#[derive(Debug, Clone, Copy, Eq, PartialEq, thiserror::Error)] +pub enum FitError { + /// Could not find allocations that fit available memory. + #[error("could not find allocations that fit available memory")] + Failure, + /// A hard error occurred during fitting (e.g. model not found at the specified path, + /// or the C++ wrapper threw an exception). + #[error("hard error during parameter fitting")] + Error, +} diff --git a/llama-cpp-bindings/src/error/grammar_error.rs b/llama-cpp-bindings/src/error/grammar_error.rs new file mode 100644 index 00000000..58216b8c --- /dev/null +++ b/llama-cpp-bindings/src/error/grammar_error.rs @@ -0,0 +1,27 @@ +use std::ffi::NulError; + +/// Errors that can occur when initializing a grammar sampler +#[derive(Debug, Eq, PartialEq, thiserror::Error)] +pub enum GrammarError { + /// The grammar root was not found in the grammar string + #[error("Grammar root not found in grammar string")] + RootNotFound, + /// The trigger word contains null bytes + #[error("Trigger word contains null bytes: {0}")] + TriggerWordNullBytes(NulError), + /// The grammar string or root contains null bytes + #[error("Grammar string or root contains null bytes: {0}")] + GrammarNullBytes(NulError), + /// A string contains null bytes + #[error("String contains null bytes: {0}")] + NulError(#[from] NulError), + /// The grammar call returned null + #[error("Grammar initialization failed: {0}")] + NullGrammar(String), + /// An integer value exceeded the allowed range + #[error("Integer overflow: {0}")] + IntegerOverflow(String), + /// An error from the llguidance library + #[error("llguidance error: {0}")] + LlguidanceError(String), +} diff --git a/llama-cpp-bindings/src/error/json_object_failure.rs b/llama-cpp-bindings/src/error/json_object_failure.rs new file mode 100644 index 00000000..b5d88570 --- /dev/null +++ b/llama-cpp-bindings/src/error/json_object_failure.rs @@ -0,0 +1,6 @@ +/// Failures specific to the JSON-object args parser (Qwen 3 `{"name":..., "arguments":...}`). +#[derive(Debug, thiserror::Error)] +pub enum JsonObjectFailure { + #[error("tool call body has malformed JSON: {message}")] + InvalidJson { message: String }, +} diff --git a/llama-cpp-bindings/src/error/key_value_xml_tags_failure.rs b/llama-cpp-bindings/src/error/key_value_xml_tags_failure.rs new file mode 100644 index 00000000..3c46093a --- /dev/null +++ b/llama-cpp-bindings/src/error/key_value_xml_tags_failure.rs @@ -0,0 +1,31 @@ +/// Failures specific to the key-value XML-tags parser (GLM-4.7 `{name}{k}{v}...`). +#[derive(Debug, thiserror::Error)] +pub enum KeyValueXmlTagsFailure { + #[error("tool call function tag has empty name")] + EmptyFunctionName, + #[error("tool call function block is missing close tag '{expected_close}'")] + UnclosedFunctionBlock { expected_close: String }, + #[error("tool call function '{function_name}' has key tag with empty content")] + EmptyKey { function_name: String }, + #[error("tool call function '{function_name}' is missing key close tag '{expected_close}'")] + UnclosedKeyTag { + function_name: String, + expected_close: String, + }, + #[error( + "tool call function '{function_name}' key '{key}' is missing value open tag '{expected_open}'" + )] + MissingValueTag { + function_name: String, + key: String, + expected_open: String, + }, + #[error( + "tool call function '{function_name}' key '{key}' is missing value close tag '{expected_close}'" + )] + UnclosedValueTag { + function_name: String, + key: String, + expected_close: String, + }, +} diff --git a/llama-cpp-bindings/src/error/llama_context_load_error.rs b/llama-cpp-bindings/src/error/llama_context_load_error.rs new file mode 100644 index 00000000..752c88af --- /dev/null +++ b/llama-cpp-bindings/src/error/llama_context_load_error.rs @@ -0,0 +1,7 @@ +/// Failed to Load context +#[derive(Debug, Eq, PartialEq, thiserror::Error)] +pub enum LlamaContextLoadError { + /// llama.cpp returned null + #[error("null reference from llama.cpp")] + NullReturn, +} diff --git a/llama-cpp-bindings/src/error/llama_cpp_error.rs b/llama-cpp-bindings/src/error/llama_cpp_error.rs new file mode 100644 index 00000000..b99fefdd --- /dev/null +++ b/llama-cpp-bindings/src/error/llama_cpp_error.rs @@ -0,0 +1,50 @@ +use crate::batch_add_error::BatchAddError; +use crate::error::chat_template_error::ChatTemplateError; +use crate::error::decode_error::DecodeError; +use crate::error::embeddings_error::EmbeddingsError; +use crate::error::encode_error::EncodeError; +use crate::error::fit_error::FitError; +use crate::error::llama_context_load_error::LlamaContextLoadError; +use crate::error::llama_model_load_error::LlamaModelLoadError; + +/// All errors that can occur in the llama-cpp crate. +#[derive(Debug, Eq, PartialEq, thiserror::Error)] +pub enum LlamaCppError { + /// The backend was already initialized. This can generally be ignored as initializing the backend + /// is idempotent. + #[error("BackendAlreadyInitialized")] + BackendAlreadyInitialized, + /// There was an error while get the chat template from model. + #[error("{0}")] + ChatTemplateError(#[from] ChatTemplateError), + /// There was an error while decoding a batch. + #[error("{0}")] + DecodeError(#[from] DecodeError), + /// There was an error while encoding a batch. + #[error("{0}")] + EncodeError(#[from] EncodeError), + /// There was an error loading a model. + #[error("{0}")] + LlamaModelLoadError(#[from] LlamaModelLoadError), + /// There was an error creating a new model context. + #[error("{0}")] + LlamaContextLoadError(#[from] LlamaContextLoadError), + /// There was an error adding a token to a batch. + #[error["{0}"]] + BatchAddError(#[from] BatchAddError), + /// see [`EmbeddingsError`] + #[error(transparent)] + EmbeddingError(#[from] EmbeddingsError), + /// Backend device not found + #[error("Backend device {0} not found")] + BackendDeviceNotFound(usize), + /// Max devices exceeded + #[error("Max devices exceeded. Max devices is {0}")] + MaxDevicesExceeded(usize), + /// Failed to convert JSON schema to grammar. + #[error("JsonSchemaToGrammarError: {0}")] + JsonSchemaToGrammarError(String), + /// see [`FitError`] + #[error(transparent)] + FitError(#[from] FitError), +} diff --git a/llama-cpp-bindings/src/error/llama_lora_adapter_init_error.rs b/llama-cpp-bindings/src/error/llama_lora_adapter_init_error.rs new file mode 100644 index 00000000..9a294994 --- /dev/null +++ b/llama-cpp-bindings/src/error/llama_lora_adapter_init_error.rs @@ -0,0 +1,19 @@ +use std::ffi::NulError; +use std::path::PathBuf; + +/// An error that can occur when loading a model. +#[derive(Debug, Eq, PartialEq, thiserror::Error)] +pub enum LlamaLoraAdapterInitError { + /// There was a null byte in a provided string and thus it could not be converted to a C string. + #[error("null byte in string {0}")] + NullError(#[from] NulError), + /// llama.cpp returned a nullptr - this could be many different causes. + #[error("null result from llama cpp")] + NullResult, + /// Failed to convert the path to a rust str. This means the path was not valid unicode + #[error("failed to convert path {0} to str")] + PathToStrError(PathBuf), + /// The adapter file does not exist at the given path. + #[error("adapter file not found: {0}")] + FileNotFound(PathBuf), +} diff --git a/llama-cpp-bindings/src/error/llama_lora_adapter_remove_error.rs b/llama-cpp-bindings/src/error/llama_lora_adapter_remove_error.rs new file mode 100644 index 00000000..3d536c4a --- /dev/null +++ b/llama-cpp-bindings/src/error/llama_lora_adapter_remove_error.rs @@ -0,0 +1,7 @@ +/// An error that can occur when loading a model. +#[derive(Debug, Eq, PartialEq, thiserror::Error)] +pub enum LlamaLoraAdapterRemoveError { + /// llama.cpp returned a non-zero error code. + #[error("error code from llama cpp")] + ErrorResult(i32), +} diff --git a/llama-cpp-bindings/src/error/llama_lora_adapter_set_error.rs b/llama-cpp-bindings/src/error/llama_lora_adapter_set_error.rs new file mode 100644 index 00000000..362f6ca1 --- /dev/null +++ b/llama-cpp-bindings/src/error/llama_lora_adapter_set_error.rs @@ -0,0 +1,7 @@ +/// An error that can occur when loading a model. +#[derive(Debug, Eq, PartialEq, thiserror::Error)] +pub enum LlamaLoraAdapterSetError { + /// llama.cpp returned a non-zero error code. + #[error("error code from llama cpp")] + ErrorResult(i32), +} diff --git a/llama-cpp-bindings/src/error/llama_model_load_error.rs b/llama-cpp-bindings/src/error/llama_model_load_error.rs new file mode 100644 index 00000000..a7b24012 --- /dev/null +++ b/llama-cpp-bindings/src/error/llama_model_load_error.rs @@ -0,0 +1,19 @@ +use std::ffi::NulError; +use std::path::PathBuf; + +/// An error that can occur when loading a model. +#[derive(Debug, Eq, PartialEq, thiserror::Error)] +pub enum LlamaModelLoadError { + /// There was a null byte in a provided string and thus it could not be converted to a C string. + #[error("null byte in string {0}")] + NullError(#[from] NulError), + /// llama.cpp returned a nullptr - this could be many different causes. + #[error("null result from llama cpp")] + NullResult, + /// Failed to convert the path to a rust str. This means the path was not valid unicode + #[error("failed to convert path {0} to str")] + PathToStrError(PathBuf), + /// The model file does not exist at the given path. + #[error("model file not found: {0}")] + FileNotFound(PathBuf), +} diff --git a/llama-cpp-bindings/src/error/logits_error.rs b/llama-cpp-bindings/src/error/logits_error.rs new file mode 100644 index 00000000..f6a198d2 --- /dev/null +++ b/llama-cpp-bindings/src/error/logits_error.rs @@ -0,0 +1,24 @@ +/// When logits-related functions fail +#[derive(Debug, Eq, PartialEq, thiserror::Error)] +pub enum LogitsError { + /// The logits data pointer is null. + #[error("logits data pointer is null")] + NullLogits, + /// The requested token index has not been initialized for logits. + #[error("logit for token index {0} is not initialized")] + TokenNotInitialized(i32), + /// The token index exceeds the context size. + #[error("token index {token_index} exceeds context size {context_size}")] + TokenIndexExceedsContext { + /// The token index that was requested. + token_index: u32, + /// The context size. + context_size: u32, + }, + /// The vocabulary size does not fit into a usize. + #[error("n_vocab does not fit into usize: {0}")] + VocabSizeOverflow(#[source] std::num::TryFromIntError), + /// The token index does not fit into a u32. + #[error("token_index does not fit into u32: {0}")] + TokenIndexOverflow(#[source] std::num::TryFromIntError), +} diff --git a/llama-cpp-bindings/src/error/marker_detection_error.rs b/llama-cpp-bindings/src/error/marker_detection_error.rs new file mode 100644 index 00000000..aa755878 --- /dev/null +++ b/llama-cpp-bindings/src/error/marker_detection_error.rs @@ -0,0 +1,15 @@ +use std::string::FromUtf8Error; + +/// Failed to detect tool-call diagnostic markers for a model. +#[derive(Debug, thiserror::Error)] +pub enum MarkerDetectionError { + /// llama.cpp returned an error code from the marker detection FFI call. + #[error("ffi error {0}")] + FfiError(i32), + /// The C++ side threw an exception during template analysis. + #[error("c++ exception during template analysis: {0}")] + AnalyzeException(String), + /// llama.cpp returned a marker string but its bytes were not valid UTF-8. + #[error("ffi returned non-utf8 marker bytes: {0}")] + MarkerUtf8Error(#[from] FromUtf8Error), +} diff --git a/llama-cpp-bindings/src/error/meta_val_error.rs b/llama-cpp-bindings/src/error/meta_val_error.rs new file mode 100644 index 00000000..30b07223 --- /dev/null +++ b/llama-cpp-bindings/src/error/meta_val_error.rs @@ -0,0 +1,18 @@ +use std::ffi::NulError; +use std::string::FromUtf8Error; + +/// Failed fetching metadata value +#[derive(Debug, Eq, PartialEq, thiserror::Error)] +pub enum MetaValError { + /// The provided string contains an unexpected null-byte + #[error("null byte in string {0}")] + NullError(#[from] NulError), + + /// The returned data contains invalid UTF8 data + #[error("FromUtf8Error {0}")] + FromUtf8Error(#[from] FromUtf8Error), + + /// Got negative return value. This happens if the key or index queried does not exist. + #[error("Negative return value. Likely due to a missing index or key. Got return value: {0}")] + NegativeReturn(i32), +} diff --git a/llama-cpp-bindings/src/error/model_params_error.rs b/llama-cpp-bindings/src/error/model_params_error.rs new file mode 100644 index 00000000..377596f1 --- /dev/null +++ b/llama-cpp-bindings/src/error/model_params_error.rs @@ -0,0 +1,18 @@ +/// Errors that can occur when modifying model parameters. +#[derive(Debug, Eq, PartialEq, thiserror::Error)] +pub enum ModelParamsError { + /// The internal override vector has no available slot. + #[error("No available slot in override vector")] + NoAvailableSlot, + /// The first override slot is not empty. + #[error("Override slot is not empty")] + SlotNotEmpty, + /// A character in the key is not a valid C char. + #[error("Invalid character in key: byte {byte}, {reason}")] + InvalidCharacterInKey { + /// The byte value that failed conversion. + byte: u8, + /// The reason the conversion failed. + reason: String, + }, +} diff --git a/llama-cpp-bindings/src/error/new_llama_chat_message_error.rs b/llama-cpp-bindings/src/error/new_llama_chat_message_error.rs new file mode 100644 index 00000000..c7076486 --- /dev/null +++ b/llama-cpp-bindings/src/error/new_llama_chat_message_error.rs @@ -0,0 +1,9 @@ +use std::ffi::NulError; + +/// Failed to apply model chat template. +#[derive(Debug, thiserror::Error)] +pub enum NewLlamaChatMessageError { + /// the string contained a null byte and thus could not be converted to a c string. + #[error("{0}")] + NulError(#[from] NulError), +} diff --git a/llama-cpp-bindings/src/error/paired_quote_failure.rs b/llama-cpp-bindings/src/error/paired_quote_failure.rs new file mode 100644 index 00000000..9a2a3d85 --- /dev/null +++ b/llama-cpp-bindings/src/error/paired_quote_failure.rs @@ -0,0 +1,23 @@ +/// Failures specific to the paired-quote args parser (Gemma 4 `<|tool_call>call:name{key:<|"|>val<|"|>}`). +#[derive(Debug, thiserror::Error)] +pub enum PairedQuoteFailure { + #[error("empty key in tool call '{tool_name}' arguments")] + EmptyKey { tool_name: String }, + #[error("tool call '{tool_name}' translated arguments are not valid JSON: {message}")] + InvalidJsonArguments { tool_name: String, message: String }, + #[error("tool call '{tool_name}' has unclosed quoted value for key '{key}'")] + UnclosedQuotedValue { tool_name: String, key: String }, + #[error("tool call '{tool_name}' arguments ended without close marker (state: {state})")] + UnclosedArgumentBlock { + tool_name: String, + state: &'static str, + }, + #[error( + "tool call '{tool_name}' has unexpected character '{character}' after value for key '{key}'" + )] + UnexpectedCharAfterValue { + tool_name: String, + key: String, + character: char, + }, +} diff --git a/llama-cpp-bindings/src/error/parse_chat_message_error.rs b/llama-cpp-bindings/src/error/parse_chat_message_error.rs new file mode 100644 index 00000000..75460ed4 --- /dev/null +++ b/llama-cpp-bindings/src/error/parse_chat_message_error.rs @@ -0,0 +1,32 @@ +use std::string::FromUtf8Error; + +use crate::error::tool_call_format_failure::ToolCallFormatFailure; + +/// Failed to parse a chat message via [`crate::Model::parse_chat_message`]. +#[derive(Debug, thiserror::Error)] +pub enum ParseChatMessageError { + /// llama.cpp returned an error code from the parse FFI call. + #[error("ffi error {0}")] + FfiError(i32), + /// The C++ side threw an exception while parsing. + #[error("c++ exception during chat parse: {0}")] + ParseException(String), + /// An accessor returned bytes that were not valid UTF-8. + #[error("ffi returned non-utf8 string: {0}")] + StringUtf8Error(#[from] FromUtf8Error), + /// The caller passed a `tools_json` argument that is not valid JSON. + #[error("tools_json is not valid JSON: {0}")] + ToolsJsonInvalid(#[source] serde_json::Error), + /// The caller passed a `tools_json` argument that parses as JSON but is not an array. + #[error("tools_json must be a JSON array")] + ToolsJsonNotArray, + /// Failed to serialize the tools array for the FFI call. + #[error("could not serialize tools to JSON: {0}")] + ToolsSerialization(String), + /// The model has no usable chat template, so the parser cannot be built. + #[error("model has no chat template")] + NoChatTemplate, + /// The wrapper-side fallback parser detected a structural issue while parsing the body. + #[error("template-override fallback parser failed: {0}")] + TemplateOverrideFailed(#[from] ToolCallFormatFailure), +} diff --git a/llama-cpp-bindings/src/error/sample_error.rs b/llama-cpp-bindings/src/error/sample_error.rs new file mode 100644 index 00000000..a7bbf4e8 --- /dev/null +++ b/llama-cpp-bindings/src/error/sample_error.rs @@ -0,0 +1,11 @@ +/// Errors that can occur when sampling a token. +#[derive(Debug, Eq, PartialEq, thiserror::Error)] +pub enum SampleError { + /// A C++ exception was thrown during sampling + #[error("C++ exception during sampling: {0}")] + CppException(String), + + /// An invalid argument was passed to the sampler + #[error("Invalid argument passed to sampler")] + InvalidArgument, +} diff --git a/llama-cpp-bindings/src/error/sampler_accept_error.rs b/llama-cpp-bindings/src/error/sampler_accept_error.rs new file mode 100644 index 00000000..afa32a61 --- /dev/null +++ b/llama-cpp-bindings/src/error/sampler_accept_error.rs @@ -0,0 +1,11 @@ +/// Failed to accept a token in a sampler. +#[derive(Debug, thiserror::Error)] +pub enum SamplerAcceptError { + /// A C++ exception was thrown during accept + #[error("C++ exception during sampler accept: {0}")] + CppException(String), + + /// An invalid argument was passed (null sampler or null error pointer) + #[error("Invalid argument passed to sampler accept")] + InvalidArgument, +} diff --git a/llama-cpp-bindings/src/error/sampling_error.rs b/llama-cpp-bindings/src/error/sampling_error.rs new file mode 100644 index 00000000..7a2e7346 --- /dev/null +++ b/llama-cpp-bindings/src/error/sampling_error.rs @@ -0,0 +1,7 @@ +/// Errors that can occur when creating a sampling configuration. +#[derive(Debug, Eq, PartialEq, thiserror::Error)] +pub enum SamplingError { + /// An integer value exceeded the allowed range + #[error("Integer overflow: {0}")] + IntegerOverflow(String), +} diff --git a/llama-cpp-bindings/src/error/string_to_token_error.rs b/llama-cpp-bindings/src/error/string_to_token_error.rs new file mode 100644 index 00000000..dc00b484 --- /dev/null +++ b/llama-cpp-bindings/src/error/string_to_token_error.rs @@ -0,0 +1,12 @@ +use std::ffi::NulError; + +/// Failed to convert a string to a token sequence. +#[derive(Debug, thiserror::Error)] +pub enum StringToTokenError { + /// the string contained a null byte and thus could not be converted to a c string. + #[error("{0}")] + NulError(#[from] NulError), + #[error("{0}")] + /// Failed to convert a provided integer to a [`c_int`]. + CIntConversionError(#[from] std::num::TryFromIntError), +} diff --git a/llama-cpp-bindings/src/error/token_sampling_error.rs b/llama-cpp-bindings/src/error/token_sampling_error.rs new file mode 100644 index 00000000..da1bc7f0 --- /dev/null +++ b/llama-cpp-bindings/src/error/token_sampling_error.rs @@ -0,0 +1,7 @@ +/// Failed to sample a token from the data array. +#[derive(Debug, Eq, PartialEq, thiserror::Error)] +pub enum TokenSamplingError { + /// The sampler did not select any token. + #[error("No token was selected by the sampler")] + NoTokenSelected, +} diff --git a/llama-cpp-bindings/src/error/token_to_string_error.rs b/llama-cpp-bindings/src/error/token_to_string_error.rs new file mode 100644 index 00000000..0fb0eb89 --- /dev/null +++ b/llama-cpp-bindings/src/error/token_to_string_error.rs @@ -0,0 +1,20 @@ +use std::os::raw::c_int; +use std::string::FromUtf8Error; + +/// An error that can occur when converting a token to a string. +#[derive(Debug, thiserror::Error, Clone)] +#[non_exhaustive] +pub enum TokenToStringError { + /// the token type was unknown + #[error("Unknown Token Type")] + UnknownTokenType, + /// There was insufficient buffer space to convert the token to a string. + #[error("Insufficient Buffer Space {0}")] + InsufficientBufferSpace(c_int), + /// The token was not valid utf8. + #[error("FromUtf8Error {0}")] + FromUtf8Error(#[from] FromUtf8Error), + /// An integer conversion failed. + #[error("Integer conversion error: {0}")] + IntConversionError(#[from] std::num::TryFromIntError), +} diff --git a/llama-cpp-bindings/src/error/tool_call_format_failure.rs b/llama-cpp-bindings/src/error/tool_call_format_failure.rs new file mode 100644 index 00000000..ca1bd3d7 --- /dev/null +++ b/llama-cpp-bindings/src/error/tool_call_format_failure.rs @@ -0,0 +1,20 @@ +use crate::error::bracketed_args_failure::BracketedArgsFailure; +use crate::error::json_object_failure::JsonObjectFailure; +use crate::error::key_value_xml_tags_failure::KeyValueXmlTagsFailure; +use crate::error::paired_quote_failure::PairedQuoteFailure; +use crate::error::xml_function_tags_failure::XmlFunctionTagsFailure; + +/// Top-level failure for the wrapper-side template-override parsers (one variant per supported shape). +#[derive(Debug, thiserror::Error)] +pub enum ToolCallFormatFailure { + #[error("bracketed-args fallback parser: {0}")] + BracketedArgs(#[from] BracketedArgsFailure), + #[error("json-object fallback parser: {0}")] + JsonObject(#[from] JsonObjectFailure), + #[error("key-value-xml-tags fallback parser: {0}")] + KeyValueXmlTags(#[from] KeyValueXmlTagsFailure), + #[error("paired-quote fallback parser: {0}")] + PairedQuote(#[from] PairedQuoteFailure), + #[error("xml-function-tags fallback parser: {0}")] + XmlFunctionTags(#[from] XmlFunctionTagsFailure), +} diff --git a/llama-cpp-bindings/src/error/xml_function_tags_failure.rs b/llama-cpp-bindings/src/error/xml_function_tags_failure.rs new file mode 100644 index 00000000..49180c00 --- /dev/null +++ b/llama-cpp-bindings/src/error/xml_function_tags_failure.rs @@ -0,0 +1,21 @@ +/// Failures specific to the XML function-tags parser (Qwen 3.5+ `val`). +#[derive(Debug, thiserror::Error)] +pub enum XmlFunctionTagsFailure { + #[error("tool call function tag has empty name")] + EmptyFunctionName, + #[error("tool call function '{function_name}' is missing close tag '{expected_close}'")] + UnclosedFunctionBlock { + function_name: String, + expected_close: String, + }, + #[error("tool call function '{function_name}' has parameter with empty name")] + EmptyParameterName { function_name: String }, + #[error( + "tool call function '{function_name}' parameter '{parameter_name}' is missing close tag '{expected_close}'" + )] + UnclosedParameterBlock { + function_name: String, + parameter_name: String, + expected_close: String, + }, +} diff --git a/llama-cpp-bindings/src/ingest_outcome.rs b/llama-cpp-bindings/src/ingest_outcome.rs new file mode 100644 index 00000000..abf3a44b --- /dev/null +++ b/llama-cpp-bindings/src/ingest_outcome.rs @@ -0,0 +1,14 @@ +use crate::sampled_token::SampledToken; + +#[derive(Clone, Debug)] +pub struct IngestOutcome { + pub sampled_token: SampledToken, + /// Empty when the token is part of a recognised marker boundary; otherwise + /// the decoded UTF-8 piece. Callers should stream `visible_piece` and skip + /// emission when it is empty. + pub visible_piece: String, + /// Always the decoded UTF-8 piece, even for marker-boundary tokens. Useful + /// for accumulating the full raw model output (e.g. for downstream parser + /// cross-checks) without losing marker bytes. + pub raw_piece: String, +} diff --git a/llama-cpp-bindings/src/invalid_numa_strategy.rs b/llama-cpp-bindings/src/invalid_numa_strategy.rs new file mode 100644 index 00000000..2d00b029 --- /dev/null +++ b/llama-cpp-bindings/src/invalid_numa_strategy.rs @@ -0,0 +1,6 @@ +/// An invalid numa strategy was provided. +#[derive(Debug, Eq, PartialEq, Copy, Clone)] +pub struct InvalidNumaStrategy( + /// The invalid numa strategy that was provided. + pub llama_cpp_bindings_sys::ggml_numa_strategy, +); diff --git a/llama-cpp-bindings/src/lib.rs b/llama-cpp-bindings/src/lib.rs index 4ee62c7e..b77d14a4 100644 --- a/llama-cpp-bindings/src/lib.rs +++ b/llama-cpp-bindings/src/lib.rs @@ -8,7 +8,6 @@ //! # Feature Flags //! //! - `cuda` enables CUDA gpu support. -//! - `sampler` adds the [`context::sample::sampler`] struct for a more rusty way of sampling. pub mod batch_add_error; pub mod chat_message_parse_outcome; @@ -22,10 +21,13 @@ pub mod ggml_time_us; pub mod gguf_context; pub mod gguf_context_error; pub mod gguf_type; +pub mod ingest_outcome; pub mod ingest_prompt_chunk; +pub mod invalid_numa_strategy; pub mod json_schema_to_grammar; pub mod llama_backend; pub mod llama_backend_device; +pub mod llama_backend_device_type; pub mod llama_backend_numa_strategy; pub mod llama_batch; pub mod llama_time_us; @@ -39,7 +41,6 @@ pub mod load_backends; pub mod load_backends_error; #[cfg(feature = "dynamic-backends")] pub mod load_backends_from_path; -pub mod log; pub mod log_options; pub mod max_devices; pub mod mlock_supported; @@ -50,8 +51,11 @@ pub mod raw_chat_message; pub mod resolved_tool_call_markers; pub mod sampled_token; pub mod sampled_token_classifier; +pub mod sampled_token_section; pub mod sampling; +pub mod send_logs_to_log; pub mod streaming_json_probe; +pub mod streaming_markers; pub mod timing; pub mod token; pub mod tool_call_format; @@ -68,9 +72,8 @@ pub use error::{ }; pub use chat_message_parse_outcome::ChatMessageParseOutcome; -pub use llama_backend_device::{ - LlamaBackendDevice, LlamaBackendDeviceType, list_llama_ggml_backend_devices, -}; +pub use llama_backend_device::{LlamaBackendDevice, list_llama_ggml_backend_devices}; +pub use llama_backend_device_type::LlamaBackendDeviceType; pub use llama_cpp_bindings_types::{ BracketedJsonShape, KeyValueXmlTagsShape, PairedQuoteShape, ParsedChatMessage, ParsedToolCall, ReasoningMarkers, TokenUsage, TokenUsageError, ToolCallArgsShape, ToolCallArguments, @@ -79,7 +82,7 @@ pub use llama_cpp_bindings_types::{ pub use raw_chat_message::RawChatMessage; pub use sampled_token::SampledToken; pub use sampled_token_classifier::SampledTokenClassifier; -pub use sampled_token_classifier::SampledTokenSection; +pub use sampled_token_section::SampledTokenSection; pub use ffi_status_is_ok::status_is_ok; pub use ffi_status_to_i32::status_to_i32; @@ -91,5 +94,5 @@ pub use max_devices::max_devices; pub use mlock_supported::mlock_supported; pub use mmap_supported::mmap_supported; -pub use log::send_logs_to_tracing; pub use log_options::LogOptions; +pub use send_logs_to_log::send_logs_to_log; diff --git a/llama-cpp-bindings/src/llama_backend.rs b/llama-cpp-bindings/src/llama_backend.rs index 803c27a2..20ad3ac3 100644 --- a/llama-cpp-bindings/src/llama_backend.rs +++ b/llama-cpp-bindings/src/llama_backend.rs @@ -45,7 +45,6 @@ impl LlamaBackend { /// ``` /// # Errors /// Returns an error if the backend was already initialized. - #[tracing::instrument(skip_all)] pub fn init() -> crate::Result { Self::mark_init()?; unsafe { llama_cpp_bindings_sys::llama_backend_init() } @@ -67,7 +66,6 @@ impl LlamaBackend { /// ``` /// # Errors /// Returns an error if the backend was already initialized. - #[tracing::instrument(skip_all)] pub fn init_numa(strategy: NumaStrategy) -> crate::Result { Self::mark_init()?; unsafe { diff --git a/llama-cpp-bindings/src/llama_backend_device.rs b/llama-cpp-bindings/src/llama_backend_device.rs index c65d2f99..b5851efb 100644 --- a/llama-cpp-bindings/src/llama_backend_device.rs +++ b/llama-cpp-bindings/src/llama_backend_device.rs @@ -1,35 +1,8 @@ use std::ffi::c_char; -/// Backend device type -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum LlamaBackendDeviceType { - /// CPU device - Cpu, - /// ACCEL device - Accelerator, - /// GPU device - Gpu, - /// iGPU device - IntegratedGpu, - /// Unknown device type - Unknown, -} +use crate::llama_backend_device_type::device_type_from_raw; -const fn device_type_from_raw( - raw_type: llama_cpp_bindings_sys::ggml_backend_dev_type, -) -> LlamaBackendDeviceType { - match raw_type { - llama_cpp_bindings_sys::GGML_BACKEND_DEVICE_TYPE_CPU => LlamaBackendDeviceType::Cpu, - llama_cpp_bindings_sys::GGML_BACKEND_DEVICE_TYPE_ACCEL => { - LlamaBackendDeviceType::Accelerator - } - llama_cpp_bindings_sys::GGML_BACKEND_DEVICE_TYPE_GPU => LlamaBackendDeviceType::Gpu, - llama_cpp_bindings_sys::GGML_BACKEND_DEVICE_TYPE_IGPU => { - LlamaBackendDeviceType::IntegratedGpu - } - _ => LlamaBackendDeviceType::Unknown, - } -} +pub use crate::llama_backend_device_type::LlamaBackendDeviceType; /// A ggml backend device /// @@ -127,28 +100,4 @@ mod tests { assert_eq!(devices[0].index, 0); assert!(!devices[0].name.is_empty()); } - - #[test] - fn device_type_from_raw_all_variants() { - use super::LlamaBackendDeviceType; - use super::device_type_from_raw; - - assert_eq!( - device_type_from_raw(llama_cpp_bindings_sys::GGML_BACKEND_DEVICE_TYPE_CPU), - LlamaBackendDeviceType::Cpu - ); - assert_eq!( - device_type_from_raw(llama_cpp_bindings_sys::GGML_BACKEND_DEVICE_TYPE_ACCEL), - LlamaBackendDeviceType::Accelerator - ); - assert_eq!( - device_type_from_raw(llama_cpp_bindings_sys::GGML_BACKEND_DEVICE_TYPE_GPU), - LlamaBackendDeviceType::Gpu - ); - assert_eq!( - device_type_from_raw(llama_cpp_bindings_sys::GGML_BACKEND_DEVICE_TYPE_IGPU), - LlamaBackendDeviceType::IntegratedGpu - ); - assert_eq!(device_type_from_raw(9999), LlamaBackendDeviceType::Unknown); - } } diff --git a/llama-cpp-bindings/src/llama_backend_device_type.rs b/llama-cpp-bindings/src/llama_backend_device_type.rs new file mode 100644 index 00000000..fd22c8fd --- /dev/null +++ b/llama-cpp-bindings/src/llama_backend_device_type.rs @@ -0,0 +1,58 @@ +/// Backend device type +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum LlamaBackendDeviceType { + /// CPU device + Cpu, + /// ACCEL device + Accelerator, + /// GPU device + Gpu, + /// iGPU device + IntegratedGpu, + /// Unknown device type + Unknown, +} + +#[must_use] +pub const fn device_type_from_raw( + raw_type: llama_cpp_bindings_sys::ggml_backend_dev_type, +) -> LlamaBackendDeviceType { + match raw_type { + llama_cpp_bindings_sys::GGML_BACKEND_DEVICE_TYPE_CPU => LlamaBackendDeviceType::Cpu, + llama_cpp_bindings_sys::GGML_BACKEND_DEVICE_TYPE_ACCEL => { + LlamaBackendDeviceType::Accelerator + } + llama_cpp_bindings_sys::GGML_BACKEND_DEVICE_TYPE_GPU => LlamaBackendDeviceType::Gpu, + llama_cpp_bindings_sys::GGML_BACKEND_DEVICE_TYPE_IGPU => { + LlamaBackendDeviceType::IntegratedGpu + } + _ => LlamaBackendDeviceType::Unknown, + } +} + +#[cfg(test)] +mod tests { + use super::LlamaBackendDeviceType; + use super::device_type_from_raw; + + #[test] + fn device_type_from_raw_all_variants() { + assert_eq!( + device_type_from_raw(llama_cpp_bindings_sys::GGML_BACKEND_DEVICE_TYPE_CPU), + LlamaBackendDeviceType::Cpu + ); + assert_eq!( + device_type_from_raw(llama_cpp_bindings_sys::GGML_BACKEND_DEVICE_TYPE_ACCEL), + LlamaBackendDeviceType::Accelerator + ); + assert_eq!( + device_type_from_raw(llama_cpp_bindings_sys::GGML_BACKEND_DEVICE_TYPE_GPU), + LlamaBackendDeviceType::Gpu + ); + assert_eq!( + device_type_from_raw(llama_cpp_bindings_sys::GGML_BACKEND_DEVICE_TYPE_IGPU), + LlamaBackendDeviceType::IntegratedGpu + ); + assert_eq!(device_type_from_raw(9999), LlamaBackendDeviceType::Unknown); + } +} diff --git a/llama-cpp-bindings/src/llama_backend_numa_strategy.rs b/llama-cpp-bindings/src/llama_backend_numa_strategy.rs index 6c555123..2be150fa 100644 --- a/llama-cpp-bindings/src/llama_backend_numa_strategy.rs +++ b/llama-cpp-bindings/src/llama_backend_numa_strategy.rs @@ -1,3 +1,5 @@ +use crate::invalid_numa_strategy::InvalidNumaStrategy; + /// NUMA (Non-Uniform Memory Access) thread affinity strategy for llama.cpp. #[derive(Debug, Eq, PartialEq, Copy, Clone)] pub enum NumaStrategy { @@ -13,13 +15,6 @@ pub enum NumaStrategy { Mirror, } -/// An invalid numa strategy was provided. -#[derive(Debug, Eq, PartialEq, Copy, Clone)] -pub struct InvalidNumaStrategy( - /// The invalid numa strategy that was provided. - pub llama_cpp_bindings_sys::ggml_numa_strategy, -); - impl TryFrom for NumaStrategy { type Error = InvalidNumaStrategy; @@ -49,7 +44,8 @@ impl From for llama_cpp_bindings_sys::ggml_numa_strategy { #[cfg(test)] mod tests { - use super::{InvalidNumaStrategy, NumaStrategy}; + use super::NumaStrategy; + use crate::invalid_numa_strategy::InvalidNumaStrategy; #[test] fn numa_from_and_to() { diff --git a/llama-cpp-bindings/src/llguidance_sampler.rs b/llama-cpp-bindings/src/llguidance_sampler.rs index 67da9f09..ffd51d75 100644 --- a/llama-cpp-bindings/src/llguidance_sampler.rs +++ b/llama-cpp-bindings/src/llguidance_sampler.rs @@ -34,10 +34,8 @@ unsafe extern "C" fn llg_accept( let ctx = unsafe { &mut *(*smpl).ctx.cast::() }; if let Err(consume_error) = ctx.matcher.consume_token(token.cast_unsigned()) { - tracing::warn!( - token = token, - error = %consume_error, - "llguidance sampler failed to consume token" + log::warn!( + "llguidance sampler failed to consume token: token={token}, error={consume_error}", ); } } @@ -52,9 +50,8 @@ unsafe extern "C" fn llg_apply( let mask = match ctx.matcher.compute_mask() { Ok(mask) => mask, Err(compute_error) => { - tracing::warn!( - error = %compute_error, - "llguidance sampler failed to compute mask, skipping constraint application" + log::warn!( + "llguidance sampler failed to compute mask, skipping constraint application: error={compute_error}", ); return; @@ -73,10 +70,7 @@ unsafe extern "C" fn llg_reset(smpl: *mut llama_cpp_bindings_sys::llama_sampler) let ctx = unsafe { &mut *(*smpl).ctx.cast::() }; if let Err(reset_error) = ctx.matcher.reset() { - tracing::warn!( - error = %reset_error, - "llguidance sampler failed to reset" - ); + log::warn!("llguidance sampler failed to reset: error={reset_error}"); } } diff --git a/llama-cpp-bindings/src/log.rs b/llama-cpp-bindings/src/log.rs deleted file mode 100644 index 639cc5f0..00000000 --- a/llama-cpp-bindings/src/log.rs +++ /dev/null @@ -1,1022 +0,0 @@ -use crate::log_options::LogOptions; -use std::sync::OnceLock; -use tracing_core::{Interest, Kind, Metadata, callsite, field, identify_callsite}; - -static FIELD_NAMES: &[&str] = &["message", "module"]; - -struct OverridableFields { - message: tracing::field::Field, - target: tracing::field::Field, -} - -macro_rules! log_cs { - ($level:expr, $cs:ident, $meta:ident, $fields:ident, $ty:ident) => { - struct $ty; - static $cs: $ty = $ty; - static $meta: Metadata<'static> = Metadata::new( - "log event", - "llama-cpp-bindings", - $level, - ::core::option::Option::None, - ::core::option::Option::None, - ::core::option::Option::None, - field::FieldSet::new(FIELD_NAMES, identify_callsite!(&$cs)), - Kind::EVENT, - ); - static $fields: std::sync::LazyLock = std::sync::LazyLock::new(|| { - let fields = $meta.fields(); - OverridableFields { - message: fields - .field("message") - .expect("message field defined in FIELD_NAMES"), - target: fields - .field("module") - .expect("module field defined in FIELD_NAMES"), - } - }); - - impl callsite::Callsite for $ty { - fn set_interest(&self, _: Interest) {} - fn metadata(&self) -> &'static Metadata<'static> { - &$meta - } - } - }; -} -log_cs!( - tracing_core::Level::DEBUG, - DEBUG_CS, - DEBUG_META, - DEBUG_FIELDS, - DebugCallsite -); -log_cs!( - tracing_core::Level::INFO, - INFO_CS, - INFO_META, - INFO_FIELDS, - InfoCallsite -); -log_cs!( - tracing_core::Level::WARN, - WARN_CS, - WARN_META, - WARN_FIELDS, - WarnCallsite -); -log_cs!( - tracing_core::Level::ERROR, - ERROR_CS, - ERROR_META, - ERROR_FIELDS, - ErrorCallsite -); - -#[derive(Clone, Copy)] -pub enum Module { - Ggml, - LlamaCpp, -} - -impl Module { - const fn name(self) -> &'static str { - match self { - Self::Ggml => "ggml", - Self::LlamaCpp => "llama.cpp", - } - } -} - -fn meta_for_level( - level: llama_cpp_bindings_sys::ggml_log_level, -) -> Option<(&'static Metadata<'static>, &'static OverridableFields)> { - match level { - llama_cpp_bindings_sys::GGML_LOG_LEVEL_DEBUG => Some((&DEBUG_META, &DEBUG_FIELDS)), - llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO => Some((&INFO_META, &INFO_FIELDS)), - llama_cpp_bindings_sys::GGML_LOG_LEVEL_WARN => Some((&WARN_META, &WARN_FIELDS)), - llama_cpp_bindings_sys::GGML_LOG_LEVEL_ERROR => Some((&ERROR_META, &ERROR_FIELDS)), - _ => None, - } -} - -pub struct State { - pub options: LogOptions, - module: Module, - buffered: std::sync::Mutex>, - previous_level: std::sync::atomic::AtomicI32, - is_buffering: std::sync::atomic::AtomicBool, -} - -impl State { - #[must_use] - pub fn new(module: Module, options: LogOptions) -> Self { - Self { - options, - module, - buffered: std::sync::Mutex::default(), - previous_level: std::sync::atomic::AtomicI32::default(), - is_buffering: std::sync::atomic::AtomicBool::default(), - } - } - - /// The match arms are duplicated per module because the `tracing` macros - /// require the `target` argument to be a string literal — the upstream - /// submodule name cannot be propagated dynamically. - fn generate_log(&self, level: llama_cpp_bindings_sys::ggml_log_level, text: &str) { - let (module, text) = text - .char_indices() - .take_while(|(_, ch)| ch.is_ascii_lowercase() || *ch == '_') - .last() - .and_then(|(pos, _)| { - let next_two = text.get(pos + 1..pos + 3); - if next_two == Some(": ") { - let (sub_module, text) = text.split_at(pos + 1); - let text = text.split_at(2).1; - Some((Some(format!("{}::{sub_module}", self.module.name())), text)) - } else { - None - } - }) - .unwrap_or((None, text)); - - let effective_level = if self.options.demote_info_to_debug - && (level == llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO - || level == llama_cpp_bindings_sys::GGML_LOG_LEVEL_DEBUG) - { - llama_cpp_bindings_sys::GGML_LOG_LEVEL_DEBUG - } else { - level - }; - - let Some((meta, fields)) = meta_for_level(effective_level) else { - tracing::warn!( - level = effective_level, - text = text, - origin = "crate", - "generate_log called with unmapped log level" - ); - - return; - }; - - tracing::dispatcher::get_default(|dispatcher| { - dispatcher.event(&tracing::Event::new( - meta, - &meta.fields().value_set(&[ - (&fields.message, Some(&text as &dyn tracing::field::Value)), - ( - &fields.target, - module - .as_ref() - .map(|module_name| module_name as &dyn tracing::field::Value), - ), - ]), - )); - }); - } - - /// Append more text to the previously buffered log. - /// - /// The text may or may not end with a newline. - /// - /// # Panics - /// Panics if the internal mutex is poisoned. - pub fn cont_buffered_log(&self, text: &str) { - let mut lock = self.buffered.lock().unwrap(); - - if let Some((previous_log_level, mut buffer)) = lock.take() { - buffer.push_str(text); - if buffer.ends_with('\n') { - self.is_buffering - .store(false, std::sync::atomic::Ordering::Release); - self.generate_log(previous_log_level, buffer.as_str()); - } else { - *lock = Some((previous_log_level, buffer)); - } - } else { - let level = self - .previous_level - .load(std::sync::atomic::Ordering::Acquire) - .cast_unsigned(); - tracing::warn!( - inferred_level = level, - text = text, - origin = "crate", - "llama.cpp sent out a CONT log without any previously buffered message" - ); - *lock = Some((level, text.to_string())); - } - } - - /// Start buffering a message. Not the CONT log level and text is missing a newline. - /// - /// # Panics - /// Panics if the internal mutex is poisoned. - pub fn buffer_non_cont(&self, level: llama_cpp_bindings_sys::ggml_log_level, text: &str) { - let replaced = self - .buffered - .lock() - .unwrap() - .replace((level, text.to_string())); - - if let Some((previous_log_level, buffer)) = replaced { - tracing::warn!( - level = previous_log_level, - text = &buffer, - origin = "crate", - "Message buffered unnecessarily due to missing newline and not followed by a CONT" - ); - self.generate_log(previous_log_level, buffer.as_str()); - } - - self.is_buffering - .store(true, std::sync::atomic::Ordering::Release); - self.previous_level - .store(level.cast_signed(), std::sync::atomic::Ordering::Release); - } - - /// Emit a normal unbuffered log message (not the CONT log level and the text ends with a newline). - /// - /// # Panics - /// Panics if the internal mutex is poisoned. - pub fn emit_non_cont_line(&self, level: llama_cpp_bindings_sys::ggml_log_level, text: &str) { - if self - .is_buffering - .swap(false, std::sync::atomic::Ordering::Acquire) - && let Some((buf_level, buf_text)) = self.buffered.lock().unwrap().take() - { - tracing::warn!( - level = buf_level, - text = buf_text, - origin = "crate", - "llama.cpp message buffered spuriously due to missing \\n and being followed by a non-CONT message! (this indicates a bug within llama.cpp)" - ); - self.generate_log(buf_level, buf_text.as_str()); - } - - self.previous_level - .store(level.cast_signed(), std::sync::atomic::Ordering::Release); - - let (text, _trailing_newline) = text.split_at(text.len() - 1); - - match level { - llama_cpp_bindings_sys::GGML_LOG_LEVEL_NONE => { - if self.options.demote_info_to_debug { - self.generate_log(llama_cpp_bindings_sys::GGML_LOG_LEVEL_DEBUG, text); - } else { - tracing::info!(no_log_level = true, text); - } - } - llama_cpp_bindings_sys::GGML_LOG_LEVEL_DEBUG - | llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO - | llama_cpp_bindings_sys::GGML_LOG_LEVEL_WARN - | llama_cpp_bindings_sys::GGML_LOG_LEVEL_ERROR => self.generate_log(level, text), - llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT => { - tracing::warn!( - text = text, - origin = "crate", - "CONT log level passed to emit_non_cont_line" - ); - } - _ => { - tracing::warn!( - level = level, - text = text, - origin = "crate", - "Unknown llama.cpp log level" - ); - } - } - } - - pub fn update_previous_level_for_disabled_log( - &self, - level: llama_cpp_bindings_sys::ggml_log_level, - ) { - if level != llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT { - self.previous_level - .store(level.cast_signed(), std::sync::atomic::Ordering::Release); - } - } - - /// Checks whether the given log level is enabled by the current tracing - /// subscriber. CONT lines inherit the previous line's level rather than - /// being checked on their own. - pub fn is_enabled_for_level(&self, level: llama_cpp_bindings_sys::ggml_log_level) -> bool { - let level = if level == llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT { - self.previous_level - .load(std::sync::atomic::Ordering::Relaxed) - .cast_unsigned() - } else { - level - }; - - let effective_level = if self.options.demote_info_to_debug - && (level == llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO - || level == llama_cpp_bindings_sys::GGML_LOG_LEVEL_DEBUG) - { - llama_cpp_bindings_sys::GGML_LOG_LEVEL_DEBUG - } else { - level - }; - - let Some((meta, _)) = meta_for_level(effective_level) else { - return false; - }; - - tracing::dispatcher::get_default(|dispatcher| dispatcher.enabled(meta)) - } -} - -pub static LLAMA_STATE: OnceLock> = OnceLock::new(); -pub static GGML_STATE: OnceLock> = OnceLock::new(); - -/// Bridges llama.cpp / ggml log callbacks into the `tracing` ecosystem. -/// -/// The fast path — newline-terminated DEBUG/INFO/WARN/ERROR lines — must avoid -/// taking the log state lock and must not allocate, so the buffering and -/// CONT-handling logic only runs on the slow path. Lines that lack a trailing -/// newline are buffered: their absence is the only signal upstream uses to -/// announce that a CONT message will follow, and we cannot distinguish that -/// from a typo until the next message arrives. -extern "C" fn logs_to_trace( - level: llama_cpp_bindings_sys::ggml_log_level, - text: *const ::std::os::raw::c_char, - data: *mut ::std::os::raw::c_void, -) { - use std::borrow::Borrow; - - let log_state = unsafe { &*(data as *const State) }; - - if log_state.options.disabled { - return; - } - - if !log_state.is_enabled_for_level(level) { - log_state.update_previous_level_for_disabled_log(level); - - return; - } - - let text = unsafe { std::ffi::CStr::from_ptr(text) }; - let text = text.to_string_lossy(); - let text: &str = text.borrow(); - - if level == llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT { - log_state.cont_buffered_log(text); - } else if text.ends_with('\n') { - log_state.emit_non_cont_line(level, text); - } else { - log_state.buffer_non_cont(level, text); - } -} - -/// Redirect llama.cpp logs into tracing. -/// -/// `llama.cpp` and `ggml` are wired up to separate `State` instances so a CONT -/// line emitted by one cannot be appended to a buffered line from the other. -/// `llama_log_set` also installs the callback for `ggml`, so the `ggml_log_set` -/// call must come second to override that and bind the ggml state explicitly. -pub fn send_logs_to_tracing(options: LogOptions) { - let llama_heap_state = Box::as_ref( - LLAMA_STATE.get_or_init(|| Box::new(State::new(Module::LlamaCpp, options.clone()))), - ) as *const _; - let ggml_heap_state = - Box::as_ref(GGML_STATE.get_or_init(|| Box::new(State::new(Module::Ggml, options)))) - as *const _; - - unsafe { - llama_cpp_bindings_sys::llama_log_set(Some(logs_to_trace), llama_heap_state as *mut _); - llama_cpp_bindings_sys::ggml_log_set(Some(logs_to_trace), ggml_heap_state as *mut _); - } -} - -#[cfg(test)] -mod tests { - use std::sync::{Arc, Mutex}; - - use tracing_subscriber::util::SubscriberInitExt; - - use super::{Module, State, logs_to_trace}; - use crate::log_options::LogOptions; - - #[test] - fn module_name_ggml() { - assert_eq!(Module::Ggml.name(), "ggml"); - } - - #[test] - fn module_name_llama_cpp() { - assert_eq!(Module::LlamaCpp.name(), "llama.cpp"); - } - - #[test] - fn state_new_creates_empty_buffer() { - let state = State::new(Module::LlamaCpp, LogOptions::default()); - let buffer = state - .buffered - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - - assert!(buffer.is_none()); - drop(buffer); - assert!(!state.options.disabled); - } - - #[test] - fn update_previous_level_for_disabled_log_stores_level() { - let state = State::new(Module::LlamaCpp, LogOptions::default()); - - state.update_previous_level_for_disabled_log(llama_cpp_bindings_sys::GGML_LOG_LEVEL_WARN); - - let stored = state - .previous_level - .load(std::sync::atomic::Ordering::Relaxed); - - assert_eq!( - stored, - llama_cpp_bindings_sys::GGML_LOG_LEVEL_WARN.cast_signed() - ); - } - - #[test] - fn update_previous_level_ignores_cont() { - let state = State::new(Module::LlamaCpp, LogOptions::default()); - - state.update_previous_level_for_disabled_log(llama_cpp_bindings_sys::GGML_LOG_LEVEL_ERROR); - state.update_previous_level_for_disabled_log(llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT); - - let stored = state - .previous_level - .load(std::sync::atomic::Ordering::Relaxed); - - assert_eq!( - stored, - llama_cpp_bindings_sys::GGML_LOG_LEVEL_ERROR.cast_signed() - ); - } - - #[test] - fn buffer_non_cont_sets_buffering_flag() { - let state = State::new(Module::LlamaCpp, LogOptions::default()); - - state.buffer_non_cont(llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO, "partial"); - - assert!( - state - .is_buffering - .load(std::sync::atomic::Ordering::Relaxed) - ); - - let buffer = state - .buffered - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - - assert!(buffer.is_some()); - let (level, text) = buffer.as_ref().unwrap(); - assert_eq!(*level, llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO); - assert_eq!(text, "partial"); - drop(buffer); - } - - #[test] - fn cont_buffered_log_appends_to_existing_buffer() { - let state = State::new(Module::LlamaCpp, LogOptions::default()); - - state.buffer_non_cont(llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO, "hello "); - - state.cont_buffered_log("world"); - - let buffer = state - .buffered - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - - assert!(buffer.is_some()); - let (_, text) = buffer.as_ref().unwrap(); - assert_eq!(text, "hello world"); - drop(buffer); - } - - struct Logger { - #[expect( - unused, - reason = "guard must outlive the test body so the tracing subscriber stays installed; \ - dropping it un-installs the subscriber and tests would silently miss log lines" - )] - guard: tracing::subscriber::DefaultGuard, - logs: Arc>>, - } - - #[derive(Clone)] - struct VecWriter(Arc>>); - - impl std::io::Write for VecWriter { - fn write(&mut self, buf: &[u8]) -> std::io::Result { - let log_line = String::from_utf8_lossy(buf).into_owned(); - self.0.lock().unwrap().push(log_line); - - Ok(buf.len()) - } - - fn flush(&mut self) -> std::io::Result<()> { - Ok(()) - } - } - - fn create_logger(max_level: tracing::Level) -> Logger { - let logs = Arc::new(Mutex::new(vec![])); - let writer = VecWriter(logs.clone()); - - Logger { - guard: tracing_subscriber::fmt() - .with_max_level(max_level) - .with_ansi(false) - .without_time() - .with_file(false) - .with_line_number(false) - .with_level(false) - .with_target(false) - .with_writer(move || writer.clone()) - .finish() - .set_default(), - logs, - } - } - - #[test] - fn cont_disabled_log() { - let logger = create_logger(tracing::Level::INFO); - let mut log_state = Box::new(State::new(Module::LlamaCpp, LogOptions::default())); - let log_ptr = - std::ptr::from_mut::(log_state.as_mut()).cast::(); - - logs_to_trace( - llama_cpp_bindings_sys::GGML_LOG_LEVEL_DEBUG, - c"Hello ".as_ptr(), - log_ptr, - ); - logs_to_trace( - llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT, - c"world\n".as_ptr(), - log_ptr, - ); - - assert!(logger.logs.lock().unwrap().is_empty()); - - logs_to_trace( - llama_cpp_bindings_sys::GGML_LOG_LEVEL_DEBUG, - c"Hello ".as_ptr(), - log_ptr, - ); - logs_to_trace( - llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT, - c"world".as_ptr(), - log_ptr, - ); - logs_to_trace( - llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT, - c"\n".as_ptr(), - log_ptr, - ); - } - - #[test] - fn cont_message_concatenates_payload_then_flush_appends_extra_newline() { - let logger = create_logger(tracing::Level::INFO); - let mut log_state = Box::new(State::new(Module::LlamaCpp, LogOptions::default())); - let log_ptr = - std::ptr::from_mut::(log_state.as_mut()).cast::(); - - logs_to_trace( - llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO, - c"Hello ".as_ptr(), - log_ptr, - ); - let cont_payload_with_newline = c"world\n"; - logs_to_trace( - llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT, - cont_payload_with_newline.as_ptr(), - log_ptr, - ); - - let payload_newline = '\n'; - let flush_appended_newline = '\n'; - assert_eq!( - *logger.logs.lock().unwrap(), - vec![format!( - "Hello world{payload_newline}{flush_appended_newline}" - )] - ); - } - - #[test] - fn disabled_logs_are_suppressed() { - let logger = create_logger(tracing::Level::DEBUG); - let disabled_options = LogOptions::default().with_logs_enabled(false); - let mut log_state = Box::new(State::new(Module::LlamaCpp, disabled_options)); - let log_ptr = - std::ptr::from_mut::(log_state.as_mut()).cast::(); - - logs_to_trace( - llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO, - c"Should not appear\n".as_ptr(), - log_ptr, - ); - logs_to_trace( - llama_cpp_bindings_sys::GGML_LOG_LEVEL_ERROR, - c"Also suppressed\n".as_ptr(), - log_ptr, - ); - - assert!(logger.logs.lock().unwrap().is_empty()); - } - - #[test] - fn info_level_log_emitted() { - let logger = create_logger(tracing::Level::INFO); - let mut log_state = Box::new(State::new(Module::LlamaCpp, LogOptions::default())); - let log_ptr = - std::ptr::from_mut::(log_state.as_mut()).cast::(); - - logs_to_trace( - llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO, - c"info message\n".as_ptr(), - log_ptr, - ); - - let logs = logger.logs.lock().unwrap(); - assert_eq!(logs.len(), 1); - assert!(logs[0].contains("info message")); - drop(logs); - } - - #[test] - fn warn_level_log_emitted() { - let logger = create_logger(tracing::Level::WARN); - let mut log_state = Box::new(State::new(Module::LlamaCpp, LogOptions::default())); - let log_ptr = - std::ptr::from_mut::(log_state.as_mut()).cast::(); - - logs_to_trace( - llama_cpp_bindings_sys::GGML_LOG_LEVEL_WARN, - c"warning message\n".as_ptr(), - log_ptr, - ); - - let logs = logger.logs.lock().unwrap(); - assert_eq!(logs.len(), 1); - assert!(logs[0].contains("warning message")); - drop(logs); - } - - #[test] - fn error_level_log_emitted() { - let logger = create_logger(tracing::Level::ERROR); - let mut log_state = Box::new(State::new(Module::LlamaCpp, LogOptions::default())); - let log_ptr = - std::ptr::from_mut::(log_state.as_mut()).cast::(); - - logs_to_trace( - llama_cpp_bindings_sys::GGML_LOG_LEVEL_ERROR, - c"error message\n".as_ptr(), - log_ptr, - ); - - let logs = logger.logs.lock().unwrap(); - assert_eq!(logs.len(), 1); - assert!(logs[0].contains("error message")); - drop(logs); - } - - #[test] - fn debug_level_log_emitted_when_enabled() { - let logger = create_logger(tracing::Level::DEBUG); - let mut log_state = Box::new(State::new(Module::LlamaCpp, LogOptions::default())); - let log_ptr = - std::ptr::from_mut::(log_state.as_mut()).cast::(); - - logs_to_trace( - llama_cpp_bindings_sys::GGML_LOG_LEVEL_DEBUG, - c"debug message\n".as_ptr(), - log_ptr, - ); - - let logs = logger.logs.lock().unwrap(); - assert_eq!(logs.len(), 1); - assert!(logs[0].contains("debug message")); - drop(logs); - } - - #[test] - fn submodule_extraction_from_log_text() { - let logger = create_logger(tracing::Level::INFO); - let mut log_state = Box::new(State::new(Module::LlamaCpp, LogOptions::default())); - let log_ptr = - std::ptr::from_mut::(log_state.as_mut()).cast::(); - - logs_to_trace( - llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO, - c"sampling: initialized\n".as_ptr(), - log_ptr, - ); - - let logs = logger.logs.lock().unwrap(); - assert_eq!(logs.len(), 1); - assert!(logs[0].contains("initialized")); - drop(logs); - } - - #[test] - fn multi_part_cont_log() { - let logger = create_logger(tracing::Level::INFO); - let mut log_state = Box::new(State::new(Module::LlamaCpp, LogOptions::default())); - let log_ptr = - std::ptr::from_mut::(log_state.as_mut()).cast::(); - - logs_to_trace( - llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO, - c"part1 ".as_ptr(), - log_ptr, - ); - logs_to_trace( - llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT, - c"part2 ".as_ptr(), - log_ptr, - ); - logs_to_trace( - llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT, - c"part3\n".as_ptr(), - log_ptr, - ); - - let logs = logger.logs.lock().unwrap(); - assert_eq!(logs.len(), 1); - assert!(logs[0].contains("part1 part2 part3")); - drop(logs); - } - - #[test] - fn demote_info_to_debug_suppresses_info_under_info_subscriber() { - let logger = create_logger(tracing::Level::INFO); - let options = LogOptions::default().with_demote_info_to_debug(true); - let mut log_state = Box::new(State::new(Module::LlamaCpp, options)); - let log_ptr = - std::ptr::from_mut::(log_state.as_mut()).cast::(); - - logs_to_trace( - llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO, - c"should be suppressed\n".as_ptr(), - log_ptr, - ); - - assert!(logger.logs.lock().unwrap().is_empty()); - } - - #[test] - fn demote_info_to_debug_emits_info_under_debug_subscriber() { - let logger = create_logger(tracing::Level::DEBUG); - let options = LogOptions::default().with_demote_info_to_debug(true); - let mut log_state = Box::new(State::new(Module::LlamaCpp, options)); - let log_ptr = - std::ptr::from_mut::(log_state.as_mut()).cast::(); - - logs_to_trace( - llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO, - c"visible at debug\n".as_ptr(), - log_ptr, - ); - - let logs = logger.logs.lock().unwrap(); - assert_eq!(logs.len(), 1); - assert!(logs[0].contains("visible at debug")); - drop(logs); - } - - #[test] - fn demote_info_to_debug_preserves_error_under_info_subscriber() { - let logger = create_logger(tracing::Level::INFO); - let options = LogOptions::default().with_demote_info_to_debug(true); - let mut log_state = Box::new(State::new(Module::LlamaCpp, options)); - let log_ptr = - std::ptr::from_mut::(log_state.as_mut()).cast::(); - - logs_to_trace( - llama_cpp_bindings_sys::GGML_LOG_LEVEL_ERROR, - c"error still visible\n".as_ptr(), - log_ptr, - ); - - let logs = logger.logs.lock().unwrap(); - assert_eq!(logs.len(), 1); - assert!(logs[0].contains("error still visible")); - drop(logs); - } - - #[test] - fn demote_info_to_debug_preserves_warn_under_info_subscriber() { - let logger = create_logger(tracing::Level::INFO); - let options = LogOptions::default().with_demote_info_to_debug(true); - let mut log_state = Box::new(State::new(Module::LlamaCpp, options)); - let log_ptr = - std::ptr::from_mut::(log_state.as_mut()).cast::(); - - logs_to_trace( - llama_cpp_bindings_sys::GGML_LOG_LEVEL_WARN, - c"warning still visible\n".as_ptr(), - log_ptr, - ); - - let logs = logger.logs.lock().unwrap(); - assert_eq!(logs.len(), 1); - assert!(logs[0].contains("warning still visible")); - drop(logs); - } - - #[test] - fn emit_non_cont_line_level_none() { - let logger = create_logger(tracing::Level::INFO); - let state = State::new(Module::LlamaCpp, LogOptions::default()); - - state.emit_non_cont_line( - llama_cpp_bindings_sys::GGML_LOG_LEVEL_NONE, - "none level message\n", - ); - - let logs = logger.logs.lock().unwrap(); - assert_eq!(logs.len(), 1); - assert!(logs[0].contains("none level message")); - drop(logs); - } - - #[test] - fn emit_non_cont_line_level_none_demoted_to_debug() { - let logger = create_logger(tracing::Level::DEBUG); - let options = LogOptions::default().with_demote_info_to_debug(true); - let state = State::new(Module::LlamaCpp, options); - - state.emit_non_cont_line( - llama_cpp_bindings_sys::GGML_LOG_LEVEL_NONE, - "demoted none\n", - ); - - let logs = logger.logs.lock().unwrap(); - assert_eq!(logs.len(), 1); - assert!(logs[0].contains("demoted none")); - drop(logs); - } - - #[test] - fn cont_without_prior_buffer_infers_level() { - let _logger = create_logger(tracing::Level::WARN); - let state = State::new(Module::LlamaCpp, LogOptions::default()); - - state.update_previous_level_for_disabled_log(llama_cpp_bindings_sys::GGML_LOG_LEVEL_WARN); - state.cont_buffered_log("orphan text"); - - let buffer = state.buffered.lock().unwrap(); - assert!(buffer.is_some()); - let (level, text) = buffer.as_ref().unwrap(); - assert_eq!(*level, llama_cpp_bindings_sys::GGML_LOG_LEVEL_WARN); - assert_eq!(text, "orphan text"); - drop(buffer); - } - - #[test] - fn emit_non_cont_flushes_stale_buffer() { - let _logger = create_logger(tracing::Level::WARN); - let state = State::new(Module::LlamaCpp, LogOptions::default()); - - state.buffer_non_cont(llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO, "stale"); - - state.emit_non_cont_line(llama_cpp_bindings_sys::GGML_LOG_LEVEL_WARN, "new line\n"); - - let buffer = state.buffered.lock().unwrap(); - assert!(buffer.is_none()); - drop(buffer); - } - - #[test] - fn buffer_non_cont_replaces_previous_buffer() { - let _logger = create_logger(tracing::Level::WARN); - let state = State::new(Module::LlamaCpp, LogOptions::default()); - - state.buffer_non_cont(llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO, "first"); - state.buffer_non_cont(llama_cpp_bindings_sys::GGML_LOG_LEVEL_WARN, "second"); - - let buffer = state.buffered.lock().unwrap(); - let (level, text) = buffer.as_ref().unwrap(); - assert_eq!(*level, llama_cpp_bindings_sys::GGML_LOG_LEVEL_WARN); - assert_eq!(text, "second"); - drop(buffer); - } - - #[test] - fn is_enabled_for_cont_uses_previous_level() { - let _logger = create_logger(tracing::Level::WARN); - let state = State::new(Module::LlamaCpp, LogOptions::default()); - - state.update_previous_level_for_disabled_log(llama_cpp_bindings_sys::GGML_LOG_LEVEL_ERROR); - - let enabled = state.is_enabled_for_level(llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT); - - assert!(enabled); - } - - #[test] - fn unknown_log_level_emits_warning() { - let logger = create_logger(tracing::Level::WARN); - let state = State::new(Module::LlamaCpp, LogOptions::default()); - - state.emit_non_cont_line(9999, "unknown level message\n"); - - let logs = logger.logs.lock().unwrap(); - assert!( - logs.iter() - .any(|log_line| log_line.contains("Unknown llama.cpp log level")) - ); - drop(logs); - } - - #[test] - fn send_logs_to_tracing_initializes_global_states() { - use super::{GGML_STATE, LLAMA_STATE, send_logs_to_tracing}; - - send_logs_to_tracing(LogOptions::default()); - - assert!(LLAMA_STATE.get().is_some()); - assert!(GGML_STATE.get().is_some()); - } - - #[test] - fn meta_for_level_returns_none_for_unknown_level() { - let result = super::meta_for_level(9999); - - assert!(result.is_none()); - } - - #[test] - fn is_enabled_for_level_returns_false_for_none_level() { - let _logger = create_logger(tracing::Level::DEBUG); - let state = State::new(Module::LlamaCpp, LogOptions::default()); - - let enabled = state.is_enabled_for_level(llama_cpp_bindings_sys::GGML_LOG_LEVEL_NONE); - - assert!(!enabled); - } - - #[test] - fn generate_log_handles_unmapped_level_gracefully() { - let _logger = create_logger(tracing::Level::WARN); - let state = State::new(Module::LlamaCpp, LogOptions::default()); - - state.generate_log(9999, "unmapped level message"); - } - - #[test] - fn emit_non_cont_line_handles_cont_level_gracefully() { - let _logger = create_logger(tracing::Level::WARN); - let state = State::new(Module::LlamaCpp, LogOptions::default()); - - state.emit_non_cont_line( - llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT, - "cont passed to non-cont\n", - ); - } - - #[test] - fn callsite_metadata_returns_static_metadata() { - use tracing_core::callsite::Callsite; - - let debug_meta = super::DEBUG_CS.metadata(); - let info_meta = super::INFO_CS.metadata(); - let warn_meta = super::WARN_CS.metadata(); - let error_meta = super::ERROR_CS.metadata(); - - assert_eq!(*debug_meta.level(), tracing_core::Level::DEBUG); - assert_eq!(*info_meta.level(), tracing_core::Level::INFO); - assert_eq!(*warn_meta.level(), tracing_core::Level::WARN); - assert_eq!(*error_meta.level(), tracing_core::Level::ERROR); - } - - #[test] - fn callsite_set_interest_does_not_panic() { - use tracing_core::callsite::Callsite; - use tracing_core::subscriber::Interest; - - super::DEBUG_CS.set_interest(Interest::always()); - super::INFO_CS.set_interest(Interest::never()); - super::WARN_CS.set_interest(Interest::sometimes()); - super::ERROR_CS.set_interest(Interest::always()); - } - - #[test] - fn vec_writer_flush_succeeds() { - use std::io::Write; - - let mut writer = VecWriter(Arc::new(Mutex::new(vec![]))); - - writer.flush().unwrap(); - } -} diff --git a/llama-cpp-bindings/src/log_options.rs b/llama-cpp-bindings/src/log_options.rs index 65b583ba..ca6eacca 100644 --- a/llama-cpp-bindings/src/log_options.rs +++ b/llama-cpp-bindings/src/log_options.rs @@ -6,8 +6,8 @@ pub struct LogOptions { } impl LogOptions { - /// If enabled, logs are sent to tracing. If disabled, all logs are suppressed. Default is for - /// logs to be sent to tracing. + /// If enabled, logs are dispatched through the `log` crate. If disabled, all logs are + /// suppressed. Default is for logs to be dispatched. #[must_use] pub const fn with_logs_enabled(mut self, enabled: bool) -> Self { self.disabled = !enabled; @@ -15,9 +15,9 @@ impl LogOptions { self } - /// When enabled, llama.cpp and ggml INFO logs are demoted to DEBUG tracing level. WARN and + /// When enabled, llama.cpp and ggml INFO logs are dispatched at DEBUG level. WARN and /// ERROR logs retain their original severity. This suppresses verbose informational output - /// under a typical INFO-level subscriber while keeping important diagnostics visible. + /// under a typical INFO-level logger while keeping important diagnostics visible. /// All demoted logs remain available via `RUST_LOG=debug`. #[must_use] pub const fn with_demote_info_to_debug(mut self, demote: bool) -> Self { diff --git a/llama-cpp-bindings/src/model.rs b/llama-cpp-bindings/src/model.rs index e8d5ac01..de22549d 100644 --- a/llama-cpp-bindings/src/model.rs +++ b/llama-cpp-bindings/src/model.rs @@ -1,8 +1,22 @@ //! A safe wrapper around `llama_model`. + +pub mod add_bos; +pub mod llama_chat_message; +pub mod llama_chat_template; +pub mod llama_lora_adapter; +pub mod llama_split_mode_parse_error; +pub mod params; +pub mod rope_type; +pub mod split_mode; +pub mod vocab_type; +pub mod vocab_type_from_int_error; + use std::ffi::{CStr, CString, c_char}; use std::num::NonZeroU16; use std::os::raw::c_int; use std::path::Path; +use std::ptr; +use std::ptr::NonNull; use std::sync::Arc; use std::sync::OnceLock; @@ -10,25 +24,11 @@ use toktrie::ApproximateTokEnv; use toktrie::TokRxInfo; use toktrie::TokTrie; -fn truncated_buffer_to_string( - mut buffer: Vec, - length: usize, -) -> Result { - buffer.truncate(length); - - Ok(String::from_utf8(buffer)?) -} - -fn validate_string_length_for_tokenizer(length: usize) -> Result { - Ok(c_int::try_from(length)?) -} - -fn cstring_with_validated_len(str: &str) -> Result<(CString, c_int), StringToTokenError> { - let c_string = CString::new(str)?; - let len = validate_string_length_for_tokenizer(c_string.as_bytes().len())?; - Ok((c_string, len)) -} -use std::ptr::{self, NonNull}; +use llama_cpp_bindings_types::ParsedChatMessage; +use llama_cpp_bindings_types::ParsedToolCall; +use llama_cpp_bindings_types::ReasoningMarkers; +use llama_cpp_bindings_types::ToolCallArguments; +use llama_cpp_bindings_types::ToolCallMarkers; use crate::chat_message_parse_outcome::ChatMessageParseOutcome; use crate::ffi_status_to_i32::status_to_i32; @@ -39,32 +39,16 @@ use crate::raw_chat_message::RawChatMessage; use crate::resolved_tool_call_markers::ResolvedToolCallMarkers; use crate::sampled_token::SampledToken; use crate::sampled_token_classifier::SampledTokenClassifier; -use crate::sampled_token_classifier::StreamingMarkers; +use crate::streaming_markers::StreamingMarkers; use crate::token::LlamaToken; +use crate::tool_call_format; +use crate::tool_call_format::ToolCallFormatOutcome; +use crate::tool_call_template_overrides; use crate::{ ApplyChatTemplateError, ChatTemplateError, LlamaLoraAdapterInitError, LlamaModelLoadError, MarkerDetectionError, MetaValError, ParseChatMessageError, StringToTokenError, TokenToStringError, }; -use llama_cpp_bindings_types::ParsedChatMessage; -use llama_cpp_bindings_types::ParsedToolCall; -use llama_cpp_bindings_types::ReasoningMarkers; -use llama_cpp_bindings_types::ToolCallArguments; -use llama_cpp_bindings_types::ToolCallMarkers; - -use crate::tool_call_format; -use crate::tool_call_format::ToolCallFormatOutcome; -use crate::tool_call_template_overrides; - -pub mod add_bos; -pub mod llama_chat_message; -pub mod llama_chat_template; -pub mod llama_lora_adapter; -pub mod params; -pub mod rope_type; -pub mod split_mode; -pub mod vocab_type; -pub mod vocab_type_from_int_error; pub use add_bos::AddBos; pub use llama_chat_message::LlamaChatMessage; @@ -76,6 +60,25 @@ pub use vocab_type_from_int_error::VocabTypeFromIntError; use params::LlamaModelParams; +fn truncated_buffer_to_string( + mut buffer: Vec, + length: usize, +) -> Result { + buffer.truncate(length); + + Ok(String::from_utf8(buffer)?) +} + +fn validate_string_length_for_tokenizer(length: usize) -> Result { + Ok(c_int::try_from(length)?) +} + +fn cstring_with_validated_len(str: &str) -> Result<(CString, c_int), StringToTokenError> { + let c_string = CString::new(str)?; + let len = validate_string_length_for_tokenizer(c_string.as_bytes().len())?; + Ok((c_string, len)) +} + /// A safe wrapper around `llama_model`. pub struct LlamaModel { /// Raw pointer to the underlying `llama_model`. @@ -559,7 +562,6 @@ impl LlamaModel { /// # Panics /// /// Panics if a valid UTF-8 path somehow contains interior null bytes (should never happen). - #[tracing::instrument(skip_all, fields(params))] pub fn load_from_file( _: &LlamaBackend, path: impl AsRef, @@ -644,7 +646,6 @@ impl LlamaModel { /// /// # Errors /// There are many ways this can fail. See [`ApplyChatTemplateError`] for more information. - #[tracing::instrument(skip_all)] pub fn apply_chat_template( &self, tmpl: &LlamaChatTemplate, @@ -720,8 +721,8 @@ impl LlamaModel { let markers = match self.streaming_markers() { Ok(markers) => markers, Err(detection_error) => { - tracing::warn!( - "streaming markers detection failed; classifier will run blind: {detection_error}" + log::warn!( + "streaming markers detection failed; classifier will run blind: {detection_error}", ); StreamingMarkers::default() } @@ -843,8 +844,8 @@ impl LlamaModel { let template = match self.chat_template(None) { Ok(template) => template, Err(error) => { - tracing::debug!( - "tool-call markers unavailable: chat template missing or invalid: {error}" + log::debug!( + "tool-call markers unavailable: chat template missing or invalid: {error}", ); return None; } @@ -852,8 +853,8 @@ impl LlamaModel { let template_str = match template.to_str() { Ok(template_str) => template_str, Err(error) => { - tracing::debug!( - "tool-call markers unavailable: chat template is not valid UTF-8: {error}" + log::debug!( + "tool-call markers unavailable: chat template is not valid UTF-8: {error}", ); return None; } @@ -870,8 +871,8 @@ impl LlamaModel { Ok(tokens) if !tokens.is_empty() => Some(tokens), Ok(_) => None, Err(tokenize_error) => { - tracing::debug!( - "marker {marker:?} failed to tokenise; classifier will ignore it: {tokenize_error}" + log::debug!( + "marker {marker:?} failed to tokenise; classifier will ignore it: {tokenize_error}", ); None } diff --git a/llama-cpp-bindings/src/model/llama_split_mode_parse_error.rs b/llama-cpp-bindings/src/model/llama_split_mode_parse_error.rs new file mode 100644 index 00000000..ed644534 --- /dev/null +++ b/llama-cpp-bindings/src/model/llama_split_mode_parse_error.rs @@ -0,0 +1,8 @@ +/// An error that occurs when unknown split mode is encountered. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LlamaSplitModeParseError { + /// The value that could not be parsed as a split mode. + pub value: i32, + /// Additional context about why the parse failed. + pub context: String, +} diff --git a/llama-cpp-bindings/src/model/params.rs b/llama-cpp-bindings/src/model/params.rs index 519d2a22..ebd7edd7 100644 --- a/llama-cpp-bindings/src/model/params.rs +++ b/llama-cpp-bindings/src/model/params.rs @@ -3,17 +3,20 @@ use crate::LlamaCppError; use crate::context::params::LlamaContextParams; use crate::error::{FitError, ModelParamsError}; +use crate::model::llama_split_mode_parse_error::LlamaSplitModeParseError; use crate::model::params::fit_result::FitResult; use crate::model::params::kv_overrides::KvOverrides; -use crate::model::split_mode::{LlamaSplitMode, LlamaSplitModeParseError}; +use crate::model::split_mode::LlamaSplitMode; use std::ffi::{CStr, c_char}; use std::fmt::{Debug, Formatter}; use std::pin::Pin; use std::ptr::null; pub mod fit_result; +pub mod kv_override_value_iterator; pub mod kv_overrides; pub mod param_override_value; +pub mod unknown_kv_override_tag; /// The maximum number of devices supported. /// diff --git a/llama-cpp-bindings/src/model/params/kv_override_value_iterator.rs b/llama-cpp-bindings/src/model/params/kv_override_value_iterator.rs new file mode 100644 index 00000000..6073673d --- /dev/null +++ b/llama-cpp-bindings/src/model/params/kv_override_value_iterator.rs @@ -0,0 +1,53 @@ +use std::ffi::{CStr, CString}; +use std::fmt::Debug; + +use crate::model::params::LlamaModelParams; +use crate::model::params::param_override_value::ParamOverrideValue; + +/// An iterator over the key-value overrides for a model. +#[derive(Debug)] +pub struct KvOverrideValueIterator<'model_params> { + model_params: &'model_params LlamaModelParams, + current: usize, +} + +impl<'model_params> KvOverrideValueIterator<'model_params> { + #[must_use] + pub const fn new(model_params: &'model_params LlamaModelParams) -> Self { + Self { + model_params, + current: 0, + } + } +} + +impl Iterator for KvOverrideValueIterator<'_> { + type Item = (CString, ParamOverrideValue); + + fn next(&mut self) -> Option { + let overrides = self.model_params.params.kv_overrides; + + if overrides.is_null() { + return None; + } + + loop { + // SAFETY: llama.cpp guarantees the last element contains an empty key. + // We've checked the previous one in the last iteration, the next one + // should be valid or 0 (and thus safe to deref). + let current = unsafe { *overrides.add(self.current) }; + + if current.key[0] == 0 { + return None; + } + + self.current += 1; + + if let Ok(value) = ParamOverrideValue::try_from(¤t) { + let key = unsafe { CStr::from_ptr(current.key.as_ptr()).to_owned() }; + + return Some((key, value)); + } + } + } +} diff --git a/llama-cpp-bindings/src/model/params/kv_overrides.rs b/llama-cpp-bindings/src/model/params/kv_overrides.rs index 688d9f62..d3f46c28 100644 --- a/llama-cpp-bindings/src/model/params/kv_overrides.rs +++ b/llama-cpp-bindings/src/model/params/kv_overrides.rs @@ -1,10 +1,10 @@ //! Key-value overrides for a model. -use crate::model::params::LlamaModelParams; -use crate::model::params::param_override_value::ParamOverrideValue; -use std::ffi::{CStr, CString}; use std::fmt::Debug; +use crate::model::params::LlamaModelParams; +use crate::model::params::kv_override_value_iterator::KvOverrideValueIterator; + /// A struct implementing [`IntoIterator`] over the key-value overrides for a model. #[derive(Debug)] pub struct KvOverrides<'model_params> { @@ -20,52 +20,11 @@ impl KvOverrides<'_> { } impl<'model_params> IntoIterator for KvOverrides<'model_params> { - type Item = (CString, ParamOverrideValue); + type Item = as Iterator>::Item; type IntoIter = KvOverrideValueIterator<'model_params>; fn into_iter(self) -> Self::IntoIter { - KvOverrideValueIterator { - model_params: self.model_params, - current: 0, - } - } -} - -/// An iterator over the key-value overrides for a model. -#[derive(Debug)] -pub struct KvOverrideValueIterator<'model_params> { - model_params: &'model_params LlamaModelParams, - current: usize, -} - -impl Iterator for KvOverrideValueIterator<'_> { - type Item = (CString, ParamOverrideValue); - - fn next(&mut self) -> Option { - let overrides = self.model_params.params.kv_overrides; - - if overrides.is_null() { - return None; - } - - loop { - // SAFETY: llama.cpp guarantees the last element contains an empty key. - // We've checked the previous one in the last iteration, the next one - // should be valid or 0 (and thus safe to deref). - let current = unsafe { *overrides.add(self.current) }; - - if current.key[0] == 0 { - return None; - } - - self.current += 1; - - if let Ok(value) = ParamOverrideValue::try_from(¤t) { - let key = unsafe { CStr::from_ptr(current.key.as_ptr()).to_owned() }; - - return Some((key, value)); - } - } + KvOverrideValueIterator::new(self.model_params) } } diff --git a/llama-cpp-bindings/src/model/params/param_override_value.rs b/llama-cpp-bindings/src/model/params/param_override_value.rs index c4639fc3..b20e12af 100644 --- a/llama-cpp-bindings/src/model/params/param_override_value.rs +++ b/llama-cpp-bindings/src/model/params/param_override_value.rs @@ -1,3 +1,5 @@ +use crate::model::params::unknown_kv_override_tag::UnknownKvOverrideTag; + /// An override value for a model parameter. #[derive(Debug, Clone, Copy, PartialEq)] pub enum ParamOverrideValue { @@ -43,11 +45,6 @@ impl ParamOverrideValue { } } -/// Unknown KV override tag from the FFI layer. -#[derive(Debug, thiserror::Error)] -#[error("unknown KV override tag: {0}")] -pub struct UnknownKvOverrideTag(pub llama_cpp_bindings_sys::llama_model_kv_override_type); - impl TryFrom<&llama_cpp_bindings_sys::llama_model_kv_override> for ParamOverrideValue { type Error = UnknownKvOverrideTag; diff --git a/llama-cpp-bindings/src/model/params/unknown_kv_override_tag.rs b/llama-cpp-bindings/src/model/params/unknown_kv_override_tag.rs new file mode 100644 index 00000000..67978bde --- /dev/null +++ b/llama-cpp-bindings/src/model/params/unknown_kv_override_tag.rs @@ -0,0 +1,4 @@ +/// Unknown KV override tag from the FFI layer. +#[derive(Debug, thiserror::Error)] +#[error("unknown KV override tag: {0}")] +pub struct UnknownKvOverrideTag(pub llama_cpp_bindings_sys::llama_model_kv_override_type); diff --git a/llama-cpp-bindings/src/model/split_mode.rs b/llama-cpp-bindings/src/model/split_mode.rs index 595490a7..170c5596 100644 --- a/llama-cpp-bindings/src/model/split_mode.rs +++ b/llama-cpp-bindings/src/model/split_mode.rs @@ -1,3 +1,5 @@ +use crate::model::llama_split_mode_parse_error::LlamaSplitModeParseError; + /// A rusty wrapper around `llama_split_mode`. #[repr(i8)] #[derive(Copy, Clone, Debug, PartialEq, Eq)] @@ -33,15 +35,6 @@ const LLAMA_SPLIT_MODE_ROW: i8 = llama_cpp_bindings_sys::LLAMA_SPLIT_MODE_ROW as )] const LLAMA_SPLIT_MODE_TENSOR: i8 = llama_cpp_bindings_sys::LLAMA_SPLIT_MODE_TENSOR as i8; -/// An error that occurs when unknown split mode is encountered. -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct LlamaSplitModeParseError { - /// The value that could not be parsed as a split mode. - pub value: i32, - /// Additional context about why the parse failed. - pub context: String, -} - /// Create a `LlamaSplitMode` from a `i32`. /// /// # Errors diff --git a/llama-cpp-bindings/src/mtmd.rs b/llama-cpp-bindings/src/mtmd.rs index e5787a83..7d87980a 100644 --- a/llama-cpp-bindings/src/mtmd.rs +++ b/llama-cpp-bindings/src/mtmd.rs @@ -8,25 +8,36 @@ pub mod image_chunk_batch_size_mismatch; pub mod mtmd_bitmap; +pub mod mtmd_bitmap_error; pub mod mtmd_context; pub mod mtmd_context_params; pub mod mtmd_default_marker; -pub mod mtmd_error; +pub mod mtmd_encode_error; +pub mod mtmd_eval_error; +pub mod mtmd_init_error; pub mod mtmd_input_chunk; +pub mod mtmd_input_chunk_error; pub mod mtmd_input_chunk_type; +pub mod mtmd_input_chunk_type_error; pub mod mtmd_input_chunks; +pub mod mtmd_input_chunks_error; pub mod mtmd_input_text; +pub mod mtmd_tokenize_error; pub use image_chunk_batch_size_mismatch::ImageChunkBatchSizeMismatch; pub use mtmd_bitmap::MtmdBitmap; +pub use mtmd_bitmap_error::MtmdBitmapError; pub use mtmd_context::MtmdContext; pub use mtmd_context_params::MtmdContextParams; pub use mtmd_default_marker::mtmd_default_marker; -pub use mtmd_error::{ - MtmdBitmapError, MtmdEncodeError, MtmdEvalError, MtmdInitError, MtmdInputChunkError, - MtmdInputChunksError, MtmdTokenizeError, -}; +pub use mtmd_encode_error::MtmdEncodeError; +pub use mtmd_eval_error::MtmdEvalError; +pub use mtmd_init_error::MtmdInitError; pub use mtmd_input_chunk::MtmdInputChunk; -pub use mtmd_input_chunk_type::{MtmdInputChunkType, MtmdInputChunkTypeError}; +pub use mtmd_input_chunk_error::MtmdInputChunkError; +pub use mtmd_input_chunk_type::MtmdInputChunkType; +pub use mtmd_input_chunk_type_error::MtmdInputChunkTypeError; pub use mtmd_input_chunks::MtmdInputChunks; +pub use mtmd_input_chunks_error::MtmdInputChunksError; pub use mtmd_input_text::MtmdInputText; +pub use mtmd_tokenize_error::MtmdTokenizeError; diff --git a/llama-cpp-bindings/src/mtmd/image_chunk_batch_size_mismatch.rs b/llama-cpp-bindings/src/mtmd/image_chunk_batch_size_mismatch.rs index 992b0eec..dfac7f12 100644 --- a/llama-cpp-bindings/src/mtmd/image_chunk_batch_size_mismatch.rs +++ b/llama-cpp-bindings/src/mtmd/image_chunk_batch_size_mismatch.rs @@ -1,4 +1,4 @@ -/// Carried by [`super::mtmd_error::MtmdEvalError::ImageChunkExceedsBatchSize`]. +/// Carried by [`super::mtmd_eval_error::MtmdEvalError::ImageChunkExceedsBatchSize`]. /// /// `n_batch` is the per-decode batch budget enforced by `cparams.n_batch` in /// llama.cpp; `image_tokens` is the number of tokens this image chunk would diff --git a/llama-cpp-bindings/src/mtmd/mtmd_bitmap.rs b/llama-cpp-bindings/src/mtmd/mtmd_bitmap.rs index 8076d6e6..08e2ce6c 100644 --- a/llama-cpp-bindings/src/mtmd/mtmd_bitmap.rs +++ b/llama-cpp-bindings/src/mtmd/mtmd_bitmap.rs @@ -2,8 +2,8 @@ use std::ffi::{CStr, CString, c_char}; use std::ptr::NonNull; use std::slice; +use super::mtmd_bitmap_error::MtmdBitmapError; use super::mtmd_context::MtmdContext; -use super::mtmd_error::MtmdBitmapError; fn cstr_ptr_to_optional_string(ptr: *const c_char) -> Option { if ptr.is_null() { diff --git a/llama-cpp-bindings/src/mtmd/mtmd_bitmap_error.rs b/llama-cpp-bindings/src/mtmd/mtmd_bitmap_error.rs new file mode 100644 index 00000000..c0ad849c --- /dev/null +++ b/llama-cpp-bindings/src/mtmd/mtmd_bitmap_error.rs @@ -0,0 +1,16 @@ +/// Errors that can occur when working with MTMD bitmaps +#[derive(thiserror::Error, Debug)] +pub enum MtmdBitmapError { + /// Failed to create `CString` from input + #[error("Failed to create CString: {0}")] + CStringError(#[from] std::ffi::NulError), + /// Invalid data size for bitmap + #[error("Invalid data size for bitmap")] + InvalidDataSize, + /// Image dimensions too small for processing (minimum 2x2) + #[error("Image dimensions too small: {0}x{1} (minimum 2x2)")] + ImageDimensionsTooSmall(u32, u32), + /// Bitmap creation returned null + #[error("Bitmap creation returned null")] + NullResult, +} diff --git a/llama-cpp-bindings/src/mtmd/mtmd_context.rs b/llama-cpp-bindings/src/mtmd/mtmd_context.rs index f9f29abb..4445a6ad 100644 --- a/llama-cpp-bindings/src/mtmd/mtmd_context.rs +++ b/llama-cpp-bindings/src/mtmd/mtmd_context.rs @@ -5,10 +5,12 @@ use crate::model::LlamaModel; use super::mtmd_bitmap::MtmdBitmap; use super::mtmd_context_params::MtmdContextParams; -use super::mtmd_error::{MtmdEncodeError, MtmdInitError, MtmdTokenizeError}; +use super::mtmd_encode_error::MtmdEncodeError; +use super::mtmd_init_error::MtmdInitError; use super::mtmd_input_chunk::MtmdInputChunk; use super::mtmd_input_chunks::MtmdInputChunks; use super::mtmd_input_text::MtmdInputText; +use super::mtmd_tokenize_error::MtmdTokenizeError; const fn tokenize_result_to_error(result: i32) -> MtmdTokenizeError { match result { diff --git a/llama-cpp-bindings/src/mtmd/mtmd_encode_error.rs b/llama-cpp-bindings/src/mtmd/mtmd_encode_error.rs new file mode 100644 index 00000000..fabd3311 --- /dev/null +++ b/llama-cpp-bindings/src/mtmd/mtmd_encode_error.rs @@ -0,0 +1,7 @@ +/// Errors that can occur during encoding +#[derive(thiserror::Error, Debug)] +pub enum MtmdEncodeError { + /// Encode operation failed + #[error("Encode failed with code: {0}")] + EncodeFailure(i32), +} diff --git a/llama-cpp-bindings/src/mtmd/mtmd_error.rs b/llama-cpp-bindings/src/mtmd/mtmd_error.rs deleted file mode 100644 index 687b7243..00000000 --- a/llama-cpp-bindings/src/mtmd/mtmd_error.rs +++ /dev/null @@ -1,97 +0,0 @@ -/// Errors that can occur when initializing MTMD context -#[derive(thiserror::Error, Debug)] -pub enum MtmdInitError { - /// Failed to create `CString` from input - #[error("Failed to create CString: {0}")] - CStringError(#[from] std::ffi::NulError), - /// MTMD context initialization returned null - #[error("MTMD context initialization returned null")] - NullResult, -} - -/// Errors that can occur when working with MTMD bitmaps -#[derive(thiserror::Error, Debug)] -pub enum MtmdBitmapError { - /// Failed to create `CString` from input - #[error("Failed to create CString: {0}")] - CStringError(#[from] std::ffi::NulError), - /// Invalid data size for bitmap - #[error("Invalid data size for bitmap")] - InvalidDataSize, - /// Image dimensions too small for processing (minimum 2x2) - #[error("Image dimensions too small: {0}x{1} (minimum 2x2)")] - ImageDimensionsTooSmall(u32, u32), - /// Bitmap creation returned null - #[error("Bitmap creation returned null")] - NullResult, -} - -/// Errors that can occur when working with MTMD input chunks collections -#[derive(thiserror::Error, Debug)] -pub enum MtmdInputChunksError { - /// Input chunks creation returned null - #[error("Input chunks creation returned null")] - NullResult, -} - -/// Errors that can occur when working with individual MTMD input chunks -#[derive(thiserror::Error, Debug)] -pub enum MtmdInputChunkError { - /// Input chunk operation returned null - #[error("Input chunk operation returned null")] - NullResult, -} - -/// Errors that can occur during tokenization -#[derive(thiserror::Error, Debug)] -pub enum MtmdTokenizeError { - /// Number of bitmaps does not match number of markers in text - #[error("Number of bitmaps does not match number of markers")] - BitmapCountMismatch, - /// Image preprocessing error occurred - #[error("Image preprocessing error")] - ImagePreprocessingError, - /// Failed to create input chunks collection - #[error("{0}")] - InputChunksError(#[from] MtmdInputChunksError), - /// Text contains characters that cannot be converted to C string - #[error("Failed to create CString from text: {0}")] - CStringError(#[from] std::ffi::NulError), - /// Unknown error occurred during tokenization - #[error("Unknown error: {0}")] - UnknownError(i32), -} - -/// Errors that can occur during encoding -#[derive(thiserror::Error, Debug)] -pub enum MtmdEncodeError { - /// Encode operation failed - #[error("Encode failed with code: {0}")] - EncodeFailure(i32), -} - -use crate::mtmd::image_chunk_batch_size_mismatch::ImageChunkBatchSizeMismatch; - -/// Errors that can occur during evaluation -#[derive(thiserror::Error, Debug)] -pub enum MtmdEvalError { - /// Requested batch size exceeds the context's maximum batch size - #[error("batch size {requested} exceeds context batch size {context_max}")] - BatchSizeExceedsContextLimit { - /// The batch size requested in `eval_chunks` - requested: i32, - /// The maximum batch size configured on the context - context_max: u32, - }, - /// An image chunk's token count exceeds the per-decode `n_batch` budget, - /// so handing it to `llama_decode` would trip the `GGML_ASSERT`. - #[error( - "image chunk has {} tokens but n_batch is {}", - .0.image_tokens, - .0.n_batch, - )] - ImageChunkExceedsBatchSize(ImageChunkBatchSizeMismatch), - /// Evaluation operation failed - #[error("Eval failed with code: {0}")] - EvalFailure(i32), -} diff --git a/llama-cpp-bindings/src/mtmd/mtmd_eval_error.rs b/llama-cpp-bindings/src/mtmd/mtmd_eval_error.rs new file mode 100644 index 00000000..c4efa643 --- /dev/null +++ b/llama-cpp-bindings/src/mtmd/mtmd_eval_error.rs @@ -0,0 +1,25 @@ +use crate::mtmd::image_chunk_batch_size_mismatch::ImageChunkBatchSizeMismatch; + +/// Errors that can occur during evaluation +#[derive(thiserror::Error, Debug)] +pub enum MtmdEvalError { + /// Requested batch size exceeds the context's maximum batch size + #[error("batch size {requested} exceeds context batch size {context_max}")] + BatchSizeExceedsContextLimit { + /// The batch size requested in `eval_chunks` + requested: i32, + /// The maximum batch size configured on the context + context_max: u32, + }, + /// An image chunk's token count exceeds the per-decode `n_batch` budget, + /// so handing it to `llama_decode` would trip the `GGML_ASSERT`. + #[error( + "image chunk has {} tokens but n_batch is {}", + .0.image_tokens, + .0.n_batch, + )] + ImageChunkExceedsBatchSize(ImageChunkBatchSizeMismatch), + /// Evaluation operation failed + #[error("Eval failed with code: {0}")] + EvalFailure(i32), +} diff --git a/llama-cpp-bindings/src/mtmd/mtmd_init_error.rs b/llama-cpp-bindings/src/mtmd/mtmd_init_error.rs new file mode 100644 index 00000000..755d6a55 --- /dev/null +++ b/llama-cpp-bindings/src/mtmd/mtmd_init_error.rs @@ -0,0 +1,10 @@ +/// Errors that can occur when initializing MTMD context +#[derive(thiserror::Error, Debug)] +pub enum MtmdInitError { + /// Failed to create `CString` from input + #[error("Failed to create CString: {0}")] + CStringError(#[from] std::ffi::NulError), + /// MTMD context initialization returned null + #[error("MTMD context initialization returned null")] + NullResult, +} diff --git a/llama-cpp-bindings/src/mtmd/mtmd_input_chunk.rs b/llama-cpp-bindings/src/mtmd/mtmd_input_chunk.rs index 4bfa1110..50643547 100644 --- a/llama-cpp-bindings/src/mtmd/mtmd_input_chunk.rs +++ b/llama-cpp-bindings/src/mtmd/mtmd_input_chunk.rs @@ -7,9 +7,10 @@ use crate::token::LlamaToken; use super::image_chunk_batch_size_mismatch::ImageChunkBatchSizeMismatch; use super::mtmd_context::MtmdContext; -use super::mtmd_error::MtmdEvalError; -use super::mtmd_error::MtmdInputChunkError; -use super::mtmd_input_chunk_type::{MtmdInputChunkType, MtmdInputChunkTypeError}; +use super::mtmd_eval_error::MtmdEvalError; +use super::mtmd_input_chunk_error::MtmdInputChunkError; +use super::mtmd_input_chunk_type::MtmdInputChunkType; +use super::mtmd_input_chunk_type_error::MtmdInputChunkTypeError; /// # Safety /// diff --git a/llama-cpp-bindings/src/mtmd/mtmd_input_chunk_error.rs b/llama-cpp-bindings/src/mtmd/mtmd_input_chunk_error.rs new file mode 100644 index 00000000..e44e1c30 --- /dev/null +++ b/llama-cpp-bindings/src/mtmd/mtmd_input_chunk_error.rs @@ -0,0 +1,7 @@ +/// Errors that can occur when working with individual MTMD input chunks +#[derive(thiserror::Error, Debug)] +pub enum MtmdInputChunkError { + /// Input chunk operation returned null + #[error("Input chunk operation returned null")] + NullResult, +} diff --git a/llama-cpp-bindings/src/mtmd/mtmd_input_chunk_type.rs b/llama-cpp-bindings/src/mtmd/mtmd_input_chunk_type.rs index a779363a..ef628b89 100644 --- a/llama-cpp-bindings/src/mtmd/mtmd_input_chunk_type.rs +++ b/llama-cpp-bindings/src/mtmd/mtmd_input_chunk_type.rs @@ -1,7 +1,4 @@ -/// Error when converting from an unknown MTMD input chunk type value. -#[derive(Debug, PartialEq, Eq, thiserror::Error)] -#[error("Unknown MTMD input chunk type: {0}")] -pub struct MtmdInputChunkTypeError(pub llama_cpp_bindings_sys::mtmd_input_chunk_type); +use crate::mtmd::mtmd_input_chunk_type_error::MtmdInputChunkTypeError; /// Input chunk types for multimodal data /// @@ -48,7 +45,7 @@ impl TryFrom for MtmdInputChunkTy #[cfg(test)] mod tests { use super::MtmdInputChunkType; - use super::MtmdInputChunkTypeError; + use crate::mtmd::mtmd_input_chunk_type_error::MtmdInputChunkTypeError; #[test] fn text_variant_converts_from_raw() { diff --git a/llama-cpp-bindings/src/mtmd/mtmd_input_chunk_type_error.rs b/llama-cpp-bindings/src/mtmd/mtmd_input_chunk_type_error.rs new file mode 100644 index 00000000..ae3ca7e8 --- /dev/null +++ b/llama-cpp-bindings/src/mtmd/mtmd_input_chunk_type_error.rs @@ -0,0 +1,4 @@ +/// Error when converting from an unknown MTMD input chunk type value. +#[derive(Debug, PartialEq, Eq, thiserror::Error)] +#[error("Unknown MTMD input chunk type: {0}")] +pub struct MtmdInputChunkTypeError(pub llama_cpp_bindings_sys::mtmd_input_chunk_type); diff --git a/llama-cpp-bindings/src/mtmd/mtmd_input_chunks.rs b/llama-cpp-bindings/src/mtmd/mtmd_input_chunks.rs index d9b3a9d8..a74eb296 100644 --- a/llama-cpp-bindings/src/mtmd/mtmd_input_chunks.rs +++ b/llama-cpp-bindings/src/mtmd/mtmd_input_chunks.rs @@ -3,9 +3,9 @@ use std::ptr::NonNull; use crate::context::LlamaContext; use super::mtmd_context::MtmdContext; -use super::mtmd_error::MtmdEvalError; -use super::mtmd_error::MtmdInputChunksError; +use super::mtmd_eval_error::MtmdEvalError; use super::mtmd_input_chunk::MtmdInputChunk; +use super::mtmd_input_chunks_error::MtmdInputChunksError; const fn check_eval_result(result: i32) -> Result<(), MtmdEvalError> { if result == 0 { diff --git a/llama-cpp-bindings/src/mtmd/mtmd_input_chunks_error.rs b/llama-cpp-bindings/src/mtmd/mtmd_input_chunks_error.rs new file mode 100644 index 00000000..10a251d1 --- /dev/null +++ b/llama-cpp-bindings/src/mtmd/mtmd_input_chunks_error.rs @@ -0,0 +1,7 @@ +/// Errors that can occur when working with MTMD input chunks collections +#[derive(thiserror::Error, Debug)] +pub enum MtmdInputChunksError { + /// Input chunks creation returned null + #[error("Input chunks creation returned null")] + NullResult, +} diff --git a/llama-cpp-bindings/src/mtmd/mtmd_tokenize_error.rs b/llama-cpp-bindings/src/mtmd/mtmd_tokenize_error.rs new file mode 100644 index 00000000..8886bc19 --- /dev/null +++ b/llama-cpp-bindings/src/mtmd/mtmd_tokenize_error.rs @@ -0,0 +1,21 @@ +use crate::mtmd::mtmd_input_chunks_error::MtmdInputChunksError; + +/// Errors that can occur during tokenization +#[derive(thiserror::Error, Debug)] +pub enum MtmdTokenizeError { + /// Number of bitmaps does not match number of markers in text + #[error("Number of bitmaps does not match number of markers")] + BitmapCountMismatch, + /// Image preprocessing error occurred + #[error("Image preprocessing error")] + ImagePreprocessingError, + /// Failed to create input chunks collection + #[error("{0}")] + InputChunksError(#[from] MtmdInputChunksError), + /// Text contains characters that cannot be converted to C string + #[error("Failed to create CString from text: {0}")] + CStringError(#[from] std::ffi::NulError), + /// Unknown error occurred during tokenization + #[error("Unknown error: {0}")] + UnknownError(i32), +} diff --git a/llama-cpp-bindings/src/sampled_token_classifier.rs b/llama-cpp-bindings/src/sampled_token_classifier.rs index 89c034f2..83d0d108 100644 --- a/llama-cpp-bindings/src/sampled_token_classifier.rs +++ b/llama-cpp-bindings/src/sampled_token_classifier.rs @@ -17,81 +17,11 @@ use crate::mtmd::MtmdInputChunks; use crate::sampled_token::SampledToken; use crate::sampling::LlamaSampler; use crate::streaming_json_probe::JsonProbeOutcome; +use crate::streaming_markers::{MarkerKind, StreamingMarkers}; use crate::token::LlamaToken; -#[derive(Copy, Clone, Debug, Eq, PartialEq)] -pub enum SampledTokenSection { - Pending, - Content, - Reasoning, - ToolCall, -} - -#[derive(Copy, Clone, Debug, Eq, PartialEq)] -enum MarkerKind { - ReasoningOpen, - ReasoningClose, - ToolCallOpen, - ToolCallClose, -} - -/// Tokenized marker sequences (token IDs, not strings). -/// -/// Each marker is a `Vec` of length `>= 1`; absent markers are -/// `None`. Sequence matching at every `ingest()` is by token-ID equality, -/// never by substring scanning of decoded text. -#[derive(Clone, Debug, Default, Eq, PartialEq)] -pub struct StreamingMarkers { - pub reasoning_open: Option>, - pub reasoning_close: Option>, - pub tool_call_open: Option>, - pub tool_call_close: Option>, -} - -impl StreamingMarkers { - const fn has_any(&self) -> bool { - self.reasoning_open.is_some() - || self.reasoning_close.is_some() - || self.tool_call_open.is_some() - || self.tool_call_close.is_some() - } - - fn max_token_len(&self) -> usize { - [ - self.reasoning_open.as_deref(), - self.reasoning_close.as_deref(), - self.tool_call_open.as_deref(), - self.tool_call_close.as_deref(), - ] - .into_iter() - .flatten() - .map(<[LlamaToken]>::len) - .max() - .unwrap_or(0) - } - - fn lookup(&self, kind: MarkerKind) -> Option<&[LlamaToken]> { - match kind { - MarkerKind::ReasoningOpen => self.reasoning_open.as_deref(), - MarkerKind::ReasoningClose => self.reasoning_close.as_deref(), - MarkerKind::ToolCallOpen => self.tool_call_open.as_deref(), - MarkerKind::ToolCallClose => self.tool_call_close.as_deref(), - } - } -} - -#[derive(Clone, Debug)] -pub struct IngestOutcome { - pub sampled_token: SampledToken, - /// Empty when the token is part of a recognised marker boundary; otherwise - /// the decoded UTF-8 piece. Callers should stream `visible_piece` and skip - /// emission when it is empty. - pub visible_piece: String, - /// Always the decoded UTF-8 piece, even for marker-boundary tokens. Useful - /// for accumulating the full raw model output (e.g. for downstream parser - /// cross-checks) without losing marker bytes. - pub raw_piece: String, -} +pub use crate::ingest_outcome::IngestOutcome; +pub use crate::sampled_token_section::SampledTokenSection; #[derive(Clone, Debug)] struct PendingToken { @@ -250,8 +180,8 @@ impl<'model> SampledTokenClassifier<'model> { ) { Ok(piece) => piece, Err(detokenize_error) => { - tracing::debug!( - "token_to_piece failed during classification, dropping piece: {detokenize_error}" + log::debug!( + "token_to_piece failed during classification, dropping piece: {detokenize_error}", ); String::new() } @@ -259,14 +189,6 @@ impl<'model> SampledTokenClassifier<'model> { } fn try_consume_marker_at_tail(&mut self) { - // Probe every marker in every section so the user-visible streams stay - // free of marker text even when the model misbehaves: a stray - // `` / `` / `[/THINK]` while in `Content` is - // suppressed (close markers transition to Content — a no-op when - // already there); a nested `` while in `Reasoning` is also - // suppressed (open markers keep the section in Reasoning). Without - // this, models like Gemma 4 E4B that emit close markers without ever - // opening leak the literal marker text into `content_stream`. const PROBE_KINDS: &[MarkerKind] = &[ MarkerKind::ReasoningOpen, MarkerKind::ReasoningClose, @@ -301,15 +223,6 @@ impl<'model> SampledTokenClassifier<'model> { MarkerKind::ReasoningClose | MarkerKind::ToolCallClose => SampledTokenSection::Content, MarkerKind::ToolCallOpen => SampledTokenSection::ToolCall, }; - // For open markers, the boundary tokens are classified as the destination - // section — they are the marker itself (`` is part of reasoning, - // `` is part of the tool-call protocol). For close markers, - // the boundary tokens are classified as the section the model was in: - // a normal `` while in `Reasoning` is still reasoning, but a - // spurious `` while in `Content` (e.g. some Gemma variants - // re-emit close markers without ever opening) is just noise in the - // content section — counting it as `Reasoning` would inflate - // `observed_reasoning` and falsely indicate the model thought. let span_section = match kind { MarkerKind::ReasoningOpen => SampledTokenSection::Reasoning, MarkerKind::ToolCallOpen => SampledTokenSection::ToolCall, @@ -584,9 +497,6 @@ impl<'model> SampledTokenClassifier<'model> { logits_last: bool, ) -> Result { let chunk_count = chunks.len(); - // `start_position` stays read-only; `next_position` is the loop - // accumulator that walks forward chunk-by-chunk and is the function's - // return value. Two locals, single responsibility each. let mut next_position = start_position; for index in 0..chunk_count { @@ -651,13 +561,13 @@ impl<'model> SampledTokenClassifier<'model> { #[cfg(test)] mod tests { - use super::IngestOutcome; use super::PendingToken; use super::ProbeMode; use super::SampledTokenClassifier; - use super::SampledTokenSection; - use super::StreamingMarkers; + use crate::ingest_outcome::IngestOutcome; use crate::sampled_token::SampledToken; + use crate::sampled_token_section::SampledTokenSection; + use crate::streaming_markers::StreamingMarkers; use crate::token::LlamaToken; fn token(id: i32) -> LlamaToken { @@ -676,9 +586,6 @@ mod tests { } } - /// Builds a classifier without a real model — only safe for tests that go - /// through `try_consume_marker_at_tail` / `drain_overflow` directly, never - /// through `ingest()` (which calls `model.token_to_piece`). fn synthetic_classifier(markers: StreamingMarkers) -> SampledTokenClassifier<'static> { SampledTokenClassifier { model: unsafe { &*std::ptr::NonNull::::dangling().as_ptr() }, @@ -750,24 +657,6 @@ mod tests { .collect() } - #[test] - fn streaming_markers_with_no_markers_reports_none() { - let markers = StreamingMarkers::default(); - assert!(!markers.has_any()); - assert_eq!(markers.max_token_len(), 0); - } - - #[test] - fn streaming_markers_max_token_len_takes_longest() { - let markers = StreamingMarkers { - reasoning_open: Some(vec![token(1)]), - reasoning_close: Some(vec![token(2), token(3), token(4)]), - tool_call_open: Some(vec![token(5), token(6)]), - tool_call_close: None, - }; - assert_eq!(markers.max_token_len(), 3); - } - #[test] fn single_token_close_marker_when_already_in_reasoning_emits_empty_piece_for_marker() { let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)])); @@ -1065,7 +954,6 @@ mod tests { let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)])); let mut classifier = synthetic_classifier(markers); - // body body for token_id in [100, 7, 200, 100, 8, 200] { push_pending_from_prompt(&mut classifier, token_id); classifier.try_consume_marker_at_tail(); @@ -1092,7 +980,6 @@ mod tests { let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)])); let mut classifier = synthetic_classifier(markers); - // Closed-think prompt: body for token_id in [100, 7, 200] { push_pending_from_prompt(&mut classifier, token_id); classifier.try_consume_marker_at_tail(); @@ -1103,9 +990,6 @@ mod tests { assert_eq!(classifier.usage().reasoning_tokens, 0); assert_eq!(classifier.usage().content_tokens, 0); - // Generated content token (not from prompt): pushed with section=Content, - // is_from_prompt=false. drain_overflow finalises it as SampledToken::Content - // and increments usage.content_tokens. classifier.pending.push_back(PendingToken { token: token(50), decoded: "hi".to_owned(), @@ -1130,13 +1014,6 @@ mod tests { #[test] fn close_marker_in_content_section_is_suppressed_as_boundary() { - // When a misbehaving model emits a close marker (e.g. ``) while - // already in the Content section, the classifier must treat it as a - // boundary so the marker text never reaches the user-visible content - // stream. The boundary token is classified as Content (not Reasoning): - // there is no reasoning to close, the close marker is just noise in - // the content section. This is the architectural backstop against - // models that re-emit close markers without a preceding open. let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)])); let mut classifier = synthetic_classifier(markers); classifier.section = SampledTokenSection::Content; @@ -1157,17 +1034,12 @@ mod tests { SampledTokenSection::Content, ], ); - // The close marker's `visible_piece` is empty (boundary), so the - // user-visible content stream is "hi" + "" + "ok" = "hiok". assert_eq!(outcome_pieces(&outcomes), vec!["hi", "", "ok"]); assert_eq!(classifier.section, SampledTokenSection::Content); } #[test] fn open_marker_in_reasoning_section_is_suppressed_as_boundary() { - // A nested `` while already in Reasoning is suppressed (so the - // user never sees the marker text in the reasoning stream) and the - // section stays Reasoning. let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)])); let mut classifier = synthetic_classifier(markers); classifier.section = SampledTokenSection::Reasoning; @@ -1240,8 +1112,6 @@ mod tests { #[test] fn spurious_tool_call_close_in_content_section_classifies_as_content() { - // A `` while in Content (model misbehaves) is classified as - // Content (not ToolCall) so observed_tool_calls isn't inflated. let mut markers = markers_with(None, None); markers.tool_call_close = Some(vec![token(300)]); let mut classifier = synthetic_classifier(markers); @@ -1504,11 +1374,6 @@ mod tests { #[test] fn marker_probe_takes_precedence_when_both_could_match() { - // Marker is a single token whose decoded text starts with `"` (a JSON - // signature-valid byte). The JSON probe holds the leading `{`, the - // marker matches at the next token, the section transitions to ToolCall, - // the JSON probe abandons. The leading `{` releases as Content; the - // marker token releases as a ToolCall boundary (suppressed). let markers = markers_with_tool_call_open(vec![token(900)]); let mut classifier = synthetic_classifier(markers); classifier.section = SampledTokenSection::Content; diff --git a/llama-cpp-bindings/src/sampled_token_section.rs b/llama-cpp-bindings/src/sampled_token_section.rs new file mode 100644 index 00000000..4b54fe88 --- /dev/null +++ b/llama-cpp-bindings/src/sampled_token_section.rs @@ -0,0 +1,7 @@ +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum SampledTokenSection { + Pending, + Content, + Reasoning, + ToolCall, +} diff --git a/llama-cpp-bindings/src/send_logs_to_log.rs b/llama-cpp-bindings/src/send_logs_to_log.rs new file mode 100644 index 00000000..6bd8fbb7 --- /dev/null +++ b/llama-cpp-bindings/src/send_logs_to_log.rs @@ -0,0 +1,422 @@ +#![deny(clippy::expect_used)] +#![deny(clippy::indexing_slicing)] +#![deny(clippy::panic)] +#![deny(clippy::unwrap_used)] + +use std::sync::{Mutex, OnceLock}; + +use llama_cpp_log_decoder::decode_anomaly::DecodeAnomaly; +use llama_cpp_log_decoder::decode_output::DecodeOutput; +use llama_cpp_log_decoder::incoming_log_level::IncomingLogLevel; +use llama_cpp_log_decoder::log_decoder::LogDecoder; +use llama_cpp_log_decoder::log_level::LogLevel; +use llama_cpp_log_decoder::log_line::LogLine; + +use crate::log_options::LogOptions; + +struct LogSource { + decoder: Mutex, + target: &'static str, + options: LogOptions, +} + +impl LogSource { + const fn new(target: &'static str, options: LogOptions) -> Self { + Self { + decoder: Mutex::new(LogDecoder::new()), + target, + options, + } + } +} + +static LLAMA_SOURCE: OnceLock = OnceLock::new(); +static GGML_SOURCE: OnceLock = OnceLock::new(); + +const fn ggml_level_to_incoming(raw: llama_cpp_bindings_sys::ggml_log_level) -> IncomingLogLevel { + match raw { + llama_cpp_bindings_sys::GGML_LOG_LEVEL_NONE => IncomingLogLevel::None, + llama_cpp_bindings_sys::GGML_LOG_LEVEL_DEBUG => IncomingLogLevel::Debug, + llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO => IncomingLogLevel::Info, + llama_cpp_bindings_sys::GGML_LOG_LEVEL_WARN => IncomingLogLevel::Warn, + llama_cpp_bindings_sys::GGML_LOG_LEVEL_ERROR => IncomingLogLevel::Error, + llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT => IncomingLogLevel::Cont, + other => IncomingLogLevel::Unknown(other), + } +} + +fn resolve_record(line: LogLine, demote_info_to_debug: bool) -> (log::Level, String) { + let effective_level = + if demote_info_to_debug && matches!(line.level, LogLevel::Info | LogLevel::None) { + LogLevel::Debug + } else { + line.level + }; + + match effective_level { + LogLevel::Debug => (log::Level::Debug, line.text), + LogLevel::Info | LogLevel::None => (log::Level::Info, line.text), + LogLevel::Warn => (log::Level::Warn, line.text), + LogLevel::Error => (log::Level::Error, line.text), + LogLevel::Unknown(raw) => ( + log::Level::Warn, + format!("[unknown level {raw}] {}", line.text), + ), + } +} + +fn dispatch_line(source: &LogSource, line: LogLine) { + let (level, message) = resolve_record(line, source.options.demote_info_to_debug); + log::log!(target: source.target, level, "{message}"); +} + +fn dispatch_output(source: &LogSource, output: DecodeOutput) { + match output { + DecodeOutput::None => {} + DecodeOutput::Line(line) => dispatch_line(source, line), + DecodeOutput::TwoLines { earlier, current } => { + dispatch_line(source, earlier); + dispatch_line(source, current); + } + } +} + +fn dispatch_anomaly(source: &LogSource, anomaly: DecodeAnomaly) { + log::warn!( + target: source.target, + "llama.cpp log decoder anomaly: {anomaly:?}", + ); +} + +unsafe extern "C" fn logs_to_log( + raw_level: llama_cpp_bindings_sys::ggml_log_level, + text_ptr: *const std::os::raw::c_char, + data_ptr: *mut std::os::raw::c_void, +) { + let source: &LogSource = unsafe { &*data_ptr.cast::() }; + + if source.options.disabled { + return; + } + + if text_ptr.is_null() { + log::warn!( + target: source.target, + "received NULL text pointer from llama.cpp log callback", + ); + return; + } + + let text_cstr = unsafe { std::ffi::CStr::from_ptr(text_ptr) }; + let text = text_cstr.to_string_lossy(); + + let incoming = ggml_level_to_incoming(raw_level); + + let result = { + let mut decoder = source + .decoder + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + decoder.feed(incoming, &text) + }; + + dispatch_output(source, result.output); + + if let Some(anomaly) = result.anomaly { + dispatch_anomaly(source, anomaly); + } +} + +pub fn send_logs_to_log(options: LogOptions) { + let llama_source: *const LogSource = + LLAMA_SOURCE.get_or_init(|| LogSource::new("llama.cpp", options.clone())); + let ggml_source: *const LogSource = GGML_SOURCE.get_or_init(|| LogSource::new("ggml", options)); + + unsafe { + llama_cpp_bindings_sys::llama_log_set( + Some(logs_to_log), + llama_source.cast::().cast_mut(), + ); + llama_cpp_bindings_sys::ggml_log_set( + Some(logs_to_log), + ggml_source.cast::().cast_mut(), + ); + } +} + +#[cfg(test)] +mod tests { + use std::sync::{Mutex, Once}; + + use llama_cpp_log_decoder::incoming_log_level::IncomingLogLevel; + use log::{Level, Log, Metadata, Record}; + use serial_test::serial; + + use super::{ + GGML_SOURCE, LLAMA_SOURCE, LogSource, ggml_level_to_incoming, logs_to_log, send_logs_to_log, + }; + use crate::log_options::LogOptions; + + #[derive(Clone, Debug)] + struct CapturedRecord { + level: Level, + target: String, + message: String, + } + + struct TestLogger { + records: Mutex>, + } + + impl Log for TestLogger { + fn enabled(&self, _: &Metadata) -> bool { + true + } + + fn log(&self, record: &Record) { + let mut guard = self + .records + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + guard.push(CapturedRecord { + level: record.level(), + target: record.target().to_owned(), + message: record.args().to_string(), + }); + } + + fn flush(&self) {} + } + + static TEST_LOGGER: TestLogger = TestLogger { + records: Mutex::new(Vec::new()), + }; + static INSTALL: Once = Once::new(); + + fn ensure_test_logger_installed() { + INSTALL.call_once(|| { + if log::set_logger(&TEST_LOGGER).is_ok() { + log::set_max_level(log::LevelFilter::Trace); + } + }); + } + + fn records_for(target: &str) -> Vec { + let guard = TEST_LOGGER + .records + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + guard + .iter() + .filter(|record| record.target == target) + .cloned() + .collect() + } + + fn invoke_callback( + level: llama_cpp_bindings_sys::ggml_log_level, + text: &std::ffi::CStr, + source: &LogSource, + ) { + let ptr = std::ptr::from_ref(source) + .cast::() + .cast_mut(); + unsafe { + logs_to_log(level, text.as_ptr(), ptr); + } + } + + #[test] + fn ggml_level_to_incoming_known_constants() { + assert_eq!( + ggml_level_to_incoming(llama_cpp_bindings_sys::GGML_LOG_LEVEL_NONE), + IncomingLogLevel::None, + ); + assert_eq!( + ggml_level_to_incoming(llama_cpp_bindings_sys::GGML_LOG_LEVEL_DEBUG), + IncomingLogLevel::Debug, + ); + assert_eq!( + ggml_level_to_incoming(llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO), + IncomingLogLevel::Info, + ); + assert_eq!( + ggml_level_to_incoming(llama_cpp_bindings_sys::GGML_LOG_LEVEL_WARN), + IncomingLogLevel::Warn, + ); + assert_eq!( + ggml_level_to_incoming(llama_cpp_bindings_sys::GGML_LOG_LEVEL_ERROR), + IncomingLogLevel::Error, + ); + assert_eq!( + ggml_level_to_incoming(llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT), + IncomingLogLevel::Cont, + ); + } + + #[test] + fn ggml_level_to_incoming_unknown_value() { + assert_eq!( + ggml_level_to_incoming(9999), + IncomingLogLevel::Unknown(9999) + ); + } + + #[test] + fn dispatch_when_disabled() { + ensure_test_logger_installed(); + + let target = "test-dispatch-when-disabled"; + let source = LogSource::new(target, LogOptions::default().with_logs_enabled(false)); + invoke_callback( + llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO, + c"hello\n", + &source, + ); + + assert!(records_for(target).is_empty()); + } + + #[test] + fn demote_info_to_debug_on_info() { + ensure_test_logger_installed(); + + let target = "test-demote-info-on-info"; + let source = LogSource::new( + target, + LogOptions::default().with_demote_info_to_debug(true), + ); + invoke_callback( + llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO, + c"info-line\n", + &source, + ); + + assert!(records_for(target).iter().any(|record| { + record.level == Level::Debug && record.message.contains("info-line") + })); + } + + #[test] + fn demote_info_to_debug_on_warn() { + ensure_test_logger_installed(); + + let target = "test-demote-info-on-warn"; + let source = LogSource::new( + target, + LogOptions::default().with_demote_info_to_debug(true), + ); + invoke_callback( + llama_cpp_bindings_sys::GGML_LOG_LEVEL_WARN, + c"warn-line\n", + &source, + ); + + assert!( + records_for(target).iter().any(|record| { + record.level == Level::Warn && record.message.contains("warn-line") + }) + ); + } + + #[test] + fn dispatch_unknown_level() { + ensure_test_logger_installed(); + + let target = "test-dispatch-unknown-level"; + let source = LogSource::new(target, LogOptions::default()); + invoke_callback(9999, c"weird\n", &source); + + assert!(records_for(target).iter().any(|record| { + record.level == Level::Warn + && record.message.contains("[unknown level 9999]") + && record.message.contains("weird") + })); + } + + #[test] + fn dispatch_orphan_cont_anomaly() { + ensure_test_logger_installed(); + + let target = "test-dispatch-orphan-cont"; + let source = LogSource::new(target, LogOptions::default()); + invoke_callback( + llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT, + c"ghost\n", + &source, + ); + + assert!(records_for(target).iter().any(|record| { + record.level == Level::Warn && record.message.contains("OrphanCont") + })); + } + + #[test] + #[serial] + fn send_logs_to_log_initialization() { + ensure_test_logger_installed(); + send_logs_to_log(LogOptions::default()); + + assert!(LLAMA_SOURCE.get().is_some()); + assert!(GGML_SOURCE.get().is_some()); + } + + #[test] + fn null_text_pointer() { + ensure_test_logger_installed(); + + let target = "test-null-text-pointer"; + let source = LogSource::new(target, LogOptions::default()); + let source_ptr = std::ptr::from_ref(&source) + .cast::() + .cast_mut(); + unsafe { + logs_to_log( + llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO, + std::ptr::null(), + source_ptr, + ); + } + + assert!(records_for(target).iter().any(|record| { + record.level == Level::Warn && record.message.contains("NULL text pointer") + })); + } + + #[test] + #[expect( + clippy::panic, + reason = "deliberate panic to poison the decoder mutex for fault-injection coverage" + )] + fn decoder_mutex_poison() { + ensure_test_logger_installed(); + + let target = "test-decoder-mutex-poison"; + let source = LogSource::new(target, LogOptions::default()); + + std::thread::scope(|scope| { + let handle = scope.spawn(|| { + let _guard = source + .decoder + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + panic!("intentional poison"); + }); + let _ = handle.join(); + }); + + assert!(source.decoder.is_poisoned()); + + invoke_callback( + llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO, + c"after-poison\n", + &source, + ); + + assert!( + records_for(target) + .iter() + .any(|record| record.message.contains("after-poison")) + ); + } +} diff --git a/llama-cpp-bindings/src/streaming_markers.rs b/llama-cpp-bindings/src/streaming_markers.rs new file mode 100644 index 00000000..9eaaddf2 --- /dev/null +++ b/llama-cpp-bindings/src/streaming_markers.rs @@ -0,0 +1,85 @@ +use crate::token::LlamaToken; + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum MarkerKind { + ReasoningOpen, + ReasoningClose, + ToolCallOpen, + ToolCallClose, +} + +/// Tokenized marker sequences (token IDs, not strings). +/// +/// Each marker is a `Vec` of length `>= 1`; absent markers are +/// `None`. Sequence matching at every `ingest()` is by token-ID equality, +/// never by substring scanning of decoded text. +#[derive(Clone, Debug, Default, Eq, PartialEq)] +pub struct StreamingMarkers { + pub reasoning_open: Option>, + pub reasoning_close: Option>, + pub tool_call_open: Option>, + pub tool_call_close: Option>, +} + +impl StreamingMarkers { + #[must_use] + pub const fn has_any(&self) -> bool { + self.reasoning_open.is_some() + || self.reasoning_close.is_some() + || self.tool_call_open.is_some() + || self.tool_call_close.is_some() + } + + #[must_use] + pub fn max_token_len(&self) -> usize { + [ + self.reasoning_open.as_deref(), + self.reasoning_close.as_deref(), + self.tool_call_open.as_deref(), + self.tool_call_close.as_deref(), + ] + .into_iter() + .flatten() + .map(<[LlamaToken]>::len) + .max() + .unwrap_or(0) + } + + #[must_use] + pub fn lookup(&self, kind: MarkerKind) -> Option<&[LlamaToken]> { + match kind { + MarkerKind::ReasoningOpen => self.reasoning_open.as_deref(), + MarkerKind::ReasoningClose => self.reasoning_close.as_deref(), + MarkerKind::ToolCallOpen => self.tool_call_open.as_deref(), + MarkerKind::ToolCallClose => self.tool_call_close.as_deref(), + } + } +} + +#[cfg(test)] +mod tests { + use super::StreamingMarkers; + use crate::token::LlamaToken; + + fn token(id: i32) -> LlamaToken { + LlamaToken::new(id) + } + + #[test] + fn streaming_markers_with_no_markers_reports_none() { + let markers = StreamingMarkers::default(); + assert!(!markers.has_any()); + assert_eq!(markers.max_token_len(), 0); + } + + #[test] + fn streaming_markers_max_token_len_takes_longest() { + let markers = StreamingMarkers { + reasoning_open: Some(vec![token(1)]), + reasoning_close: Some(vec![token(2), token(3), token(4)]), + tool_call_open: Some(vec![token(5), token(6)]), + tool_call_close: None, + }; + assert_eq!(markers.max_token_len(), 3); + } +} diff --git a/llama-cpp-log-decoder/Cargo.toml b/llama-cpp-log-decoder/Cargo.toml new file mode 100644 index 00000000..6746b463 --- /dev/null +++ b/llama-cpp-log-decoder/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "llama-cpp-log-decoder" +description = "Decoder for the llama.cpp / ggml log callback stream" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true + +[lints.rust] +unsafe_op_in_unsafe_fn = "warn" +unused_qualifications = "warn" + +[lints.clippy] +all = { level = "deny", priority = -1 } +pedantic = { level = "warn", priority = -1 } +nursery = { level = "warn", priority = -1 } +module_name_repetitions = "allow" + +unwrap_used = "deny" +expect_used = "deny" +panic = "deny" +indexing_slicing = "deny" diff --git a/llama-cpp-log-decoder/src/decode_anomaly.rs b/llama-cpp-log-decoder/src/decode_anomaly.rs new file mode 100644 index 00000000..0607306b --- /dev/null +++ b/llama-cpp-log-decoder/src/decode_anomaly.rs @@ -0,0 +1,5 @@ +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum DecodeAnomaly { + OrphanCont, + StaleBufferAbandoned, +} diff --git a/llama-cpp-log-decoder/src/decode_output.rs b/llama-cpp-log-decoder/src/decode_output.rs new file mode 100644 index 00000000..924dc853 --- /dev/null +++ b/llama-cpp-log-decoder/src/decode_output.rs @@ -0,0 +1,8 @@ +use crate::log_line::LogLine; + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum DecodeOutput { + None, + Line(LogLine), + TwoLines { earlier: LogLine, current: LogLine }, +} diff --git a/llama-cpp-log-decoder/src/decode_result.rs b/llama-cpp-log-decoder/src/decode_result.rs new file mode 100644 index 00000000..be530be5 --- /dev/null +++ b/llama-cpp-log-decoder/src/decode_result.rs @@ -0,0 +1,8 @@ +use crate::decode_anomaly::DecodeAnomaly; +use crate::decode_output::DecodeOutput; + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct DecodeResult { + pub output: DecodeOutput, + pub anomaly: Option, +} diff --git a/llama-cpp-log-decoder/src/incoming_log_level.rs b/llama-cpp-log-decoder/src/incoming_log_level.rs new file mode 100644 index 00000000..3cce5605 --- /dev/null +++ b/llama-cpp-log-decoder/src/incoming_log_level.rs @@ -0,0 +1,21 @@ +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum IncomingLogLevel { + Cont, + Debug, + Error, + Info, + None, + Unknown(u32), + Warn, +} + +#[cfg(test)] +mod tests { + use super::IncomingLogLevel; + + #[test] + fn unknown_variant_equality() { + assert_eq!(IncomingLogLevel::Unknown(42), IncomingLogLevel::Unknown(42)); + assert_ne!(IncomingLogLevel::Unknown(42), IncomingLogLevel::Unknown(43)); + } +} diff --git a/llama-cpp-log-decoder/src/lib.rs b/llama-cpp-log-decoder/src/lib.rs new file mode 100644 index 00000000..b7a96a37 --- /dev/null +++ b/llama-cpp-log-decoder/src/lib.rs @@ -0,0 +1,17 @@ +//! Decoder for the llama.cpp / ggml log callback stream. +//! +//! The C side delivers log lines in fragments: a missing trailing newline +//! signals that more fragments will follow at `GGML_LOG_LEVEL_CONT`. This +//! crate is a pure `&mut self` transducer — feed `(level, text)` pairs, get +//! complete [`LogLine`]s back when the trailing newline arrives. No globals, +//! no atomics, no FFI, no logger. +//! +//! [`LogLine`]: log_line::LogLine + +pub mod decode_anomaly; +pub mod decode_output; +pub mod decode_result; +pub mod incoming_log_level; +pub mod log_decoder; +pub mod log_level; +pub mod log_line; diff --git a/llama-cpp-log-decoder/src/log_decoder.rs b/llama-cpp-log-decoder/src/log_decoder.rs new file mode 100644 index 00000000..221aba42 --- /dev/null +++ b/llama-cpp-log-decoder/src/log_decoder.rs @@ -0,0 +1,327 @@ +use crate::decode_anomaly::DecodeAnomaly; +use crate::decode_output::DecodeOutput; +use crate::decode_result::DecodeResult; +use crate::incoming_log_level::IncomingLogLevel; +use crate::log_level::LogLevel; +use crate::log_line::LogLine; + +pub struct LogDecoder { + buffered: Option<(LogLevel, String)>, + previous_level: LogLevel, +} + +impl LogDecoder { + #[must_use] + pub const fn new() -> Self { + Self { + buffered: None, + previous_level: LogLevel::None, + } + } + + pub fn feed(&mut self, level: IncomingLogLevel, text: &str) -> DecodeResult { + match level { + IncomingLogLevel::Cont => self.feed_cont(text), + IncomingLogLevel::Debug => self.feed_non_cont(LogLevel::Debug, text), + IncomingLogLevel::Error => self.feed_non_cont(LogLevel::Error, text), + IncomingLogLevel::Info => self.feed_non_cont(LogLevel::Info, text), + IncomingLogLevel::None => self.feed_non_cont(LogLevel::None, text), + IncomingLogLevel::Unknown(raw) => self.feed_non_cont(LogLevel::Unknown(raw), text), + IncomingLogLevel::Warn => self.feed_non_cont(LogLevel::Warn, text), + } + } + + fn feed_cont(&mut self, text: &str) -> DecodeResult { + if let Some((level, mut buffer)) = self.buffered.take() { + buffer.push_str(text); + if let Some(without_newline) = buffer.strip_suffix('\n') { + DecodeResult { + output: DecodeOutput::Line(LogLine { + level, + text: without_newline.to_owned(), + }), + anomaly: None, + } + } else { + self.buffered = Some((level, buffer)); + DecodeResult { + output: DecodeOutput::None, + anomaly: None, + } + } + } else { + self.feed_orphan_cont(text) + } + } + + fn feed_orphan_cont(&mut self, text: &str) -> DecodeResult { + let level = self.previous_level; + if let Some(without_newline) = text.strip_suffix('\n') { + DecodeResult { + output: DecodeOutput::Line(LogLine { + level, + text: without_newline.to_owned(), + }), + anomaly: Some(DecodeAnomaly::OrphanCont), + } + } else { + self.buffered = Some((level, text.to_owned())); + DecodeResult { + output: DecodeOutput::None, + anomaly: Some(DecodeAnomaly::OrphanCont), + } + } + } + + fn feed_non_cont(&mut self, level: LogLevel, text: &str) -> DecodeResult { + self.previous_level = level; + let stale = self.buffered.take(); + match (text.strip_suffix('\n'), stale) { + (Some(without_newline), Some((stale_level, stale_text))) => DecodeResult { + output: DecodeOutput::TwoLines { + earlier: LogLine { + level: stale_level, + text: stale_text, + }, + current: LogLine { + level, + text: without_newline.to_owned(), + }, + }, + anomaly: Some(DecodeAnomaly::StaleBufferAbandoned), + }, + (Some(without_newline), None) => DecodeResult { + output: DecodeOutput::Line(LogLine { + level, + text: without_newline.to_owned(), + }), + anomaly: None, + }, + (None, Some((stale_level, stale_text))) => { + self.buffered = Some((level, text.to_owned())); + DecodeResult { + output: DecodeOutput::Line(LogLine { + level: stale_level, + text: stale_text, + }), + anomaly: Some(DecodeAnomaly::StaleBufferAbandoned), + } + } + (None, None) => { + self.buffered = Some((level, text.to_owned())); + DecodeResult { + output: DecodeOutput::None, + anomaly: None, + } + } + } + } +} + +impl Default for LogDecoder { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::LogDecoder; + use crate::decode_anomaly::DecodeAnomaly; + use crate::decode_output::DecodeOutput; + use crate::decode_result::DecodeResult; + use crate::incoming_log_level::IncomingLogLevel; + use crate::log_level::LogLevel; + use crate::log_line::LogLine; + + #[test] + fn feed_complete_info_line() { + let mut decoder = LogDecoder::new(); + let result = decoder.feed(IncomingLogLevel::Info, "hello\n"); + + assert_eq!( + result, + DecodeResult { + output: DecodeOutput::Line(LogLine { + level: LogLevel::Info, + text: "hello".to_owned(), + }), + anomaly: None, + } + ); + } + + #[test] + fn feed_partial_without_newline() { + let mut decoder = LogDecoder::new(); + let result = decoder.feed(IncomingLogLevel::Info, "hello"); + + assert_eq!( + result, + DecodeResult { + output: DecodeOutput::None, + anomaly: None, + } + ); + } + + #[test] + fn feed_cont_completion() { + let mut decoder = LogDecoder::new(); + decoder.feed(IncomingLogLevel::Info, "hello "); + let result = decoder.feed(IncomingLogLevel::Cont, "world\n"); + + assert_eq!( + result, + DecodeResult { + output: DecodeOutput::Line(LogLine { + level: LogLevel::Info, + text: "hello world".to_owned(), + }), + anomaly: None, + } + ); + } + + #[test] + fn feed_multi_part_cont() { + let mut decoder = LogDecoder::new(); + decoder.feed(IncomingLogLevel::Info, "part1 "); + decoder.feed(IncomingLogLevel::Cont, "part2 "); + let result = decoder.feed(IncomingLogLevel::Cont, "part3\n"); + + assert_eq!( + result, + DecodeResult { + output: DecodeOutput::Line(LogLine { + level: LogLevel::Info, + text: "part1 part2 part3".to_owned(), + }), + anomaly: None, + } + ); + } + + #[test] + fn feed_non_cont_while_buffering() { + let mut decoder = LogDecoder::new(); + decoder.feed(IncomingLogLevel::Info, "stale"); + let result = decoder.feed(IncomingLogLevel::Warn, "fresh\n"); + + assert_eq!( + result, + DecodeResult { + output: DecodeOutput::TwoLines { + earlier: LogLine { + level: LogLevel::Info, + text: "stale".to_owned(), + }, + current: LogLine { + level: LogLevel::Warn, + text: "fresh".to_owned(), + }, + }, + anomaly: Some(DecodeAnomaly::StaleBufferAbandoned), + } + ); + } + + #[test] + fn feed_buffer_replacement() { + let mut decoder = LogDecoder::new(); + decoder.feed(IncomingLogLevel::Info, "first"); + let result = decoder.feed(IncomingLogLevel::Warn, "second"); + + assert_eq!( + result, + DecodeResult { + output: DecodeOutput::Line(LogLine { + level: LogLevel::Info, + text: "first".to_owned(), + }), + anomaly: Some(DecodeAnomaly::StaleBufferAbandoned), + } + ); + + let follow_up = decoder.feed(IncomingLogLevel::Cont, "more\n"); + assert_eq!( + follow_up, + DecodeResult { + output: DecodeOutput::Line(LogLine { + level: LogLevel::Warn, + text: "secondmore".to_owned(), + }), + anomaly: None, + } + ); + } + + #[test] + fn feed_orphan_cont() { + let mut decoder = LogDecoder::new(); + let result = decoder.feed(IncomingLogLevel::Cont, "ghost\n"); + + assert_eq!( + result, + DecodeResult { + output: DecodeOutput::Line(LogLine { + level: LogLevel::None, + text: "ghost".to_owned(), + }), + anomaly: Some(DecodeAnomaly::OrphanCont), + } + ); + } + + #[test] + fn feed_orphan_cont_previous_level() { + let mut decoder = LogDecoder::new(); + decoder.feed(IncomingLogLevel::Warn, "complete\n"); + let result = decoder.feed(IncomingLogLevel::Cont, "ghost\n"); + + assert_eq!( + result, + DecodeResult { + output: DecodeOutput::Line(LogLine { + level: LogLevel::Warn, + text: "ghost".to_owned(), + }), + anomaly: Some(DecodeAnomaly::OrphanCont), + } + ); + } + + #[test] + fn feed_none_level() { + let mut decoder = LogDecoder::new(); + let result = decoder.feed(IncomingLogLevel::None, "no-level\n"); + + assert_eq!( + result, + DecodeResult { + output: DecodeOutput::Line(LogLine { + level: LogLevel::None, + text: "no-level".to_owned(), + }), + anomaly: None, + } + ); + } + + #[test] + fn feed_unknown_level() { + let mut decoder = LogDecoder::new(); + let result = decoder.feed(IncomingLogLevel::Unknown(9999), "weird\n"); + + assert_eq!( + result, + DecodeResult { + output: DecodeOutput::Line(LogLine { + level: LogLevel::Unknown(9999), + text: "weird".to_owned(), + }), + anomaly: None, + } + ); + } +} diff --git a/llama-cpp-log-decoder/src/log_level.rs b/llama-cpp-log-decoder/src/log_level.rs new file mode 100644 index 00000000..0fa157be --- /dev/null +++ b/llama-cpp-log-decoder/src/log_level.rs @@ -0,0 +1,20 @@ +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum LogLevel { + Debug, + Error, + Info, + None, + Unknown(u32), + Warn, +} + +#[cfg(test)] +mod tests { + use super::LogLevel; + + #[test] + fn unknown_variant_equality() { + assert_eq!(LogLevel::Unknown(42), LogLevel::Unknown(42)); + assert_ne!(LogLevel::Unknown(42), LogLevel::Unknown(43)); + } +} diff --git a/llama-cpp-log-decoder/src/log_line.rs b/llama-cpp-log-decoder/src/log_line.rs new file mode 100644 index 00000000..71376ccd --- /dev/null +++ b/llama-cpp-log-decoder/src/log_line.rs @@ -0,0 +1,7 @@ +use crate::log_level::LogLevel; + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct LogLine { + pub level: LogLevel, + pub text: String, +}