From 44840a12bcdbdef7eabd82a6dd5b396fdf7d90ed Mon Sep 17 00:00:00 2001 From: Trevor Gross Date: Mon, 19 Aug 2024 15:33:05 -0500 Subject: [PATCH] Add `Shr` to `u256` Float division requires some shift operations on big integers; implement right shift here. --- library/compiler-builtins/src/int/big.rs | 37 +++++++++- .../compiler-builtins/testcrate/tests/big.rs | 73 +++++++++++++++++++ 2 files changed, 109 insertions(+), 1 deletion(-) diff --git a/library/compiler-builtins/src/int/big.rs b/library/compiler-builtins/src/int/big.rs index 019dd728b5d9..e565da897203 100644 --- a/library/compiler-builtins/src/int/big.rs +++ b/library/compiler-builtins/src/int/big.rs @@ -93,7 +93,7 @@ macro_rules! impl_common { type Output = Self; fn shl(self, rhs: u32) -> Self::Output { - todo!() + unimplemented!("only used to meet trait bounds") } } }; @@ -102,6 +102,41 @@ macro_rules! impl_common { impl_common!(i256); impl_common!(u256); +impl ops::Shr for u256 { + type Output = Self; + + fn shr(self, rhs: u32) -> Self::Output { + assert!(rhs < Self::BITS, "attempted to shift right with overflow"); + + if rhs == 0 { + return self; + } + + let mut ret = self; + let byte_shift = rhs / 64; + let bit_shift = rhs % 64; + + for idx in 0..4 { + let base_idx = idx + byte_shift as usize; + + let Some(base) = ret.0.get(base_idx) else { + ret.0[idx] = 0; + continue; + }; + + let mut new_val = base >> bit_shift; + + if let Some(new) = ret.0.get(base_idx + 1) { + new_val |= new.overflowing_shl(64 - bit_shift).0; + } + + ret.0[idx] = new_val; + } + + ret + } +} + macro_rules! word { (1, $val:expr) => { (($val >> (32 * 3)) & Self::from(WORD_LO_MASK)) as u64 diff --git a/library/compiler-builtins/testcrate/tests/big.rs b/library/compiler-builtins/testcrate/tests/big.rs index 128b5ddfd6d5..595f62256079 100644 --- a/library/compiler-builtins/testcrate/tests/big.rs +++ b/library/compiler-builtins/testcrate/tests/big.rs @@ -59,3 +59,76 @@ fn widen_mul_u128() { } assert!(errors.is_empty()); } + +#[test] +fn not_u128() { + assert_eq!(!u256::ZERO, u256::MAX); +} + +#[test] +fn shr_u128() { + let only_low = [ + 1, + u16::MAX.into(), + u32::MAX.into(), + u64::MAX.into(), + u128::MAX, + ]; + + let mut errors = Vec::new(); + + for a in only_low { + for perturb in 0..10 { + let a = a.saturating_add(perturb); + for shift in 0..128 { + let res = a.widen() >> shift; + let expected = (a >> shift).widen(); + if res != expected { + errors.push((a.widen(), shift, res, expected)); + } + } + } + } + + let check = [ + ( + u256::MAX, + 1, + u256([u64::MAX, u64::MAX, u64::MAX, u64::MAX >> 1]), + ), + ( + u256::MAX, + 5, + u256([u64::MAX, u64::MAX, u64::MAX, u64::MAX >> 5]), + ), + (u256::MAX, 63, u256([u64::MAX, u64::MAX, u64::MAX, 1])), + (u256::MAX, 64, u256([u64::MAX, u64::MAX, u64::MAX, 0])), + (u256::MAX, 65, u256([u64::MAX, u64::MAX, u64::MAX >> 1, 0])), + (u256::MAX, 127, u256([u64::MAX, u64::MAX, 1, 0])), + (u256::MAX, 128, u256([u64::MAX, u64::MAX, 0, 0])), + (u256::MAX, 129, u256([u64::MAX, u64::MAX >> 1, 0, 0])), + (u256::MAX, 191, u256([u64::MAX, 1, 0, 0])), + (u256::MAX, 192, u256([u64::MAX, 0, 0, 0])), + (u256::MAX, 193, u256([u64::MAX >> 1, 0, 0, 0])), + (u256::MAX, 191, u256([u64::MAX, 1, 0, 0])), + (u256::MAX, 254, u256([0b11, 0, 0, 0])), + (u256::MAX, 255, u256([1, 0, 0, 0])), + ]; + + for (input, shift, expected) in check { + let res = input >> shift; + if res != expected { + errors.push((input, shift, res, expected)); + } + } + + for (a, b, res, expected) in &errors { + eprintln!( + "FAILURE: {} >> {b} = {} got {}", + hexu(*a), + hexu(*expected), + hexu(*res), + ); + } + assert!(errors.is_empty()); +}