Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/rustc_interface/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ pub(crate) fn run_in_thread_pool_with_globals<
// locals to it. The new thread runs the deadlock handler.

let current_gcx2 = current_gcx2.clone();
let registry = rustc_thread_pool::Registry::current();
let registry = rustc_thread_pool::Registry::current_cloned();
let session_globals = rustc_span::with_session_globals(|session_globals| {
session_globals as *const SessionGlobals as usize
});
Expand Down
11 changes: 6 additions & 5 deletions compiler/rustc_middle/src/query/job.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,12 @@ impl<'tcx> QueryLatch<'tcx> {
fn set(&self) {
let mut waiters_guard = self.waiters.lock();
let waiters = waiters_guard.take().unwrap(); // mark the latch as complete
let registry = rustc_thread_pool::Registry::current();
for waiter in waiters {
rustc_thread_pool::mark_unblocked(&registry);
waiter.condvar.notify_one();
}
rustc_thread_pool::Registry::current(|registry| {
for waiter in waiters {
rustc_thread_pool::mark_unblocked(registry);
waiter.condvar.notify_one();
}
});
}

/// Removes a single waiter from the list of waiters.
Expand Down
86 changes: 45 additions & 41 deletions compiler/rustc_thread_pool/src/broadcast/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::fmt;
use std::marker::PhantomData;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::{fmt, ptr};

use crate::job::{ArcJob, StackJob};
use crate::latch::{CountLatch, LatchRef};
Expand All @@ -24,7 +24,7 @@ where
R: Send,
{
// We assert that current registry has not terminated.
unsafe { broadcast_in(op, &Registry::current()) }
Registry::current(|registry| unsafe { broadcast_in(op, registry) })
}

/// Spawns an asynchronous task on every thread in this thread-pool. This task
Expand All @@ -40,7 +40,7 @@ where
OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
{
// We assert that current registry has not terminated.
unsafe { spawn_broadcast_in(op, &Registry::current()) }
Registry::current(|registry| unsafe { spawn_broadcast_in(op, registry) })
}

/// Provides context to a closure called by `broadcast`.
Expand All @@ -53,9 +53,9 @@ pub struct BroadcastContext<'a> {

impl<'a> BroadcastContext<'a> {
pub(super) fn with<R>(f: impl FnOnce(BroadcastContext<'_>) -> R) -> R {
let worker_thread = WorkerThread::current();
assert!(!worker_thread.is_null());
f(BroadcastContext { worker: unsafe { &*worker_thread }, _marker: PhantomData })
WorkerThread::current(|worker_thread| {
f(BroadcastContext { worker: worker_thread.as_ref().unwrap(), _marker: PhantomData })
})
}

/// Our index amongst the broadcast threads (ranges from `0..self.num_threads()`).
Expand Down Expand Up @@ -98,41 +98,45 @@ where
OP: Fn(BroadcastContext<'_>) -> R + Sync,
R: Send,
{
let current_thread = WorkerThread::current();
let current_thread_addr = current_thread.expose_provenance();
let started = &AtomicBool::new(false);
let f = move |injected: bool| {
debug_assert!(injected);

// Mark as started if we are the thread that initiated that broadcast.
if current_thread_addr == WorkerThread::current().expose_provenance() {
started.store(true, Ordering::Relaxed);
}

BroadcastContext::with(&op)
};

let n_threads = registry.num_threads();
let current_thread = unsafe { current_thread.as_ref() };
let tlv = crate::tlv::get();
let latch = CountLatch::with_count(n_threads, current_thread);
let jobs: Vec<_> =
(0..n_threads).map(|_| StackJob::new(tlv, &f, LatchRef::new(&latch))).collect();
let job_refs = jobs.iter().map(|job| unsafe { job.as_job_ref() });

registry.inject_broadcast(job_refs);

let current_thread_job_id = current_thread
.and_then(|worker| (registry.id() == worker.registry.id()).then(|| worker))
.map(|worker| unsafe { jobs[worker.index()].as_job_ref() }.id());

// Wait for all jobs to complete, then collect the results, maybe propagating a panic.
latch.wait(
current_thread,
|| started.load(Ordering::Relaxed),
|job| Some(job.id()) == current_thread_job_id,
);
jobs.into_iter().map(|job| unsafe { job.into_result() }).collect()
WorkerThread::current(|current_thread| {
let current_thread = current_thread.as_ref();
let current_thread_addr = current_thread.map_or(ptr::null(), |p| p).addr();
let started = &AtomicBool::new(false);
let f = move |injected: bool| {
debug_assert!(injected);

WorkerThread::current(|worker_thread| {
let worker_thread_addr = worker_thread.as_ref().map_or(ptr::null(), |p| p).addr();
if current_thread_addr == worker_thread_addr {
// Mark as started if we are the thread that initiated that broadcast.
started.store(true, Ordering::Relaxed);
}
});

BroadcastContext::with(&op)
};

let n_threads = registry.num_threads();
let tlv = crate::tlv::get();
let latch = CountLatch::with_count(n_threads, current_thread);
let jobs: Vec<_> =
(0..n_threads).map(|_| StackJob::new(tlv, &f, LatchRef::new(&latch))).collect();
let job_refs = jobs.iter().map(|job| unsafe { job.as_job_ref() });

registry.inject_broadcast(job_refs);

let current_thread_job_id = current_thread
.filter(|worker| registry.id() == worker.registry.id())
.map(|worker| unsafe { jobs[worker.index()].as_job_ref() }.id());

// Wait for all jobs to complete, then collect the results, maybe propagating a panic.
latch.wait(
current_thread,
|| started.load(Ordering::Relaxed),
|job| Some(job.id()) == current_thread_job_id,
);
jobs.into_iter().map(|job| unsafe { job.into_result() }).collect()
})
}

/// Execute `op` on every thread in the pool. It will be executed on each
Expand Down
Loading
Loading