Add Shr to u256

Float division requires some shift operations on big integers; implement
right shift here.
This commit is contained in:
Trevor Gross 2024-08-19 15:33:05 -05:00
parent cad966f041
commit 44840a12bc
2 changed files with 109 additions and 1 deletions

View file

@ -93,7 +93,7 @@ macro_rules! impl_common {
type Output = Self;
fn shl(self, rhs: u32) -> Self::Output {
todo!()
unimplemented!("only used to meet trait bounds")
}
}
};
@ -102,6 +102,41 @@ macro_rules! impl_common {
impl_common!(i256);
impl_common!(u256);
impl ops::Shr<u32> for u256 {
type Output = Self;
fn shr(self, rhs: u32) -> Self::Output {
assert!(rhs < Self::BITS, "attempted to shift right with overflow");
if rhs == 0 {
return self;
}
let mut ret = self;
let byte_shift = rhs / 64;
let bit_shift = rhs % 64;
for idx in 0..4 {
let base_idx = idx + byte_shift as usize;
let Some(base) = ret.0.get(base_idx) else {
ret.0[idx] = 0;
continue;
};
let mut new_val = base >> bit_shift;
if let Some(new) = ret.0.get(base_idx + 1) {
new_val |= new.overflowing_shl(64 - bit_shift).0;
}
ret.0[idx] = new_val;
}
ret
}
}
macro_rules! word {
(1, $val:expr) => {
(($val >> (32 * 3)) & Self::from(WORD_LO_MASK)) as u64

View file

@ -59,3 +59,76 @@ fn widen_mul_u128() {
}
assert!(errors.is_empty());
}
#[test]
fn not_u128() {
assert_eq!(!u256::ZERO, u256::MAX);
}
#[test]
fn shr_u128() {
let only_low = [
1,
u16::MAX.into(),
u32::MAX.into(),
u64::MAX.into(),
u128::MAX,
];
let mut errors = Vec::new();
for a in only_low {
for perturb in 0..10 {
let a = a.saturating_add(perturb);
for shift in 0..128 {
let res = a.widen() >> shift;
let expected = (a >> shift).widen();
if res != expected {
errors.push((a.widen(), shift, res, expected));
}
}
}
}
let check = [
(
u256::MAX,
1,
u256([u64::MAX, u64::MAX, u64::MAX, u64::MAX >> 1]),
),
(
u256::MAX,
5,
u256([u64::MAX, u64::MAX, u64::MAX, u64::MAX >> 5]),
),
(u256::MAX, 63, u256([u64::MAX, u64::MAX, u64::MAX, 1])),
(u256::MAX, 64, u256([u64::MAX, u64::MAX, u64::MAX, 0])),
(u256::MAX, 65, u256([u64::MAX, u64::MAX, u64::MAX >> 1, 0])),
(u256::MAX, 127, u256([u64::MAX, u64::MAX, 1, 0])),
(u256::MAX, 128, u256([u64::MAX, u64::MAX, 0, 0])),
(u256::MAX, 129, u256([u64::MAX, u64::MAX >> 1, 0, 0])),
(u256::MAX, 191, u256([u64::MAX, 1, 0, 0])),
(u256::MAX, 192, u256([u64::MAX, 0, 0, 0])),
(u256::MAX, 193, u256([u64::MAX >> 1, 0, 0, 0])),
(u256::MAX, 191, u256([u64::MAX, 1, 0, 0])),
(u256::MAX, 254, u256([0b11, 0, 0, 0])),
(u256::MAX, 255, u256([1, 0, 0, 0])),
];
for (input, shift, expected) in check {
let res = input >> shift;
if res != expected {
errors.push((input, shift, res, expected));
}
}
for (a, b, res, expected) in &errors {
eprintln!(
"FAILURE: {} >> {b} = {} got {}",
hexu(*a),
hexu(*expected),
hexu(*res),
);
}
assert!(errors.is_empty());
}