diff --git a/crates/code_assistant/src/agent/runner.rs b/crates/code_assistant/src/agent/runner.rs index 6f905bc7..5c4372d4 100644 --- a/crates/code_assistant/src/agent/runner.rs +++ b/crates/code_assistant/src/agent/runner.rs @@ -1049,6 +1049,8 @@ impl Agent { plan: None, // spawn_agent doesn't use plan ui: Some(ui), tool_id: Some(tool_id.to_string()), + session_id: None, + model_name: None, permission_handler: None, // Will be handled by sub-agent runner sub_agent_runner, }; @@ -2238,6 +2240,11 @@ impl Agent { plan: Some(&mut self.plan), ui: Some(self.ui.as_ref()), tool_id: Some(tool_request.id.clone()), + session_id: self.session_id.clone(), + model_name: self + .session_model_config + .as_ref() + .map(|config| config.model_name.clone()), permission_handler: self.permission_handler.as_deref(), sub_agent_runner: self.sub_agent_runner.as_deref(), diff --git a/crates/code_assistant/src/mcp/handler.rs b/crates/code_assistant/src/mcp/handler.rs index 78cab502..d2146d0c 100644 --- a/crates/code_assistant/src/mcp/handler.rs +++ b/crates/code_assistant/src/mcp/handler.rs @@ -268,6 +268,8 @@ impl MessageHandler { plan: None, ui: None, tool_id: None, + session_id: None, + model_name: None, permission_handler: None, sub_agent_runner: None, }; diff --git a/crates/code_assistant/src/tests/mocks.rs b/crates/code_assistant/src/tests/mocks.rs index 1375f396..7d409837 100644 --- a/crates/code_assistant/src/tests/mocks.rs +++ b/crates/code_assistant/src/tests/mocks.rs @@ -230,6 +230,8 @@ pub fn create_test_tool_context<'a>( plan, ui, tool_id, + session_id: None, + model_name: None, permission_handler: None, sub_agent_runner: None, } @@ -1135,6 +1137,8 @@ impl ToolTestFixture { plan: self.plan.as_mut(), ui: self.ui.as_ref().map(|ui| ui as &dyn UserInterface), tool_id: self.tool_id.clone(), + session_id: None, + model_name: None, permission_handler: self.permission_handler.as_deref(), sub_agent_runner: None, } diff --git a/crates/code_assistant/src/tools/core/tool.rs b/crates/code_assistant/src/tools/core/tool.rs index 78f9f409..9f648ff9 100644 --- a/crates/code_assistant/src/tools/core/tool.rs +++ b/crates/code_assistant/src/tools/core/tool.rs @@ -20,6 +20,14 @@ pub struct ToolContext<'a> { pub ui: Option<&'a dyn crate::ui::UserInterface>, /// Optional current tool ID for streaming output pub tool_id: Option, + /// Optional session ID used only for diagnostic logging — lets tools + /// correlate log lines with the session persistence file. Never affects + /// tool behavior; leave `None` when the diag log is not relevant + /// (MCP, tests). + pub session_id: Option, + /// Optional active model display name (from models.json) for model-specific + /// tool behavior flags. + pub model_name: Option, /// Optional permission handler for potentially sensitive operations pub permission_handler: Option<&'a dyn PermissionMediator>, @@ -39,6 +47,8 @@ impl<'a> ToolContext<'a> { plan: None, ui: None, tool_id: None, + session_id: None, + model_name: None, permission_handler: None, sub_agent_runner: None, } diff --git a/crates/code_assistant/src/tools/impls/execute_command.rs b/crates/code_assistant/src/tools/impls/execute_command.rs index 62429b81..465aa6d9 100644 --- a/crates/code_assistant/src/tools/impls/execute_command.rs +++ b/crates/code_assistant/src/tools/impls/execute_command.rs @@ -6,9 +6,12 @@ use crate::ui::streaming::DisplayFragment; use crate::ui::UserInterface; use anyhow::{anyhow, Result}; use command_executor::{SandboxCommandRequest, StreamingCallback}; +use llm::provider_config::ConfigurationSystem; use serde::{Deserialize, Serialize}; use serde_json::json; use std::path::PathBuf; +use std::process::Stdio; +use tracing::{debug, warn}; // Input type for the execute_command tool #[derive(Deserialize, Serialize)] @@ -33,6 +36,40 @@ pub struct ExecuteCommandOutput { pub success: bool, } +fn parse_rtk_response_optimized_command(stdout: &[u8]) -> Option { + let response_optimized_command = String::from_utf8_lossy(stdout).trim().to_string(); + if response_optimized_command.is_empty() { + None + } else { + Some(response_optimized_command) + } +} + +fn should_use_rtk_for_model(model_name: Option<&str>) -> bool { + let Some(model_name) = model_name else { + debug!("No active model name in tool context; use_rtk defaults to false"); + return false; + }; + + let config_system = match ConfigurationSystem::load() { + Ok(config) => config, + Err(err) => { + warn!("Failed to load configuration system for use_rtk lookup: {err}"); + return false; + } + }; + + let Some(model_config) = config_system.get_model(model_name) else { + warn!( + "Model '{}' not found in models.json; use_rtk defaults to false", + model_name + ); + return false; + }; + + model_config.use_rtk +} + // Render implementation for output formatting impl Render for ExecuteCommandOutput { fn status(&self) -> String { @@ -174,6 +211,63 @@ impl Tool for ExecuteCommandTool { context: &mut ToolContext<'a>, input: &mut Self::Input, ) -> Result { + + let use_rtk = should_use_rtk_for_model(context.model_name.as_deref()); + let original_command_line = input.command_line.clone(); + let effective_command_line = if !use_rtk { + original_command_line.clone() + } else { + match tokio::process::Command::new("rtk") + .arg("rewrite") + .arg(&original_command_line) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .output() + .await + { + Ok(rtk_process_output) => { + if let Some(output) = + parse_rtk_response_optimized_command(&rtk_process_output.stdout) + { + if !rtk_process_output.status.success() { + let stderr = String::from_utf8_lossy(&rtk_process_output.stderr); + debug!( + "RTK returned non-zero status ({:?}) but produced a rewritten command line; using it. stderr: {}", + rtk_process_output.status.code(), + stderr.trim() + ); + } + + if output != original_command_line { + debug!( + "RTK command-line rewrite details: original={:?} rewritten={:?}", + original_command_line, output + ); + } + output + } else { + if rtk_process_output.status.success() { + warn!( + "RTK command-line rewrite returned empty output with success status, using original command" + ); + } else { + let stderr = String::from_utf8_lossy(&rtk_process_output.stderr); + debug!( + "RTK command-line rewrite returned no output (status: {:?}); using original command. stderr: {}", + rtk_process_output.status.code(), + stderr.trim() + ); + } + original_command_line.clone() + } + } + Err(err) => { + warn!("RTK command-line rewrite is enabled but failed to execute `rtk`: {err}"); + original_command_line.clone() + } + } + }; + // Get explorer for the specified project let explorer = context .project_manager @@ -215,14 +309,14 @@ impl Tool for ExecuteCommandTool { })?; let decision = handler - .request_permission(PermissionRequest { - tool_id: context.tool_id.as_deref(), - tool_name: "execute_command", - reason: PermissionRequestReason::ExecuteCommand { - command_line: &input.command_line, - working_dir: Some(effective_working_dir.as_path()), - }, - }) + .request_permission(PermissionRequest { + tool_id: context.tool_id.as_deref(), + tool_name: "execute_command", + reason: PermissionRequestReason::ExecuteCommand { + command_line: &effective_command_line, + working_dir: Some(effective_working_dir.as_path()), + }, + }) .await?; match decision { @@ -253,7 +347,7 @@ impl Tool for ExecuteCommandTool { context .command_executor .execute_streaming( - &input.command_line, + &effective_command_line, Some(&effective_working_dir), Some(&callback), Some(&sandbox_request), @@ -265,7 +359,7 @@ impl Tool for ExecuteCommandTool { context .command_executor .execute_streaming( - &input.command_line, + &effective_command_line, Some(&effective_working_dir), None, Some(&sandbox_request), @@ -276,7 +370,7 @@ impl Tool for ExecuteCommandTool { Ok(ExecuteCommandOutput { project: input.project.clone(), - command_line: input.command_line.clone(), + command_line: effective_command_line, working_dir: working_dir_path, output: result.output, success: result.success, @@ -386,14 +480,18 @@ mod tests { let result = tool.execute(&mut context, &mut input).await?; // Verify result - assert_eq!(result.command_line, "ls -la"); + assert!( + result.command_line == "ls -la" || result.command_line.ends_with("ls -la"), + "command should be original or RTK-rewritten variant, got: {}", + result.command_line + ); assert_eq!(result.output, "Command output"); // Match expected output from mock assert!(result.success); // Verify command was executed with correct parameters let commands = fixture.command_executor().get_captured_commands(); assert_eq!(commands.len(), 1); - assert_eq!(commands[0].command_line, "ls -la"); + assert_eq!(commands[0].command_line, result.command_line); assert_eq!(commands[0].working_dir, Some(PathBuf::from("./root/src"))); Ok(()) @@ -423,14 +521,19 @@ mod tests { let result = tool.execute(&mut context, &mut input).await?; // Verify result shows failure - assert_eq!(result.command_line, "rm -rf /tmp/nonexistent"); + assert!( + result.command_line == "rm -rf /tmp/nonexistent" + || result.command_line.ends_with("rm -rf /tmp/nonexistent"), + "command should be original or RTK-rewritten variant, got: {}", + result.command_line + ); assert_eq!(result.output, "Command failed: permission denied"); assert!(!result.success); // Verify command was executed let commands = fixture.command_executor().get_captured_commands(); assert_eq!(commands.len(), 1); - assert_eq!(commands[0].command_line, "rm -rf /tmp/nonexistent"); + assert_eq!(commands[0].command_line, result.command_line); assert_eq!(commands[0].working_dir, Some(PathBuf::from("./root"))); Ok(()) @@ -569,4 +672,13 @@ mod tests { Ok(()) } + + #[test] + fn test_parse_rtk_response_optimized_command() { + assert_eq!( + parse_rtk_response_optimized_command(b"rg -n \"foo\" src\n"), + Some("rg -n \"foo\" src".to_string()) + ); + assert_eq!(parse_rtk_response_optimized_command(b"\n \t"), None); + } } diff --git a/crates/llm/src/provider_config.rs b/crates/llm/src/provider_config.rs index 2bc96f75..bc5f264b 100644 --- a/crates/llm/src/provider_config.rs +++ b/crates/llm/src/provider_config.rs @@ -25,6 +25,10 @@ pub struct ModelConfig { pub provider: String, /// Model ID within the provider pub id: String, + /// Whether command output should be optimized through RTK for this model. + /// Defaults to false when omitted in models.json. + #[serde(default)] + pub use_rtk: bool, /// Model-specific configuration pub config: serde_json::Value, /// Maximum context window supported by the model (token count)