Use simpler branchy swap logic in tiny merge sort

Avoids the code duplication issue and results in
smaller binary size, which after all is the
purpose of the feature.
This commit is contained in:
Lukas Bergdoll 2024-08-29 10:32:59 +02:00
parent 7815d77ee0
commit 00eca77b72
2 changed files with 13 additions and 42 deletions

View file

@ -378,6 +378,11 @@ where
/// Swap two values in the slice pointed to by `v_base` at the position `a_pos` and `b_pos` if the
/// value at position `b_pos` is less than the one at position `a_pos`.
///
/// Purposefully not marked `#[inline]`, despite us wanting it to be inlined for integers like
/// types. `is_less` could be a huge function and we want to give the compiler an option to
/// not inline this function. For the same reasons that this function is very perf critical
/// it should be in the same module as the functions that use it.
unsafe fn swap_if_less<T, F>(v_base: *mut T, a_pos: usize, b_pos: usize, is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,

View file

@ -1,6 +1,6 @@
//! Binary-size optimized mergesort inspired by https://github.com/voultapher/tiny-sort-rs.
use crate::mem::{ManuallyDrop, MaybeUninit};
use crate::mem::MaybeUninit;
use crate::ptr;
use crate::slice::sort::stable::merge;
@ -27,49 +27,15 @@ pub fn mergesort<T, F: FnMut(&T, &T) -> bool>(
merge::merge(v, scratch, mid, is_less);
} else if len == 2 {
// Branchless swap the two elements. This reduces the recursion depth and improves
// perf significantly at a small binary-size cost. Trades ~10% perf boost for integers
// for ~50 bytes in the binary.
// SAFETY: We checked the len, the pointers we create are valid and don't overlap.
unsafe {
swap_if_less(v.as_mut_ptr(), 0, 1, is_less);
let v_base = v.as_mut_ptr();
let v_a = v_base;
let v_b = v_base.add(1);
if is_less(&*v_b, &*v_a) {
ptr::swap_nonoverlapping(v_a, v_b, 1);
}
}
}
}
/// Swap two values in the slice pointed to by `v_base` at the position `a_pos` and `b_pos` if the
/// value at position `b_pos` is less than the one at position `a_pos`.
unsafe fn swap_if_less<T, F>(v_base: *mut T, a_pos: usize, b_pos: usize, is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
// SAFETY: the caller must guarantee that `a` and `b` each added to `v_base` yield valid
// pointers into `v_base`, and are properly aligned, and part of the same allocation.
unsafe {
let v_a = v_base.add(a_pos);
let v_b = v_base.add(b_pos);
// PANIC SAFETY: if is_less panics, no scratch memory was created and the slice should still be
// in a well defined state, without duplicates.
// Important to only swap if it is more and not if it is equal. is_less should return false for
// equal, so we don't swap.
let should_swap = is_less(&*v_b, &*v_a);
// This is a branchless version of swap if.
// The equivalent code with a branch would be:
//
// if should_swap {
// ptr::swap(left, right, 1);
// }
// The goal is to generate cmov instructions here.
let left_swap = if should_swap { v_b } else { v_a };
let right_swap = if should_swap { v_a } else { v_b };
let right_swap_tmp = ManuallyDrop::new(ptr::read(right_swap));
ptr::copy(left_swap, v_a, 1);
ptr::copy_nonoverlapping(&*right_swap_tmp, v_b, 1);
}
}