diff --git a/src/eval.rs b/src/eval.rs index 085a53862fd4..ab82c39836b2 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -205,15 +205,25 @@ pub fn eval_main<'tcx>(tcx: TyCtxt<'tcx>, main_id: DefId, config: MiriConfig) -> // Perform the main execution. let res: InterpResult<'_, i64> = (|| { // Main loop. - while ecx.schedule()? { - assert!(ecx.step()?, "a terminated thread was scheduled for execution"); + loop { + match ecx.schedule()? { + SchedulingAction::ExecuteStep => { + assert!(ecx.step()?, "a terminated thread was scheduled for execution"); + } + SchedulingAction::ExecuteDtors => { + ecx.run_tls_dtors_for_active_thread()?; + } + SchedulingAction::Stop => { + break; + } + } ecx.process_diagnostics(); } // Read the return code pointer *before* we run TLS destructors, to assert // that it was written to by the time that `start` lang item returned. let return_code = ecx.read_scalar(ret_place.into())?.not_undef()?.to_machine_isize(&ecx)?; // Global destructors. - ecx.run_tls_dtors()?; + ecx.run_windows_tls_dtors()?; Ok(return_code) })(); diff --git a/src/lib.rs b/src/lib.rs index 96e6f7d63e69..beee94b918b5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -64,7 +64,7 @@ pub use crate::stacked_borrows::{ EvalContextExt as StackedBorEvalContextExt, Item, Permission, PtrId, Stack, Stacks, Tag, }; pub use crate::thread::{ - EvalContextExt as ThreadsEvalContextExt, ThreadId, ThreadManager, ThreadState, + EvalContextExt as ThreadsEvalContextExt, SchedulingAction, ThreadId, ThreadManager, ThreadState, }; /// Insert rustc arguments at the beginning of the argument list that Miri wants to be diff --git a/src/shims/foreign_items/posix/macos.rs b/src/shims/foreign_items/posix/macos.rs index dd3dba6ec07c..9f65d0f9c47d 100644 --- a/src/shims/foreign_items/posix/macos.rs +++ b/src/shims/foreign_items/posix/macos.rs @@ -82,7 +82,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx let dtor = this.read_scalar(args[0])?.not_undef()?; let dtor = this.memory.get_fn(dtor)?.as_instance()?; let data = this.read_scalar(args[1])?.not_undef()?; - this.machine.tls.set_global_dtor(dtor, data)?; + let active_thread = this.get_active_thread()?; + this.machine.tls.set_global_dtor(active_thread, dtor, data)?; } // Querying system information diff --git a/src/shims/tls.rs b/src/shims/tls.rs index da0c585958a3..6dc3025acd5a 100644 --- a/src/shims/tls.rs +++ b/src/shims/tls.rs @@ -2,15 +2,17 @@ use std::collections::BTreeMap; use std::collections::btree_map::Entry; +use std::collections::HashSet; use log::trace; +use rustc_index::vec::Idx; use rustc_middle::ty; use rustc_target::abi::{Size, HasDataLayout}; use crate::{ HelpersEvalContextExt, InterpResult, MPlaceTy, Scalar, StackPopCleanup, Tag, ThreadId, - ThreadState, ThreadsEvalContextExt, + ThreadsEvalContextExt, }; pub type TlsKey = u128; @@ -32,11 +34,11 @@ pub struct TlsData<'tcx> { /// pthreads-style thread-local storage. keys: BTreeMap>, - /// A single global dtor (that's how things work on macOS) with a data argument. - global_dtor: Option<(ty::Instance<'tcx>, Scalar)>, + /// A single global per thread dtor (that's how things work on macOS) with a data argument. + global_dtors: BTreeMap, Scalar)>, /// Whether we are in the "destruct" phase, during which some operations are UB. - dtors_running: bool, + dtors_running: HashSet, } impl<'tcx> Default for TlsData<'tcx> { @@ -44,8 +46,8 @@ impl<'tcx> Default for TlsData<'tcx> { TlsData { next_key: 1, // start with 1 as we must not use 0 on Windows keys: Default::default(), - global_dtor: None, - dtors_running: false, + global_dtors: Default::default(), + dtors_running: Default::default(), } } } @@ -112,16 +114,15 @@ impl<'tcx> TlsData<'tcx> { } } - pub fn set_global_dtor(&mut self, dtor: ty::Instance<'tcx>, data: Scalar) -> InterpResult<'tcx> { - if self.dtors_running { + /// Set global dtor for the given thread. + pub fn set_global_dtor(&mut self, thread: ThreadId, dtor: ty::Instance<'tcx>, data: Scalar) -> InterpResult<'tcx> { + if self.dtors_running.contains(&thread) { // UB, according to libstd docs. throw_ub_format!("setting global destructor while destructors are already running"); } - if self.global_dtor.is_some() { - throw_unsup_format!("setting more than one global destructor is not supported"); + if self.global_dtors.insert(thread, (dtor, data)).is_some() { + throw_unsup_format!("setting more than one global destructor for the same thread is not supported"); } - - self.global_dtor = Some((dtor, data)); Ok(()) } @@ -148,7 +149,7 @@ impl<'tcx> TlsData<'tcx> { &mut self, key: Option, thread_id: ThreadId, - ) -> Option<(ty::Instance<'tcx>, ThreadId, Scalar, TlsKey)> { + ) -> Option<(ty::Instance<'tcx>, Scalar, TlsKey)> { use std::collections::Bound::*; let thread_local = &mut self.keys; @@ -161,9 +162,9 @@ impl<'tcx> TlsData<'tcx> { { match data.entry(thread_id) { Entry::Occupied(entry) => { - let (thread_id, data_scalar) = entry.remove_entry(); + let data_scalar = entry.remove(); if let Some(dtor) = dtor { - let ret = Some((*dtor, thread_id, data_scalar, key)); + let ret = Some((*dtor, data_scalar, key)); return ret; } } @@ -176,41 +177,61 @@ impl<'tcx> TlsData<'tcx> { impl<'mir, 'tcx: 'mir> EvalContextExt<'mir, 'tcx> for crate::MiriEvalContext<'mir, 'tcx> {} pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx> { - /// Run TLS destructors for all threads. - fn run_tls_dtors(&mut self) -> InterpResult<'tcx> { + + /// Run TLS destructors for the main thread on Windows. The implementation + /// assumes that we do not support concurrency on Windows yet. + /// + /// Note: on non-Windows OS this function is a no-op. + fn run_windows_tls_dtors(&mut self) -> InterpResult<'tcx> { let this = self.eval_context_mut(); - assert!(!this.machine.tls.dtors_running, "running TLS dtors twice"); - this.machine.tls.dtors_running = true; - - if this.tcx.sess.target.target.target_os == "windows" { - // Windows has a special magic linker section that is run on certain events. - // Instead of searching for that section and supporting arbitrary hooks in there - // (that would be basically https://github.com/rust-lang/miri/issues/450), - // we specifically look up the static in libstd that we know is placed - // in that section. - let thread_callback = this.eval_path_scalar(&["std", "sys", "windows", "thread_local", "p_thread_callback"])?; - let thread_callback = this.memory.get_fn(thread_callback.not_undef()?)?.as_instance()?; - - // The signature of this function is `unsafe extern "system" fn(h: c::LPVOID, dwReason: c::DWORD, pv: c::LPVOID)`. - let reason = this.eval_path_scalar(&["std", "sys", "windows", "c", "DLL_PROCESS_DETACH"])?; - let ret_place = MPlaceTy::dangling(this.machine.layouts.unit, this).into(); - this.call_function( - thread_callback, - &[Scalar::null_ptr(this).into(), reason.into(), Scalar::null_ptr(this).into()], - Some(ret_place), - StackPopCleanup::None { cleanup: true }, - )?; - - // step until out of stackframes - this.run()?; - - // Windows doesn't have other destructors. + if this.tcx.sess.target.target.target_os != "windows" { return Ok(()); } + let active_thread = this.get_active_thread()?; + assert_eq!(active_thread.index(), 0, "concurrency on Windows not supported"); + assert!(!this.machine.tls.dtors_running.contains(&active_thread), "running TLS dtors twice"); + this.machine.tls.dtors_running.insert(active_thread); + // Windows has a special magic linker section that is run on certain events. + // Instead of searching for that section and supporting arbitrary hooks in there + // (that would be basically https://github.com/rust-lang/miri/issues/450), + // we specifically look up the static in libstd that we know is placed + // in that section. + let thread_callback = this.eval_path_scalar(&["std", "sys", "windows", "thread_local", "p_thread_callback"])?; + let thread_callback = this.memory.get_fn(thread_callback.not_undef()?)?.as_instance()?; + + // The signature of this function is `unsafe extern "system" fn(h: c::LPVOID, dwReason: c::DWORD, pv: c::LPVOID)`. + let reason = this.eval_path_scalar(&["std", "sys", "windows", "c", "DLL_PROCESS_DETACH"])?; + let ret_place = MPlaceTy::dangling(this.machine.layouts.unit, this).into(); + this.call_function( + thread_callback, + &[Scalar::null_ptr(this).into(), reason.into(), Scalar::null_ptr(this).into()], + Some(ret_place), + StackPopCleanup::None { cleanup: true }, + )?; + + // step until out of stackframes + this.run()?; + + // Windows doesn't have other destructors. + Ok(()) + } + + /// Run TLS destructors for the active thread. + /// + /// Note: on Windows OS this function is a no-op because we do not support + /// concurrency on Windows yet. + fn run_tls_dtors_for_active_thread(&mut self) -> InterpResult<'tcx> { + let this = self.eval_context_mut(); + if this.tcx.sess.target.target.target_os == "windows" { + return Ok(()); + } + let thread_id = this.get_active_thread()?; + assert!(!this.machine.tls.dtors_running.contains(&thread_id), "running TLS dtors twice"); + this.machine.tls.dtors_running.insert(thread_id); // The macOS global dtor runs "before any TLS slots get freed", so do that first. - if let Some((instance, data)) = this.machine.tls.global_dtor { - trace!("Running global dtor {:?} on {:?}", instance, data); + if let Some(&(instance, data)) = this.machine.tls.global_dtors.get(&thread_id) { + trace!("Running global dtor {:?} on {:?} at {:?}", instance, data, thread_id); let ret_place = MPlaceTy::dangling(this.machine.layouts.unit, this).into(); this.call_function( @@ -224,35 +245,31 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx this.run()?; } - // Now run the "keyed" destructors. - for (thread_id, thread_state) in this.get_all_thread_ids_with_states() { - assert!(thread_state == ThreadState::Terminated, - "TLS destructors should be executed after all threads terminated."); - this.set_active_thread(thread_id)?; - let mut dtor = this.machine.tls.fetch_tls_dtor(None, thread_id); - while let Some((instance, thread_id, ptr, key)) = dtor { - trace!("Running TLS dtor {:?} on {:?} at {:?}", instance, ptr, thread_id); - assert!(!this.is_null(ptr).unwrap(), "Data can't be NULL when dtor is called!"); + assert!(this.has_terminated(thread_id)?, "running TLS dtors for non-terminated thread"); + let mut dtor = this.machine.tls.fetch_tls_dtor(None, thread_id); + while let Some((instance, ptr, key)) = dtor { + trace!("Running TLS dtor {:?} on {:?} at {:?}", instance, ptr, thread_id); + assert!(!this.is_null(ptr).unwrap(), "Data can't be NULL when dtor is called!"); - let ret_place = MPlaceTy::dangling(this.layout_of(this.tcx.mk_unit())?, this).into(); - this.call_function( - instance, - &[ptr.into()], - Some(ret_place), - StackPopCleanup::None { cleanup: true }, - )?; + let ret_place = MPlaceTy::dangling(this.machine.layouts.unit, this).into(); + this.call_function( + instance, + &[ptr.into()], + Some(ret_place), + StackPopCleanup::None { cleanup: true }, + )?; - // step until out of stackframes - this.run()?; + // step until out of stackframes + this.run()?; - // Fetch next dtor after `key`. - dtor = match this.machine.tls.fetch_tls_dtor(Some(key), thread_id) { - dtor @ Some(_) => dtor, - // We ran each dtor once, start over from the beginning. - None => this.machine.tls.fetch_tls_dtor(None, thread_id), - }; - } + // Fetch next dtor after `key`. + dtor = match this.machine.tls.fetch_tls_dtor(Some(key), thread_id) { + dtor @ Some(_) => dtor, + // We ran each dtor once, start over from the beginning. + None => this.machine.tls.fetch_tls_dtor(None, thread_id), + }; } + Ok(()) } } diff --git a/src/thread.rs b/src/thread.rs index 72584b726557..d40b2a176e73 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -17,6 +17,16 @@ use rustc_middle::{ use crate::*; +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum SchedulingAction { + /// Execute step on the active thread. + ExecuteStep, + /// Execute destructors of the active thread. + ExecuteDtors, + /// Stop the program. + Stop, +} + /// A thread identifier. #[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)] pub struct ThreadId(usize); @@ -197,6 +207,11 @@ impl<'mir, 'tcx: 'mir> ThreadManager<'mir, 'tcx> { self.active_thread } + /// Has the given thread terminated? + fn has_terminated(&self, thread_id: ThreadId) -> bool { + self.threads[thread_id].state == ThreadState::Terminated + } + /// Get the borrow of the currently active thread. fn active_thread_mut(&mut self) -> &mut Thread<'mir, 'tcx> { &mut self.threads[self.active_thread] @@ -234,11 +249,6 @@ impl<'mir, 'tcx: 'mir> ThreadManager<'mir, 'tcx> { self.active_thread_mut().thread_name = Some(new_thread_name); } - /// Get ids and states of all threads ever allocated. - fn get_all_thread_ids_with_states(&self) -> Vec<(ThreadId, ThreadState)> { - self.threads.iter_enumerated().map(|(id, thread)| (id, thread.state)).collect() - } - /// Allocate a new blockset id. fn create_blockset(&mut self) -> BlockSetId { self.blockset_counter = self.blockset_counter.checked_add(1).unwrap(); @@ -265,10 +275,8 @@ impl<'mir, 'tcx: 'mir> ThreadManager<'mir, 'tcx> { None } - /// Decide which thread to run next. - /// - /// Returns `false` if all threads terminated. - fn schedule(&mut self) -> InterpResult<'tcx, bool> { + /// Decide which action to take next and on which thread. + fn schedule(&mut self) -> InterpResult<'tcx, SchedulingAction> { if self.threads[self.active_thread].check_terminated() { // Check if we need to unblock any threads. for (i, thread) in self.threads.iter_enumerated_mut() { @@ -277,18 +285,19 @@ impl<'mir, 'tcx: 'mir> ThreadManager<'mir, 'tcx> { thread.state = ThreadState::Enabled; } } + return Ok(SchedulingAction::ExecuteDtors); } if self.threads[self.active_thread].state == ThreadState::Enabled { - return Ok(true); + return Ok(SchedulingAction::ExecuteStep); } if let Some(enabled_thread) = self.threads.iter().position(|thread| thread.state == ThreadState::Enabled) { self.active_thread = ThreadId::new(enabled_thread); - return Ok(true); + return Ok(SchedulingAction::ExecuteStep); } if self.threads.iter().all(|thread| thread.state == ThreadState::Terminated) { - Ok(false) + Ok(SchedulingAction::Stop) } else { throw_machine_stop!(TerminationInfo::Deadlock); } @@ -409,6 +418,11 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx Ok(this.machine.threads.get_active_thread_id()) } + fn has_terminated(&self, thread_id: ThreadId) -> InterpResult<'tcx, bool> { + let this = self.eval_context_ref(); + Ok(this.machine.threads.has_terminated(thread_id)) + } + fn active_thread_stack(&self) -> &[Frame<'mir, 'tcx, Tag, FrameData<'tcx>>] { let this = self.eval_context_ref(); this.machine.threads.active_thread_stack() @@ -424,11 +438,6 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx Ok(this.machine.threads.set_thread_name(new_thread_name)) } - fn get_all_thread_ids_with_states(&mut self) -> Vec<(ThreadId, ThreadState)> { - let this = self.eval_context_mut(); - this.machine.threads.get_all_thread_ids_with_states() - } - fn create_blockset(&mut self) -> InterpResult<'tcx, BlockSetId> { let this = self.eval_context_mut(); Ok(this.machine.threads.create_blockset()) @@ -444,10 +453,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx Ok(this.machine.threads.unblock_random_thread(set)) } - /// Decide which thread to run next. - /// - /// Returns `false` if all threads terminated. - fn schedule(&mut self) -> InterpResult<'tcx, bool> { + /// Decide which action to take next and on which thread. + fn schedule(&mut self) -> InterpResult<'tcx, SchedulingAction> { let this = self.eval_context_mut(); this.machine.threads.schedule() }