559 lines
17 KiB
Rust
559 lines
17 KiB
Rust
/// A simple queue implementation for synchronization primitives.
|
|
///
|
|
/// This queue is used to implement condition variable and mutexes.
|
|
///
|
|
/// Users of this API are expected to use the `WaitVariable<T>` type. Since
|
|
/// that type is not `Sync`, it needs to be protected by e.g., a `SpinMutex` to
|
|
/// allow shared access.
|
|
///
|
|
/// Since userspace may send spurious wake-ups, the wakeup event state is
|
|
/// recorded in the enclave. The wakeup event state is protected by a spinlock.
|
|
/// The queue and associated wait state are stored in a `WaitVariable`.
|
|
|
|
use crate::ops::{Deref, DerefMut};
|
|
use crate::num::NonZeroUsize;
|
|
|
|
use fortanix_sgx_abi::{Tcs, EV_UNPARK, WAIT_INDEFINITE};
|
|
use super::abi::usercalls;
|
|
use super::abi::thread;
|
|
|
|
use self::unsafe_list::{UnsafeList, UnsafeListEntry};
|
|
pub use self::spin_mutex::{SpinMutex, SpinMutexGuard, try_lock_or_false};
|
|
|
|
/// An queue entry in a `WaitQueue`.
|
|
struct WaitEntry {
|
|
/// TCS address of the thread that is waiting
|
|
tcs: Tcs,
|
|
/// Whether this thread has been notified to be awoken
|
|
wake: bool
|
|
}
|
|
|
|
/// Data stored with a `WaitQueue` alongside it. This ensures accesses to the
|
|
/// queue and the data are synchronized, since the type itself is not `Sync`.
|
|
///
|
|
/// Consumers of this API should use a synchronization primitive for shared
|
|
/// access, such as `SpinMutex`.
|
|
#[derive(Default)]
|
|
pub struct WaitVariable<T> {
|
|
queue: WaitQueue,
|
|
lock: T
|
|
}
|
|
|
|
impl<T> WaitVariable<T> {
|
|
pub const fn new(var: T) -> Self {
|
|
WaitVariable {
|
|
queue: WaitQueue::new(),
|
|
lock: var
|
|
}
|
|
}
|
|
|
|
pub fn queue_empty(&self) -> bool {
|
|
self.queue.is_empty()
|
|
}
|
|
|
|
pub fn lock_var(&self) -> &T {
|
|
&self.lock
|
|
}
|
|
|
|
pub fn lock_var_mut(&mut self) -> &mut T {
|
|
&mut self.lock
|
|
}
|
|
}
|
|
|
|
#[derive(Copy, Clone)]
|
|
pub enum NotifiedTcs {
|
|
Single(Tcs),
|
|
All { count: NonZeroUsize }
|
|
}
|
|
|
|
/// An RAII guard that will notify a set of target threads as well as unlock
|
|
/// a mutex on drop.
|
|
pub struct WaitGuard<'a, T: 'a> {
|
|
mutex_guard: Option<SpinMutexGuard<'a, WaitVariable<T>>>,
|
|
notified_tcs: NotifiedTcs
|
|
}
|
|
|
|
/// A queue of threads that are waiting on some synchronization primitive.
|
|
///
|
|
/// `UnsafeList` entries are allocated on the waiting thread's stack. This
|
|
/// avoids any global locking that might happen in the heap allocator. This is
|
|
/// safe because the waiting thread will not return from that stack frame until
|
|
/// after it is notified. The notifying thread ensures to clean up any
|
|
/// references to the list entries before sending the wakeup event.
|
|
pub struct WaitQueue {
|
|
// We use an inner Mutex here to protect the data in the face of spurious
|
|
// wakeups.
|
|
inner: UnsafeList<SpinMutex<WaitEntry>>,
|
|
}
|
|
unsafe impl Send for WaitQueue {}
|
|
|
|
impl Default for WaitQueue {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
impl<'a, T> WaitGuard<'a, T> {
|
|
/// Returns which TCSes will be notified when this guard drops.
|
|
pub fn notified_tcs(&self) -> NotifiedTcs {
|
|
self.notified_tcs
|
|
}
|
|
|
|
/// Drop this `WaitGuard`, after dropping another `guard`.
|
|
pub fn drop_after<U>(self, guard: U) {
|
|
drop(guard);
|
|
drop(self);
|
|
}
|
|
}
|
|
|
|
impl<'a, T> Deref for WaitGuard<'a, T> {
|
|
type Target = SpinMutexGuard<'a, WaitVariable<T>>;
|
|
|
|
fn deref(&self) -> &Self::Target {
|
|
self.mutex_guard.as_ref().unwrap()
|
|
}
|
|
}
|
|
|
|
impl<'a, T> DerefMut for WaitGuard<'a, T> {
|
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
|
self.mutex_guard.as_mut().unwrap()
|
|
}
|
|
}
|
|
|
|
impl<'a, T> Drop for WaitGuard<'a, T> {
|
|
fn drop(&mut self) {
|
|
drop(self.mutex_guard.take());
|
|
let target_tcs = match self.notified_tcs {
|
|
NotifiedTcs::Single(tcs) => Some(tcs),
|
|
NotifiedTcs::All { .. } => None
|
|
};
|
|
rtunwrap!(Ok, usercalls::send(EV_UNPARK, target_tcs));
|
|
}
|
|
}
|
|
|
|
impl WaitQueue {
|
|
pub const fn new() -> Self {
|
|
WaitQueue {
|
|
inner: UnsafeList::new()
|
|
}
|
|
}
|
|
|
|
pub fn is_empty(&self) -> bool {
|
|
self.inner.is_empty()
|
|
}
|
|
|
|
/// Adds the calling thread to the `WaitVariable`'s wait queue, then wait
|
|
/// until a wakeup event.
|
|
///
|
|
/// This function does not return until this thread has been awoken.
|
|
pub fn wait<T, F: FnOnce()>(mut guard: SpinMutexGuard<'_, WaitVariable<T>>, before_wait: F) {
|
|
// very unsafe: check requirements of UnsafeList::push
|
|
unsafe {
|
|
let mut entry = UnsafeListEntry::new(SpinMutex::new(WaitEntry {
|
|
tcs: thread::current(),
|
|
wake: false
|
|
}));
|
|
let entry = guard.queue.inner.push(&mut entry);
|
|
drop(guard);
|
|
before_wait();
|
|
while !entry.lock().wake {
|
|
// don't panic, this would invalidate `entry` during unwinding
|
|
let eventset = rtunwrap!(Ok, usercalls::wait(EV_UNPARK, WAIT_INDEFINITE));
|
|
rtassert!(eventset & EV_UNPARK == EV_UNPARK);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Either find the next waiter on the wait queue, or return the mutex
|
|
/// guard unchanged.
|
|
///
|
|
/// If a waiter is found, a `WaitGuard` is returned which will notify the
|
|
/// waiter when it is dropped.
|
|
pub fn notify_one<T>(mut guard: SpinMutexGuard<'_, WaitVariable<T>>)
|
|
-> Result<WaitGuard<'_, T>, SpinMutexGuard<'_, WaitVariable<T>>>
|
|
{
|
|
unsafe {
|
|
if let Some(entry) = guard.queue.inner.pop() {
|
|
let mut entry_guard = entry.lock();
|
|
let tcs = entry_guard.tcs;
|
|
entry_guard.wake = true;
|
|
drop(entry);
|
|
Ok(WaitGuard {
|
|
mutex_guard: Some(guard),
|
|
notified_tcs: NotifiedTcs::Single(tcs)
|
|
})
|
|
} else {
|
|
Err(guard)
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Either find any and all waiters on the wait queue, or return the mutex
|
|
/// guard unchanged.
|
|
///
|
|
/// If at least one waiter is found, a `WaitGuard` is returned which will
|
|
/// notify all waiters when it is dropped.
|
|
pub fn notify_all<T>(mut guard: SpinMutexGuard<'_, WaitVariable<T>>)
|
|
-> Result<WaitGuard<'_, T>, SpinMutexGuard<'_, WaitVariable<T>>>
|
|
{
|
|
unsafe {
|
|
let mut count = 0;
|
|
while let Some(entry) = guard.queue.inner.pop() {
|
|
count += 1;
|
|
let mut entry_guard = entry.lock();
|
|
entry_guard.wake = true;
|
|
}
|
|
if let Some(count) = NonZeroUsize::new(count) {
|
|
Ok(WaitGuard {
|
|
mutex_guard: Some(guard),
|
|
notified_tcs: NotifiedTcs::All { count }
|
|
})
|
|
} else {
|
|
Err(guard)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// A doubly-linked list where callers are in charge of memory allocation
|
|
/// of the nodes in the list.
|
|
mod unsafe_list {
|
|
use crate::ptr::NonNull;
|
|
use crate::mem;
|
|
|
|
pub struct UnsafeListEntry<T> {
|
|
next: NonNull<UnsafeListEntry<T>>,
|
|
prev: NonNull<UnsafeListEntry<T>>,
|
|
value: Option<T>
|
|
}
|
|
|
|
impl<T> UnsafeListEntry<T> {
|
|
fn dummy() -> Self {
|
|
UnsafeListEntry {
|
|
next: NonNull::dangling(),
|
|
prev: NonNull::dangling(),
|
|
value: None
|
|
}
|
|
}
|
|
|
|
pub fn new(value: T) -> Self {
|
|
UnsafeListEntry {
|
|
value: Some(value),
|
|
..Self::dummy()
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct UnsafeList<T> {
|
|
head_tail: NonNull<UnsafeListEntry<T>>,
|
|
head_tail_entry: Option<UnsafeListEntry<T>>,
|
|
}
|
|
|
|
impl<T> UnsafeList<T> {
|
|
pub const fn new() -> Self {
|
|
unsafe {
|
|
UnsafeList {
|
|
head_tail: NonNull::new_unchecked(1 as _),
|
|
head_tail_entry: None
|
|
}
|
|
}
|
|
}
|
|
|
|
unsafe fn init(&mut self) {
|
|
if self.head_tail_entry.is_none() {
|
|
self.head_tail_entry = Some(UnsafeListEntry::dummy());
|
|
self.head_tail = NonNull::new_unchecked(self.head_tail_entry.as_mut().unwrap());
|
|
self.head_tail.as_mut().next = self.head_tail;
|
|
self.head_tail.as_mut().prev = self.head_tail;
|
|
}
|
|
}
|
|
|
|
pub fn is_empty(&self) -> bool {
|
|
unsafe {
|
|
if self.head_tail_entry.is_some() {
|
|
let first = self.head_tail.as_ref().next;
|
|
if first == self.head_tail {
|
|
// ,-------> /---------\ next ---,
|
|
// | |head_tail| |
|
|
// `--- prev \---------/ <-------`
|
|
rtassert!(self.head_tail.as_ref().prev == first);
|
|
true
|
|
} else {
|
|
false
|
|
}
|
|
} else {
|
|
true
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Pushes an entry onto the back of the list.
|
|
///
|
|
/// # Safety
|
|
///
|
|
/// The entry must remain allocated until the entry is removed from the
|
|
/// list AND the caller who popped is done using the entry. Special
|
|
/// care must be taken in the caller of `push` to ensure unwinding does
|
|
/// not destroy the stack frame containing the entry.
|
|
pub unsafe fn push<'a>(&mut self, entry: &'a mut UnsafeListEntry<T>) -> &'a T {
|
|
self.init();
|
|
|
|
// BEFORE:
|
|
// /---------\ next ---> /---------\
|
|
// ... |prev_tail| |head_tail| ...
|
|
// \---------/ <--- prev \---------/
|
|
//
|
|
// AFTER:
|
|
// /---------\ next ---> /-----\ next ---> /---------\
|
|
// ... |prev_tail| |entry| |head_tail| ...
|
|
// \---------/ <--- prev \-----/ <--- prev \---------/
|
|
let mut entry = NonNull::new_unchecked(entry);
|
|
let mut prev_tail = mem::replace(&mut self.head_tail.as_mut().prev, entry);
|
|
entry.as_mut().prev = prev_tail;
|
|
entry.as_mut().next = self.head_tail;
|
|
prev_tail.as_mut().next = entry;
|
|
// unwrap ok: always `Some` on non-dummy entries
|
|
(*entry.as_ptr()).value.as_ref().unwrap()
|
|
}
|
|
|
|
/// Pops an entry from the front of the list.
|
|
///
|
|
/// # Safety
|
|
///
|
|
/// The caller must make sure to synchronize ending the borrow of the
|
|
/// return value and deallocation of the containing entry.
|
|
pub unsafe fn pop<'a>(&mut self) -> Option<&'a T> {
|
|
self.init();
|
|
|
|
if self.is_empty() {
|
|
None
|
|
} else {
|
|
// BEFORE:
|
|
// /---------\ next ---> /-----\ next ---> /------\
|
|
// ... |head_tail| |first| |second| ...
|
|
// \---------/ <--- prev \-----/ <--- prev \------/
|
|
//
|
|
// AFTER:
|
|
// /---------\ next ---> /------\
|
|
// ... |head_tail| |second| ...
|
|
// \---------/ <--- prev \------/
|
|
let mut first = self.head_tail.as_mut().next;
|
|
let mut second = first.as_mut().next;
|
|
self.head_tail.as_mut().next = second;
|
|
second.as_mut().prev = self.head_tail;
|
|
first.as_mut().next = NonNull::dangling();
|
|
first.as_mut().prev = NonNull::dangling();
|
|
// unwrap ok: always `Some` on non-dummy entries
|
|
Some((*first.as_ptr()).value.as_ref().unwrap())
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::cell::Cell;
|
|
|
|
unsafe fn assert_empty<T>(list: &mut UnsafeList<T>) {
|
|
assert!(list.pop().is_none(), "assertion failed: list is not empty");
|
|
}
|
|
|
|
#[test]
|
|
fn init_empty() {
|
|
unsafe {
|
|
assert_empty(&mut UnsafeList::<i32>::new());
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn push_pop() {
|
|
unsafe {
|
|
let mut node = UnsafeListEntry::new(1234);
|
|
let mut list = UnsafeList::new();
|
|
assert_eq!(list.push(&mut node), &1234);
|
|
assert_eq!(list.pop().unwrap(), &1234);
|
|
assert_empty(&mut list);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn complex_pushes_pops() {
|
|
unsafe {
|
|
let mut node1 = UnsafeListEntry::new(1234);
|
|
let mut node2 = UnsafeListEntry::new(4567);
|
|
let mut node3 = UnsafeListEntry::new(9999);
|
|
let mut node4 = UnsafeListEntry::new(8642);
|
|
let mut list = UnsafeList::new();
|
|
list.push(&mut node1);
|
|
list.push(&mut node2);
|
|
assert_eq!(list.pop().unwrap(), &1234);
|
|
list.push(&mut node3);
|
|
assert_eq!(list.pop().unwrap(), &4567);
|
|
assert_eq!(list.pop().unwrap(), &9999);
|
|
assert_empty(&mut list);
|
|
list.push(&mut node4);
|
|
assert_eq!(list.pop().unwrap(), &8642);
|
|
assert_empty(&mut list);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn cell() {
|
|
unsafe {
|
|
let mut node = UnsafeListEntry::new(Cell::new(0));
|
|
let mut list = UnsafeList::new();
|
|
let noderef = list.push(&mut node);
|
|
assert_eq!(noderef.get(), 0);
|
|
list.pop().unwrap().set(1);
|
|
assert_empty(&mut list);
|
|
assert_eq!(noderef.get(), 1);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Trivial spinlock-based implementation of `sync::Mutex`.
|
|
// FIXME: Perhaps use Intel TSX to avoid locking?
|
|
mod spin_mutex {
|
|
use crate::cell::UnsafeCell;
|
|
use crate::sync::atomic::{AtomicBool, Ordering, spin_loop_hint};
|
|
use crate::ops::{Deref, DerefMut};
|
|
|
|
#[derive(Default)]
|
|
pub struct SpinMutex<T> {
|
|
value: UnsafeCell<T>,
|
|
lock: AtomicBool,
|
|
}
|
|
|
|
unsafe impl<T: Send> Send for SpinMutex<T> {}
|
|
unsafe impl<T: Send> Sync for SpinMutex<T> {}
|
|
|
|
pub struct SpinMutexGuard<'a, T: 'a> {
|
|
mutex: &'a SpinMutex<T>,
|
|
}
|
|
|
|
impl<'a, T> !Send for SpinMutexGuard<'a, T> {}
|
|
unsafe impl<'a, T: Sync> Sync for SpinMutexGuard<'a, T> {}
|
|
|
|
impl<T> SpinMutex<T> {
|
|
pub const fn new(value: T) -> Self {
|
|
SpinMutex {
|
|
value: UnsafeCell::new(value),
|
|
lock: AtomicBool::new(false)
|
|
}
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub fn lock(&self) -> SpinMutexGuard<'_, T> {
|
|
loop {
|
|
match self.try_lock() {
|
|
None => while self.lock.load(Ordering::Relaxed) {
|
|
spin_loop_hint()
|
|
},
|
|
Some(guard) => return guard
|
|
}
|
|
}
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub fn try_lock(&self) -> Option<SpinMutexGuard<'_, T>> {
|
|
if !self.lock.compare_and_swap(false, true, Ordering::Acquire) {
|
|
Some(SpinMutexGuard {
|
|
mutex: self,
|
|
})
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Lock the Mutex or return false.
|
|
pub macro try_lock_or_false {
|
|
($e:expr) => {
|
|
if let Some(v) = $e.try_lock() {
|
|
v
|
|
} else {
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<'a, T> Deref for SpinMutexGuard<'a, T> {
|
|
type Target = T;
|
|
|
|
fn deref(&self) -> &T {
|
|
unsafe {
|
|
&*self.mutex.value.get()
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<'a, T> DerefMut for SpinMutexGuard<'a, T> {
|
|
fn deref_mut(&mut self) -> &mut T {
|
|
unsafe {
|
|
&mut*self.mutex.value.get()
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<'a, T> Drop for SpinMutexGuard<'a, T> {
|
|
fn drop(&mut self) {
|
|
self.mutex.lock.store(false, Ordering::Release)
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
#![allow(deprecated)]
|
|
|
|
use super::*;
|
|
use crate::sync::Arc;
|
|
use crate::thread;
|
|
use crate::time::{SystemTime, Duration};
|
|
|
|
#[test]
|
|
fn sleep() {
|
|
let mutex = Arc::new(SpinMutex::<i32>::default());
|
|
let mutex2 = mutex.clone();
|
|
let guard = mutex.lock();
|
|
let t1 = thread::spawn(move || {
|
|
*mutex2.lock() = 1;
|
|
});
|
|
|
|
// "sleep" for 50ms
|
|
// FIXME: https://github.com/fortanix/rust-sgx/issues/31
|
|
let start = SystemTime::now();
|
|
let max = Duration::from_millis(50);
|
|
while start.elapsed().unwrap() < max {}
|
|
|
|
assert_eq!(*guard, 0);
|
|
drop(guard);
|
|
t1.join().unwrap();
|
|
assert_eq!(*mutex.lock(), 1);
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::sync::Arc;
|
|
use crate::thread;
|
|
|
|
#[test]
|
|
fn queue() {
|
|
let wq = Arc::new(SpinMutex::<WaitVariable<()>>::default());
|
|
let wq2 = wq.clone();
|
|
|
|
let locked = wq.lock();
|
|
|
|
let t1 = thread::spawn(move || {
|
|
// if we obtain the lock, the main thread should be waiting
|
|
assert!(WaitQueue::notify_one(wq2.lock()).is_ok());
|
|
});
|
|
|
|
WaitQueue::wait(locked, ||{});
|
|
|
|
t1.join().unwrap();
|
|
}
|
|
}
|