From 0e24ad537be4d47686f3b9e3e6623664bce7cbc2 Mon Sep 17 00:00:00 2001 From: Mara Bos Date: Tue, 4 Jan 2022 14:51:39 +0100 Subject: [PATCH] Implement RFC 3151: Scoped threads. --- library/std/src/thread/mod.rs | 104 +++++++++++++++++------- library/std/src/thread/scoped.rs | 132 +++++++++++++++++++++++++++++++ 2 files changed, 206 insertions(+), 30 deletions(-) create mode 100644 library/std/src/thread/scoped.rs diff --git a/library/std/src/thread/mod.rs b/library/std/src/thread/mod.rs index c799b64c05b2..0125545a3dbd 100644 --- a/library/std/src/thread/mod.rs +++ b/library/std/src/thread/mod.rs @@ -180,6 +180,12 @@ use crate::time::Duration; #[macro_use] mod local; +#[unstable(feature = "scoped_threads", issue = "none")] +mod scoped; + +#[unstable(feature = "scoped_threads", issue = "none")] +pub use scoped::{scope, Scope, ScopedJoinHandle}; + #[stable(feature = "rust1", since = "1.0.0")] pub use self::local::{AccessError, LocalKey}; @@ -446,6 +452,20 @@ impl Builder { F: FnOnce() -> T, F: Send + 'a, T: Send + 'a, + { + Ok(JoinHandle(unsafe { self.spawn_unchecked_(f, None) }?)) + } + + unsafe fn spawn_unchecked_<'a, 'scope, F, T>( + self, + f: F, + scope_data: Option<&'scope scoped::ScopeData>, + ) -> io::Result> + where + F: FnOnce() -> T, + F: Send + 'a, + T: Send + 'a, + 'scope: 'a, { let Builder { name, stack_size } = self; @@ -456,7 +476,8 @@ impl Builder { })); let their_thread = my_thread.clone(); - let my_packet: Arc>>> = Arc::new(UnsafeCell::new(None)); + let my_packet: Arc> = + Arc::new(Packet { scope: scope_data, result: UnsafeCell::new(None) }); let their_packet = my_packet.clone(); let output_capture = crate::io::set_output_capture(None); @@ -480,10 +501,14 @@ impl Builder { // closure (it is an Arc<...>) and `my_packet` will be stored in the // same `JoinInner` as this closure meaning the mutation will be // safe (not modify it and affect a value far away). - unsafe { *their_packet.get() = Some(try_result) }; + unsafe { *their_packet.result.get() = Some(try_result) }; }; - Ok(JoinHandle(JoinInner { + if let Some(scope_data) = scope_data { + scope_data.increment_n_running_threads(); + } + + Ok(JoinInner { // SAFETY: // // `imp::Thread::new` takes a closure with a `'static` lifetime, since it's passed @@ -506,8 +531,8 @@ impl Builder { )? }, thread: my_thread, - packet: Packet(my_packet), - })) + packet: my_packet, + }) } } @@ -1239,34 +1264,53 @@ impl fmt::Debug for Thread { #[stable(feature = "rust1", since = "1.0.0")] pub type Result = crate::result::Result>; -// This packet is used to communicate the return value between the spawned thread -// and the rest of the program. Memory is shared through the `Arc` within and there's -// no need for a mutex here because synchronization happens with `join()` (the -// caller will never read this packet until the thread has exited). +// This packet is used to communicate the return value between the spawned +// thread and the rest of the program. It is shared through an `Arc` and +// there's no need for a mutex here because synchronization happens with `join()` +// (the caller will never read this packet until the thread has exited). // -// This packet itself is then stored into a `JoinInner` which in turns is placed -// in `JoinHandle` and `JoinGuard`. Due to the usage of `UnsafeCell` we need to -// manually worry about impls like Send and Sync. The type `T` should -// already always be Send (otherwise the thread could not have been created) and -// this type is inherently Sync because no methods take &self. Regardless, -// however, we add inheriting impls for Send/Sync to this type to ensure it's -// Send/Sync and that future modifications will still appropriately classify it. -struct Packet(Arc>>>); - -unsafe impl Send for Packet {} -unsafe impl Sync for Packet {} - -/// Inner representation for JoinHandle -struct JoinInner { - native: imp::Thread, - thread: Thread, - packet: Packet, +// An Arc to the packet is stored into a `JoinInner` which in turns is placed +// in `JoinHandle`. Due to the usage of `UnsafeCell` we need to manually worry +// about impls like Send and Sync. The type `T` should already always be Send +// (otherwise the thread could not have been created) and this type is +// inherently Sync because no methods take &self. Regardless, however, we add +// inheriting impls for Send/Sync to this type to ensure it's Send/Sync and +// that future modifications will still appropriately classify it. +struct Packet<'scope, T> { + scope: Option<&'scope scoped::ScopeData>, + result: UnsafeCell>>, } -impl JoinInner { +unsafe impl<'scope, T: Send> Send for Packet<'scope, T> {} +unsafe impl<'scope, T: Sync> Sync for Packet<'scope, T> {} + +impl<'scope, T> Drop for Packet<'scope, T> { + fn drop(&mut self) { + if let Some(scope) = self.scope { + // If this packet was for a thread that ran in a scope, the thread + // panicked, and nobody consumed the panic payload, we put the + // panic payload in the scope so it can re-throw it, if it didn't + // already capture any panic yet. + if let Some(Err(e)) = self.result.get_mut().take() { + scope.panic_payload.lock().unwrap().get_or_insert(e); + } + // Book-keeping so the scope knows when it's done. + scope.decrement_n_running_threads(); + } + } +} + +/// Inner representation for JoinHandle +struct JoinInner<'scope, T> { + native: imp::Thread, + thread: Thread, + packet: Arc>, +} + +impl<'scope, T> JoinInner<'scope, T> { fn join(mut self) -> Result { self.native.join(); - Arc::get_mut(&mut self.packet.0).unwrap().get_mut().take().unwrap() + Arc::get_mut(&mut self.packet).unwrap().result.get_mut().take().unwrap() } } @@ -1333,7 +1377,7 @@ impl JoinInner { /// [`thread::Builder::spawn`]: Builder::spawn /// [`thread::spawn`]: spawn #[stable(feature = "rust1", since = "1.0.0")] -pub struct JoinHandle(JoinInner); +pub struct JoinHandle(JoinInner<'static, T>); #[stable(feature = "joinhandle_impl_send_sync", since = "1.29.0")] unsafe impl Send for JoinHandle {} @@ -1407,7 +1451,7 @@ impl JoinHandle { /// function has returned, but before the thread itself has stopped running. #[unstable(feature = "thread_is_running", issue = "90470")] pub fn is_running(&self) -> bool { - Arc::strong_count(&self.0.packet.0) > 1 + Arc::strong_count(&self.0.packet) > 1 } } diff --git a/library/std/src/thread/scoped.rs b/library/std/src/thread/scoped.rs new file mode 100644 index 000000000000..8e9a43e05bef --- /dev/null +++ b/library/std/src/thread/scoped.rs @@ -0,0 +1,132 @@ +use super::{current, park, Builder, JoinInner, Result, Thread}; +use crate::any::Any; +use crate::fmt; +use crate::io; +use crate::marker::PhantomData; +use crate::panic::{catch_unwind, resume_unwind, AssertUnwindSafe}; +use crate::sync::atomic::{AtomicUsize, Ordering}; +use crate::sync::Mutex; + +/// TODO: documentation +pub struct Scope<'env> { + data: ScopeData, + env: PhantomData<&'env ()>, +} + +/// TODO: documentation +pub struct ScopedJoinHandle<'scope, T>(JoinInner<'scope, T>); + +pub(super) struct ScopeData { + n_running_threads: AtomicUsize, + main_thread: Thread, + pub(super) panic_payload: Mutex>>, +} + +impl ScopeData { + pub(super) fn increment_n_running_threads(&self) { + // We check for 'overflow' with usize::MAX / 2, to make sure there's no + // chance it overflows to 0, which would result in unsoundness. + if self.n_running_threads.fetch_add(1, Ordering::Relaxed) == usize::MAX / 2 { + // This can only reasonably happen by mem::forget()'ing many many ScopedJoinHandles. + self.decrement_n_running_threads(); + panic!("too many running threads in thread scope"); + } + } + pub(super) fn decrement_n_running_threads(&self) { + if self.n_running_threads.fetch_sub(1, Ordering::Release) == 1 { + self.main_thread.unpark(); + } + } +} + +/// TODO: documentation +pub fn scope<'env, F, T>(f: F) -> T +where + F: FnOnce(&Scope<'env>) -> T, +{ + let mut scope = Scope { + data: ScopeData { + n_running_threads: AtomicUsize::new(0), + main_thread: current(), + panic_payload: Mutex::new(None), + }, + env: PhantomData, + }; + + // Run `f`, but catch panics so we can make sure to wait for all the threads to join. + let result = catch_unwind(AssertUnwindSafe(|| f(&scope))); + + // Wait until all the threads are finished. + while scope.data.n_running_threads.load(Ordering::Acquire) != 0 { + park(); + } + + // Throw any panic from `f` or from any panicked thread, or the return value of `f` otherwise. + match result { + Err(e) => { + // `f` itself panicked. + resume_unwind(e); + } + Ok(result) => { + if let Some(panic_payload) = scope.data.panic_payload.get_mut().unwrap().take() { + // A thread panicked. + resume_unwind(panic_payload); + } else { + // Nothing panicked. + result + } + } + } +} + +impl<'env> Scope<'env> { + /// TODO: documentation + pub fn spawn<'scope, F, T>(&'scope self, f: F) -> ScopedJoinHandle<'scope, T> + where + F: FnOnce(&Scope<'env>) -> T + Send + 'env, + T: Send + 'env, + { + Builder::new().spawn_scoped(self, f).expect("failed to spawn thread") + } +} + +impl Builder { + fn spawn_scoped<'scope, 'env, F, T>( + self, + scope: &'scope Scope<'env>, + f: F, + ) -> io::Result> + where + F: FnOnce(&Scope<'env>) -> T + Send + 'env, + T: Send + 'env, + { + Ok(ScopedJoinHandle(unsafe { self.spawn_unchecked_(|| f(scope), Some(&scope.data)) }?)) + } +} + +impl<'scope, T> ScopedJoinHandle<'scope, T> { + /// TODO + pub fn join(self) -> Result { + self.0.join() + } + + /// TODO + pub fn thread(&self) -> &Thread { + &self.0.thread + } +} + +impl<'env> fmt::Debug for Scope<'env> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Scope") + .field("n_running_threads", &self.data.n_running_threads.load(Ordering::Relaxed)) + .field("panic_payload", &self.data.panic_payload) + .finish_non_exhaustive() + } +} + +impl<'scope, T> fmt::Debug for ScopedJoinHandle<'scope, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ScopedJoinHandle").finish_non_exhaustive() + } +}