diff --git a/compiler/rustc_interface/src/util.rs b/compiler/rustc_interface/src/util.rs index 24b23cc4199e9..64fb275882ed3 100644 --- a/compiler/rustc_interface/src/util.rs +++ b/compiler/rustc_interface/src/util.rs @@ -222,7 +222,7 @@ pub(crate) fn run_in_thread_pool_with_globals< // locals to it. The new thread runs the deadlock handler. let current_gcx2 = current_gcx2.clone(); - let registry = rustc_thread_pool::Registry::current(); + let registry = rustc_thread_pool::Registry::current_cloned(); let session_globals = rustc_span::with_session_globals(|session_globals| { session_globals as *const SessionGlobals as usize }); diff --git a/compiler/rustc_middle/src/query/job.rs b/compiler/rustc_middle/src/query/job.rs index 8c78bf24287e0..5a0683334a74d 100644 --- a/compiler/rustc_middle/src/query/job.rs +++ b/compiler/rustc_middle/src/query/job.rs @@ -120,11 +120,12 @@ impl<'tcx> QueryLatch<'tcx> { fn set(&self) { let mut waiters_guard = self.waiters.lock(); let waiters = waiters_guard.take().unwrap(); // mark the latch as complete - let registry = rustc_thread_pool::Registry::current(); - for waiter in waiters { - rustc_thread_pool::mark_unblocked(®istry); - waiter.condvar.notify_one(); - } + rustc_thread_pool::Registry::current(|registry| { + for waiter in waiters { + rustc_thread_pool::mark_unblocked(registry); + waiter.condvar.notify_one(); + } + }); } /// Removes a single waiter from the list of waiters. diff --git a/compiler/rustc_thread_pool/src/broadcast/mod.rs b/compiler/rustc_thread_pool/src/broadcast/mod.rs index 1707ebb59883c..ac6245f96a6ae 100644 --- a/compiler/rustc_thread_pool/src/broadcast/mod.rs +++ b/compiler/rustc_thread_pool/src/broadcast/mod.rs @@ -1,7 +1,7 @@ -use std::fmt; use std::marker::PhantomData; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; +use std::{fmt, ptr}; use crate::job::{ArcJob, StackJob}; use crate::latch::{CountLatch, LatchRef}; @@ -24,7 +24,7 @@ where R: Send, { // We assert that current registry has not terminated. - unsafe { broadcast_in(op, &Registry::current()) } + Registry::current(|registry| unsafe { broadcast_in(op, registry) }) } /// Spawns an asynchronous task on every thread in this thread-pool. This task @@ -40,7 +40,7 @@ where OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static, { // We assert that current registry has not terminated. - unsafe { spawn_broadcast_in(op, &Registry::current()) } + Registry::current(|registry| unsafe { spawn_broadcast_in(op, registry) }) } /// Provides context to a closure called by `broadcast`. @@ -53,9 +53,9 @@ pub struct BroadcastContext<'a> { impl<'a> BroadcastContext<'a> { pub(super) fn with(f: impl FnOnce(BroadcastContext<'_>) -> R) -> R { - let worker_thread = WorkerThread::current(); - assert!(!worker_thread.is_null()); - f(BroadcastContext { worker: unsafe { &*worker_thread }, _marker: PhantomData }) + WorkerThread::current(|worker_thread| { + f(BroadcastContext { worker: worker_thread.as_ref().unwrap(), _marker: PhantomData }) + }) } /// Our index amongst the broadcast threads (ranges from `0..self.num_threads()`). @@ -98,41 +98,45 @@ where OP: Fn(BroadcastContext<'_>) -> R + Sync, R: Send, { - let current_thread = WorkerThread::current(); - let current_thread_addr = current_thread.expose_provenance(); - let started = &AtomicBool::new(false); - let f = move |injected: bool| { - debug_assert!(injected); - - // Mark as started if we are the thread that initiated that broadcast. - if current_thread_addr == WorkerThread::current().expose_provenance() { - started.store(true, Ordering::Relaxed); - } - - BroadcastContext::with(&op) - }; - - let n_threads = registry.num_threads(); - let current_thread = unsafe { current_thread.as_ref() }; - let tlv = crate::tlv::get(); - let latch = CountLatch::with_count(n_threads, current_thread); - let jobs: Vec<_> = - (0..n_threads).map(|_| StackJob::new(tlv, &f, LatchRef::new(&latch))).collect(); - let job_refs = jobs.iter().map(|job| unsafe { job.as_job_ref() }); - - registry.inject_broadcast(job_refs); - - let current_thread_job_id = current_thread - .and_then(|worker| (registry.id() == worker.registry.id()).then(|| worker)) - .map(|worker| unsafe { jobs[worker.index()].as_job_ref() }.id()); - - // Wait for all jobs to complete, then collect the results, maybe propagating a panic. - latch.wait( - current_thread, - || started.load(Ordering::Relaxed), - |job| Some(job.id()) == current_thread_job_id, - ); - jobs.into_iter().map(|job| unsafe { job.into_result() }).collect() + WorkerThread::current(|current_thread| { + let current_thread = current_thread.as_ref(); + let current_thread_addr = current_thread.map_or(ptr::null(), |p| p).addr(); + let started = &AtomicBool::new(false); + let f = move |injected: bool| { + debug_assert!(injected); + + WorkerThread::current(|worker_thread| { + let worker_thread_addr = worker_thread.as_ref().map_or(ptr::null(), |p| p).addr(); + if current_thread_addr == worker_thread_addr { + // Mark as started if we are the thread that initiated that broadcast. + started.store(true, Ordering::Relaxed); + } + }); + + BroadcastContext::with(&op) + }; + + let n_threads = registry.num_threads(); + let tlv = crate::tlv::get(); + let latch = CountLatch::with_count(n_threads, current_thread); + let jobs: Vec<_> = + (0..n_threads).map(|_| StackJob::new(tlv, &f, LatchRef::new(&latch))).collect(); + let job_refs = jobs.iter().map(|job| unsafe { job.as_job_ref() }); + + registry.inject_broadcast(job_refs); + + let current_thread_job_id = current_thread + .filter(|worker| registry.id() == worker.registry.id()) + .map(|worker| unsafe { jobs[worker.index()].as_job_ref() }.id()); + + // Wait for all jobs to complete, then collect the results, maybe propagating a panic. + latch.wait( + current_thread, + || started.load(Ordering::Relaxed), + |job| Some(job.id()) == current_thread_job_id, + ); + jobs.into_iter().map(|job| unsafe { job.into_result() }).collect() + }) } /// Execute `op` on every thread in the pool. It will be executed on each diff --git a/compiler/rustc_thread_pool/src/registry.rs b/compiler/rustc_thread_pool/src/registry.rs index 92bb8961e7dfc..20efb92a362da 100644 --- a/compiler/rustc_thread_pool/src/registry.rs +++ b/compiler/rustc_thread_pool/src/registry.rs @@ -1,6 +1,7 @@ -use std::cell::Cell; +use std::cell::{Cell, UnsafeCell}; use std::collections::hash_map::DefaultHasher; use std::hash::Hasher; +use std::mem::ManuallyDrop; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, Mutex, Once}; use std::{fmt, io, mem, ptr, thread}; @@ -208,41 +209,36 @@ where } fn default_global_registry() -> Result, ThreadPoolBuildError> { - let result = Registry::new(ThreadPoolBuilder::new()); + Registry::new(ThreadPoolBuilder::new()).or_else(|e| { + // If we're running in an environment that doesn't support threads at all, we can fall back to + // using the current thread alone. This is crude, and probably won't work for non-blocking + // calls like `spawn` or `broadcast_spawn`, but a lot of stuff does work fine. + // + // Notably, this allows current WebAssembly targets to work even though their threading support + // is stubbed out, and we won't have to change anything if they do add real threading. + if e.is_unsupported() && WorkerThread::current(Option::is_none) { + let builder = ThreadPoolBuilder::new().num_threads(1).spawn_handler(|thread| { + // Rather than starting a new thread, we're just taking over the current thread + // *without* running the main loop, so we can still return from here. + // The WorkerThread is leaked, but we never shutdown the global pool anyway. + WorkerThread::set_current(WorkerThread::from(thread)); + WorkerThread::current(|worker_thread| { + let worker_thread = worker_thread.as_ref().unwrap(); + let latch = &worker_thread.registry.thread_infos[worker_thread.index].primed; + // let registry know we are ready to do work + unsafe { Latch::set(latch) } + }); - // If we're running in an environment that doesn't support threads at all, we can fall back to - // using the current thread alone. This is crude, and probably won't work for non-blocking - // calls like `spawn` or `broadcast_spawn`, but a lot of stuff does work fine. - // - // Notably, this allows current WebAssembly targets to work even though their threading support - // is stubbed out, and we won't have to change anything if they do add real threading. - let unsupported = matches!(&result, Err(e) if e.is_unsupported()); - if unsupported && WorkerThread::current().is_null() { - let builder = ThreadPoolBuilder::new().num_threads(1).spawn_handler(|thread| { - // Rather than starting a new thread, we're just taking over the current thread - // *without* running the main loop, so we can still return from here. - // The WorkerThread is leaked, but we never shutdown the global pool anyway. - let worker_thread = Box::leak(Box::new(WorkerThread::from(thread))); - let registry = &*worker_thread.registry; - let index = worker_thread.index; + Ok(()) + }); - unsafe { - WorkerThread::set_current(worker_thread); - - // let registry know we are ready to do work - Latch::set(®istry.thread_infos[index].primed); + if let Ok(fallback_result) = Registry::new(builder) { + return Ok(fallback_result); } - - Ok(()) - }); - - let fallback_result = Registry::new(builder); - if fallback_result.is_ok() { - return fallback_result; } - } - result + Err(e) + }) } struct Terminator<'a>(&'a Arc); @@ -319,47 +315,47 @@ impl Registry { Ok(registry) } - pub fn current() -> Arc { - unsafe { - let worker_thread = WorkerThread::current(); - let registry = if worker_thread.is_null() { - global_registry() - } else { - &(*worker_thread).registry - }; - Arc::clone(registry) - } + pub fn current(f: F) -> R + where + F: FnOnce(&Arc) -> R, + { + WorkerThread::current(|worker_thread| { + f(match worker_thread { + Some(worker_thread) => &worker_thread.registry, + None => global_registry(), + }) + }) + } + + pub fn current_cloned() -> Arc { + Registry::current(Arc::clone) } /// Returns the number of threads in the current registry. This /// is better than `Registry::current().num_threads()` because it /// avoids incrementing the `Arc`. pub(super) fn current_num_threads() -> usize { - unsafe { - let worker_thread = WorkerThread::current(); - if worker_thread.is_null() { - global_registry().num_threads() - } else { - (*worker_thread).registry.num_threads() - } - } + Registry::current(|registry| registry.num_threads()) } /// Returns the current `WorkerThread` if it's part of this `Registry`. - pub(super) fn current_thread(&self) -> Option<&WorkerThread> { - unsafe { - let worker = WorkerThread::current().as_ref()?; - if worker.registry().id() == self.id() { Some(worker) } else { None } - } + pub(super) fn current_thread_with(&self, f: F) -> R + where + F: FnOnce(Option<&WorkerThread>) -> R, + { + WorkerThread::current(|worker| { + f(worker.as_ref().filter(|worker| worker.registry.id() == self.id())) + }) } /// Returns an opaque identifier for this registry. pub(super) fn id(&self) -> RegistryId { // We can rely on `self` not to change since we only ever create // registries that are boxed up in an `Arc` (see `new()` above). - RegistryId { addr: self as *const Self as usize } + RegistryId { addr: (self as *const Self).addr() } } + #[inline] pub(super) fn num_threads(&self) -> usize { self.thread_infos.len() } @@ -417,14 +413,13 @@ impl Registry { /// worker thread for the registry, this will push onto the /// deque. Else, it will inject from the outside (which is slower). pub(super) fn inject_or_push(&self, job_ref: JobRef) { - let worker_thread = WorkerThread::current(); - unsafe { - if !worker_thread.is_null() && (*worker_thread).registry().id() == self.id() { - (*worker_thread).push(job_ref); + self.current_thread_with(|worker_thread| { + if let Some(worker_thread) = worker_thread { + unsafe { worker_thread.push(job_ref) }; } else { - self.inject(job_ref); + self.inject(job_ref) } - } + }) } /// Push a job into the "external jobs" queue; it will be taken by @@ -503,19 +498,19 @@ impl Registry { OP: FnOnce(&WorkerThread, bool) -> R + Send, R: Send, { - unsafe { - let worker_thread = WorkerThread::current(); - if worker_thread.is_null() { - self.in_worker_cold(op) - } else if (*worker_thread).registry().id() != self.id() { - self.in_worker_cross(&*worker_thread, op) - } else { - // Perfectly valid to give them a `&T`: this is the - // current thread, so we know the data structure won't be - // invalidated until we return. - op(&*worker_thread, false) + WorkerThread::current(|worker_thread| { + let Some(worker_thread) = worker_thread else { + return unsafe { self.in_worker_cold(op) }; + }; + if worker_thread.registry().id() != self.id() { + return unsafe { self.in_worker_cross(worker_thread, op) }; } - } + + // Perfectly valid to give them a `&T`: this is the + // current thread, so we know the data structure won't be + // invalidated until we return. + op(&*worker_thread, false) + }) } #[cold] @@ -528,13 +523,12 @@ impl Registry { LOCK_LATCH.with(|l| { // This thread isn't a member of *any* thread pool, so just block. - debug_assert!(WorkerThread::current().is_null()); + debug_assert!(WorkerThread::current(Option::is_none)); let job = StackJob::new( Tlv::null(), |injected| { - let worker_thread = WorkerThread::current(); - assert!(injected && !worker_thread.is_null()); - op(unsafe { &*worker_thread }, true) + assert!(injected); + WorkerThread::current(|worker_thread| op(worker_thread.as_ref().unwrap(), true)) }, LatchRef::new(l), ); @@ -560,9 +554,8 @@ impl Registry { let job = StackJob::new( Tlv::null(), |injected| { - let worker_thread = WorkerThread::current(); - assert!(injected && !worker_thread.is_null()); - op(unsafe { &*worker_thread }, true) + assert!(injected); + WorkerThread::current(|worker_thread| op(worker_thread.as_ref().unwrap(), true)) }, latch, ); @@ -618,12 +611,10 @@ impl Registry { /// if no other worker thread is active #[inline] pub fn mark_blocked() { - let worker_thread = WorkerThread::current(); - assert!(!worker_thread.is_null()); - unsafe { - let registry = &(*worker_thread).registry; + WorkerThread::current(|worker_thread| { + let registry = &*worker_thread.as_ref().unwrap().registry; registry.sleep.mark_blocked(®istry.deadlock_handler) - } + }); } /// Mark a previously blocked Rayon worker thread as unblocked @@ -695,7 +686,8 @@ pub(super) struct WorkerThread { // worker is fully unwound. Using an unsafe pointer avoids the need // for a RefCell etc. thread_local! { - static WORKER_THREAD_STATE: Cell<*const WorkerThread> = const { Cell::new(ptr::null()) }; + static WORKER_THREAD_STATE: ManuallyDrop>> = + const { ManuallyDrop::new(UnsafeCell::new(None)) }; } impl From for WorkerThread { @@ -713,30 +705,53 @@ impl From for WorkerThread { impl Drop for WorkerThread { fn drop(&mut self) { - // Undo `set_current` - WORKER_THREAD_STATE.with(|t| { - assert!(t.get().eq(&(self as *const _))); - t.set(ptr::null()); - }); + assert!(WorkerThread::current(Option::is_none)); } } impl WorkerThread { - /// Gets the `WorkerThread` index for the current thread; returns - /// NULL if this is not a worker thread. This pointer is valid + /// Gets the `WorkerThread` index for the current thread; passes + /// None if this is not a worker thread. This pointer is valid /// anywhere on the current thread. #[inline] - pub(super) fn current() -> *const WorkerThread { - WORKER_THREAD_STATE.with(Cell::get) + pub(super) fn current(f: F) -> R + where + F: FnOnce(&Option) -> R, + { + WORKER_THREAD_STATE.with(|t| unsafe { f(&*t.get()) }) } /// Sets `self` as the worker thread index for the current thread. /// This is done during worker thread startup. - unsafe fn set_current(thread: *const WorkerThread) { - WORKER_THREAD_STATE.with(|t| { - assert!(t.get().is_null()); - t.set(thread); - }); + fn set_current(thread: WorkerThread) { + WORKER_THREAD_STATE.with(|t| unsafe { + let t = &mut *t.get(); + assert!(t.is_none()); + *t = Some(thread) + }) + } + + /// Sets `self` as the worker thread index for the current thread. + /// This is done during worker thread startup. + fn with_current(thread: WorkerThread, f: F) -> R + where + F: FnOnce(&WorkerThread) -> R, + { + WORKER_THREAD_STATE.with(|t| unsafe { + assert!((&*t.get()).is_none()); + t.get().write(Some(thread)); + + struct Guard(*mut Option); + + impl Drop for Guard { + fn drop(&mut self) { + unsafe { self.0.replace(None) }; + } + } + + let _g = Guard(t.get()); + f((&*t.get()).as_ref().unwrap_unchecked()) + }) } /// Returns the registry that owns this worker thread. @@ -932,7 +947,8 @@ impl WorkerThread { } unsafe fn wait_until_out_of_work(&self) { - debug_assert_eq!(self as *const _, WorkerThread::current()); + debug_assert!(WorkerThread::current(|curr| ptr::eq(self, curr.as_ref().unwrap()))); + let registry = &*self.registry; let index = self.index; @@ -1020,36 +1036,36 @@ impl WorkerThread { } unsafe fn main_loop(thread: ThreadBuilder) { - let worker_thread = &WorkerThread::from(thread); - unsafe { WorkerThread::set_current(worker_thread) }; - let registry = &*worker_thread.registry; - let index = worker_thread.index; - - // let registry know we are ready to do work - unsafe { Latch::set(®istry.thread_infos[index].primed) }; - - // Worker threads should not panic. If they do, just abort, as the - // internal state of the threadpool is corrupted. Note that if - // **user code** panics, we should catch that and redirect. - let abort_guard = unwind::AbortIfPanic; - - // Inform a user callback that we started a thread. - if let Some(ref handler) = registry.start_handler { - registry.catch_unwind(|| handler(index)); - } + WorkerThread::with_current(WorkerThread::from(thread), |worker_thread| { + let registry = &*worker_thread.registry; + let index = worker_thread.index; + + // let registry know we are ready to do work + unsafe { Latch::set(®istry.thread_infos[index].primed) }; - unsafe { worker_thread.wait_until_out_of_work() }; + // Worker threads should not panic. If they do, just abort, as the + // internal state of the threadpool is corrupted. Note that if + // **user code** panics, we should catch that and redirect. + let abort_guard = unwind::AbortIfPanic; + + // Inform a user callback that we started a thread. + if let Some(ref handler) = registry.start_handler { + registry.catch_unwind(|| handler(index)); + } - // Normal termination, do not abort. - mem::forget(abort_guard); + unsafe { worker_thread.wait_until_out_of_work() }; - // Inform a user callback that we exited a thread. - if let Some(ref handler) = registry.exit_handler { - registry.catch_unwind(|| handler(index)); - // We're already exiting the thread, there's nothing else to do. - } + // Normal termination, do not abort. + mem::forget(abort_guard); - registry.release_thread(); + // Inform a user callback that we exited a thread. + if let Some(ref handler) = registry.exit_handler { + registry.catch_unwind(|| handler(index)); + // We're already exiting the thread, there's nothing else to do. + } + + registry.release_thread(); + }); } /// If already in a worker-thread, just execute `op`. Otherwise, @@ -1062,17 +1078,16 @@ where OP: FnOnce(&WorkerThread, bool) -> R + Send, R: Send, { - unsafe { - let owner_thread = WorkerThread::current(); - if !owner_thread.is_null() { + WorkerThread::current(|owner_thread| { + if let Some(owner_thread) = owner_thread { // Perfectly valid to give them a `&T`: this is the // current thread, so we know the data structure won't be // invalidated until we return. - op(&*owner_thread, false) + op(owner_thread, false) } else { global_registry().in_worker(op) } - } + }) } /// [xorshift*] is a fast pseudorandom number generator which will diff --git a/compiler/rustc_thread_pool/src/scope/mod.rs b/compiler/rustc_thread_pool/src/scope/mod.rs index 677009a9bc3da..df31f6a64f482 100644 --- a/compiler/rustc_thread_pool/src/scope/mod.rs +++ b/compiler/rustc_thread_pool/src/scope/mod.rs @@ -18,7 +18,7 @@ use crate::job::{ArcJob, HeapJob, JobFifo, JobRef, JobRefId}; use crate::latch::{CountLatch, Latch}; use crate::registry::{Registry, WorkerThread, global_registry, in_worker}; use crate::tlv::{self, Tlv}; -use crate::unwind; +use crate::{current_thread_index, unwind}; #[cfg(test)] mod tests; @@ -428,9 +428,11 @@ pub(crate) fn do_in_place_scope<'scope, OP, R>(registry: Option<&Arc>, where OP: FnOnce(&Scope<'scope>) -> R, { - let thread = unsafe { WorkerThread::current().as_ref() }; - let scope = Scope::<'scope>::new(thread, registry); - scope.base.complete(thread, || op(&scope)) + WorkerThread::current(|thread| { + let thread = thread.as_ref(); + let scope = Scope::<'scope>::new(thread, registry); + scope.base.complete(thread, || op(&scope)) + }) } /// Creates a "fork-join" scope `s` with FIFO order, and invokes the @@ -465,9 +467,11 @@ pub(crate) fn do_in_place_scope_fifo<'scope, OP, R>(registry: Option<&Arc) -> R, { - let thread = unsafe { WorkerThread::current().as_ref() }; - let scope = ScopeFifo::<'scope>::new(thread, registry); - scope.base.complete(thread, || op(&scope)) + WorkerThread::current(|thread| { + let thread = thread.as_ref(); + let scope = ScopeFifo::<'scope>::new(thread, registry); + scope.base.complete(thread, || op(&scope)) + }) } impl<'scope> Scope<'scope> { @@ -566,8 +570,7 @@ impl<'scope> Scope<'scope> { let scope = scope_ptr.as_ref(); let body = &body; - let current_index = WorkerThread::current().as_ref().map(|worker| worker.index()); - if current_index == scope.base.worker { + if scope.base.worker == current_thread_index() { // Mark this job as started on the scope's worker thread. scope.base.pending_jobs.lock().unwrap().remove(&id); } @@ -639,8 +642,7 @@ impl<'scope> ScopeFifo<'scope> { // SAFETY: this job will execute before the scope ends. let scope = scope_ptr.as_ref(); - let current_index = WorkerThread::current().as_ref().map(|worker| worker.index()); - if current_index == scope.base.worker { + if scope.base.worker == current_thread_index() { // Mark this job as started on the scope's worker thread. scope.base.pending_jobs.lock().unwrap().remove(&id); } diff --git a/compiler/rustc_thread_pool/src/spawn/mod.rs b/compiler/rustc_thread_pool/src/spawn/mod.rs index d403deaa1088f..6926e91f79d1e 100644 --- a/compiler/rustc_thread_pool/src/spawn/mod.rs +++ b/compiler/rustc_thread_pool/src/spawn/mod.rs @@ -64,7 +64,7 @@ where F: FnOnce() + Send + 'static, { // We assert that current registry has not terminated. - unsafe { spawn_in(func, &Registry::current()) } + Registry::current(|registry| unsafe { spawn_in(func, registry) }); } /// Spawns an asynchronous job in `registry.` @@ -134,7 +134,7 @@ where F: FnOnce() + Send + 'static, { // We assert that current registry has not terminated. - unsafe { spawn_fifo_in(func, &Registry::current()) } + Registry::current(|registry| unsafe { spawn_fifo_in(func, registry) }); } /// Spawns an asynchronous FIFO job in `registry.` @@ -154,10 +154,10 @@ where // If we're in the pool, use our thread's private fifo for this thread to execute // in a locally-FIFO order. Otherwise, just use the pool's global injector. - match registry.current_thread() { + registry.current_thread_with(|worker| match worker { Some(worker) => unsafe { worker.push_fifo(job_ref) }, None => registry.inject(job_ref), - } + }); mem::forget(abort_guard); } diff --git a/compiler/rustc_thread_pool/src/thread_pool/mod.rs b/compiler/rustc_thread_pool/src/thread_pool/mod.rs index 3294e2a77cbe6..f88c592eafd3d 100644 --- a/compiler/rustc_thread_pool/src/thread_pool/mod.rs +++ b/compiler/rustc_thread_pool/src/thread_pool/mod.rs @@ -235,8 +235,7 @@ impl ThreadPool { /// [snt]: struct.ThreadPoolBuilder.html#method.num_threads #[inline] pub fn current_thread_index(&self) -> Option { - let curr = self.registry.current_thread()?; - Some(curr.index()) + self.registry.current_thread_with(|curr| curr.map(WorkerThread::index)) } /// Returns true if the current worker thread currently has "local @@ -262,8 +261,7 @@ impl ThreadPool { /// [deque]: https://en.wikipedia.org/wiki/Double-ended_queue #[inline] pub fn current_thread_has_pending_tasks(&self) -> Option { - let curr = self.registry.current_thread()?; - Some(!curr.local_deque_is_empty()) + self.registry.current_thread_with(|curr| curr.map(|curr| !curr.local_deque_is_empty())) } /// Execute `oper_a` and `oper_b` in the thread-pool and return @@ -384,8 +382,7 @@ impl ThreadPool { /// Returns `Some(Yield::Executed)` if anything was executed, `Some(Yield::Idle)` if /// nothing was available, or `None` if the current thread is not part this pool. pub fn yield_now(&self) -> Option { - let curr = self.registry.current_thread()?; - Some(curr.yield_now()) + self.registry.current_thread_with(|curr| curr.map(WorkerThread::yield_now)) } /// Cooperatively yields execution to local Rayon work. @@ -396,8 +393,7 @@ impl ThreadPool { /// Returns `Some(Yield::Executed)` if anything was executed, `Some(Yield::Idle)` if /// nothing was available, or `None` if the current thread is not part this pool. pub fn yield_local(&self) -> Option { - let curr = self.registry.current_thread()?; - Some(curr.yield_local()) + self.registry.current_thread_with(|curr| curr.map(WorkerThread::yield_local)) } pub(crate) fn wait_until_stopped(self) { @@ -447,10 +443,7 @@ impl fmt::Debug for ThreadPool { /// [snt]: struct.ThreadPoolBuilder.html#method.num_threads #[inline] pub fn current_thread_index() -> Option { - unsafe { - let curr = WorkerThread::current().as_ref()?; - Some(curr.index()) - } + WorkerThread::current(|curr| curr.as_ref().map(WorkerThread::index)) } /// If called from a Rayon worker thread, indicates whether that @@ -461,10 +454,7 @@ pub fn current_thread_index() -> Option { /// [m]: struct.ThreadPool.html#method.current_thread_has_pending_tasks #[inline] pub fn current_thread_has_pending_tasks() -> Option { - unsafe { - let curr = WorkerThread::current().as_ref()?; - Some(!curr.local_deque_is_empty()) - } + WorkerThread::current(|curr| curr.as_ref().map(|curr| !curr.local_deque_is_empty())) } /// Cooperatively yields execution to Rayon. @@ -480,10 +470,7 @@ pub fn current_thread_has_pending_tasks() -> Option { /// Returns `Some(Yield::Executed)` if anything was executed, `Some(Yield::Idle)` if /// nothing was available, or `None` if this thread is not part of any pool at all. pub fn yield_now() -> Option { - unsafe { - let thread = WorkerThread::current().as_ref()?; - Some(thread.yield_now()) - } + WorkerThread::current(|curr| curr.as_ref().map(WorkerThread::yield_now)) } /// Cooperatively yields execution to local Rayon work. @@ -497,10 +484,7 @@ pub fn yield_now() -> Option { /// Returns `Some(Yield::Executed)` if anything was executed, `Some(Yield::Idle)` if /// nothing was available, or `None` if this thread is not part of any pool at all. pub fn yield_local() -> Option { - unsafe { - let thread = WorkerThread::current().as_ref()?; - Some(thread.yield_local()) - } + WorkerThread::current(|curr| curr.as_ref().map(WorkerThread::yield_local)) } /// Result of [`yield_now()`] or [`yield_local()`]. diff --git a/compiler/rustc_thread_pool/src/worker_local.rs b/compiler/rustc_thread_pool/src/worker_local.rs index 912001233bfea..6aac4ae85174c 100644 --- a/compiler/rustc_thread_pool/src/worker_local.rs +++ b/compiler/rustc_thread_pool/src/worker_local.rs @@ -1,6 +1,6 @@ -use std::fmt; use std::ops::Deref; use std::sync::Arc; +use std::{fmt, ptr}; use crate::registry::{Registry, WorkerThread}; @@ -26,7 +26,7 @@ impl WorkerLocal { /// value this worker local should take for each thread in the thread pool. #[inline] pub fn new T>(mut initial: F) -> WorkerLocal { - let registry = Registry::current(); + let registry = Registry::current_cloned(); WorkerLocal { locals: (0..registry.num_threads()).map(|i| CacheAligned(initial(i))).collect(), registry, @@ -40,15 +40,12 @@ impl WorkerLocal { } fn current(&self) -> &T { - unsafe { - let worker_thread = WorkerThread::current(); - if worker_thread.is_null() - || !std::ptr::eq(&*(*worker_thread).registry, &*self.registry) - { - panic!("WorkerLocal can only be used on the thread pool it was created on") + WorkerThread::current(|worker_thread| match worker_thread { + Some(worker_thread) if ptr::eq(&*worker_thread.registry, &*self.registry) => { + &self.locals[worker_thread.index].0 } - &self.locals[(*worker_thread).index].0 - } + _ => panic!("WorkerLocal can only be used on the thread pool it was created on"), + }) } }