diff --git a/Cargo.toml b/Cargo.toml index 0b029bca..af286be4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,13 +45,23 @@ serde_json = "1" take_mut = "0.2.2" thiserror = "1" tokio = { version = "1", features = ["sync", "rt-multi-thread", "macros"] } +# Keep these compatible with rust-toolchain 1.84.1. +indexmap = "=2.13.1" +icu_collections = "<2.2" +icu_locale_core = "<2.2" +icu_normalizer = "<2.2" +icu_normalizer_data = "<2.2" +icu_properties = "<2.2" +icu_properties_data = "<2.2" +icu_provider = "<2.2" + tonic = { version = "0.10", features = ["tls", "gzip"] } [dev-dependencies] clap = "2" env_logger = "0.10" fail = { version = "0.4", features = ["failpoints"] } -proptest = "1" +proptest = "<1.11" proptest-derive = "0.5.1" reqwest = { version = "0.11", features = ["json", "native-tls-vendored"] } rstest = "0.18.2" diff --git a/proto-build/Cargo.toml b/proto-build/Cargo.toml index ca3fad6e..dc88e313 100644 --- a/proto-build/Cargo.toml +++ b/proto-build/Cargo.toml @@ -13,5 +13,6 @@ edition = "2021" [dependencies] glob = "0.3" tonic-build = { version = "0.10", features = ["cleanup-markdown"] } -# Keep this compatible with rust-toolchain 1.84.1. +# Keep these compatible with rust-toolchain 1.84.1. +indexmap = "=2.13.1" tempfile = "=3.14.0" diff --git a/src/request/plan.rs b/src/request/plan.rs index 8bd15bb5..4ae009e2 100644 --- a/src/request/plan.rs +++ b/src/request/plan.rs @@ -5,11 +5,12 @@ use std::sync::Arc; use async_recursion::async_recursion; use async_trait::async_trait; -use futures::future::try_join_all; use futures::prelude::*; use log::debug; +use log::error; use log::info; use tokio::sync::Semaphore; +use tokio::task::JoinSet; use tokio::time::sleep; use crate::backoff::Backoff; @@ -92,6 +93,35 @@ pub(crate) fn is_grpc_error(e: &Error) -> bool { matches!(e, Error::GrpcAPI(_) | Error::Grpc(_)) } +async fn collect_join_set_results( + mut join_set: JoinSet<(usize, T)>, + task_count: usize, + handler_name: &str, +) -> Result> +where + T: Send + 'static, +{ + let mut results = (0..task_count).map(|_| None).collect::>(); + while let Some(join_result) = join_set.join_next().await { + match join_result { + Ok((idx, val)) => results[idx] = Some(val), + Err(e) => { + error!( + "{}: failed to join task ({} tasks): {}", + handler_name, task_count, e + ); + join_set.shutdown().await; + return Err(Error::JoinError(e)); + } + } + } + + Ok(results + .into_iter() + .map(|result| result.expect("all spawned tasks should return a result")) + .collect()) +} + pub struct RetryableMultiRegion { pub(super) inner: P, pub pd_client: Arc, @@ -117,23 +147,39 @@ where preserve_region_results: bool, ) -> Result<::Result> { let shards = current_plan.shards(&pd_client).collect::>().await; - debug!("single_plan_handler, shards: {}", shards.len()); - let mut handles = Vec::with_capacity(shards.len()); - for shard in shards { - let (shard, region) = shard?; + let shards_len = shards.len(); + debug!("single_plan_handler, shards: {}", shards_len); + let mut join_set = JoinSet::new(); + for (idx, shard) in shards.into_iter().enumerate() { + let (shard, region) = match shard { + Ok(shard) => shard, + Err(e) => { + join_set.shutdown().await; + return Err(e); + } + }; let clone = current_plan.clone_then_apply_shard(shard); - let handle = tokio::spawn(Self::single_shard_handler( - pd_client.clone(), - clone, - region, - backoff.clone(), - permits.clone(), - preserve_region_results, - )); - handles.push(handle); + let pd_client = pd_client.clone(); + let backoff = backoff.clone(); + let permits = permits.clone(); + join_set.spawn(async move { + ( + idx, + Self::single_shard_handler( + pd_client, + clone, + region, + backoff, + permits, + preserve_region_results, + ) + .await, + ) + }); } - let results = try_join_all(handles).await?; + let results = collect_join_set_results(join_set, shards_len, "single_plan_handler").await?; + if preserve_region_results { Ok(results .into_iter() @@ -449,19 +495,24 @@ where async fn execute(&self) -> Result { let concurrency_permits = Arc::new(Semaphore::new(MULTI_STORES_CONCURRENCY)); let stores = self.pd_client.clone().all_stores().await?; - let mut handles = Vec::with_capacity(stores.len()); - for store in stores { + let stores_len = stores.len(); + let mut join_set = JoinSet::new(); + for (idx, store) in stores.into_iter().enumerate() { let mut clone = self.inner.clone(); clone.apply_store(&store); - let handle = tokio::spawn(Self::single_store_handler( - clone, - self.backoff.clone(), - concurrency_permits.clone(), - )); - handles.push(handle); + let backoff = self.backoff.clone(); + let concurrency_permits = concurrency_permits.clone(); + join_set.spawn(async move { + ( + idx, + Self::single_store_handler(clone, backoff, concurrency_permits).await, + ) + }); } - let results = try_join_all(handles).await?; - Ok(results.into_iter().collect::>()) + + let results = + collect_join_set_results(join_set, stores_len, "single_store_handler").await?; + Ok(results) } } @@ -921,6 +972,8 @@ impl HasRegionError for ResponseWithShard