From d335ea91b5bc6cf184229c49e313e70f3ab180bd Mon Sep 17 00:00:00 2001 From: sayantn Date: Thu, 9 Oct 2025 17:05:34 +0530 Subject: [PATCH] Refactor implementation of float minmax intrinsics --- .../src/interpret/intrinsics.rs | 135 +++++++++--------- .../src/interpret/intrinsics/simd.rs | 54 ++----- 2 files changed, 81 insertions(+), 108 deletions(-) diff --git a/compiler/rustc_const_eval/src/interpret/intrinsics.rs b/compiler/rustc_const_eval/src/interpret/intrinsics.rs index f0712644465a..90f14a69fac3 100644 --- a/compiler/rustc_const_eval/src/interpret/intrinsics.rs +++ b/compiler/rustc_const_eval/src/interpret/intrinsics.rs @@ -34,6 +34,26 @@ enum MulAddType { Nondeterministic, } +#[derive(Copy, Clone)] +pub(crate) enum MinMax { + /// The IEEE `Minimum` operation - see `f32::minimum` etc + /// In particular, `-0.0` is considered smaller than `+0.0` and + /// if either input is NaN, the result is NaN. + Minimum, + /// The IEEE `MinNum` operation - see `f32::min` etc + /// In particular, if the inputs are `-0.0` and `+0.0`, the result is non-deterministic, + /// and is one argument is NaN, the other one is returned. + MinNum, + /// The IEEE `Maximum` operation - see `f32::maximum` etc + /// In particular, `-0.0` is considered smaller than `+0.0` and + /// if either input is NaN, the result is NaN. + Maximum, + /// The IEEE `MaxNum` operation - see `f32::max` etc + /// In particular, if the inputs are `-0.0` and `+0.0`, the result is non-deterministic, + /// and is one argument is NaN, the other one is returned. + MaxNum, +} + /// Directly returns an `Allocation` containing an absolute path representation of the given type. pub(crate) fn alloc_type_name<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> (AllocId, u64) { let path = crate::util::type_name(tcx, ty); @@ -513,25 +533,33 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { self.write_scalar(Scalar::from_target_usize(align.bytes(), self), dest)?; } - sym::minnumf16 => self.float_min_intrinsic::(args, dest)?, - sym::minnumf32 => self.float_min_intrinsic::(args, dest)?, - sym::minnumf64 => self.float_min_intrinsic::(args, dest)?, - sym::minnumf128 => self.float_min_intrinsic::(args, dest)?, + sym::minnumf16 => self.float_minmax_intrinsic::(args, MinMax::MinNum, dest)?, + sym::minnumf32 => self.float_minmax_intrinsic::(args, MinMax::MinNum, dest)?, + sym::minnumf64 => self.float_minmax_intrinsic::(args, MinMax::MinNum, dest)?, + sym::minnumf128 => self.float_minmax_intrinsic::(args, MinMax::MinNum, dest)?, - sym::minimumf16 => self.float_minimum_intrinsic::(args, dest)?, - sym::minimumf32 => self.float_minimum_intrinsic::(args, dest)?, - sym::minimumf64 => self.float_minimum_intrinsic::(args, dest)?, - sym::minimumf128 => self.float_minimum_intrinsic::(args, dest)?, + sym::minimumf16 => self.float_minmax_intrinsic::(args, MinMax::Minimum, dest)?, + sym::minimumf32 => { + self.float_minmax_intrinsic::(args, MinMax::Minimum, dest)? + } + sym::minimumf64 => { + self.float_minmax_intrinsic::(args, MinMax::Minimum, dest)? + } + sym::minimumf128 => self.float_minmax_intrinsic::(args, MinMax::Minimum, dest)?, - sym::maxnumf16 => self.float_max_intrinsic::(args, dest)?, - sym::maxnumf32 => self.float_max_intrinsic::(args, dest)?, - sym::maxnumf64 => self.float_max_intrinsic::(args, dest)?, - sym::maxnumf128 => self.float_max_intrinsic::(args, dest)?, + sym::maxnumf16 => self.float_minmax_intrinsic::(args, MinMax::MaxNum, dest)?, + sym::maxnumf32 => self.float_minmax_intrinsic::(args, MinMax::MaxNum, dest)?, + sym::maxnumf64 => self.float_minmax_intrinsic::(args, MinMax::MaxNum, dest)?, + sym::maxnumf128 => self.float_minmax_intrinsic::(args, MinMax::MaxNum, dest)?, - sym::maximumf16 => self.float_maximum_intrinsic::(args, dest)?, - sym::maximumf32 => self.float_maximum_intrinsic::(args, dest)?, - sym::maximumf64 => self.float_maximum_intrinsic::(args, dest)?, - sym::maximumf128 => self.float_maximum_intrinsic::(args, dest)?, + sym::maximumf16 => self.float_minmax_intrinsic::(args, MinMax::Maximum, dest)?, + sym::maximumf32 => { + self.float_minmax_intrinsic::(args, MinMax::Maximum, dest)? + } + sym::maximumf64 => { + self.float_minmax_intrinsic::(args, MinMax::Maximum, dest)? + } + sym::maximumf128 => self.float_minmax_intrinsic::(args, MinMax::Maximum, dest)?, sym::copysignf16 => self.float_copysign_intrinsic::(args, dest)?, sym::copysignf32 => self.float_copysign_intrinsic::(args, dest)?, @@ -936,76 +964,45 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { interp_ok(Scalar::from_bool(lhs_bytes == rhs_bytes)) } - fn float_min_intrinsic( - &mut self, - args: &[OpTy<'tcx, M::Provenance>], - dest: &PlaceTy<'tcx, M::Provenance>, - ) -> InterpResult<'tcx, ()> + fn float_minmax( + &self, + a: Scalar, + b: Scalar, + op: MinMax, + ) -> InterpResult<'tcx, Scalar> where F: rustc_apfloat::Float + rustc_apfloat::FloatConvert + Into>, { - let a: F = self.read_scalar(&args[0])?.to_float()?; - let b: F = self.read_scalar(&args[1])?.to_float()?; - let res = if a == b { + let a: F = a.to_float()?; + let b: F = b.to_float()?; + let res = if matches!(op, MinMax::MinNum | MinMax::MaxNum) && a == b { // They are definitely not NaN (those are never equal), but they could be `+0` and `-0`. // Let the machine decide which one to return. M::equal_float_min_max(self, a, b) } else { - self.adjust_nan(a.min(b), &[a, b]) + let result = match op { + MinMax::Minimum => a.minimum(b), + MinMax::MinNum => a.min(b), + MinMax::Maximum => a.maximum(b), + MinMax::MaxNum => a.max(b), + }; + self.adjust_nan(result, &[a, b]) }; - self.write_scalar(res, dest)?; - interp_ok(()) + + interp_ok(res.into()) } - fn float_max_intrinsic( + fn float_minmax_intrinsic( &mut self, args: &[OpTy<'tcx, M::Provenance>], + op: MinMax, dest: &PlaceTy<'tcx, M::Provenance>, ) -> InterpResult<'tcx, ()> where F: rustc_apfloat::Float + rustc_apfloat::FloatConvert + Into>, { - let a: F = self.read_scalar(&args[0])?.to_float()?; - let b: F = self.read_scalar(&args[1])?.to_float()?; - let res = if a == b { - // They are definitely not NaN (those are never equal), but they could be `+0` and `-0`. - // Let the machine decide which one to return. - M::equal_float_min_max(self, a, b) - } else { - self.adjust_nan(a.max(b), &[a, b]) - }; - self.write_scalar(res, dest)?; - interp_ok(()) - } - - fn float_minimum_intrinsic( - &mut self, - args: &[OpTy<'tcx, M::Provenance>], - dest: &PlaceTy<'tcx, M::Provenance>, - ) -> InterpResult<'tcx, ()> - where - F: rustc_apfloat::Float + rustc_apfloat::FloatConvert + Into>, - { - let a: F = self.read_scalar(&args[0])?.to_float()?; - let b: F = self.read_scalar(&args[1])?.to_float()?; - let res = a.minimum(b); - let res = self.adjust_nan(res, &[a, b]); - self.write_scalar(res, dest)?; - interp_ok(()) - } - - fn float_maximum_intrinsic( - &mut self, - args: &[OpTy<'tcx, M::Provenance>], - dest: &PlaceTy<'tcx, M::Provenance>, - ) -> InterpResult<'tcx, ()> - where - F: rustc_apfloat::Float + rustc_apfloat::FloatConvert + Into>, - { - let a: F = self.read_scalar(&args[0])?.to_float()?; - let b: F = self.read_scalar(&args[1])?.to_float()?; - let res = a.maximum(b); - let res = self.adjust_nan(res, &[a, b]); + let res = + self.float_minmax::(self.read_scalar(&args[0])?, self.read_scalar(&args[1])?, op)?; self.write_scalar(res, dest)?; interp_ok(()) } diff --git a/compiler/rustc_const_eval/src/interpret/intrinsics/simd.rs b/compiler/rustc_const_eval/src/interpret/intrinsics/simd.rs index 84489028e190..13b6623accd2 100644 --- a/compiler/rustc_const_eval/src/interpret/intrinsics/simd.rs +++ b/compiler/rustc_const_eval/src/interpret/intrinsics/simd.rs @@ -9,17 +9,11 @@ use rustc_span::{Symbol, sym}; use tracing::trace; use super::{ - ImmTy, InterpCx, InterpResult, Machine, MulAddType, OpTy, PlaceTy, Provenance, Scalar, Size, - interp_ok, throw_ub_format, + ImmTy, InterpCx, InterpResult, Machine, MinMax, MulAddType, OpTy, PlaceTy, Provenance, Scalar, + Size, interp_ok, throw_ub_format, }; use crate::interpret::Writeable; -#[derive(Copy, Clone)] -pub(crate) enum MinMax { - Min, - Max, -} - impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { /// Returns `true` if emulation happened. /// Here we implement the intrinsics that are common to all CTFE instances; individual machines can add their own @@ -217,8 +211,8 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { sym::simd_le => Op::MirOp(BinOp::Le), sym::simd_gt => Op::MirOp(BinOp::Gt), sym::simd_ge => Op::MirOp(BinOp::Ge), - sym::simd_fmax => Op::FMinMax(MinMax::Max), - sym::simd_fmin => Op::FMinMax(MinMax::Min), + sym::simd_fmax => Op::FMinMax(MinMax::MaxNum), + sym::simd_fmin => Op::FMinMax(MinMax::MinNum), sym::simd_saturating_add => Op::SaturatingOp(BinOp::Add), sym::simd_saturating_sub => Op::SaturatingOp(BinOp::Sub), sym::simd_arith_offset => Op::WrappingOffset, @@ -310,8 +304,8 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { sym::simd_reduce_xor => Op::MirOp(BinOp::BitXor), sym::simd_reduce_any => Op::MirOpBool(BinOp::BitOr), sym::simd_reduce_all => Op::MirOpBool(BinOp::BitAnd), - sym::simd_reduce_max => Op::MinMax(MinMax::Max), - sym::simd_reduce_min => Op::MinMax(MinMax::Min), + sym::simd_reduce_max => Op::MinMax(MinMax::MaxNum), + sym::simd_reduce_min => Op::MinMax(MinMax::MinNum), _ => unreachable!(), }; @@ -333,10 +327,10 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { if matches!(res.layout.ty.kind(), ty::Float(_)) { ImmTy::from_scalar(self.fminmax_op(mmop, &res, &op)?, res.layout) } else { - // Just boring integers, so NaNs to worry about + // Just boring integers, no NaNs to worry about. let mirop = match mmop { - MinMax::Min => BinOp::Le, - MinMax::Max => BinOp::Ge, + MinMax::MinNum | MinMax::Minimum => BinOp::Le, + MinMax::MaxNum | MinMax::Maximum => BinOp::Ge, }; if self.binary_op(mirop, &res, &op)?.to_scalar().to_bool()? { res @@ -749,12 +743,12 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { interp_ok(true) } - fn fminmax_op( + fn fminmax_op( &self, op: MinMax, - left: &ImmTy<'tcx, Prov>, - right: &ImmTy<'tcx, Prov>, - ) -> InterpResult<'tcx, Scalar> { + left: &ImmTy<'tcx, M::Provenance>, + right: &ImmTy<'tcx, M::Provenance>, + ) -> InterpResult<'tcx, Scalar> { assert_eq!(left.layout.ty, right.layout.ty); let ty::Float(float_ty) = left.layout.ty.kind() else { bug!("fmax operand is not a float") @@ -763,26 +757,8 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { let right = right.to_scalar(); interp_ok(match float_ty { FloatTy::F16 => unimplemented!("f16_f128"), - FloatTy::F32 => { - let left = left.to_f32()?; - let right = right.to_f32()?; - let res = match op { - MinMax::Min => left.min(right), - MinMax::Max => left.max(right), - }; - let res = self.adjust_nan(res, &[left, right]); - Scalar::from_f32(res) - } - FloatTy::F64 => { - let left = left.to_f64()?; - let right = right.to_f64()?; - let res = match op { - MinMax::Min => left.min(right), - MinMax::Max => left.max(right), - }; - let res = self.adjust_nan(res, &[left, right]); - Scalar::from_f64(res) - } + FloatTy::F32 => self.float_minmax::(left, right, op)?, + FloatTy::F64 => self.float_minmax::(left, right, op)?, FloatTy::F128 => unimplemented!("f16_f128"), }) }