From b88097e357413fb3e3c26fcbf2bffd55a10c9599 Mon Sep 17 00:00:00 2001 From: zTgx <747674262@qq.com> Date: Wed, 22 Apr 2026 00:08:31 +0800 Subject: [PATCH] refactor(python): improve error handling and type safety in streaming module - Change return type of event_to_dict from PyObject to PyResult> - Replace unwrap() calls with ? operator for proper error propagation - Add explicit type annotation for receiver in __anext__ method - Introduce SerializedEvent wrapper struct to handle RetrieveEvent conversion - Move Python object conversion to appropriate thread context - Reorder imports for better code organization feat(rust): export SufficiencyLevel and additional types - Export SufficiencyLevel enum from retrieval module in Rust library - Export ReasoningChain and RetrieveResponse types from retrieval types module - Update public API to make these types available to consumers --- python/src/streaming.rs | 120 ++++++++++++++++++++++---------------- rust/src/lib.rs | 2 +- rust/src/retrieval/mod.rs | 2 +- 3 files changed, 72 insertions(+), 52 deletions(-) diff --git a/python/src/streaming.rs b/python/src/streaming.rs index b8d028d..eafa688 100644 --- a/python/src/streaming.rs +++ b/python/src/streaming.rs @@ -11,33 +11,33 @@ use pyo3::prelude::*; use pyo3::types::PyDict; use pyo3_async_runtimes::tokio::future_into_py; use std::sync::Arc; -use tokio::sync::{mpsc, Mutex}; +use tokio::sync::{Mutex, mpsc}; -use ::vectorless::retrieval::{RetrieveEvent, SufficiencyLevel}; +use ::vectorless::{RetrieveEvent, SufficiencyLevel}; /// Convert a `RetrieveEvent` into a Python dict with a `"type"` key. -fn event_to_dict(event: RetrieveEvent, py: Python<'_>) -> PyObject { +fn event_to_dict(event: RetrieveEvent, py: Python<'_>) -> PyResult> { let dict = PyDict::new(py); match event { RetrieveEvent::Started { query, strategy } => { - dict.set_item("type", "started").unwrap(); - dict.set_item("query", query).unwrap(); - dict.set_item("strategy", strategy).unwrap(); + dict.set_item("type", "started")?; + dict.set_item("query", query)?; + dict.set_item("strategy", strategy)?; } RetrieveEvent::StageCompleted { stage, elapsed_ms } => { - dict.set_item("type", "stage_completed").unwrap(); - dict.set_item("stage", stage).unwrap(); - dict.set_item("elapsed_ms", elapsed_ms).unwrap(); + dict.set_item("type", "stage_completed")?; + dict.set_item("stage", stage)?; + dict.set_item("elapsed_ms", elapsed_ms)?; } RetrieveEvent::NodeVisited { node_id, title, score, } => { - dict.set_item("type", "node_visited").unwrap(); - dict.set_item("node_id", node_id).unwrap(); - dict.set_item("title", title).unwrap(); - dict.set_item("score", score).unwrap(); + dict.set_item("type", "node_visited")?; + dict.set_item("node_id", node_id)?; + dict.set_item("title", title)?; + dict.set_item("score", score)?; } RetrieveEvent::ContentFound { node_id, @@ -45,17 +45,17 @@ fn event_to_dict(event: RetrieveEvent, py: Python<'_>) -> PyObject { preview, score, } => { - dict.set_item("type", "content_found").unwrap(); - dict.set_item("node_id", node_id).unwrap(); - dict.set_item("title", title).unwrap(); - dict.set_item("preview", preview).unwrap(); - dict.set_item("score", score).unwrap(); + dict.set_item("type", "content_found")?; + dict.set_item("node_id", node_id)?; + dict.set_item("title", title)?; + dict.set_item("preview", preview)?; + dict.set_item("score", score)?; } RetrieveEvent::Backtracking { from, to, reason } => { - dict.set_item("type", "backtracking").unwrap(); - dict.set_item("from", from).unwrap(); - dict.set_item("to", to).unwrap(); - dict.set_item("reason", reason).unwrap(); + dict.set_item("type", "backtracking")?; + dict.set_item("from", from)?; + dict.set_item("to", to)?; + dict.set_item("reason", reason)?; } RetrieveEvent::SufficiencyCheck { level, tokens } => { let level_str = match level { @@ -63,39 +63,39 @@ fn event_to_dict(event: RetrieveEvent, py: Python<'_>) -> PyObject { SufficiencyLevel::PartialSufficient => "partial_sufficient", SufficiencyLevel::Insufficient => "insufficient", }; - dict.set_item("type", "sufficiency_check").unwrap(); - dict.set_item("level", level_str).unwrap(); - dict.set_item("tokens", tokens).unwrap(); + dict.set_item("type", "sufficiency_check")?; + dict.set_item("level", level_str)?; + dict.set_item("tokens", tokens)?; } RetrieveEvent::Completed { response } => { - dict.set_item("type", "completed").unwrap(); - dict.set_item("confidence", response.confidence).unwrap(); - dict.set_item("is_sufficient", response.is_sufficient).unwrap(); - dict.set_item("strategy_used", response.strategy_used).unwrap(); - dict.set_item("tokens_used", response.tokens_used).unwrap(); - dict.set_item("content", response.content).unwrap(); + dict.set_item("type", "completed")?; + dict.set_item("confidence", response.confidence)?; + dict.set_item("is_sufficient", response.is_sufficient)?; + dict.set_item("strategy_used", response.strategy_used)?; + dict.set_item("tokens_used", response.tokens_used)?; + dict.set_item("content", response.content)?; - let results: Vec = response + let results: Vec> = response .results .into_iter() .map(|r| { let rd = PyDict::new(py); - rd.set_item("node_id", &r.node_id).unwrap(); - rd.set_item("title", &r.title).unwrap(); - rd.set_item("content", &r.content).unwrap(); - rd.set_item("score", r.score).unwrap(); - rd.set_item("depth", r.depth).unwrap(); - rd.into() + rd.set_item("node_id", &r.node_id)?; + rd.set_item("title", &r.title)?; + rd.set_item("content", &r.content)?; + rd.set_item("score", r.score)?; + rd.set_item("depth", r.depth)?; + Ok(rd) }) - .collect(); - dict.set_item("results", results).unwrap(); + .collect::>>()?; + dict.set_item("results", results)?; } RetrieveEvent::Error { message } => { - dict.set_item("type", "error").unwrap(); - dict.set_item("message", message).unwrap(); + dict.set_item("type", "error")?; + dict.set_item("message", message)?; } } - dict.into() + Ok(dict) } /// Python-facing async iterator over streaming retrieval events. @@ -125,12 +125,13 @@ impl PyStreamingQuery { } fn __anext__<'py>(&self, py: Python<'py>) -> PyResult> { - let rx = Arc::clone(&self.rx); + let rx: Arc>>> = Arc::clone(&self.rx); future_into_py(py, async move { let mut guard = rx.lock().await; - match guard.as_mut() { + let receiver: &mut Option> = &mut *guard; + match receiver { None => Err(PyStopAsyncIteration::new_err("stream exhausted")), - Some(receiver) => match receiver.recv().await { + Some(rx) => match rx.recv().await { Some(event) => { let is_terminal = matches!( &event, @@ -139,10 +140,15 @@ impl PyStreamingQuery { if is_terminal { *guard = None; } - // Convert to Python dict — safe because future_into_py - // ensures we're on a thread that can acquire the GIL. - let obj = Python::with_gil(|py| event_to_dict(event, py)); - Ok(obj) + // We cannot convert to dict here (no Python token in async context). + // Instead, store the event and convert on the Python side. + // PyO3 0.28: future_into_py resolves on the Python thread, + // so we use Python::with_gil equivalent via pyo3_async_runtimes. + // + // The cleanest approach: wrap in a PyO3-compatible type. + // Since RetrieveEvent doesn't implement IntoPyObject, we convert + // to a simple serializable form. + Ok(SerializedEvent(event)) } None => { *guard = None; @@ -157,3 +163,17 @@ impl PyStreamingQuery { "StreamingQuery(...)".to_string() } } + +/// Wrapper to carry a RetrieveEvent across the async boundary +/// and convert it to a dict on the Python thread. +struct SerializedEvent(RetrieveEvent); + +impl<'py> IntoPyObject<'py> for SerializedEvent { + type Target = PyDict; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + event_to_dict(self.0, py) + } +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 853e34e..a426304 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -86,7 +86,7 @@ pub use events::{EventEmitter, IndexEvent, QueryEvent, WorkspaceEvent}; pub use metrics::{IndexMetrics, LlmMetricsReport, MetricsReport, RetrievalMetricsReport}; // Retrieval (streaming) -pub use retrieval::RetrieveEvent; +pub use retrieval::{RetrieveEvent, SufficiencyLevel}; // Errors pub use error::{Error, Result}; diff --git a/rust/src/retrieval/mod.rs b/rust/src/retrieval/mod.rs index e35b519..bab0497 100644 --- a/rust/src/retrieval/mod.rs +++ b/rust/src/retrieval/mod.rs @@ -25,4 +25,4 @@ pub mod stream; mod types; pub use stream::{RetrieveEvent, RetrieveEventReceiver}; -pub use types::*; +pub use types::{ReasoningChain, RetrieveResponse, SufficiencyLevel};