diff --git a/src/proxy_server.rs b/src/proxy_server.rs index 549a4bb..bdd4787 100644 --- a/src/proxy_server.rs +++ b/src/proxy_server.rs @@ -3,6 +3,7 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; +use bytes::Bytes; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream, UdpSocket}; use tokio::sync::{mpsc, Mutex}; @@ -965,7 +966,7 @@ struct SocksUdpTarget { /// to abort mid-await. struct UdpRelaySession { sid: String, - uplink: mpsc::Sender>, + uplink: mpsc::Sender, } /// All per-ASSOCIATE UDP relay state behind a single mutex so insertion @@ -991,7 +992,7 @@ impl UdpRelayState { } } - fn get_uplink(&self, target: &SocksUdpTarget) -> Option>> { + fn get_uplink(&self, target: &SocksUdpTarget) -> Option> { self.sessions.get(target).map(|s| s.uplink.clone()) } @@ -1118,7 +1119,15 @@ async fn handle_socks5_udp_associate( client_peer_ip ); - let mut buf = vec![0u8; SOCKS5_UDP_RECV_BUF_BYTES]; + // Fixed reusable recv buffer. We deliberately don't go the + // `BytesMut::split().freeze()` route here even though `tunnel_loop` + // does: in TCP the read region IS the payload, but UDP always + // slices the SOCKS5 header off, so we'd be copying out anyway — + // and a frozen `Bytes` from the recv buf would refcount-pin the + // full ~65 KB allocation behind a tiny DNS reply, ballooning + // memory under bursts. Right-sized `Bytes::copy_from_slice` on + // accepted payloads keeps retention proportional to actual data. + let mut recv_buf = vec![0u8; SOCKS5_UDP_RECV_BUF_BYTES]; let mut control_buf = [0u8; 1]; let mut client_addr: Option = None; let state: Arc> = Arc::new(Mutex::new(UdpRelayState::new())); @@ -1134,7 +1143,7 @@ async fn handle_socks5_udp_associate( loop { tokio::select! { - recv = udp.recv_from(&mut buf) => { + recv = udp.recv_from(&mut recv_buf) => { let (n, peer) = match recv { Ok(v) => v, Err(e) => { @@ -1142,6 +1151,7 @@ async fn handle_socks5_udp_associate( break; } }; + // Source-IP check: anything not from the SOCKS5 client's // host is dropped silently. if peer.ip() != client_peer_ip { @@ -1162,9 +1172,10 @@ async fn handle_socks5_udp_associate( // can race one bad packet to DoS the legitimate client // (whose real datagram, sent from a different ephemeral // port, would then be silently rejected). - let Some((target, payload)) = parse_socks5_udp_packet(&buf[..n]) else { + let Some((target, payload_off)) = parse_socks5_udp_packet_offsets(&recv_buf[..n]) else { continue; }; + let payload_slice = &recv_buf[payload_off..n]; // Issue #213: client-side QUIC block. UDP/443 is // HTTP/3 — drop the datagram silently so the client @@ -1206,19 +1217,26 @@ async fn handle_socks5_udp_associate( // the mux. Each datagram costs ~payload * 1.33 in the // batched JSON envelope plus tunnel-node CPU; uncapped, // a runaway client can exhaust Apps Script quota. - if payload.len() > MAX_UDP_PAYLOAD_BYTES { + if payload_slice.len() > MAX_UDP_PAYLOAD_BYTES { oversized_dropped += 1; if oversized_dropped == 1 || oversized_dropped.is_multiple_of(100) { tracing::debug!( "udp datagram dropped: {} B > {} B (count={})", - payload.len(), + payload_slice.len(), MAX_UDP_PAYLOAD_BYTES, oversized_dropped, ); } continue; } - let payload = payload.to_vec(); + + // Right-sized copy: the queued/in-flight payload owns its + // own allocation, so the recv buffer can be reused on the + // next iteration without keeping every queued datagram + // alive. Sized to the actual payload (≤ MAX_UDP_PAYLOAD_BYTES + // = 9 KB after the guard above), not the full ~65 KB recv + // buffer. + let payload = Bytes::copy_from_slice(payload_slice); // Fast path: existing session — push payload onto its // bounded uplink queue, drop on overflow (UDP semantics). @@ -1292,7 +1310,7 @@ async fn handle_socks5_udp_associate( continue; } - let (uplink_tx, uplink_rx) = mpsc::channel::>(UDP_UPLINK_QUEUE); + let (uplink_tx, uplink_rx) = mpsc::channel::(UDP_UPLINK_QUEUE); let task_mux = mux.clone(); let task_udp = udp.clone(); let task_target = target.clone(); @@ -1365,7 +1383,7 @@ async fn udp_session_task( sid: String, target: SocksUdpTarget, client_addr: SocketAddr, - mut uplink_rx: mpsc::Receiver>, + mut uplink_rx: mpsc::Receiver, ) { let mut backoff = UDP_INITIAL_POLL_DELAY; loop { @@ -1473,7 +1491,20 @@ async fn write_socks5_reply( sock.flush().await } -fn parse_socks5_udp_packet(buf: &[u8]) -> Option<(SocksUdpTarget, &[u8])> { +/// Parse the SOCKS5 UDP frame header and return the target plus the byte +/// offset at which the payload starts. Splitting "structure parsing" +/// from "give me a payload slice" lets the recv hot path stay on a +/// fixed reusable `Vec` buffer and only allocate a right-sized +/// `Bytes::copy_from_slice(&recv_buf[off..n])` for accepted payloads +/// (after the size guard). DO NOT change this back to a zero-copy +/// `Bytes::slice` path: that was tried and reverted because slicing +/// the recv buffer with `bytes` 1.x refcounts the whole ~65 KB +/// allocation, so a queued tiny DNS reply pinned the full datagram- +/// sized buffer until it drained — burst retention regressed by +/// orders of magnitude on UDP-heavy workloads. The thin +/// `parse_socks5_udp_packet` wrapper below keeps existing `&[u8]` +/// callers (tests) working. +fn parse_socks5_udp_packet_offsets(buf: &[u8]) -> Option<(SocksUdpTarget, usize)> { if buf.len() < 4 || buf[0] != 0 || buf[1] != 0 || buf[2] != 0 { return None; } @@ -1528,10 +1559,15 @@ fn parse_socks5_udp_packet(buf: &[u8]) -> Option<(SocksUdpTarget, &[u8])> { atyp, addr, }, - &buf[pos..], + pos, )) } +fn parse_socks5_udp_packet(buf: &[u8]) -> Option<(SocksUdpTarget, &[u8])> { + let (target, off) = parse_socks5_udp_packet_offsets(buf)?; + Some((target, &buf[off..])) +} + fn build_socks5_udp_packet(target: &SocksUdpTarget, payload: &[u8]) -> Vec { let mut out = Vec::with_capacity(4 + target.addr.len() + 2 + payload.len() + 1); out.extend_from_slice(&[0, 0, 0, target.atyp]); diff --git a/src/tunnel_client.rs b/src/tunnel_client.rs index 98e1572..cb6ce12 100644 --- a/src/tunnel_client.rs +++ b/src/tunnel_client.rs @@ -19,6 +19,7 @@ use std::time::{Duration, Instant}; use base64::engine::general_purpose::STANDARD as B64; use base64::Engine; +use bytes::{Bytes, BytesMut}; use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; use tokio::sync::{mpsc, oneshot, Semaphore}; @@ -163,25 +164,26 @@ enum MuxMsg { ConnectData { host: String, port: u16, - // Arc so the caller can hand the buffer to the mux AND keep a ref - // for the fallback path without an extra 64 KB copy per session. - data: Arc>, + // `Bytes` is internally Arc-backed, so the caller can cheaply + // clone() to keep its own reference for the unsupported-fallback + // replay path without an extra 64 KB copy per session. + data: Bytes, reply: BatchedReply, }, Data { sid: String, - data: Vec, + data: Bytes, reply: BatchedReply, }, UdpOpen { host: String, port: u16, - data: Vec, + data: Bytes, reply: BatchedReply, }, UdpData { sid: String, - data: Vec, + data: Bytes, reply: BatchedReply, }, Close { @@ -189,6 +191,25 @@ enum MuxMsg { }, } +/// Raw, not-yet-encoded form of a batch operation. Lives only inside +/// `mux_loop` and gets converted to `BatchOp` (with base64-encoded `d`) +/// inside `fire_batch`'s spawned task — keeping the encoding work off +/// the single mux thread, which previously had to base64 every op +/// inline before it could move on to the next message. +struct PendingOp { + op: &'static str, + sid: Option, + host: Option, + port: Option, + /// Raw payload. `None` for empty polls / opless ops; `Some` even + /// when empty preserves the connect_data shape (always emits `d`). + data: Option, + /// True for ops that must serialize `d` even when empty (currently + /// only `connect_data`, which uses presence of `d` as the signal + /// that the caller is opting into the bundled-first-bytes flow). + encode_empty: bool, +} + pub struct TunnelMux { tx: mpsc::Sender, /// Set to `true` after the first time the tunnel-node rejects @@ -316,13 +337,13 @@ impl TunnelMux { &self, host: &str, port: u16, - data: Vec, + data: impl Into, ) -> Result { let (reply_tx, reply_rx) = oneshot::channel(); self.send(MuxMsg::UdpOpen { host: host.to_string(), port, - data, + data: data.into(), reply: reply_tx, }) .await; @@ -333,11 +354,15 @@ impl TunnelMux { } } - pub async fn udp_data(&self, sid: &str, data: Vec) -> Result { + pub async fn udp_data( + &self, + sid: &str, + data: impl Into, + ) -> Result { let (reply_tx, reply_rx) = oneshot::channel(); self.send(MuxMsg::UdpData { sid: sid.to_string(), - data, + data: data.into(), reply: reply_tx, }) .await; @@ -619,10 +644,8 @@ async fn mux_loop(mut rx: mpsc::Receiver, fronter: Arc, c } // Split: plain connects go parallel, data-bearing ops get batched. - let mut data_ops: Vec = Vec::new(); - let mut data_replies: Vec<(usize, BatchedReply)> = Vec::new(); + let mut accum = BatchAccum::new(); let mut close_sids: Vec = Vec::new(); - let mut batch_payload_bytes: usize = 0; for msg in msgs { match msg { @@ -648,68 +671,28 @@ async fn mux_loop(mut rx: mpsc::Receiver, fronter: Arc, c data, reply, } => { - let encoded = Some(B64.encode(data.as_slice())); - let op_bytes = encoded.as_ref().map(|s| s.len()).unwrap_or(0); - - if !data_ops.is_empty() - && (data_ops.len() >= MAX_BATCH_OPS - || batch_payload_bytes + op_bytes > MAX_BATCH_PAYLOAD_BYTES) - { - fire_batch( - &sems, - &fronter, - std::mem::take(&mut data_ops), - std::mem::take(&mut data_replies), - ) - .await; - batch_payload_bytes = 0; - } - - let idx = data_ops.len(); - data_ops.push(BatchOp { - op: "connect_data".into(), + let op_bytes = encoded_len(data.len()); + let op = PendingOp { + op: "connect_data", sid: None, host: Some(host), port: Some(port), - d: encoded, - }); - data_replies.push((idx, reply)); - batch_payload_bytes += op_bytes; + data: Some(data), + encode_empty: true, + }; + accum.push_or_fire(op, op_bytes, reply, &sems, &fronter).await; } MuxMsg::Data { sid, data, reply } => { - let encoded = if data.is_empty() { - None - } else { - Some(B64.encode(&data)) - }; - let op_bytes = encoded.as_ref().map(|s| s.len()).unwrap_or(0); - - // If adding this op would exceed limits, fire current - // batch first and start a new one. - if !data_ops.is_empty() - && (data_ops.len() >= MAX_BATCH_OPS - || batch_payload_bytes + op_bytes > MAX_BATCH_PAYLOAD_BYTES) - { - fire_batch( - &sems, - &fronter, - std::mem::take(&mut data_ops), - std::mem::take(&mut data_replies), - ) - .await; - batch_payload_bytes = 0; - } - - let idx = data_ops.len(); - data_ops.push(BatchOp { - op: "data".into(), + let op_bytes = encoded_len(data.len()); + let op = PendingOp { + op: "data", sid: Some(sid), host: None, port: None, - d: encoded, - }); - data_replies.push((idx, reply)); - batch_payload_bytes += op_bytes; + data: if data.is_empty() { None } else { Some(data) }, + encode_empty: false, + }; + accum.push_or_fire(op, op_bytes, reply, &sems, &fronter).await; } MuxMsg::UdpOpen { host, @@ -717,70 +700,28 @@ async fn mux_loop(mut rx: mpsc::Receiver, fronter: Arc, c data, reply, } => { - let encoded = if data.is_empty() { - None - } else { - Some(B64.encode(&data)) - }; - let op_bytes = encoded.as_ref().map(|s| s.len()).unwrap_or(0); - - if !data_ops.is_empty() - && (data_ops.len() >= MAX_BATCH_OPS - || batch_payload_bytes + op_bytes > MAX_BATCH_PAYLOAD_BYTES) - { - fire_batch( - &sems, - &fronter, - std::mem::take(&mut data_ops), - std::mem::take(&mut data_replies), - ) - .await; - batch_payload_bytes = 0; - } - - let idx = data_ops.len(); - data_ops.push(BatchOp { - op: "udp_open".into(), + let op_bytes = encoded_len(data.len()); + let op = PendingOp { + op: "udp_open", sid: None, host: Some(host), port: Some(port), - d: encoded, - }); - data_replies.push((idx, reply)); - batch_payload_bytes += op_bytes; + data: if data.is_empty() { None } else { Some(data) }, + encode_empty: false, + }; + accum.push_or_fire(op, op_bytes, reply, &sems, &fronter).await; } MuxMsg::UdpData { sid, data, reply } => { - let encoded = if data.is_empty() { - None - } else { - Some(B64.encode(&data)) - }; - let op_bytes = encoded.as_ref().map(|s| s.len()).unwrap_or(0); - - if !data_ops.is_empty() - && (data_ops.len() >= MAX_BATCH_OPS - || batch_payload_bytes + op_bytes > MAX_BATCH_PAYLOAD_BYTES) - { - fire_batch( - &sems, - &fronter, - std::mem::take(&mut data_ops), - std::mem::take(&mut data_replies), - ) - .await; - batch_payload_bytes = 0; - } - - let idx = data_ops.len(); - data_ops.push(BatchOp { - op: "udp_data".into(), + let op_bytes = encoded_len(data.len()); + let op = PendingOp { + op: "udp_data", sid: Some(sid), host: None, port: None, - d: encoded, - }); - data_replies.push((idx, reply)); - batch_payload_bytes += op_bytes; + data: if data.is_empty() { None } else { Some(data) }, + encode_empty: false, + }; + accum.push_or_fire(op, op_bytes, reply, &sems, &fronter).await; } MuxMsg::Close { sid } => { close_sids.push(sid); @@ -788,21 +729,120 @@ async fn mux_loop(mut rx: mpsc::Receiver, fronter: Arc, c } } + // `close` ops piggyback on whatever batch we're about to fire — no + // reply channel, no payload, just tell tunnel-node to drop the sid. for sid in close_sids { - data_ops.push(BatchOp { - op: "close".into(), + accum.pending_ops.push(PendingOp { + op: "close", sid: Some(sid), host: None, port: None, - d: None, + data: None, + encode_empty: false, }); } - if data_ops.is_empty() { + if accum.pending_ops.is_empty() { continue; } - fire_batch(&sems, &fronter, data_ops, data_replies).await; + fire_batch(&sems, &fronter, accum.pending_ops, accum.data_replies).await; + } +} + +/// Per-iteration accumulator for `mux_loop`. Owns the three fields that +/// the data-bearing arms used to mutate in lockstep, with a single +/// `push_or_fire` entry point so the cap-then-push pattern lives in one +/// place instead of being copy-pasted into every arm. +struct BatchAccum { + pending_ops: Vec, + data_replies: Vec<(usize, BatchedReply)>, + payload_bytes: usize, +} + +impl BatchAccum { + fn new() -> Self { + Self { + pending_ops: Vec::new(), + data_replies: Vec::new(), + payload_bytes: 0, + } + } + + /// Append `op` (with its `reply` channel and pre-computed `op_bytes`), + /// firing the current accumulator first if `op` would push us past + /// `MAX_BATCH_OPS` or `MAX_BATCH_PAYLOAD_BYTES`. After a fire the + /// accumulator is fresh for the new op. + async fn push_or_fire( + &mut self, + op: PendingOp, + op_bytes: usize, + reply: BatchedReply, + sems: &Arc>>, + fronter: &Arc, + ) { + if should_fire(self.pending_ops.len(), self.payload_bytes, op_bytes) { + fire_batch( + sems, + fronter, + std::mem::take(&mut self.pending_ops), + std::mem::take(&mut self.data_replies), + ) + .await; + self.payload_bytes = 0; + } + let idx = self.pending_ops.len(); + self.pending_ops.push(op); + self.data_replies.push((idx, reply)); + self.payload_bytes += op_bytes; + } +} + +/// Threshold predicate for `BatchAccum::push_or_fire`: would adding an +/// op of `op_bytes` to a batch already holding `pending_len` ops and +/// `payload_bytes` of base64 cross either the per-batch op cap or +/// the payload-size cap? +/// +/// Extracted from the inline `if` so the tunable boundary — including +/// the "first op never fires" rule (`pending_len == 0`) — has direct +/// unit-test coverage without spinning up a real `fire_batch`. +/// +/// `saturating_add` keeps the helper's contract self-contained: a +/// pathological `op_bytes` near `usize::MAX` clamps to "yes, fire" +/// rather than wrapping around and silently letting an oversized op +/// slip past the cap. Today's callers only feed `encoded_len(n)` on +/// reasonable buffer sizes, but the predicate is the wrong place to +/// rely on caller bounds. +fn should_fire(pending_len: usize, payload_bytes: usize, op_bytes: usize) -> bool { + pending_len > 0 + && (pending_len >= MAX_BATCH_OPS + || payload_bytes.saturating_add(op_bytes) > MAX_BATCH_PAYLOAD_BYTES) +} + +/// Exact base64-encoded length of `n` raw bytes (standard padding): +/// `((n + 2) / 3) * 4`. Used by `mux_loop` to enforce +/// `MAX_BATCH_PAYLOAD_BYTES` without doing the actual encoding inline — +/// that work now happens in `fire_batch`'s spawned task. +fn encoded_len(n: usize) -> usize { + n.div_ceil(3) * 4 +} + +/// Build the wire-shape `BatchOp` from an internal `PendingOp`. Free +/// function so the encoding contract — non-empty data → encoded, +/// empty connect_data → `Some("")`, anything else empty → `None` — is +/// directly testable without spinning up the mux loop. +fn encode_pending(p: PendingOp) -> BatchOp { + let d = match (&p.data, p.encode_empty) { + (Some(b), _) if !b.is_empty() => Some(B64.encode(b)), + (Some(_), true) => Some(String::new()), + _ => None, + }; + BatchOp { + op: p.op.into(), + sid: p.sid, + host: p.host, + port: p.port, + d, } } @@ -815,7 +855,7 @@ async fn mux_loop(mut rx: mpsc::Receiver, fronter: Arc, c async fn fire_batch( sems: &Arc>>, fronter: &Arc, - data_ops: Vec, + pending_ops: Vec, data_replies: Vec<(usize, BatchedReply)>, ) { let script_id = fronter.next_script_id(); @@ -829,7 +869,13 @@ async fn fire_batch( tokio::spawn(async move { let _permit = permit; let t0 = std::time::Instant::now(); - let n_ops = data_ops.len(); + let n_ops = pending_ops.len(); + + // Encode payloads to base64 here, off the single mux thread. + // With 50 ops × 64 KB this is up to ~3 MB of work; doing it on + // the mux task previously serialized every op behind whichever + // batch was currently encoding. + let data_ops: Vec = pending_ops.into_iter().map(encode_pending).collect(); // Bounded-wait: if the batch takes longer than the configured // batch timeout (Config::request_timeout_secs), all sessions in @@ -985,14 +1031,13 @@ pub async fn tunnel_connection( mux.record_preread_skip_port(port); None } else { - let mut buf = vec![0u8; 65536]; + let mut buf = BytesMut::with_capacity(65536); let t0 = Instant::now(); - match tokio::time::timeout(CLIENT_FIRST_DATA_WAIT, sock.read(&mut buf)).await { + match tokio::time::timeout(CLIENT_FIRST_DATA_WAIT, sock.read_buf(&mut buf)).await { Ok(Ok(0)) => return Ok(()), - Ok(Ok(n)) => { + Ok(Ok(_)) => { mux.record_preread_win(port, t0.elapsed()); - buf.truncate(n); - Some(Arc::new(buf)) + Some(buf.freeze()) } Ok(Err(e)) => return Err(e), Err(_) => { @@ -1008,14 +1053,10 @@ pub async fn tunnel_connection( ConnectDataOutcome::Unsupported => { mux.mark_connect_data_unsupported(); let sid = connect_plain(host, port, mux).await?; - // Recover the buffered ClientHello from the Arc so the - // first tunnel_loop iteration can replay it. The mux task - // may still hold the other ref during the unsupported - // reply's settle window — fall back to a clone in that - // race (rare; the reply path drops its ref before we - // reach here in practice). - let bytes = Arc::try_unwrap(data).unwrap_or_else(|a| (*a).clone()); - (sid, None, Some(bytes)) + // Replay the buffered ClientHello on the first tunnel_loop + // iteration. `Bytes::clone()` is a cheap Arc bump — no + // copy of the 64 KB buffer. + (sid, None, Some(data)) } }, None => (connect_plain(host, port, mux).await?, None, None), @@ -1107,7 +1148,7 @@ async fn connect_plain(host: &str, port: u16, mux: &Arc) -> std::io:: async fn connect_with_initial_data( host: &str, port: u16, - data: Arc>, + data: Bytes, mux: &Arc, ) -> std::io::Result { let (reply_tx, reply_rx) = oneshot::channel(); @@ -1212,10 +1253,30 @@ async fn tunnel_loop( sock: &mut TcpStream, sid: &str, mux: &Arc, - mut pending_client_data: Option>, + mut pending_client_data: Option, ) -> std::io::Result<()> { let (mut reader, mut writer) = sock.split(); - let mut buf = vec![0u8; 65536]; + // `BytesMut` + `read_buf` + a per-read decision between + // `split().freeze()` (zero-copy) and `copy_from_slice` + `clear` + // (right-sized copy, buffer reused). + // + // Why the split decision: `bytes` 1.x refcounts the *whole* + // backing allocation, so a frozen `Bytes` from a partial read + // pins all `READ_CHUNK` bytes until it drops. Under semaphore + // saturation or reply timeouts, dozens of small TLS records or + // HTTP/2 frames can each retain ~64 KB instead of their actual + // payload size — order-of-magnitude memory regression on + // constrained targets (router builds with 64 MB RAM). + // + // Threshold: at ≥ half-buffer the saved memcpy outweighs the + // wasted slack, and these reads are typically streaming bulk + // transfer where the `Bytes` flushes through the mux quickly. + // Below that, copy out and `clear()` so the same allocation + // serves the next read — equivalent memory profile to the old + // `vec![0u8; 65536]` + `to_vec()` code on small-read workloads. + const READ_CHUNK: usize = 65536; + const ZERO_COPY_THRESHOLD: usize = READ_CHUNK / 2; + let mut buf = BytesMut::with_capacity(READ_CHUNK); let mut consecutive_empty = 0u32; loop { @@ -1254,11 +1315,28 @@ async fn tunnel_loop( (true, _) => Duration::from_secs(30), }; - match tokio::time::timeout(read_timeout, reader.read(&mut buf)).await { + buf.reserve(READ_CHUNK); + match tokio::time::timeout(read_timeout, reader.read_buf(&mut buf)).await { Ok(Ok(0)) => break, Ok(Ok(n)) => { consecutive_empty = 0; - Some(buf[..n].to_vec()) + if n >= ZERO_COPY_THRESHOLD { + // Big read: split off the filled region. The + // frozen `Bytes` is at-least-half-full, so the + // saved 64 KB memcpy outweighs the brief + // retention until the mux drains. + Some(buf.split().freeze()) + } else { + // Small read: copy out a payload-sized `Bytes` + // and `clear()` so the buffer is reused on the + // next iter (no `reserve` allocation needed + // because the alloc stays uniquely owned). + // Bounds retention to actual data even when + // the mux is backpressured. + let owned = Bytes::copy_from_slice(&buf[..n]); + buf.clear(); + Some(owned) + } } Ok(Err(_)) => break, Err(_) => None, @@ -1275,7 +1353,7 @@ async fn tunnel_loop( continue; } - let data = client_data.unwrap_or_default(); + let data = client_data.unwrap_or_else(Bytes::new); let was_empty_poll = data.is_empty(); let (reply_tx, reply_rx) = oneshot::channel(); @@ -1664,7 +1742,7 @@ mod tests { let mut server_side = accept.await.unwrap(); let (mux, mut rx) = mux_for_test(); - let pending = Some(b"CLIENTHELLO".to_vec()); + let pending = Some(Bytes::from_static(b"CLIENTHELLO")); let loop_handle = tokio::spawn({ let mux = mux.clone(); @@ -1907,6 +1985,199 @@ mod tests { ); } + #[test] + fn should_fire_first_op_never_fires() { + // Empty accumulator: even a single op larger than the payload cap + // must not fire — there's nothing to fire yet, and the op gets + // added (it will simply be the only op in the next batch). + assert!(!should_fire(0, 0, 0)); + assert!(!should_fire(0, 0, MAX_BATCH_PAYLOAD_BYTES + 1_000_000)); + } + + #[test] + fn should_fire_at_max_ops_threshold() { + // 49 already-queued ops + 50th: still fits (boundary is `>=`). + assert!(!should_fire(MAX_BATCH_OPS - 1, 0, 100)); + // 50 already-queued ops + 51st: must fire. + assert!(should_fire(MAX_BATCH_OPS, 0, 100)); + // Well past the cap: must fire. + assert!(should_fire(MAX_BATCH_OPS + 5, 0, 100)); + } + + #[test] + fn should_fire_when_payload_would_exceed_cap() { + // Exactly at the cap is fine — strict `>`. + assert!(!should_fire( + 10, + MAX_BATCH_PAYLOAD_BYTES - 100, + 100, + )); + // One byte over: fire. + assert!(should_fire( + 10, + MAX_BATCH_PAYLOAD_BYTES - 100, + 101, + )); + // Sum overflow well past the cap: fire. + assert!(should_fire( + 10, + MAX_BATCH_PAYLOAD_BYTES, + 1, + )); + } + + /// Reply indices must point at the slot the op occupies *within its + /// batch*. Pre-flush ops are 0..N-1 in batch A; post-flush ops + /// restart at 0 in batch B. If this regresses, `fire_batch`'s + /// `batch_resp.r.get(idx)` lookup hands the wrong response (or + /// `None`) to the wrong session — silent data corruption that + /// the encode-layer tests can't catch. + #[tokio::test] + async fn batch_accum_reindexes_after_flush() { + // Stand-alone helper that mirrors `push_or_fire`'s push step + // without the fire_batch call — lets us simulate a flush with + // `mem::take` and assert the post-flush indexing without + // mocking the whole tunnel_request stack. + fn push_no_fire( + accum: &mut BatchAccum, + op: PendingOp, + op_bytes: usize, + reply: BatchedReply, + ) { + let idx = accum.pending_ops.len(); + accum.pending_ops.push(op); + accum.data_replies.push((idx, reply)); + accum.payload_bytes += op_bytes; + } + + let mk_op = |sid: &str| PendingOp { + op: "data", + sid: Some(sid.into()), + host: None, + port: None, + data: Some(Bytes::from_static(b"x")), + encode_empty: false, + }; + let mk_reply = || oneshot::channel::>().0; + + let mut accum = BatchAccum::new(); + + // Batch A: 3 ops at indices 0, 1, 2. + push_no_fire(&mut accum, mk_op("a0"), 4, mk_reply()); + push_no_fire(&mut accum, mk_op("a1"), 4, mk_reply()); + push_no_fire(&mut accum, mk_op("a2"), 4, mk_reply()); + assert_eq!(accum.pending_ops.len(), 3); + assert_eq!( + accum.data_replies.iter().map(|(i, _)| *i).collect::>(), + vec![0, 1, 2], + ); + assert_eq!(accum.payload_bytes, 12); + + // Simulate the flush: take the queued state and reset the byte + // counter (matches what `push_or_fire` does after `fire_batch`). + let _flushed_ops = std::mem::take(&mut accum.pending_ops); + let _flushed_replies = std::mem::take(&mut accum.data_replies); + accum.payload_bytes = 0; + + // Batch B: 2 ops, indices restart at 0. + push_no_fire(&mut accum, mk_op("b0"), 4, mk_reply()); + push_no_fire(&mut accum, mk_op("b1"), 4, mk_reply()); + assert_eq!(accum.pending_ops.len(), 2); + assert_eq!( + accum.data_replies.iter().map(|(i, _)| *i).collect::>(), + vec![0, 1], + "post-flush indices must restart at 0 — otherwise fire_batch's \ + batch_resp.r.get(idx) returns None and every session in the \ + second batch sees a missing-response error" + ); + assert_eq!(accum.payload_bytes, 8); + } + + #[test] + fn encode_pending_data_op_with_payload_emits_base64() { + let op = PendingOp { + op: "data", + sid: Some("sid-1".into()), + host: None, + port: None, + data: Some(Bytes::from_static(b"hello")), + encode_empty: false, + }; + let b = encode_pending(op); + assert_eq!(b.op, "data"); + assert_eq!(b.sid.as_deref(), Some("sid-1")); + assert_eq!(b.d.as_deref(), Some(B64.encode(b"hello").as_str())); + } + + #[test] + fn encode_pending_omits_d_for_empty_polls_and_close() { + // Empty-poll Data: mux_loop converts empty Bytes to data: None. + let empty_poll = PendingOp { + op: "data", + sid: Some("sid-2".into()), + host: None, + port: None, + data: None, + encode_empty: false, + }; + assert!(encode_pending(empty_poll).d.is_none()); + + // UDP poll with no payload: same shape. + let udp_poll = PendingOp { + op: "udp_data", + sid: Some("sid-3".into()), + host: None, + port: None, + data: None, + encode_empty: false, + }; + assert!(encode_pending(udp_poll).d.is_none()); + + // Close has no data and no reply — `d` must stay omitted. + let close = PendingOp { + op: "close", + sid: Some("sid-4".into()), + host: None, + port: None, + data: None, + encode_empty: false, + }; + assert!(encode_pending(close).d.is_none()); + } + + #[test] + fn encode_pending_connect_data_emits_empty_string_when_data_is_empty() { + // Defensive: ConnectData's wire contract is that `d` is always + // present (its presence is the signal that the caller is opting + // into the bundled-first-bytes flow). If an empty Bytes ever + // reaches the encoder, we must serialize `d: ""` not omit it. + let op = PendingOp { + op: "connect_data", + sid: None, + host: Some("example.com".into()), + port: Some(443), + data: Some(Bytes::new()), + encode_empty: true, + }; + let b = encode_pending(op); + assert_eq!(b.op, "connect_data"); + assert_eq!(b.d.as_deref(), Some("")); + } + + #[test] + fn encode_pending_connect_data_with_payload_encodes_normally() { + let op = PendingOp { + op: "connect_data", + sid: None, + host: Some("example.com".into()), + port: Some(443), + data: Some(Bytes::from_static(b"\x16\x03\x01")), // ClientHello prefix + encode_empty: true, + }; + let b = encode_pending(op); + assert_eq!(b.d.as_deref(), Some(B64.encode(b"\x16\x03\x01").as_str())); + } + #[test] fn preread_counters_track_each_outcome() { let (mux, _rx) = mux_for_test();