From 99433a1ffdb1724cbafeafba88d6d52fee579bd1 Mon Sep 17 00:00:00 2001 From: Ralf Jung Date: Fri, 26 May 2017 20:02:51 -0700 Subject: [PATCH] improve fn pointer signature check to allow some casts that should be permitted Also properly check the "non-capturing Fn to fn" case --- src/terminator/mod.rs | 74 ++++++++++++++++++++++++++---- tests/compile-fail/cast_fn_ptr2.rs | 9 ++++ tests/run-pass/cast_fn_ptr.rs | 9 ++++ 3 files changed, 83 insertions(+), 9 deletions(-) create mode 100644 tests/compile-fail/cast_fn_ptr2.rs create mode 100644 tests/run-pass/cast_fn_ptr.rs diff --git a/src/terminator/mod.rs b/src/terminator/mod.rs index 09361aa43b3b..0fea3a174be1 100644 --- a/src/terminator/mod.rs +++ b/src/terminator/mod.rs @@ -72,15 +72,8 @@ impl<'a, 'tcx> EvalContext<'a, 'tcx> { ty::TyFnDef(_, _, real_sig) => { let sig = self.erase_lifetimes(&sig); let real_sig = self.erase_lifetimes(&real_sig); - match instance.def { - // FIXME: this needs checks for weird transmutes - // we need to bail here, because noncapturing closures as fn ptrs fail the checks - ty::InstanceDef::ClosureOnceShim{..} => {} - _ => if sig.abi != real_sig.abi || - sig.variadic != real_sig.variadic || - sig.inputs_and_output != real_sig.inputs_and_output { - return Err(EvalError::FunctionPointerTyMismatch(real_sig, sig)); - }, + if !self.check_sig_compat(sig, real_sig)? { + return Err(EvalError::FunctionPointerTyMismatch(real_sig, sig)); } }, ref other => bug!("instance def ty: {:?}", other), @@ -138,6 +131,69 @@ impl<'a, 'tcx> EvalContext<'a, 'tcx> { Ok(()) } + /// Decides whether it is okay to call the method with signature `real_sig` using signature `sig` + fn check_sig_compat( + &mut self, + sig: ty::FnSig<'tcx>, + real_sig: ty::FnSig<'tcx>, + ) -> EvalResult<'tcx, bool> { + fn check_ty_compat<'tcx>( + ty: ty::Ty<'tcx>, + real_ty: ty::Ty<'tcx>, + ) -> bool { + if ty == real_ty { return true; } // This is actually a fast pointer comparison + return match (&ty.sty, &real_ty.sty) { + // Permit changing the pointer type of raw pointers and references as well as + // mutability of raw pointers. + // TODO: Should not be allowed when fat pointers are involved. + (&TypeVariants::TyRawPtr(_), &TypeVariants::TyRawPtr(_)) => true, + (&TypeVariants::TyRef(_, _), &TypeVariants::TyRef(_, _)) => + ty.is_mutable_pointer() == real_ty.is_mutable_pointer(), + // rule out everything else + _ => false + } + } + + if sig.abi == real_sig.abi && + sig.variadic == real_sig.variadic && + sig.inputs_and_output.len() == real_sig.inputs_and_output.len() && + sig.inputs_and_output.iter().zip(real_sig.inputs_and_output).all(|(ty, real_ty)| check_ty_compat(ty, real_ty)) { + // Definitely good. + return Ok(true); + } + + if sig.variadic || real_sig.variadic { + // We're not touching this + return Ok(false); + } + + // We need to allow what comes up when a non-capturing closure is cast to a fn(). + match (sig.abi, real_sig.abi) { + (Abi::Rust, Abi::RustCall) // check the ABIs. This makes the test here non-symmetric. + if check_ty_compat(sig.output(), real_sig.output()) && real_sig.inputs_and_output.len() == 3 => { + // First argument of real_sig must be a ZST + let fst_ty = real_sig.inputs_and_output[0]; + let layout = self.type_layout(fst_ty)?; + let size = layout.size(&self.tcx.data_layout).bytes(); + if size == 0 { + // Second argument must be a tuple matching the argument list of sig + let snd_ty = real_sig.inputs_and_output[1]; + match snd_ty.sty { + TypeVariants::TyTuple(tys, _) if sig.inputs().len() == tys.len() => + if sig.inputs().iter().zip(tys).all(|(ty, real_ty)| check_ty_compat(ty, real_ty)) { + return Ok(true) + }, + _ => {} + } + } + } + _ => {} + }; + + // Nope, this doesn't work. + return Ok(false); + } + fn eval_fn_call( &mut self, instance: ty::Instance<'tcx>, diff --git a/tests/compile-fail/cast_fn_ptr2.rs b/tests/compile-fail/cast_fn_ptr2.rs new file mode 100644 index 000000000000..5d902e1f9aaa --- /dev/null +++ b/tests/compile-fail/cast_fn_ptr2.rs @@ -0,0 +1,9 @@ +fn main() { + fn f(_ : (i32,i32)) {} + + let g = unsafe { + std::mem::transmute::(f) + }; + + g(42) //~ ERROR tried to call a function with sig fn((i32, i32)) through a function pointer of type fn(i32) +} diff --git a/tests/run-pass/cast_fn_ptr.rs b/tests/run-pass/cast_fn_ptr.rs new file mode 100644 index 000000000000..109e8dfc2a02 --- /dev/null +++ b/tests/run-pass/cast_fn_ptr.rs @@ -0,0 +1,9 @@ +fn main() { + fn f(_: *const u8) {} + + let g = unsafe { + std::mem::transmute::(f) + }; + + g(&42 as *const _); +}