Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 107 additions & 28 deletions datafusion/physical-plan/src/repartition/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,21 @@ type MaybeBatch = Option<Result<RepartitionBatch>>;
type InputPartitionsToCurrentPartitionSender = Vec<DistributionSender<MaybeBatch>>;
type InputPartitionsToCurrentPartitionReceiver = Vec<DistributionReceiver<MaybeBatch>>;

/// Threshold at which an input task flushes its per-partition local buffer.
/// Larger values amortize the reservation + channel overhead over more data
/// but increase how long a partition waits before its downstream receiver
/// sees new data. 256 KiB is a compromise chosen to be well above typical
/// per-partition sub-batch sizes (a ~512 KiB input batch split N-ways) while
/// keeping latency low.
const FLUSH_THRESHOLD_BYTES: usize = 256 * 1024;

/// Per-output-partition accumulation buffer used by [`RepartitionExec::pull_from_input`].
#[derive(Default)]
struct PartitionBuffer {
batches: Vec<RecordBatch>,
size: usize,
}

/// Output channel with its associated memory reservation and spill writer
struct OutputChannel {
sender: DistributionSender<MaybeBatch>,
Expand Down Expand Up @@ -1345,6 +1360,14 @@ impl RepartitionExec {
/// output partitions based on the desired partitioning
///
/// `output_channels` holds the output sending channels for each output partition
///
/// Batches are accumulated in per-partition local buffers and flushed only
/// when a buffer reaches [`FLUSH_THRESHOLD_BYTES`]. This amortizes the
/// memory reservation lock across many small batches: each flush performs
/// a single `try_grow` call instead of one per batch. When hash-partitioning
/// a batch into N output partitions, each output gets ~`input_size / N`
/// bytes, which can be much smaller than the typical batch size, so the
/// saving is significant.
async fn pull_from_input(
mut stream: SendableRecordBatchStream,
mut output_channels: HashMap<usize, OutputChannel>,
Expand Down Expand Up @@ -1374,8 +1397,14 @@ impl RepartitionExec {
}
};

// Per-output-partition local accumulation buffers. Indexed by partition id.
let num_output_partitions = partitioner.num_partitions();
let mut local_buffers: Vec<PartitionBuffer> = (0..num_output_partitions)
.map(|_| PartitionBuffer::default())
.collect();

// While there are still outputs to send to, keep pulling inputs
let mut batches_until_yield = partitioner.num_partitions();
let mut batches_until_yield = num_output_partitions;
while !output_channels.is_empty() {
// fetch the next batch
let timer = metrics.fetch_time.timer();
Expand All @@ -1397,34 +1426,23 @@ impl RepartitionExec {
let (partition, batch) = res?;
let size = batch.get_array_memory_size();

let timer = metrics.send_time[partition].timer();
// if there is still a receiver, send to it
if let Some(channel) = output_channels.get_mut(&partition) {
let (batch_to_send, is_memory_batch) =
match channel.reservation.lock().try_grow(size) {
Ok(_) => {
// Memory available - send in-memory batch
(RepartitionBatch::Memory(batch), true)
}
Err(_) => {
// We're memory limited - spill to SpillPool
// SpillPool handles file handle reuse and rotation
channel.spill_writer.push_batch(&batch)?;
// Send marker indicating batch was spilled
(RepartitionBatch::Spilled, false)
}
};
let buffer = &mut local_buffers[partition];
buffer.batches.push(batch);
buffer.size += size;

if channel.sender.send(Some(Ok(batch_to_send))).await.is_err() {
// If the other end has hung up, it was an early shutdown (e.g. LIMIT)
// Only shrink memory if it was a memory batch
if is_memory_batch {
channel.reservation.lock().shrink(size);
}
output_channels.remove(&partition);
}
// Flush when the local buffer for this partition has accumulated
// enough data to amortize the reservation + channel overhead.
if buffer.size >= FLUSH_THRESHOLD_BYTES
&& output_channels.contains_key(&partition)
{
Self::flush_partition(
partition,
&mut local_buffers[partition],
&mut output_channels,
&metrics,
)
.await?;
}
timer.done();
}

// If the input stream is endless, we may spin forever and
Expand All @@ -1445,17 +1463,78 @@ impl RepartitionExec {
// in that case anyways
if batches_until_yield == 0 {
tokio::task::yield_now().await;
batches_until_yield = partitioner.num_partitions();
batches_until_yield = num_output_partitions;
} else {
batches_until_yield -= 1;
}
}

// Flush any remaining batches in local buffers for partitions whose
// receivers are still alive.
for (partition, buffer) in local_buffers.iter_mut().enumerate() {
if buffer.batches.is_empty() || !output_channels.contains_key(&partition) {
continue;
}
Self::flush_partition(partition, buffer, &mut output_channels, &metrics)
.await?;
}

// Spill writers will auto-finalize when dropped
// No need for explicit flush
Ok(())
}

/// Flush the accumulated batches for `partition` to the output channel.
///
/// Performs a single `try_grow` call for the total buffered size; if that
/// fails, all accumulated batches are spilled together. The individual
/// batches are then sent through the channel as-is (no concatenation).
async fn flush_partition(
partition: usize,
buffer: &mut PartitionBuffer,
output_channels: &mut HashMap<usize, OutputChannel>,
metrics: &RepartitionMetrics,
) -> Result<()> {
let batches = std::mem::take(&mut buffer.batches);
let total_size = buffer.size;
buffer.size = 0;

let Some(channel) = output_channels.get_mut(&partition) else {
return Ok(());
};

let timer = metrics.send_time[partition].timer();

// Single reservation call for the entire flush. If it fails, spill
// all accumulated batches to the spill pool.
let in_memory = channel.reservation.lock().try_grow(total_size).is_ok();

for batch in batches {
let batch_to_send = if in_memory {
RepartitionBatch::Memory(batch)
} else {
channel.spill_writer.push_batch(&batch)?;
RepartitionBatch::Spilled
};

if channel.sender.send(Some(Ok(batch_to_send))).await.is_err() {
// Receiver hung up (e.g. early LIMIT). Release the full
// reserved amount — batches still sitting in the channel
// buffer will be dropped by the receiver without going
// through the shrink path, and any unsent local batches
// are dropped here.
if in_memory {
channel.reservation.lock().shrink(total_size);
}
output_channels.remove(&partition);
timer.done();
return Ok(());
}
}
timer.done();
Ok(())
}

/// Waits for `input_task` which is consuming one of the inputs to
/// complete. Upon each successful completion, sends a `None` to
/// each of the output tx channels to signal one of the inputs is
Expand Down
Loading