Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,23 @@ serde_json = "1"
take_mut = "0.2.2"
thiserror = "1"
tokio = { version = "1", features = ["sync", "rt-multi-thread", "macros"] }
# Keep these compatible with rust-toolchain 1.84.1.
indexmap = "=2.13.1"
icu_collections = "<2.2"
icu_locale_core = "<2.2"
icu_normalizer = "<2.2"
icu_normalizer_data = "<2.2"
icu_properties = "<2.2"
icu_properties_data = "<2.2"
icu_provider = "<2.2"

tonic = { version = "0.10", features = ["tls", "gzip"] }

[dev-dependencies]
clap = "2"
env_logger = "0.10"
fail = { version = "0.4", features = ["failpoints"] }
proptest = "1"
proptest = "<1.11"
proptest-derive = "0.5.1"
reqwest = { version = "0.11", features = ["json", "native-tls-vendored"] }
rstest = "0.18.2"
Expand Down
3 changes: 2 additions & 1 deletion proto-build/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ edition = "2021"
[dependencies]
glob = "0.3"
tonic-build = { version = "0.10", features = ["cleanup-markdown"] }
# Keep this compatible with rust-toolchain 1.84.1.
# Keep these compatible with rust-toolchain 1.84.1.
indexmap = "=2.13.1"
tempfile = "=3.14.0"
120 changes: 95 additions & 25 deletions src/request/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ use std::sync::Arc;

use async_recursion::async_recursion;
use async_trait::async_trait;
use futures::future::try_join_all;
use futures::prelude::*;
use log::debug;
use log::error;
use log::info;
use tokio::sync::Semaphore;
use tokio::task::JoinSet;
use tokio::time::sleep;

use crate::backoff::Backoff;
Expand Down Expand Up @@ -92,6 +93,35 @@ pub(crate) fn is_grpc_error(e: &Error) -> bool {
matches!(e, Error::GrpcAPI(_) | Error::Grpc(_))
}

async fn collect_join_set_results<T>(
mut join_set: JoinSet<(usize, T)>,
task_count: usize,
handler_name: &str,
) -> Result<Vec<T>>
where
T: Send + 'static,
{
let mut results = (0..task_count).map(|_| None).collect::<Vec<_>>();
while let Some(join_result) = join_set.join_next().await {
match join_result {
Ok((idx, val)) => results[idx] = Some(val),
Err(e) => {
error!(
"{}: failed to join task ({} tasks): {}",
handler_name, task_count, e
);
join_set.shutdown().await;
return Err(Error::JoinError(e));
}
}
}

Ok(results
.into_iter()
.map(|result| result.expect("all spawned tasks should return a result"))
.collect())
}

pub struct RetryableMultiRegion<P: Plan, PdC: PdClient> {
pub(super) inner: P,
pub pd_client: Arc<PdC>,
Expand All @@ -117,23 +147,39 @@ where
preserve_region_results: bool,
) -> Result<<Self as Plan>::Result> {
let shards = current_plan.shards(&pd_client).collect::<Vec<_>>().await;
debug!("single_plan_handler, shards: {}", shards.len());
let mut handles = Vec::with_capacity(shards.len());
for shard in shards {
let (shard, region) = shard?;
let shards_len = shards.len();
debug!("single_plan_handler, shards: {}", shards_len);
let mut join_set = JoinSet::new();
for (idx, shard) in shards.into_iter().enumerate() {
let (shard, region) = match shard {
Ok(shard) => shard,
Err(e) => {
join_set.shutdown().await;
return Err(e);
}
};
let clone = current_plan.clone_then_apply_shard(shard);
let handle = tokio::spawn(Self::single_shard_handler(
pd_client.clone(),
clone,
region,
backoff.clone(),
permits.clone(),
preserve_region_results,
));
handles.push(handle);
let pd_client = pd_client.clone();
let backoff = backoff.clone();
let permits = permits.clone();
join_set.spawn(async move {
(
idx,
Self::single_shard_handler(
pd_client,
clone,
region,
backoff,
permits,
preserve_region_results,
)
.await,
)
});
}

let results = try_join_all(handles).await?;
let results = collect_join_set_results(join_set, shards_len, "single_plan_handler").await?;

if preserve_region_results {
Ok(results
.into_iter()
Expand Down Expand Up @@ -449,19 +495,24 @@ where
async fn execute(&self) -> Result<Self::Result> {
let concurrency_permits = Arc::new(Semaphore::new(MULTI_STORES_CONCURRENCY));
let stores = self.pd_client.clone().all_stores().await?;
let mut handles = Vec::with_capacity(stores.len());
for store in stores {
let stores_len = stores.len();
let mut join_set = JoinSet::new();
for (idx, store) in stores.into_iter().enumerate() {
let mut clone = self.inner.clone();
clone.apply_store(&store);
let handle = tokio::spawn(Self::single_store_handler(
clone,
self.backoff.clone(),
concurrency_permits.clone(),
));
handles.push(handle);
let backoff = self.backoff.clone();
let concurrency_permits = concurrency_permits.clone();
join_set.spawn(async move {
(
idx,
Self::single_store_handler(clone, backoff, concurrency_permits).await,
)
});
}
let results = try_join_all(handles).await?;
Ok(results.into_iter().collect::<Vec<_>>())

let results =
collect_join_set_results(join_set, stores_len, "single_store_handler").await?;
Ok(results)
}
}

Expand Down Expand Up @@ -921,6 +972,8 @@ impl<Resp: HasRegionError, Shard> HasRegionError for ResponseWithShard<Resp, Sha

#[cfg(test)]
mod test {
use std::time::Duration;

use futures::stream::BoxStream;
use futures::stream::{self};

Expand Down Expand Up @@ -973,4 +1026,21 @@ mod test {
};
assert!(plan.execute().await.is_err())
}

#[tokio::test]
async fn test_join_set_results_keep_spawn_order() {
let mut join_set = JoinSet::new();
for (idx, delay_ms) in [(0, 30), (1, 10), (2, 20)] {
join_set.spawn(async move {
sleep(Duration::from_millis(delay_ms)).await;
(idx, idx)
});
}

let results = collect_join_set_results(join_set, 3, "test_handler")
.await
.unwrap();

assert_eq!(results, vec![0, 1, 2]);
}
}
Loading