solve autodiffv2.rs FIXME and make identical_fnc test more robust

This commit is contained in:
Manuel Drehwald
2025-10-05 03:07:51 -04:00
parent 227ac7c3cd
commit 7c8fe29fd6
2 changed files with 15 additions and 20 deletions

View File

@@ -18,11 +18,6 @@
// but each shadow argument is `width` times larger (thus 16 and 20 elements here).
// `d_square3` instead takes `width` (4) shadow arguments, which are all the same size as the
// original function arguments.
//
// FIXME(autodiff): We currently can't test `d_square1` and `d_square3` in the same file, since they
// generate the same dummy functions which get merged by LLVM, breaking pieces of our pipeline which
// try to rewrite the dummy functions later. We should consider to change to pure declarations both
// in our frontend and in the llvm backend to avoid these issues.
#![feature(autodiff)]
@@ -30,7 +25,7 @@ use std::autodiff::autodiff_forward;
// CHECK: ;
#[no_mangle]
//#[autodiff(d_square1, Forward, Dual, Dual)]
#[autodiff_forward(d_square1, Dual, Dual)]
#[autodiff_forward(d_square2, 4, Dualv, Dualv)]
#[autodiff_forward(d_square3, 4, Dual, Dual)]
fn square(x: &[f32], y: &mut [f32]) {
@@ -79,25 +74,25 @@ fn main() {
let mut dy3_4 = std::hint::black_box(vec![0.0; 5]);
// scalar.
//d_square1(&x1, &z1, &mut y1, &mut dy1_1);
//d_square1(&x1, &z2, &mut y2, &mut dy1_2);
//d_square1(&x1, &z3, &mut y3, &mut dy1_3);
//d_square1(&x1, &z4, &mut y4, &mut dy1_4);
d_square1(&x1, &z1, &mut y1, &mut dy1_1);
d_square1(&x1, &z2, &mut y2, &mut dy1_2);
d_square1(&x1, &z3, &mut y3, &mut dy1_3);
d_square1(&x1, &z4, &mut y4, &mut dy1_4);
// assert y1 == y2 == y3 == y4
//for i in 0..5 {
// assert_eq!(y1[i], y2[i]);
// assert_eq!(y1[i], y3[i]);
// assert_eq!(y1[i], y4[i]);
//}
for i in 0..5 {
assert_eq!(y1[i], y2[i]);
assert_eq!(y1[i], y3[i]);
assert_eq!(y1[i], y4[i]);
}
// batch mode A)
d_square2(&x1, &z5, &mut y5, &mut dy2);
// assert y1 == y2 == y3 == y4 == y5
//for i in 0..5 {
// assert_eq!(y1[i], y5[i]);
//}
for i in 0..5 {
assert_eq!(y1[i], y5[i]);
}
// batch mode B)
d_square3(&x1, &z1, &z2, &z3, &z4, &mut y6, &mut dy3_1, &mut dy3_2, &mut dy3_3, &mut dy3_4);

View File

@@ -32,9 +32,9 @@ fn square2(x: &f64) -> f64 {
// CHECK-NOT:br
// CHECK-NOT:ret
// CHECK:; call identical_fnc::d_square
// CHECK-NEXT:call fastcc void @_ZN13identical_fnc8d_square17hcb5768e95528c35fE(double %x.val, ptr noalias noundef align 8 dereferenceable(8) %dx1)
// CHECK-NEXT:call fastcc void @_ZN13identical_fnc8d_square[[HASH:.+]](double %x.val, ptr noalias noundef align 8 dereferenceable(8) %dx1)
// CHECK:; call identical_fnc::d_square
// CHECK-NEXT:call fastcc void @_ZN13identical_fnc8d_square17hcb5768e95528c35fE(double %x.val, ptr noalias noundef align 8 dereferenceable(8) %dx2)
// CHECK-NEXT:call fastcc void @_ZN13identical_fnc8d_square[[HASH]](double %x.val, ptr noalias noundef align 8 dereferenceable(8) %dx2)
fn main() {
let x = std::hint::black_box(3.0);