Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 172 additions & 0 deletions crates/forge_app/src/dto/anthropic/transforms/billing_header.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
use forge_domain::Transformer;
use sha2::{Digest, Sha256};

use crate::dto::anthropic::{Content, Message, Request, Role, SystemMessage};

// Mirrors the Claude Code billing metadata transform from
// https://github.com/ex-machina-co/opencode-anthropic-auth/tree/main.

/// Claude Code version used for billing header computation.
const CLAUDE_CODE_VERSION: &str = "2.1.87";

/// Salt used in version suffix computation.
const CCH_SALT: &str = "59cf53e54c78";

/// Character positions sampled from the first user message for version suffix.
const CCH_POSITIONS: &[usize] = &[4, 7, 20];

/// Entrypoint name reported in the billing header.
const ENTRYPOINT: &str = "sdk-cli";

/// Adds the Anthropic billing metadata block as the first system message.
///
/// The OAuth-backed Claude Code provider uses this metadata shape when sending
/// subscription-authenticated Anthropic requests.
pub struct BillingHeader;

impl BillingHeader {
/// Extract plain text from the first user message's first text block.
fn extract_first_user_text(messages: &[Message]) -> String {
let user_msg = messages.iter().find(|m| matches!(m.role, Role::User));
let Some(user_msg) = user_msg else {
return String::new();
};

user_msg
.content
.iter()
.find_map(|block| match block {
Content::Text { text, .. } => Some(text.clone()),
_ => None,
})
.unwrap_or_default()
}

/// Compute `cch`: first 5 hex characters of SHA-256(text).
fn compute_cch(text: &str) -> String {
let hash = Sha256::digest(text.as_bytes());
hex::encode(hash)[..5].to_string()
}

/// Compute the 3-character version suffix from sampled message characters.
fn compute_version_suffix(text: &str) -> String {
let chars: String = CCH_POSITIONS
.iter()
.map(|&pos| text.chars().nth(pos).unwrap_or('0'))
.collect();

let input = format!("{CCH_SALT}{chars}{CLAUDE_CODE_VERSION}");
let hash = Sha256::digest(input.as_bytes());
hex::encode(hash)[..3].to_string()
}

/// Build the complete billing header value.
fn build_header_value(messages: &[Message]) -> String {
let text = Self::extract_first_user_text(messages);
let suffix = Self::compute_version_suffix(&text);
let cch = Self::compute_cch(&text);

format!(
"x-anthropic-billing-header: cc_version={CLAUDE_CODE_VERSION}.{suffix}; cc_entrypoint={ENTRYPOINT}; cch={cch};"
)
}
}

impl Transformer for BillingHeader {
type Value = Request;

fn transform(&mut self, mut request: Self::Value) -> Self::Value {
if request.messages.is_empty() {
return request;
}

let header_text = Self::build_header_value(&request.messages);
let billing_message = SystemMessage {
r#type: "text".to_string(),
text: header_text,
cache_control: None,
};

let mut system_messages = request.system.take().unwrap_or_default();
system_messages.insert(0, billing_message);
request.system = Some(system_messages);
request
}
}

#[cfg(test)]
mod tests {
use forge_domain::{Context, ContextMessage, ModelId};

use super::*;

#[test]
fn test_build_header_value_format() {
let messages = vec![Message {
role: Role::User,
content: vec![Content::Text {
text: "Hello world this is a test message for billing".to_string(),
cache_control: None,
}],
}];

let header = BillingHeader::build_header_value(&messages);

assert!(
header.starts_with("x-anthropic-billing-header: cc_version=2.1.87."),
"Header should start with correct prefix, got: {header}"
);
assert!(
header.contains("cc_entrypoint=sdk-cli"),
"Header should contain Claude Code SDK entrypoint, got: {header}"
);
assert!(
header.contains("cch="),
"Header should contain cch, got: {header}"
);
}

#[test]
fn test_transform_prepends_billing_header() {
let context = Context::default().add_message(ContextMessage::user(
"test message",
Some(ModelId::new("claude-3-5-sonnet-20241022")),
));

let request = Request::try_from(context).unwrap();
let transformed = BillingHeader.transform(request);

let system = transformed.system.unwrap();
assert_eq!(system.len(), 1);
assert!(
system[0].text.starts_with("x-anthropic-billing-header:"),
"First system block should be billing header, got: {}",
system[0].text
);
}

#[test]
fn test_transform_with_existing_system_messages() {
let context = Context::default()
.add_message(ContextMessage::system("You are helpful"))
.add_message(ContextMessage::user(
"hello",
Some(ModelId::new("claude-3-5-sonnet-20241022")),
));

let request = Request::try_from(context).unwrap();
let transformed = BillingHeader.transform(request);

let system = transformed.system.unwrap();
assert_eq!(system.len(), 2);
assert!(system[0].text.starts_with("x-anthropic-billing-header:"));
assert_eq!(system[1].text, "You are helpful");
}

#[test]
fn test_empty_messages_no_panic() {
let request = Request::default();
let transformed = BillingHeader.transform(request);
assert!(transformed.system.is_none() || transformed.system.as_ref().unwrap().is_empty());
}
}
2 changes: 2 additions & 0 deletions crates/forge_app/src/dto/anthropic/transforms/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod auth_system_message;
mod billing_header;
mod capitalize_tool_names;
mod drop_invalid_toolcalls;
mod enforce_schema;
Expand All @@ -9,6 +10,7 @@ mod sanitize_tool_ids;
mod set_cache;

pub use auth_system_message::AuthSystemMessage;
pub use billing_header::BillingHeader;
pub use capitalize_tool_names::CapitalizeToolNames;
pub use drop_invalid_toolcalls::DropInvalidToolUse;
pub use enforce_schema::EnforceStrictObjectSchema;
Expand Down
127 changes: 108 additions & 19 deletions crates/forge_app/src/dto/anthropic/transforms/set_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ use forge_domain::Transformer;

use crate::dto::anthropic::Request;

/// Anthropic rejects requests with more than 4 `cache_control` blocks.
const MAX_CACHE_CONTROL_BLOCKS: usize = 4;

/// Transformer that keeps Anthropic prompt-cache markers stable:
/// - Always caches every system message so the static system prefix remains
/// reusable
/// - Prefers newer cache breakpoints because later markers capture more of the
/// prompt prefix than earlier ones
/// - Falls back to caching the first conversation message when there is no
/// system prompt so single-turn requests still establish a reusable prefix
/// - Uses exactly one rolling message-level marker on the newest message
Expand All @@ -14,9 +17,9 @@ impl Transformer for SetCache {
type Value = Request;

/// Applies the default Anthropic cache strategy:
/// 1. Cache every system message when present, otherwise cache the first
/// conversation message.
/// 2. Cache only the last message as the rolling message-level marker.
/// 1. Clear any existing cache markers.
/// 2. Select preferred cache breakpoints.
/// 3. Keep only the newest breakpoints up to Anthropic's 4-block limit.
fn transform(&mut self, mut request: Self::Value) -> Self::Value {
let len = request.get_messages().len();
let sys_len = request.system.as_ref().map_or(0, |msgs| msgs.len());
Expand All @@ -25,36 +28,68 @@ impl Transformer for SetCache {
return request;
}

let has_system_prompt = request
.system
.as_ref()
.is_some_and(|messages| !messages.is_empty());

if let Some(system_messages) = request.system.as_mut() {
for message in system_messages.iter_mut() {
*message = std::mem::take(message).cached(true);
*message = std::mem::take(message).cached(false);
}
}

for message in request.get_messages_mut().iter_mut() {
*message = std::mem::take(message).cached(false);
}

if !has_system_prompt
&& len > 0
&& let Some(first_message) = request.get_messages_mut().first_mut()
{
*first_message = std::mem::take(first_message).cached(true);
let has_system_prompt = request
.system
.as_ref()
.is_some_and(|messages| !messages.is_empty());

let mut desired_markers = Vec::new();

if has_system_prompt {
desired_markers.extend((0..sys_len).map(CacheMarker::System));
} else if len > 0 {
desired_markers.push(CacheMarker::Message(0));
}

if let Some(message) = request.get_messages_mut().last_mut() {
*message = std::mem::take(message).cached(true);
if len > 0 {
let last_message = CacheMarker::Message(len - 1);
if !desired_markers.contains(&last_message) {
desired_markers.push(last_message);
}
}

let keep_from = desired_markers
.len()
.saturating_sub(MAX_CACHE_CONTROL_BLOCKS);
for marker in desired_markers.into_iter().skip(keep_from) {
match marker {
CacheMarker::System(idx) => {
if let Some(message) = request
.system
.as_mut()
.and_then(|messages| messages.get_mut(idx))
{
*message = std::mem::take(message).cached(true);
}
}
CacheMarker::Message(idx) => {
if let Some(message) = request.get_messages_mut().get_mut(idx) {
*message = std::mem::take(message).cached(true);
}
}
}
}

request
}
}

#[derive(Clone, Copy, PartialEq, Eq)]
enum CacheMarker {
System(usize),
Message(usize),
}

#[cfg(test)]
mod tests {
use std::collections::HashSet;
Expand Down Expand Up @@ -225,7 +260,7 @@ mod tests {
}

#[test]
fn test_multiple_system_messages_all_cached() {
fn test_multiple_system_messages_keep_newest_within_limit() {
let fixture = Context {
conversation_id: None,
messages: vec![
Expand Down Expand Up @@ -264,4 +299,58 @@ mod tests {
assert_eq!(actual, expected);
assert!(request.get_messages()[0].is_cached());
}

#[test]
fn test_cache_markers_never_exceed_anthropic_limit() {
let fixture = Context {
conversation_id: None,
messages: vec![
ContextMessage::Text(TextMessage::new(Role::System, "s1")).into(),
ContextMessage::Text(TextMessage::new(Role::System, "s2")).into(),
ContextMessage::Text(TextMessage::new(Role::System, "s3")).into(),
ContextMessage::Text(TextMessage::new(Role::System, "s4")).into(),
ContextMessage::Text(TextMessage::new(Role::System, "s5")).into(),
ContextMessage::Text(
TextMessage::new(Role::User, "user")
.model(ModelId::new("claude-3-5-sonnet-20241022")),
)
.into(),
],
tools: vec![],
tool_choice: None,
max_tokens: None,
temperature: None,
top_p: None,
top_k: None,
reasoning: None,
stream: None,
response_format: None,
initiator: None,
};

let request = Request::try_from(fixture).expect("Failed to convert context to request");
let mut transformer = SetCache;
let request = transformer.transform(request);

let system_cache_flags = request
.system
.as_ref()
.unwrap()
.iter()
.map(|message| message.is_cached())
.collect::<Vec<_>>();
assert_eq!(system_cache_flags, vec![false, false, true, true, true]);
assert!(request.get_messages()[0].is_cached());

let total_cached_blocks = system_cache_flags
.into_iter()
.filter(|cached| *cached)
.count()
+ request
.get_messages()
.iter()
.filter(|message| message.is_cached())
.count();
assert_eq!(total_cached_blocks, MAX_CACHE_CONTROL_BLOCKS);
}
}
Loading