added typetree support for memcpy

This commit is contained in:
Karan Janthe
2025-08-23 23:10:48 +00:00
parent 5d3ebc3804
commit 664e83b3e7
21 changed files with 135 additions and 34 deletions

View File

@@ -1383,6 +1383,7 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> {
_src_align: Align,
size: RValue<'gcc>,
flags: MemFlags,
_tt: Option<rustc_ast::expand::typetree::FncTree>, // Autodiff TypeTrees are LLVM-only, ignored in GCC backend
) {
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
let size = self.intcast(size, self.type_size_t(), false);

View File

@@ -771,6 +771,7 @@ impl<'gcc, 'tcx> ArgAbiExt<'gcc, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
scratch_align,
bx.const_usize(self.layout.size.bytes()),
MemFlags::empty(),
None,
);
bx.lifetime_end(scratch, scratch_size);

View File

@@ -246,6 +246,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
scratch_align,
bx.const_usize(copy_bytes),
MemFlags::empty(),
None,
);
bx.lifetime_end(llscratch, scratch_size);
}

View File

@@ -2,6 +2,7 @@ use std::borrow::{Borrow, Cow};
use std::ops::Deref;
use std::{iter, ptr};
use rustc_ast::expand::typetree::FncTree;
pub(crate) mod autodiff;
pub(crate) mod gpu_offload;
@@ -1107,11 +1108,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
src_align: Align,
size: &'ll Value,
flags: MemFlags,
tt: Option<FncTree>,
) {
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
let size = self.intcast(size, self.type_isize(), false);
let is_volatile = flags.contains(MemFlags::VOLATILE);
unsafe {
let memcpy = unsafe {
llvm::LLVMRustBuildMemCpy(
self.llbuilder,
dst,
@@ -1120,7 +1122,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
src_align.bytes() as c_uint,
size,
is_volatile,
);
)
};
// TypeTree metadata for memcpy is especially important: when Enzyme encounters
// a memcpy during autodiff, it needs to know the structure of the data being
// copied to properly track derivatives. For example, copying an array of floats
// vs. copying a struct with mixed types requires different derivative handling.
// The TypeTree tells Enzyme exactly what memory layout to expect.
if let Some(tt) = tt {
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt);
}
}

View File

@@ -25,6 +25,7 @@ pub(crate) enum CConcreteType {
DT_Half = 3,
DT_Float = 4,
DT_Double = 5,
// FIXME(KMJ-007): handle f128 using long double here(https://github.com/EnzymeAD/Enzyme/issues/1600)
DT_Unknown = 6,
}

View File

@@ -1,8 +1,11 @@
use std::ffi::{CString, c_char, c_uint};
use rustc_ast::expand::typetree::FncTree;
#[cfg(llvm_enzyme)]
use {
crate::attributes,
rustc_ast::expand::typetree::TypeTree as RustTypeTree,
std::ffi::{CString, c_char, c_uint},
};
use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree};
use crate::attributes;
use crate::llvm::{self, Value};
/// Converts a Rust TypeTree to Enzyme's internal TypeTree format
@@ -50,15 +53,6 @@ fn to_enzyme_typetree(
enzyme_tt
}
#[cfg(not(llvm_enzyme))]
fn to_enzyme_typetree(
_rust_typetree: RustTypeTree,
_data_layout: &str,
_llcx: &llvm::Context,
) -> ! {
unimplemented!("TypeTree conversion not available without llvm_enzyme support")
}
// Attaches TypeTree information to LLVM function as enzyme_type attributes.
#[cfg(llvm_enzyme)]
pub(crate) fn add_tt<'ll>(

View File

@@ -738,6 +738,7 @@ fn copy_to_temporary_if_more_aligned<'ll, 'tcx>(
src_align,
bx.const_u32(layout.layout.size().bytes() as u32),
MemFlags::empty(),
None,
);
tmp
} else {

View File

@@ -1626,6 +1626,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
align,
bx.const_usize(copy_bytes),
MemFlags::empty(),
None,
);
// ...and then load it with the ABI type.
llval = load_cast(bx, cast, llscratch, scratch_align);

View File

@@ -30,7 +30,7 @@ fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
if allow_overlap {
bx.memmove(dst, align, src, align, size, flags);
} else {
bx.memcpy(dst, align, src, align, size, flags);
bx.memcpy(dst, align, src, align, size, flags, None);
}
}

View File

@@ -90,7 +90,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
let align = pointee_layout.align;
let dst = dst_val.immediate();
let src = src_val.immediate();
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty());
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty(), None);
}
mir::StatementKind::FakeRead(..)
| mir::StatementKind::Retag { .. }

View File

@@ -451,6 +451,7 @@ pub trait BuilderMethods<'a, 'tcx>:
src_align: Align,
size: Self::Value,
flags: MemFlags,
tt: Option<rustc_ast::expand::typetree::FncTree>,
);
fn memmove(
&mut self,
@@ -507,7 +508,7 @@ pub trait BuilderMethods<'a, 'tcx>:
temp.val.store_with_flags(self, dst.with_type(layout), flags);
} else if !layout.is_zst() {
let bytes = self.const_usize(layout.size.bytes());
self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags);
self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, None);
}
}

View File

@@ -765,7 +765,6 @@ fn test_unstable_options_tracking_hash() {
tracked!(allow_features, Some(vec![String::from("lang_items")]));
tracked!(always_encode_mir, true);
tracked!(assume_incomplete_release, true);
tracked!(autodiff, vec![AutoDiff::Enable]);
tracked!(autodiff, vec![AutoDiff::Enable, AutoDiff::NoTT]);
tracked!(binary_dep_depinfo, true);
tracked!(box_noalias, false);

View File

@@ -2280,12 +2280,12 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
let child = typetree_from_ty(tcx, inner_ty);
return TypeTree(vec![Type {
offset: -1,
size: 8, // TODO(KMJ-007): Get actual pointer size from target
size: tcx.data_layout.pointer_size().bytes_usize(),
kind: Kind::Pointer,
child,
}]);
}
// TODO(KMJ-007): Handle arrays, slices, structs, and other complex types
// FIXME(KMJ-007): Handle arrays, slices, structs, and other complex types
TypeTree::new()
}

View File

@@ -30,4 +30,4 @@ fn main() {
let output_ = d_simple(&x, &mut df_dx, 1.0);
assert_eq!(output, output_);
assert_eq!(2.0, df_dx);
}
}

View File

@@ -0,0 +1,8 @@
; Check that enzyme_type attributes are present in the LLVM IR function definition
; This verifies our TypeTree system correctly attaches metadata for Enzyme
CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_memcpy({{.*}}"enzyme_type"="{[]:Pointer}"
; Check that llvm.memcpy exists (either call or declare)
CHECK: {{(call|declare).*}}@llvm.memcpy

View File

@@ -0,0 +1,13 @@
CHECK: force_memcpy
CHECK: @llvm.memcpy.p0.p0.i64
CHECK: test_memcpy - {[-1]:Float@double} |{[-1]:Pointer}:{}
CHECK-DAG: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Float@double, [-1,24]:Float@double}
CHECK-DAG: load double{{.*}}: {[-1]:Float@double}
CHECK-DAG: fmul double{{.*}}: {[-1]:Float@double}
CHECK-DAG: fadd double{{.*}}: {[-1]:Float@double}

View File

@@ -0,0 +1,36 @@
#![feature(autodiff)]
use std::autodiff::autodiff_reverse;
use std::ptr;
#[inline(never)]
fn force_memcpy(src: *const f64, dst: *mut f64, count: usize) {
unsafe {
ptr::copy_nonoverlapping(src, dst, count);
}
}
#[autodiff_reverse(d_test_memcpy, Duplicated, Active)]
#[no_mangle]
fn test_memcpy(input: &[f64; 128]) -> f64 {
let mut local_data = [0.0f64; 128];
// Use a separate function to prevent inlining and optimization
force_memcpy(input.as_ptr(), local_data.as_mut_ptr(), 128);
// Sum only first few elements to keep the computation simple
local_data[0] * local_data[0]
+ local_data[1] * local_data[1]
+ local_data[2] * local_data[2]
+ local_data[3] * local_data[3]
}
fn main() {
let input = [1.0; 128];
let mut d_input = [0.0; 128];
let result = test_memcpy(&input);
let result_d = d_test_memcpy(&input, &mut d_input, 1.0);
assert_eq!(result, result_d);
println!("Memcpy test passed: result = {}", result);
}

View File

@@ -0,0 +1,39 @@
//@ needs-enzyme
//@ ignore-cross-compile
use run_make_support::{llvm_filecheck, rfs, rustc};
fn main() {
// First, compile to LLVM IR to check for enzyme_type attributes
let _ir_output = rustc()
.input("memcpy.rs")
.arg("-Zautodiff=Enable")
.arg("-Zautodiff=NoPostopt")
.opt_level("0")
.arg("--emit=llvm-ir")
.arg("-o")
.arg("main.ll")
.run();
// Then compile with TypeTree analysis output for the existing checks
let output = rustc()
.input("memcpy.rs")
.arg("-Zautodiff=Enable,PrintTAFn=test_memcpy")
.arg("-Zautodiff=NoPostopt")
.opt_level("3")
.arg("-Clto=fat")
.arg("-g")
.run();
let stdout = output.stdout_utf8();
let stderr = output.stderr_utf8();
let ir_content = rfs::read_to_string("main.ll");
rfs::write("memcpy.stdout", &stdout);
rfs::write("memcpy.stderr", &stderr);
rfs::write("main.ir", &ir_content);
llvm_filecheck().patterns("memcpy.check").stdin_buf(stdout).run();
llvm_filecheck().patterns("memcpy-ir.check").stdin_buf(ir_content).run();
}

View File

@@ -23,14 +23,8 @@ fn main() {
.run();
// Verify NoTT version does NOT have enzyme_type attributes
llvm_filecheck()
.patterns("nott.check")
.stdin_buf(rfs::read("nott.ll"))
.run();
llvm_filecheck().patterns("nott.check").stdin_buf(rfs::read("nott.ll")).run();
// Verify TypeTree version DOES have enzyme_type attributes
llvm_filecheck()
.patterns("with_tt.check")
.stdin_buf(rfs::read("with_tt.ll"))
.run();
}
llvm_filecheck().patterns("with_tt.check").stdin_buf(rfs::read("with_tt.ll")).run();
}

View File

@@ -12,4 +12,4 @@ fn main() {
let x = 2.0;
let mut dx = 0.0;
let _result = d_square(&x, &mut dx, 1.0);
}
}

View File

@@ -16,4 +16,4 @@ fn main() {
let x = 2.0;
let mut dx = 0.0;
let result = d_square(&x, &mut dx, 1.0);
}
}