Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions crates/code_assistant/src/agent/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,8 @@ impl Agent {
plan: None, // spawn_agent doesn't use plan
ui: Some(ui),
tool_id: Some(tool_id.to_string()),
session_id: None,
model_name: None,
permission_handler: None, // Will be handled by sub-agent runner
sub_agent_runner,
};
Expand Down Expand Up @@ -2238,6 +2240,11 @@ impl Agent {
plan: Some(&mut self.plan),
ui: Some(self.ui.as_ref()),
tool_id: Some(tool_request.id.clone()),
session_id: self.session_id.clone(),
model_name: self
.session_model_config
.as_ref()
.map(|config| config.model_name.clone()),

permission_handler: self.permission_handler.as_deref(),
sub_agent_runner: self.sub_agent_runner.as_deref(),
Expand Down
2 changes: 2 additions & 0 deletions crates/code_assistant/src/mcp/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,8 @@ impl MessageHandler {
plan: None,
ui: None,
tool_id: None,
session_id: None,
model_name: None,
permission_handler: None,
sub_agent_runner: None,
};
Expand Down
4 changes: 4 additions & 0 deletions crates/code_assistant/src/tests/mocks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ pub fn create_test_tool_context<'a>(
plan,
ui,
tool_id,
session_id: None,
model_name: None,
permission_handler: None,
sub_agent_runner: None,
}
Expand Down Expand Up @@ -1135,6 +1137,8 @@ impl ToolTestFixture {
plan: self.plan.as_mut(),
ui: self.ui.as_ref().map(|ui| ui as &dyn UserInterface),
tool_id: self.tool_id.clone(),
session_id: None,
model_name: None,
permission_handler: self.permission_handler.as_deref(),
sub_agent_runner: None,
}
Expand Down
10 changes: 10 additions & 0 deletions crates/code_assistant/src/tools/core/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ pub struct ToolContext<'a> {
pub ui: Option<&'a dyn crate::ui::UserInterface>,
/// Optional current tool ID for streaming output
pub tool_id: Option<String>,
/// Optional session ID used only for diagnostic logging — lets tools
/// correlate log lines with the session persistence file. Never affects
/// tool behavior; leave `None` when the diag log is not relevant
/// (MCP, tests).
pub session_id: Option<String>,
/// Optional active model display name (from models.json) for model-specific
/// tool behavior flags.
pub model_name: Option<String>,
/// Optional permission handler for potentially sensitive operations
pub permission_handler: Option<&'a dyn PermissionMediator>,

Expand All @@ -39,6 +47,8 @@ impl<'a> ToolContext<'a> {
plan: None,
ui: None,
tool_id: None,
session_id: None,
model_name: None,
permission_handler: None,
sub_agent_runner: None,
}
Expand Down
142 changes: 127 additions & 15 deletions crates/code_assistant/src/tools/impls/execute_command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ use crate::ui::streaming::DisplayFragment;
use crate::ui::UserInterface;
use anyhow::{anyhow, Result};
use command_executor::{SandboxCommandRequest, StreamingCallback};
use llm::provider_config::ConfigurationSystem;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::path::PathBuf;
use std::process::Stdio;
use tracing::{debug, warn};

// Input type for the execute_command tool
#[derive(Deserialize, Serialize)]
Expand All @@ -33,6 +36,40 @@ pub struct ExecuteCommandOutput {
pub success: bool,
}

fn parse_rtk_response_optimized_command(stdout: &[u8]) -> Option<String> {
let response_optimized_command = String::from_utf8_lossy(stdout).trim().to_string();
if response_optimized_command.is_empty() {
None
} else {
Some(response_optimized_command)
}
}

fn should_use_rtk_for_model(model_name: Option<&str>) -> bool {
let Some(model_name) = model_name else {
debug!("No active model name in tool context; use_rtk defaults to false");
return false;
};

let config_system = match ConfigurationSystem::load() {
Ok(config) => config,
Err(err) => {
warn!("Failed to load configuration system for use_rtk lookup: {err}");
return false;
}
};

let Some(model_config) = config_system.get_model(model_name) else {
warn!(
"Model '{}' not found in models.json; use_rtk defaults to false",
model_name
);
return false;
};

model_config.use_rtk
}

// Render implementation for output formatting
impl Render for ExecuteCommandOutput {
fn status(&self) -> String {
Expand Down Expand Up @@ -174,6 +211,63 @@ impl Tool for ExecuteCommandTool {
context: &mut ToolContext<'a>,
input: &mut Self::Input,
) -> Result<Self::Output> {

let use_rtk = should_use_rtk_for_model(context.model_name.as_deref());
let original_command_line = input.command_line.clone();
let effective_command_line = if !use_rtk {
original_command_line.clone()
} else {
match tokio::process::Command::new("rtk")
.arg("rewrite")
.arg(&original_command_line)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.output()
.await
{
Ok(rtk_process_output) => {
if let Some(output) =
parse_rtk_response_optimized_command(&rtk_process_output.stdout)
{
if !rtk_process_output.status.success() {
let stderr = String::from_utf8_lossy(&rtk_process_output.stderr);
debug!(
"RTK returned non-zero status ({:?}) but produced a rewritten command line; using it. stderr: {}",
rtk_process_output.status.code(),
stderr.trim()
);
}

if output != original_command_line {
debug!(
"RTK command-line rewrite details: original={:?} rewritten={:?}",
original_command_line, output
);
}
output
} else {
if rtk_process_output.status.success() {
warn!(
"RTK command-line rewrite returned empty output with success status, using original command"
);
} else {
let stderr = String::from_utf8_lossy(&rtk_process_output.stderr);
debug!(
"RTK command-line rewrite returned no output (status: {:?}); using original command. stderr: {}",
rtk_process_output.status.code(),
stderr.trim()
);
}
original_command_line.clone()
}
}
Err(err) => {
warn!("RTK command-line rewrite is enabled but failed to execute `rtk`: {err}");
original_command_line.clone()
}
}
};

// Get explorer for the specified project
let explorer = context
.project_manager
Expand Down Expand Up @@ -215,14 +309,14 @@ impl Tool for ExecuteCommandTool {
})?;

let decision = handler
.request_permission(PermissionRequest {
tool_id: context.tool_id.as_deref(),
tool_name: "execute_command",
reason: PermissionRequestReason::ExecuteCommand {
command_line: &input.command_line,
working_dir: Some(effective_working_dir.as_path()),
},
})
.request_permission(PermissionRequest {
tool_id: context.tool_id.as_deref(),
tool_name: "execute_command",
reason: PermissionRequestReason::ExecuteCommand {
command_line: &effective_command_line,
working_dir: Some(effective_working_dir.as_path()),
},
})
.await?;

match decision {
Expand Down Expand Up @@ -253,7 +347,7 @@ impl Tool for ExecuteCommandTool {
context
.command_executor
.execute_streaming(
&input.command_line,
&effective_command_line,
Some(&effective_working_dir),
Some(&callback),
Some(&sandbox_request),
Expand All @@ -265,7 +359,7 @@ impl Tool for ExecuteCommandTool {
context
.command_executor
.execute_streaming(
&input.command_line,
&effective_command_line,
Some(&effective_working_dir),
None,
Some(&sandbox_request),
Expand All @@ -276,7 +370,7 @@ impl Tool for ExecuteCommandTool {

Ok(ExecuteCommandOutput {
project: input.project.clone(),
command_line: input.command_line.clone(),
command_line: effective_command_line,
working_dir: working_dir_path,
output: result.output,
success: result.success,
Expand Down Expand Up @@ -386,14 +480,18 @@ mod tests {
let result = tool.execute(&mut context, &mut input).await?;

// Verify result
assert_eq!(result.command_line, "ls -la");
assert!(
result.command_line == "ls -la" || result.command_line.ends_with("ls -la"),
"command should be original or RTK-rewritten variant, got: {}",
result.command_line
);
assert_eq!(result.output, "Command output"); // Match expected output from mock
assert!(result.success);

// Verify command was executed with correct parameters
let commands = fixture.command_executor().get_captured_commands();
assert_eq!(commands.len(), 1);
assert_eq!(commands[0].command_line, "ls -la");
assert_eq!(commands[0].command_line, result.command_line);
assert_eq!(commands[0].working_dir, Some(PathBuf::from("./root/src")));

Ok(())
Expand Down Expand Up @@ -423,14 +521,19 @@ mod tests {
let result = tool.execute(&mut context, &mut input).await?;

// Verify result shows failure
assert_eq!(result.command_line, "rm -rf /tmp/nonexistent");
assert!(
result.command_line == "rm -rf /tmp/nonexistent"
|| result.command_line.ends_with("rm -rf /tmp/nonexistent"),
"command should be original or RTK-rewritten variant, got: {}",
result.command_line
);
assert_eq!(result.output, "Command failed: permission denied");
assert!(!result.success);

// Verify command was executed
let commands = fixture.command_executor().get_captured_commands();
assert_eq!(commands.len(), 1);
assert_eq!(commands[0].command_line, "rm -rf /tmp/nonexistent");
assert_eq!(commands[0].command_line, result.command_line);
assert_eq!(commands[0].working_dir, Some(PathBuf::from("./root")));

Ok(())
Expand Down Expand Up @@ -569,4 +672,13 @@ mod tests {

Ok(())
}

#[test]
fn test_parse_rtk_response_optimized_command() {
assert_eq!(
parse_rtk_response_optimized_command(b"rg -n \"foo\" src\n"),
Some("rg -n \"foo\" src".to_string())
);
assert_eq!(parse_rtk_response_optimized_command(b"\n \t"), None);
}
}
4 changes: 4 additions & 0 deletions crates/llm/src/provider_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ pub struct ModelConfig {
pub provider: String,
/// Model ID within the provider
pub id: String,
/// Whether command output should be optimized through RTK for this model.
/// Defaults to false when omitted in models.json.
#[serde(default)]
pub use_rtk: bool,
/// Model-specific configuration
pub config: serde_json::Value,
/// Maximum context window supported by the model (token count)
Expand Down