Provide a way to override iteration count

Benchmarks need a way to limit how many iterations get run. Introuce a
way to inject this information here.
This commit is contained in:
Trevor Gross 2025-01-16 07:30:51 +00:00
parent b39c2e4147
commit f56b41dbbd

View file

@ -40,6 +40,8 @@ pub struct CheckCtx {
/// Source of truth for tests.
pub basis: CheckBasis,
pub gen_kind: GeneratorKind,
/// If specified, this value will override the value returned by [`iteration_count`].
pub override_iterations: Option<u64>,
}
impl CheckCtx {
@ -53,6 +55,7 @@ impl CheckCtx {
base_name_str: fn_ident.base_name().as_str(),
basis,
gen_kind,
override_iterations: None,
};
ret.ulp = crate::default_ulp(&ret);
ret
@ -62,6 +65,10 @@ impl CheckCtx {
pub fn input_count(&self) -> usize {
self.fn_ident.math_op().rust_sig.args.len()
}
pub fn override_iterations(&mut self, count: u64) {
self.override_iterations = Some(count)
}
}
/// Possible items to test against
@ -71,6 +78,8 @@ pub enum CheckBasis {
Musl,
/// Check against infinite precision (MPFR).
Mpfr,
/// Benchmarks or other times when this is not relevant.
None,
}
/// The different kinds of generators that provide test input, which account for input pattern
@ -216,6 +225,12 @@ pub fn iteration_count(ctx: &CheckCtx, argnum: usize) -> u64 {
total_iterations = 800;
}
let mut overridden = false;
if let Some(count) = ctx.override_iterations {
total_iterations = count;
overridden = true;
}
// Adjust for the number of inputs
let ntests = match t_env.input_count {
1 => total_iterations,
@ -223,6 +238,7 @@ pub fn iteration_count(ctx: &CheckCtx, argnum: usize) -> u64 {
3 => (total_iterations as f64).cbrt().ceil() as u64,
_ => panic!("test has more than three arguments"),
};
let total = ntests.pow(t_env.input_count.try_into().unwrap());
let seed_msg = match ctx.gen_kind {
@ -235,12 +251,13 @@ pub fn iteration_count(ctx: &CheckCtx, argnum: usize) -> u64 {
test_log(&format!(
"{gen_kind:?} {basis:?} {fn_ident} arg {arg}/{args}: {ntests} iterations \
({total} total){seed_msg}",
({total} total){seed_msg}{omsg}",
gen_kind = ctx.gen_kind,
basis = ctx.basis,
fn_ident = ctx.fn_ident,
arg = argnum + 1,
args = t_env.input_count,
omsg = if overridden { " (overridden)" } else { "" }
));
ntests