diff --git a/ostool/src/board/serial_stream.rs b/ostool/src/board/serial_stream.rs index 624cd15..470b2df 100644 --- a/ostool/src/board/serial_stream.rs +++ b/ostool/src/board/serial_stream.rs @@ -1,3 +1,8 @@ +use std::sync::{ + Arc, + atomic::{AtomicBool, Ordering}, +}; + use anyhow::Context as _; use futures::{SinkExt, StreamExt}; use tokio::{ @@ -8,7 +13,9 @@ use tokio::{ use tokio_tungstenite::tungstenite::Message; use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; -use crate::board::terminal::ServerControlMessage; +use crate::board::terminal::{ + ServerControlAction, ServerControlMessage, classify_server_control_message, +}; pub type BoxedAsyncRead = Box; pub type BoxedAsyncWrite = Box; @@ -25,73 +32,87 @@ pub async fn connect_serial_stream( .await .with_context(|| format!("failed to connect serial websocket {}", ws_url))?; let (mut ws_sink, mut ws_stream) = stream.split(); + let locally_closed = Arc::new(AtomicBool::new(false)); let (runner_stream, bridge_stream) = tokio::io::duplex(64 * 1024); let (runner_rx, runner_tx) = split(runner_stream); let (mut bridge_rx, mut bridge_tx) = split(bridge_stream); - let read_task = tokio::spawn(async move { - while let Some(message) = ws_stream.next().await { - match message.context("serial websocket read failed")? { - Message::Binary(bytes) => { - tokio::io::AsyncWriteExt::write_all(&mut bridge_tx, &bytes) - .await - .context("failed to write serial websocket bytes")?; - tokio::io::AsyncWriteExt::flush(&mut bridge_tx) - .await - .context("failed to flush serial websocket bytes")?; - } - Message::Text(text) => { - if let Ok(control) = serde_json::from_str::(&text) { - match control.kind.as_str() { - "opened" | "closed" => continue, - "error" => { - let message = control - .message - .unwrap_or_else(|| "serial websocket error".to_string()); - anyhow::bail!("ostool-server serial websocket error: {message}"); + let read_task = tokio::spawn({ + let locally_closed = locally_closed.clone(); + async move { + while let Some(message) = ws_stream.next().await { + match message.context("serial websocket read failed")? { + Message::Binary(bytes) => { + tokio::io::AsyncWriteExt::write_all(&mut bridge_tx, &bytes) + .await + .context("failed to write serial websocket bytes")?; + tokio::io::AsyncWriteExt::flush(&mut bridge_tx) + .await + .context("failed to flush serial websocket bytes")?; + } + Message::Text(text) => { + if let Ok(control) = serde_json::from_str::(&text) { + match classify_server_control_message( + &control, + locally_closed.load(Ordering::SeqCst), + ) { + ServerControlAction::Ignore => continue, + ServerControlAction::Close => break, + ServerControlAction::Error(err) => return Err(err), + ServerControlAction::Forward => {} } - _ => {} } - } - tokio::io::AsyncWriteExt::write_all(&mut bridge_tx, text.as_bytes()) - .await - .context("failed to write text serial websocket payload")?; - tokio::io::AsyncWriteExt::flush(&mut bridge_tx) - .await - .context("failed to flush text serial websocket payload")?; + tokio::io::AsyncWriteExt::write_all(&mut bridge_tx, text.as_bytes()) + .await + .context("failed to write text serial websocket payload")?; + tokio::io::AsyncWriteExt::flush(&mut bridge_tx) + .await + .context("failed to flush text serial websocket payload")?; + } + Message::Close(_) => { + if locally_closed.load(Ordering::SeqCst) { + break; + } + anyhow::bail!( + "ostool-server closed the serial websocket; the board session may have been released" + ); + } + Message::Ping(_) => {} + Message::Pong(_) | Message::Frame(_) => {} } - Message::Close(_) => break, - Message::Ping(_) => {} - Message::Pong(_) | Message::Frame(_) => {} } - } - Ok(()) + Ok(()) + } }); - let write_task = tokio::spawn(async move { - let mut buffer = [0u8; 4096]; - loop { - let read = bridge_rx - .read(&mut buffer) - .await - .context("failed to read runner serial bytes")?; - if read == 0 { - break; + let write_task = tokio::spawn({ + let locally_closed = locally_closed.clone(); + async move { + let mut buffer = [0u8; 4096]; + loop { + let read = bridge_rx + .read(&mut buffer) + .await + .context("failed to read runner serial bytes")?; + if read == 0 { + break; + } + ws_sink + .send(Message::Binary(buffer[..read].to_vec().into())) + .await + .context("serial websocket write failed")?; } - ws_sink - .send(Message::Binary(buffer[..read].to_vec().into())) - .await - .context("serial websocket write failed")?; - } - let _ = ws_sink - .send(Message::Text(r#"{"type":"close"}"#.to_string().into())) - .await; - let _ = ws_sink.send(Message::Close(None)).await; - Ok(()) + locally_closed.store(true, Ordering::SeqCst); + let _ = ws_sink + .send(Message::Text(r#"{"type":"close"}"#.to_string().into())) + .await; + let _ = ws_sink.send(Message::Close(None)).await; + Ok(()) + } }); Ok(( diff --git a/ostool/src/board/terminal.rs b/ostool/src/board/terminal.rs index c45c9f4..19fdf61 100644 --- a/ostool/src/board/terminal.rs +++ b/ostool/src/board/terminal.rs @@ -1,4 +1,9 @@ -use anyhow::Context as _; +use std::sync::{ + Arc, + atomic::{AtomicBool, Ordering}, +}; + +use anyhow::{Context as _, anyhow}; use futures::{SinkExt, StreamExt}; use serde::Deserialize; use tokio::sync::mpsc; @@ -13,64 +18,111 @@ pub(crate) struct ServerControlMessage { pub(crate) message: Option, } +pub(crate) enum ServerControlAction { + Ignore, + Close, + Error(anyhow::Error), + Forward, +} + +pub(crate) fn classify_server_control_message( + control: &ServerControlMessage, + locally_closed: bool, +) -> ServerControlAction { + match control.kind.as_str() { + "opened" => ServerControlAction::Ignore, + "closed" => { + if locally_closed { + ServerControlAction::Close + } else { + ServerControlAction::Error(anyhow!( + "ostool-server closed the serial websocket; the board session may have been released" + )) + } + } + "error" => { + let message = control + .message + .clone() + .unwrap_or_else(|| "serial websocket error".to_string()); + ServerControlAction::Error(anyhow!("ostool-server serial websocket error: {message}")) + } + _ => ServerControlAction::Forward, + } +} + pub async fn run_serial_terminal(ws_url: reqwest::Url) -> anyhow::Result<()> { let (stream, _) = tokio_tungstenite::connect_async(ws_url.as_str()) .await .with_context(|| format!("failed to connect serial websocket {}", ws_url))?; let (mut sink, mut stream) = stream.split(); + let locally_closed = Arc::new(AtomicBool::new(false)); let (inbound_tx, inbound_rx) = mpsc::unbounded_channel::>(); let (outbound_tx, mut outbound_rx) = mpsc::unbounded_channel::>(); - let read_task = tokio::spawn(async move { - while let Some(message) = stream.next().await { - match message.context("serial websocket read failed")? { - Message::Binary(bytes) => { - if inbound_tx.send(bytes.to_vec()).is_err() { - break; + let read_task = tokio::spawn({ + let locally_closed = locally_closed.clone(); + async move { + while let Some(message) = stream.next().await { + match message.context("serial websocket read failed")? { + Message::Binary(bytes) => { + if inbound_tx.send(bytes.to_vec()).is_err() { + break; + } } - } - Message::Text(text) => { - if let Ok(control) = serde_json::from_str::(&text) { - match control.kind.as_str() { - "opened" | "closed" => continue, - "error" => { - let message = control - .message - .unwrap_or_else(|| "serial websocket error".to_string()); - let formatted = format!("\n[ostool-server] {message}\n"); - if inbound_tx.send(formatted.into_bytes()).is_err() { - break; + Message::Text(text) => { + if let Ok(control) = serde_json::from_str::(&text) { + match classify_server_control_message( + &control, + locally_closed.load(Ordering::SeqCst), + ) { + ServerControlAction::Ignore => continue, + ServerControlAction::Close => break, + ServerControlAction::Error(err) => { + let _ = inbound_tx + .send(format!("\n[ostool-server] {err}\n").into_bytes()); + return Err(err); } - break; + ServerControlAction::Forward => {} } - _ => {} + } + if inbound_tx.send(text.bytes().collect()).is_err() { + break; } } - if inbound_tx.send(text.bytes().collect()).is_err() { - break; + Message::Close(_) => { + if locally_closed.load(Ordering::SeqCst) { + break; + } + return Err(anyhow!( + "ostool-server closed the serial websocket; the board session may have been released" + )); } + Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => {} } - Message::Close(_) => break, - Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => {} } - } - Ok::<(), anyhow::Error>(()) + Ok::<(), anyhow::Error>(()) + } }); - let write_task = tokio::spawn(async move { - while let Some(bytes) = outbound_rx.recv().await { - sink.send(Message::Binary(bytes.into())) - .await - .context("serial websocket write failed")?; - } + let write_task = tokio::spawn({ + let locally_closed = locally_closed.clone(); + async move { + while let Some(bytes) = outbound_rx.recv().await { + sink.send(Message::Binary(bytes.into())) + .await + .context("serial websocket write failed")?; + } - let _ = sink - .send(Message::Text(r#"{"type":"close"}"#.to_string().into())) - .await; - let _ = sink.send(Message::Close(None)).await; - Ok::<(), anyhow::Error>(()) + locally_closed.store(true, Ordering::SeqCst); + let _ = sink + .send(Message::Text(r#"{"type":"close"}"#.to_string().into())) + .await; + let _ = sink.send(Message::Close(None)).await; + Ok::<(), anyhow::Error>(()) + } }); let terminal = AsyncTerminal::new(TerminalConfig { @@ -82,16 +134,45 @@ pub async fn run_serial_terminal(ws_url: reqwest::Url) -> anyhow::Result<()> { .run(inbound_rx, outbound_tx, |_handle, _byte| {}) .await; - read_task.abort(); - if let Err(err) = write_task.await - && !err.is_cancelled() - { - log::debug!("serial websocket writer join error: {err}"); - } - if let Err(err) = read_task.await - && !err.is_cancelled() - { - log::debug!("serial websocket reader join error: {err}"); + let mut write_task = write_task; + let write_result = + tokio::time::timeout(std::time::Duration::from_secs(1), &mut write_task).await; + let mut read_task = read_task; + let read_result = + tokio::time::timeout(std::time::Duration::from_millis(300), &mut read_task).await; + + let write_error = match write_result { + Ok(Ok(Ok(()))) => None, + Ok(Ok(Err(err))) => Some(err), + Ok(Err(err)) if !err.is_cancelled() => { + Some(anyhow!("serial websocket writer join error: {err}")) + } + Ok(Err(_)) => None, + Err(_) => { + write_task.abort(); + let _ = write_task.await; + Some(anyhow!("serial websocket writer shutdown timed out")) + } + }; + let read_error = match read_result { + Ok(Ok(Ok(()))) => None, + Ok(Ok(Err(err))) => Some(err), + Ok(Err(err)) if !err.is_cancelled() => { + Some(anyhow!("serial websocket reader join error: {err}")) + } + Ok(Err(_)) => None, + Err(_) => { + read_task.abort(); + let _ = read_task.await; + Some(anyhow!("serial websocket reader shutdown timed out")) + } + }; + + if let Some(err) = write_error.or(read_error) { + if run_result.is_ok() { + return Err(err); + } + log::warn!("remote serial terminal shutdown failed: {err:#}"); } run_result @@ -99,7 +180,7 @@ pub async fn run_serial_terminal(ws_url: reqwest::Url) -> anyhow::Result<()> { #[cfg(test)] mod tests { - use super::ServerControlMessage; + use super::{ServerControlAction, ServerControlMessage, classify_server_control_message}; #[test] fn parse_server_control_message() { @@ -114,4 +195,24 @@ mod tests { assert_eq!(error.kind, "error"); assert_eq!(error.message.as_deref(), Some("power failed")); } + + #[test] + fn closed_control_message_becomes_error_when_not_locally_closed() { + let control: ServerControlMessage = serde_json::from_str(r#"{"type":"closed"}"#).unwrap(); + match classify_server_control_message(&control, false) { + ServerControlAction::Error(err) => { + assert!(err.to_string().contains("may have been released")); + } + _ => panic!("expected error action"), + } + } + + #[test] + fn closed_control_message_is_normal_when_locally_closed() { + let control: ServerControlMessage = serde_json::from_str(r#"{"type":"closed"}"#).unwrap(); + assert!(matches!( + classify_server_control_message(&control, true), + ServerControlAction::Close + )); + } } diff --git a/ostool/src/run/uboot.rs b/ostool/src/run/uboot.rs index 1ce37c3..6c9c678 100644 --- a/ostool/src/run/uboot.rs +++ b/ostool/src/run/uboot.rs @@ -937,10 +937,8 @@ impl RunnerBackend for RemoteBackend { } async fn finish_console(&mut self) -> anyhow::Result<()> { - if let Some(tasks) = self.console_tasks.take() - && let Err(err) = tasks.shutdown_with_timeout(Duration::from_secs(2)).await - { - log::warn!("remote serial console shutdown did not complete cleanly: {err:#}"); + if let Some(tasks) = self.console_tasks.take() { + tasks.shutdown_with_timeout(Duration::from_secs(2)).await?; } Ok(()) }