Implement rounding for the hex float parsing and prepare to improve error handling

Parsing errors are now bubbled up part of the way, but that needs some
more work.

Rounding should be correct, and the `Status` returned by `parse_any`
should have the correct bits set. These are used for the current (unchanged)
behavior of the surface level functions like `hf64`: panic on invalid inputs, or
values that aren't exactly representable.
This commit is contained in:
quaternic 2025-04-15 03:46:12 +03:00 committed by GitHub
parent 28b6df8603
commit b955cc691e
3 changed files with 410 additions and 116 deletions

View file

@ -3,8 +3,6 @@
use std::cmp::{self, Ordering};
use std::{fmt, ops};
use libm::support::hex_float::parse_any;
use crate::Float;
/// Sometimes verifying float logic is easiest when all values can quickly be checked exhaustively
@ -499,5 +497,6 @@ impl fmt::LowerHex for f8 {
}
pub const fn hf8(s: &str) -> f8 {
f8(parse_any(s, 8, 3) as u8)
let Ok(bits) = libm::support::hex_float::parse_hex_exact(s, 8, 3) else { panic!() };
f8(bits as u8)
}

View file

@ -46,7 +46,7 @@ pub enum Round {
}
/// IEEE 754 exception status flags.
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct Status(u8);
impl Status {
@ -90,16 +90,22 @@ impl Status {
/// True if `UNDERFLOW` is set.
#[cfg_attr(not(feature = "unstable-public-internals"), allow(dead_code))]
pub fn underflow(self) -> bool {
pub const fn underflow(self) -> bool {
self.0 & Self::UNDERFLOW.0 != 0
}
/// True if `OVERFLOW` is set.
#[cfg_attr(not(feature = "unstable-public-internals"), allow(dead_code))]
pub const fn overflow(self) -> bool {
self.0 & Self::OVERFLOW.0 != 0
}
pub fn set_underflow(&mut self, val: bool) {
self.set_flag(val, Self::UNDERFLOW);
}
/// True if `INEXACT` is set.
pub fn inexact(self) -> bool {
pub const fn inexact(self) -> bool {
self.0 & Self::INEXACT.0 != 0
}
@ -114,4 +120,8 @@ impl Status {
self.0 &= !mask.0;
}
}
pub(crate) const fn with(self, rhs: Self) -> Self {
Self(self.0 | rhs.0)
}
}

View file

@ -2,149 +2,260 @@
use core::fmt;
use super::{Float, f32_from_bits, f64_from_bits};
use super::{Float, Round, Status, f32_from_bits, f64_from_bits};
/// Construct a 16-bit float from hex float representation (C-style)
#[cfg(f16_enabled)]
pub const fn hf16(s: &str) -> f16 {
f16::from_bits(parse_any(s, 16, 10) as u16)
match parse_hex_exact(s, 16, 10) {
Ok(bits) => f16::from_bits(bits as u16),
Err(HexFloatParseError(s)) => panic!("{}", s),
}
}
/// Construct a 32-bit float from hex float representation (C-style)
#[allow(unused)]
pub const fn hf32(s: &str) -> f32 {
f32_from_bits(parse_any(s, 32, 23) as u32)
match parse_hex_exact(s, 32, 23) {
Ok(bits) => f32_from_bits(bits as u32),
Err(HexFloatParseError(s)) => panic!("{}", s),
}
}
/// Construct a 64-bit float from hex float representation (C-style)
pub const fn hf64(s: &str) -> f64 {
f64_from_bits(parse_any(s, 64, 52) as u64)
match parse_hex_exact(s, 64, 52) {
Ok(bits) => f64_from_bits(bits as u64),
Err(HexFloatParseError(s)) => panic!("{}", s),
}
}
/// Construct a 128-bit float from hex float representation (C-style)
#[cfg(f128_enabled)]
pub const fn hf128(s: &str) -> f128 {
f128::from_bits(parse_any(s, 128, 112))
match parse_hex_exact(s, 128, 112) {
Ok(bits) => f128::from_bits(bits),
Err(HexFloatParseError(s)) => panic!("{}", s),
}
}
#[derive(Copy, Clone, Debug)]
pub struct HexFloatParseError(&'static str);
/// Parses any float to its bitwise representation, returning an error if it cannot be represented exactly
pub const fn parse_hex_exact(
s: &str,
bits: u32,
sig_bits: u32,
) -> Result<u128, HexFloatParseError> {
match parse_any(s, bits, sig_bits, Round::Nearest) {
Err(e) => Err(e),
Ok((bits, Status::OK)) => Ok(bits),
Ok((_, status)) if status.overflow() => Err(HexFloatParseError("the value is too huge")),
Ok((_, status)) if status.underflow() => Err(HexFloatParseError("the value is too tiny")),
Ok((_, status)) if status.inexact() => Err(HexFloatParseError("the value is too precise")),
Ok(_) => unreachable!(),
}
}
/// Parse any float from hex to its bitwise representation.
///
/// `nan_repr` is passed rather than constructed so the platform-specific NaN is returned.
pub const fn parse_any(s: &str, bits: u32, sig_bits: u32) -> u128 {
pub const fn parse_any(
s: &str,
bits: u32,
sig_bits: u32,
round: Round,
) -> Result<(u128, Status), HexFloatParseError> {
let mut b = s.as_bytes();
if sig_bits > 119 || bits > 128 || bits < sig_bits + 3 || bits > sig_bits + 30 {
return Err(HexFloatParseError("unsupported target float configuration"));
}
let neg = matches!(b, [b'-', ..]);
if let &[b'-' | b'+', ref rest @ ..] = b {
b = rest;
}
let sign_bit = 1 << (bits - 1);
let quiet_bit = 1 << (sig_bits - 1);
let nan = sign_bit - quiet_bit;
let inf = nan - quiet_bit;
let (mut x, status) = match *b {
[b'i' | b'I', b'n' | b'N', b'f' | b'F'] => (inf, Status::OK),
[b'n' | b'N', b'a' | b'A', b'n' | b'N'] => (nan, Status::OK),
[b'0', b'x' | b'X', ref rest @ ..] => {
let round = match (neg, round) {
// parse("-x", Round::Positive) == -parse("x", Round::Negative)
(true, Round::Positive) => Round::Negative,
(true, Round::Negative) => Round::Positive,
// rounding toward nearest or zero are symmetric
(true, Round::Nearest | Round::Zero) | (false, _) => round,
};
match parse_finite(rest, bits, sig_bits, round) {
Err(e) => return Err(e),
Ok(res) => res,
}
}
_ => return Err(HexFloatParseError("no hex indicator")),
};
if neg {
x ^= sign_bit;
}
Ok((x, status))
}
const fn parse_finite(
b: &[u8],
bits: u32,
sig_bits: u32,
rounding_mode: Round,
) -> Result<(u128, Status), HexFloatParseError> {
let exp_bits: u32 = bits - sig_bits - 1;
let max_msb: i32 = (1 << (exp_bits - 1)) - 1;
// The exponent of one ULP in the subnormals
let min_lsb: i32 = 1 - max_msb - sig_bits as i32;
let exp_mask = ((1 << exp_bits) - 1) << sig_bits;
let (neg, mut sig, exp) = match parse_hex(s.as_bytes()) {
Parsed::Finite { neg, sig: 0, .. } => return (neg as u128) << (bits - 1),
Parsed::Finite { neg, sig, exp } => (neg, sig, exp),
Parsed::Infinite { neg } => return ((neg as u128) << (bits - 1)) | exp_mask,
Parsed::Nan { neg } => {
return ((neg as u128) << (bits - 1)) | exp_mask | (1 << (sig_bits - 1));
}
let (mut sig, mut exp) = match parse_hex(b) {
Err(e) => return Err(e),
Ok(Parsed { sig: 0, .. }) => return Ok((0, Status::OK)),
Ok(Parsed { sig, exp }) => (sig, exp),
};
// exponents of the least and most significant bits in the value
let lsb = sig.trailing_zeros() as i32;
let msb = u128_ilog2(sig) as i32;
let sig_bits = sig_bits as i32;
let mut round_bits = u128_ilog2(sig) as i32 - sig_bits as i32;
assert!(msb - lsb <= sig_bits, "the value is too precise");
assert!(msb + exp <= max_msb, "the value is too huge");
assert!(lsb + exp >= min_lsb, "the value is too tiny");
// Round at least up to min_lsb
if exp < min_lsb - round_bits {
round_bits = min_lsb - exp;
}
let mut status = Status::OK;
exp += round_bits;
if round_bits > 0 {
// first, prepare for rounding exactly two bits
if round_bits == 1 {
sig <<= 1;
} else if round_bits > 2 {
sig = shr_odd_rounding(sig, (round_bits - 2) as u32);
}
if sig & 0b11 != 0 {
status = Status::INEXACT;
}
sig = shr2_round(sig, rounding_mode);
} else if round_bits < 0 {
sig <<= -round_bits;
}
// The parsed value is X = sig * 2^exp
// Expressed as a multiple U of the smallest subnormal value:
// X = U * 2^min_lsb, so U = sig * 2^(exp-min_lsb)
let mut uexp = exp - min_lsb;
let uexp = (exp - min_lsb) as u128;
let uexp = uexp << sig_bits;
let shift = if uexp + msb >= sig_bits {
// normal, shift msb to position sig_bits
sig_bits - msb
} else {
// subnormal, shift so that uexp becomes 0
uexp
// Note that it is possible for the exponent bits to equal 2 here
// if the value rounded up, but that means the mantissa is all zeroes
// so the value is still correct
debug_assert!(sig <= 2 << sig_bits);
let inf = ((1 << exp_bits) - 1) << sig_bits;
let bits = match sig.checked_add(uexp) {
Some(bits) if bits < inf => {
// inexact subnormal or zero?
if status.inexact() && bits < (1 << sig_bits) {
status = status.with(Status::UNDERFLOW);
}
bits
}
_ => {
// overflow to infinity
status = status.with(Status::OVERFLOW).with(Status::INEXACT);
match rounding_mode {
Round::Positive | Round::Nearest => inf,
Round::Negative | Round::Zero => inf - 1,
}
}
};
if shift >= 0 {
sig <<= shift;
} else {
sig >>= -shift;
}
uexp -= shift;
// the most significant bit is like having 1 in the exponent bits
// add any leftover exponent to that
assert!(uexp >= 0 && uexp < (1 << exp_bits) - 2);
sig += (uexp as u128) << sig_bits;
// finally, set the sign bit if necessary
sig | ((neg as u128) << (bits - 1))
Ok((bits, status))
}
/// A parsed floating point number.
enum Parsed {
/// Absolute value sig * 2^e
Finite {
neg: bool,
sig: u128,
exp: i32,
},
Infinite {
neg: bool,
},
Nan {
neg: bool,
},
/// Shift right, rounding all inexact divisions to the nearest odd number
/// E.g. (0 >> 4) -> 0, (1..=31 >> 4) -> 1, (32 >> 4) -> 2, ...
///
/// Useful for reducing a number before rounding the last two bits, since
/// the result of the final rounding is preserved for all rounding modes.
const fn shr_odd_rounding(x: u128, k: u32) -> u128 {
if k < 128 {
let inexact = x.trailing_zeros() < k;
(x >> k) | (inexact as u128)
} else {
(x != 0) as u128
}
}
/// Divide by 4, rounding with the given mode
const fn shr2_round(mut x: u128, round: Round) -> u128 {
let t = (x as u32) & 0b111;
x >>= 2;
match round {
// Look-up-table on the last three bits for when to round up
Round::Nearest => x + ((0b11001000_u8 >> t) & 1) as u128,
Round::Negative => x,
Round::Zero => x,
Round::Positive => x + (t & 0b11 != 0) as u128,
}
}
/// A parsed finite and unsigned floating point number.
struct Parsed {
/// Absolute value sig * 2^exp
sig: u128,
exp: i32,
}
/// Parse a hexadecimal float x
const fn parse_hex(mut b: &[u8]) -> Parsed {
let mut neg = false;
const fn parse_hex(mut b: &[u8]) -> Result<Parsed, HexFloatParseError> {
let mut sig: u128 = 0;
let mut exp: i32 = 0;
if let &[c @ (b'-' | b'+'), ref rest @ ..] = b {
b = rest;
neg = c == b'-';
}
match *b {
[b'i' | b'I', b'n' | b'N', b'f' | b'F'] => return Parsed::Infinite { neg },
[b'n' | b'N', b'a' | b'A', b'n' | b'N'] => return Parsed::Nan { neg },
_ => (),
}
if let &[b'0', b'x' | b'X', ref rest @ ..] = b {
b = rest;
} else {
panic!("no hex indicator");
}
let mut seen_point = false;
let mut some_digits = false;
let mut inexact = false;
while let &[c, ref rest @ ..] = b {
b = rest;
match c {
b'.' => {
assert!(!seen_point);
if seen_point {
return Err(HexFloatParseError("unexpected '.' parsing fractional digits"));
}
seen_point = true;
continue;
}
b'p' | b'P' => break,
c => {
let digit = hex_digit(c);
let digit = match hex_digit(c) {
Some(d) => d,
None => return Err(HexFloatParseError("expected hexadecimal digit")),
};
some_digits = true;
let of;
(sig, of) = sig.overflowing_mul(16);
assert!(!of, "too many digits");
sig |= digit as u128;
// up until the fractional point, the value grows
if (sig >> 124) == 0 {
sig <<= 4;
sig |= digit as u128;
} else {
// FIXME: it is technically possible for exp to overflow if parsing a string with >500M digits
exp += 4;
inexact |= digit != 0;
}
// Up until the fractional point, the value grows
// with more digits, but after it the exponent is
// compensated to match.
if seen_point {
@ -153,49 +264,79 @@ const fn parse_hex(mut b: &[u8]) -> Parsed {
}
}
}
assert!(some_digits, "at least one digit is required");
// If we've set inexact, the exact value has more than 125
// significant bits, and lies somewhere between sig and sig + 1.
// Because we'll round off at least two of the trailing bits,
// setting the last bit gives correct rounding for inexact values.
sig |= inexact as u128;
if !some_digits {
return Err(HexFloatParseError("at least one digit is required"));
};
some_digits = false;
let mut negate_exp = false;
if let &[c @ (b'-' | b'+'), ref rest @ ..] = b {
let negate_exp = matches!(b, [b'-', ..]);
if let &[b'-' | b'+', ref rest @ ..] = b {
b = rest;
negate_exp = c == b'-';
}
let mut pexp: i32 = 0;
let mut pexp: u32 = 0;
while let &[c, ref rest @ ..] = b {
b = rest;
let digit = dec_digit(c);
let digit = match dec_digit(c) {
Some(d) => d,
None => return Err(HexFloatParseError("expected decimal digit")),
};
some_digits = true;
let of;
(pexp, of) = pexp.overflowing_mul(10);
assert!(!of, "too many exponent digits");
pexp += digit as i32;
pexp = pexp.saturating_mul(10);
pexp += digit as u32;
}
assert!(some_digits, "at least one exponent digit is required");
if !some_digits {
return Err(HexFloatParseError("at least one exponent digit is required"));
};
{
let e;
if negate_exp {
e = (exp as i64) - (pexp as i64);
} else {
e = (exp as i64) + (pexp as i64);
};
exp = if e < i32::MIN as i64 {
i32::MIN
} else if e > i32::MAX as i64 {
i32::MAX
} else {
e as i32
};
}
/* FIXME(msrv): once MSRV >= 1.66, replace the above workaround block with:
if negate_exp {
exp -= pexp;
exp = exp.saturating_sub_unsigned(pexp);
} else {
exp += pexp;
}
exp = exp.saturating_add_unsigned(pexp);
};
*/
Parsed::Finite { neg, sig, exp }
Ok(Parsed { sig, exp })
}
const fn dec_digit(c: u8) -> u8 {
const fn dec_digit(c: u8) -> Option<u8> {
match c {
b'0'..=b'9' => c - b'0',
_ => panic!("bad char"),
b'0'..=b'9' => Some(c - b'0'),
_ => None,
}
}
const fn hex_digit(c: u8) -> u8 {
const fn hex_digit(c: u8) -> Option<u8> {
match c {
b'0'..=b'9' => c - b'0',
b'a'..=b'f' => c - b'a' + 10,
b'A'..=b'F' => c - b'A' + 10,
_ => panic!("bad char"),
b'0'..=b'9' => Some(c - b'0'),
b'a'..=b'f' => Some(c - b'a' + 10),
b'A'..=b'F' => Some(c - b'A' + 10),
_ => None,
}
}
@ -341,6 +482,61 @@ mod parse_tests {
use super::*;
#[cfg(f16_enabled)]
fn rounding_properties(s: &str) -> Result<(), HexFloatParseError> {
let (xd, s0) = parse_any(s, 16, 10, Round::Negative)?;
let (xu, s1) = parse_any(s, 16, 10, Round::Positive)?;
let (xz, s2) = parse_any(s, 16, 10, Round::Zero)?;
let (xn, s3) = parse_any(s, 16, 10, Round::Nearest)?;
// FIXME: A value between the least normal and largest subnormal
// could have underflow status depend on rounding mode.
if let Status::OK = s0 {
// an exact result is the same for all rounding modes
assert_eq!(s0, s1);
assert_eq!(s0, s2);
assert_eq!(s0, s3);
assert_eq!(xd, xu);
assert_eq!(xd, xz);
assert_eq!(xd, xn);
} else {
assert!([s0, s1, s2, s3].into_iter().all(Status::inexact));
let xd = f16::from_bits(xd as u16);
let xu = f16::from_bits(xu as u16);
let xz = f16::from_bits(xz as u16);
let xn = f16::from_bits(xn as u16);
assert_biteq!(xd.next_up(), xu, "s={s}, xd={xd:?}, xu={xu:?}");
let signs = [xd, xu, xz, xn].map(f16::is_sign_negative);
if signs == [true; 4] {
assert_biteq!(xz, xu);
} else {
assert_eq!(signs, [false; 4]);
assert_biteq!(xz, xd);
}
if xn.to_bits() != xd.to_bits() {
assert_biteq!(xn, xu);
}
}
Ok(())
}
#[test]
#[cfg(f16_enabled)]
fn test_rounding() {
let n = 1_i32 << 14;
for i in -n..n {
let u = i.rotate_right(11) as u32;
let s = format!("{}", Hexf(f32::from_bits(u)));
assert!(rounding_properties(&s).is_ok());
}
}
#[test]
fn test_parse_any() {
for k in -149..=127 {
@ -397,6 +593,48 @@ mod parse_tests {
}
}
// FIXME: this test is causing failures that are likely UB on various platforms
#[cfg(all(target_arch = "x86_64", target_os = "linux"))]
#[test]
#[cfg(f128_enabled)]
fn rounding() {
let pi = std::f128::consts::PI;
let s = format!("{}", Hexf(pi));
for k in 0..=111 {
let (bits, status) = parse_any(&s, 128 - k, 112 - k, Round::Nearest).unwrap();
let scale = (1u128 << (112 - k - 1)) as f128;
let expected = (pi * scale).round_ties_even() / scale;
assert_eq!(bits << k, expected.to_bits(), "k = {k}, s = {s}");
assert_eq!(expected != pi, status.inexact());
}
}
#[test]
fn rounding_extreme_underflow() {
for k in 1..1000 {
let s = format!("0x1p{}", -149 - k);
let Ok((bits, status)) = parse_any(&s, 32, 23, Round::Nearest) else { unreachable!() };
assert_eq!(bits, 0, "{s} should round to zero, got bits={bits}");
assert!(status.underflow(), "should indicate underflow when parsing {s}");
assert!(status.inexact(), "should indicate inexact when parsing {s}");
}
}
#[test]
fn long_tail() {
for k in 1..1000 {
let s = format!("0x1.{}p0", "0".repeat(k));
let Ok(bits) = parse_hex_exact(&s, 32, 23) else { panic!("parsing {s} failed") };
assert_eq!(f32::from_bits(bits as u32), 1.0);
let s = format!("0x1.{}1p0", "0".repeat(k));
let Ok((bits, status)) = parse_any(&s, 32, 23, Round::Nearest) else { unreachable!() };
if status.inexact() {
assert!(1.0 == f32::from_bits(bits as u32));
} else {
assert!(1.0 < f32::from_bits(bits as u32));
}
}
}
// HACK(msrv): 1.63 rejects unknown width float literals at an AST level, so use a macro to
// hide them from the AST.
#[cfg(f16_enabled)]
@ -434,6 +672,7 @@ mod parse_tests {
];
for (s, exp) in checks {
println!("parsing {s}");
assert!(rounding_properties(s).is_ok());
let act = hf16(s).to_bits();
assert_eq!(
act, exp,
@ -749,7 +988,13 @@ mod tests_panicking {
#[test]
#[should_panic(expected = "the value is too precise")]
fn test_f128_extra_precision() {
// One bit more than the above.
// Just below the maximum finite.
hf128("0x1.fffffffffffffffffffffffffffe8p+16383");
}
#[test]
#[should_panic(expected = "the value is too huge")]
fn test_f128_extra_precision_overflow() {
// One bit more than the above. Should overflow.
hf128("0x1.ffffffffffffffffffffffffffff8p+16383");
}
@ -822,6 +1067,46 @@ mod print_tests {
}
}
#[test]
#[cfg(f16_enabled)]
fn test_f16_to_f32() {
use std::format;
// Exhaustively check that these are equivalent for all `f16`:
// - `f16 -> f32`
// - `f16 -> str -> f32`
// - `f16 -> f32 -> str -> f32`
// - `f16 -> f32 -> str -> f16 -> f32`
for x in 0..=u16::MAX {
let f16 = f16::from_bits(x);
let s16 = format!("{}", Hexf(f16));
let f32 = f16 as f32;
let s32 = format!("{}", Hexf(f32));
let a = hf32(&s16);
let b = hf32(&s32);
let c = hf16(&s32);
if f32.is_nan() && a.is_nan() && b.is_nan() && c.is_nan() {
continue;
}
assert_eq!(
f32.to_bits(),
a.to_bits(),
"{f16:?} : f16 formatted as {s16} which parsed as {a:?} : f16"
);
assert_eq!(
f32.to_bits(),
b.to_bits(),
"{f32:?} : f32 formatted as {s32} which parsed as {b:?} : f32"
);
assert_eq!(
f32.to_bits(),
(c as f32).to_bits(),
"{f32:?} : f32 formatted as {s32} which parsed as {c:?} : f16"
);
}
}
#[test]
fn spot_checks() {
assert_eq!(Hexf(f32::MAX).to_string(), "0x1.fffffep+127");