Shorten the autodiff batching test, to make it more reliable
This commit is contained in:
parent
7bcc8a7053
commit
b6d567c12c
1 changed files with 16 additions and 67 deletions
|
|
@ -1,13 +1,11 @@
|
|||
//@ compile-flags: -Zautodiff=Enable,NoTT,NoPostopt -C opt-level=3 -Clto=fat
|
||||
//@ no-prefer-dynamic
|
||||
//@ needs-enzyme
|
||||
//
|
||||
// In Enzyme, we test against a large range of LLVM versions (5+) and don't have overly many
|
||||
// breakages. One benefit is that we match the IR generated by Enzyme only after running it
|
||||
// through LLVM's O3 pipeline, which will remove most of the noise.
|
||||
// However, our integration test could also be affected by changes in how rustc lowers MIR into
|
||||
// LLVM-IR, which could cause additional noise and thus breakages. If that's the case, we should
|
||||
// reduce this test to only match the first lines and the ret instructions.
|
||||
|
||||
// This test combines two features of Enzyme, automatic differentiation and batching. As such, it is
|
||||
// especially prone to breakages. I reduced it therefore to a minimal check matches argument/return
|
||||
// types. Based on the original batching author, implementing the batching feature over MLIR instead
|
||||
// of LLVM should give significantly more reliable performance.
|
||||
|
||||
#![feature(autodiff)]
|
||||
|
||||
|
|
@ -22,69 +20,20 @@ fn square(x: &f32) -> f32 {
|
|||
x * x
|
||||
}
|
||||
|
||||
// The base ("scalar") case d_square3, without batching.
|
||||
// CHECK: define internal fastcc float @fwddiffesquare(float %x.0.val, float %"x'.0.val")
|
||||
// CHECK: %0 = fadd fast float %"x'.0.val", %"x'.0.val"
|
||||
// CHECK-NEXT: %1 = fmul fast float %0, %x.0.val
|
||||
// CHECK-NEXT: ret float %1
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// d_square2
|
||||
// CHECK: define internal [4 x float] @fwddiffe4square(ptr noalias noundef readonly align 4 captures(none) dereferenceable(4) %x, [4 x ptr] %"x'")
|
||||
// CHECK-NEXT: start:
|
||||
// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
|
||||
// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4
|
||||
// CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1
|
||||
// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4
|
||||
// CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2
|
||||
// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4
|
||||
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
|
||||
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4
|
||||
// CHECK-NEXT: %_2 = load float, ptr %x, align 4
|
||||
// CHECK-NEXT: %4 = fmul fast float %"_2'ipl", %_2
|
||||
// CHECK-NEXT: %5 = fmul fast float %"_2'ipl1", %_2
|
||||
// CHECK-NEXT: %6 = fmul fast float %"_2'ipl2", %_2
|
||||
// CHECK-NEXT: %7 = fmul fast float %"_2'ipl3", %_2
|
||||
// CHECK-NEXT: %8 = fmul fast float %"_2'ipl", %_2
|
||||
// CHECK-NEXT: %9 = fmul fast float %"_2'ipl1", %_2
|
||||
// CHECK-NEXT: %10 = fmul fast float %"_2'ipl2", %_2
|
||||
// CHECK-NEXT: %11 = fmul fast float %"_2'ipl3", %_2
|
||||
// CHECK-NEXT: %12 = fadd fast float %4, %8
|
||||
// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0
|
||||
// CHECK-NEXT: %14 = fadd fast float %5, %9
|
||||
// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1
|
||||
// CHECK-NEXT: %16 = fadd fast float %6, %10
|
||||
// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2
|
||||
// CHECK-NEXT: %18 = fadd fast float %7, %11
|
||||
// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3
|
||||
// CHECK-NEXT: ret [4 x float] %19
|
||||
// CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'")
|
||||
// CHECK: ret [4 x float]
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// d_square3, the extra float is the original return value (x * x)
|
||||
// CHECK: define internal { float, [4 x float] } @fwddiffe4square.1(ptr noalias noundef readonly align 4 captures(none) dereferenceable(4) %x, [4 x ptr] %"x'")
|
||||
// CHECK-NEXT: start:
|
||||
// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
|
||||
// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4
|
||||
// CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1
|
||||
// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4
|
||||
// CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2
|
||||
// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4
|
||||
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
|
||||
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4
|
||||
// CHECK-NEXT: %_2 = load float, ptr %x, align 4
|
||||
// CHECK-NEXT: %_0 = fmul float %_2, %_2
|
||||
// CHECK-NEXT: %4 = fmul fast float %"_2'ipl", %_2
|
||||
// CHECK-NEXT: %5 = fmul fast float %"_2'ipl1", %_2
|
||||
// CHECK-NEXT: %6 = fmul fast float %"_2'ipl2", %_2
|
||||
// CHECK-NEXT: %7 = fmul fast float %"_2'ipl3", %_2
|
||||
// CHECK-NEXT: %8 = fmul fast float %"_2'ipl", %_2
|
||||
// CHECK-NEXT: %9 = fmul fast float %"_2'ipl1", %_2
|
||||
// CHECK-NEXT: %10 = fmul fast float %"_2'ipl2", %_2
|
||||
// CHECK-NEXT: %11 = fmul fast float %"_2'ipl3", %_2
|
||||
// CHECK-NEXT: %12 = fadd fast float %4, %8
|
||||
// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0
|
||||
// CHECK-NEXT: %14 = fadd fast float %5, %9
|
||||
// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1
|
||||
// CHECK-NEXT: %16 = fadd fast float %6, %10
|
||||
// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2
|
||||
// CHECK-NEXT: %18 = fadd fast float %7, %11
|
||||
// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3
|
||||
// CHECK-NEXT: %20 = insertvalue { float, [4 x float] } undef, float %_0, 0
|
||||
// CHECK-NEXT: %21 = insertvalue { float, [4 x float] } %20, [4 x float] %19, 1
|
||||
// CHECK-NEXT: ret { float, [4 x float] } %21
|
||||
// CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.{{.*}}(float %x.0.val, [4 x ptr] %"x'")
|
||||
// CHECK: ret { float, [4 x float] }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
fn main() {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue