215 lines
6.2 KiB
Rust
215 lines
6.2 KiB
Rust
use crate::fmt;
|
|
use crate::sync::{Condvar, Mutex};
|
|
|
|
/// A barrier enables multiple threads to synchronize the beginning
|
|
/// of some computation.
|
|
///
|
|
/// # Examples
|
|
///
|
|
/// ```
|
|
/// use std::sync::{Arc, Barrier};
|
|
/// use std::thread;
|
|
///
|
|
/// let mut handles = Vec::with_capacity(10);
|
|
/// let barrier = Arc::new(Barrier::new(10));
|
|
/// for _ in 0..10 {
|
|
/// let c = barrier.clone();
|
|
/// // The same messages will be printed together.
|
|
/// // You will NOT see any interleaving.
|
|
/// handles.push(thread::spawn(move|| {
|
|
/// println!("before wait");
|
|
/// c.wait();
|
|
/// println!("after wait");
|
|
/// }));
|
|
/// }
|
|
/// // Wait for other threads to finish.
|
|
/// for handle in handles {
|
|
/// handle.join().unwrap();
|
|
/// }
|
|
/// ```
|
|
#[stable(feature = "rust1", since = "1.0.0")]
|
|
pub struct Barrier {
|
|
lock: Mutex<BarrierState>,
|
|
cvar: Condvar,
|
|
num_threads: usize,
|
|
}
|
|
|
|
// The inner state of a double barrier
|
|
struct BarrierState {
|
|
count: usize,
|
|
generation_id: usize,
|
|
}
|
|
|
|
/// A `BarrierWaitResult` is returned by [`wait`] when all threads in the [`Barrier`]
|
|
/// have rendezvoused.
|
|
///
|
|
/// [`wait`]: struct.Barrier.html#method.wait
|
|
/// [`Barrier`]: struct.Barrier.html
|
|
///
|
|
/// # Examples
|
|
///
|
|
/// ```
|
|
/// use std::sync::Barrier;
|
|
///
|
|
/// let barrier = Barrier::new(1);
|
|
/// let barrier_wait_result = barrier.wait();
|
|
/// ```
|
|
#[stable(feature = "rust1", since = "1.0.0")]
|
|
pub struct BarrierWaitResult(bool);
|
|
|
|
#[stable(feature = "std_debug", since = "1.16.0")]
|
|
impl fmt::Debug for Barrier {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
f.pad("Barrier { .. }")
|
|
}
|
|
}
|
|
|
|
impl Barrier {
|
|
/// Creates a new barrier that can block a given number of threads.
|
|
///
|
|
/// A barrier will block `n`-1 threads which call [`wait`] and then wake up
|
|
/// all threads at once when the `n`th thread calls [`wait`].
|
|
///
|
|
/// [`wait`]: #method.wait
|
|
///
|
|
/// # Examples
|
|
///
|
|
/// ```
|
|
/// use std::sync::Barrier;
|
|
///
|
|
/// let barrier = Barrier::new(10);
|
|
/// ```
|
|
#[stable(feature = "rust1", since = "1.0.0")]
|
|
pub fn new(n: usize) -> Barrier {
|
|
Barrier {
|
|
lock: Mutex::new(BarrierState { count: 0, generation_id: 0 }),
|
|
cvar: Condvar::new(),
|
|
num_threads: n,
|
|
}
|
|
}
|
|
|
|
/// Blocks the current thread until all threads have rendezvoused here.
|
|
///
|
|
/// Barriers are re-usable after all threads have rendezvoused once, and can
|
|
/// be used continuously.
|
|
///
|
|
/// A single (arbitrary) thread will receive a [`BarrierWaitResult`] that
|
|
/// returns `true` from [`is_leader`] when returning from this function, and
|
|
/// all other threads will receive a result that will return `false` from
|
|
/// [`is_leader`].
|
|
///
|
|
/// [`BarrierWaitResult`]: struct.BarrierWaitResult.html
|
|
/// [`is_leader`]: struct.BarrierWaitResult.html#method.is_leader
|
|
///
|
|
/// # Examples
|
|
///
|
|
/// ```
|
|
/// use std::sync::{Arc, Barrier};
|
|
/// use std::thread;
|
|
///
|
|
/// let mut handles = Vec::with_capacity(10);
|
|
/// let barrier = Arc::new(Barrier::new(10));
|
|
/// for _ in 0..10 {
|
|
/// let c = barrier.clone();
|
|
/// // The same messages will be printed together.
|
|
/// // You will NOT see any interleaving.
|
|
/// handles.push(thread::spawn(move|| {
|
|
/// println!("before wait");
|
|
/// c.wait();
|
|
/// println!("after wait");
|
|
/// }));
|
|
/// }
|
|
/// // Wait for other threads to finish.
|
|
/// for handle in handles {
|
|
/// handle.join().unwrap();
|
|
/// }
|
|
/// ```
|
|
#[stable(feature = "rust1", since = "1.0.0")]
|
|
pub fn wait(&self) -> BarrierWaitResult {
|
|
let mut lock = self.lock.lock().unwrap();
|
|
let local_gen = lock.generation_id;
|
|
lock.count += 1;
|
|
if lock.count < self.num_threads {
|
|
// We need a while loop to guard against spurious wakeups.
|
|
// http://en.wikipedia.org/wiki/Spurious_wakeup
|
|
while local_gen == lock.generation_id && lock.count < self.num_threads {
|
|
lock = self.cvar.wait(lock).unwrap();
|
|
}
|
|
BarrierWaitResult(false)
|
|
} else {
|
|
lock.count = 0;
|
|
lock.generation_id = lock.generation_id.wrapping_add(1);
|
|
self.cvar.notify_all();
|
|
BarrierWaitResult(true)
|
|
}
|
|
}
|
|
}
|
|
|
|
#[stable(feature = "std_debug", since = "1.16.0")]
|
|
impl fmt::Debug for BarrierWaitResult {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
f.debug_struct("BarrierWaitResult").field("is_leader", &self.is_leader()).finish()
|
|
}
|
|
}
|
|
|
|
impl BarrierWaitResult {
|
|
/// Returns `true` if this thread from [`wait`] is the "leader thread".
|
|
///
|
|
/// Only one thread will have `true` returned from their result, all other
|
|
/// threads will have `false` returned.
|
|
///
|
|
/// [`wait`]: struct.Barrier.html#method.wait
|
|
///
|
|
/// # Examples
|
|
///
|
|
/// ```
|
|
/// use std::sync::Barrier;
|
|
///
|
|
/// let barrier = Barrier::new(1);
|
|
/// let barrier_wait_result = barrier.wait();
|
|
/// println!("{:?}", barrier_wait_result.is_leader());
|
|
/// ```
|
|
#[stable(feature = "rust1", since = "1.0.0")]
|
|
pub fn is_leader(&self) -> bool {
|
|
self.0
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use crate::sync::mpsc::{channel, TryRecvError};
|
|
use crate::sync::{Arc, Barrier};
|
|
use crate::thread;
|
|
|
|
#[test]
|
|
#[cfg_attr(target_os = "emscripten", ignore)]
|
|
fn test_barrier() {
|
|
const N: usize = 10;
|
|
|
|
let barrier = Arc::new(Barrier::new(N));
|
|
let (tx, rx) = channel();
|
|
|
|
for _ in 0..N - 1 {
|
|
let c = barrier.clone();
|
|
let tx = tx.clone();
|
|
thread::spawn(move || {
|
|
tx.send(c.wait().is_leader()).unwrap();
|
|
});
|
|
}
|
|
|
|
// At this point, all spawned threads should be blocked,
|
|
// so we shouldn't get anything from the port
|
|
assert!(matches!(rx.try_recv(), Err(TryRecvError::Empty)));
|
|
|
|
let mut leader_found = barrier.wait().is_leader();
|
|
|
|
// Now, the barrier is cleared and we should get data.
|
|
for _ in 0..N - 1 {
|
|
if rx.recv().unwrap() {
|
|
assert!(!leader_found);
|
|
leader_found = true;
|
|
}
|
|
}
|
|
assert!(leader_found);
|
|
}
|
|
}
|