added typetree support for memcpy

This commit is contained in:
Karan Janthe
2025-08-23 23:10:48 +00:00
parent 5d3ebc3804
commit 664e83b3e7
21 changed files with 135 additions and 34 deletions

View File

@@ -246,6 +246,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
scratch_align,
bx.const_usize(copy_bytes),
MemFlags::empty(),
None,
);
bx.lifetime_end(llscratch, scratch_size);
}

View File

@@ -2,6 +2,7 @@ use std::borrow::{Borrow, Cow};
use std::ops::Deref;
use std::{iter, ptr};
use rustc_ast::expand::typetree::FncTree;
pub(crate) mod autodiff;
pub(crate) mod gpu_offload;
@@ -1107,11 +1108,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
src_align: Align,
size: &'ll Value,
flags: MemFlags,
tt: Option<FncTree>,
) {
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
let size = self.intcast(size, self.type_isize(), false);
let is_volatile = flags.contains(MemFlags::VOLATILE);
unsafe {
let memcpy = unsafe {
llvm::LLVMRustBuildMemCpy(
self.llbuilder,
dst,
@@ -1120,7 +1122,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
src_align.bytes() as c_uint,
size,
is_volatile,
);
)
};
// TypeTree metadata for memcpy is especially important: when Enzyme encounters
// a memcpy during autodiff, it needs to know the structure of the data being
// copied to properly track derivatives. For example, copying an array of floats
// vs. copying a struct with mixed types requires different derivative handling.
// The TypeTree tells Enzyme exactly what memory layout to expect.
if let Some(tt) = tt {
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt);
}
}

View File

@@ -25,6 +25,7 @@ pub(crate) enum CConcreteType {
DT_Half = 3,
DT_Float = 4,
DT_Double = 5,
// FIXME(KMJ-007): handle f128 using long double here(https://github.com/EnzymeAD/Enzyme/issues/1600)
DT_Unknown = 6,
}

View File

@@ -1,8 +1,11 @@
use std::ffi::{CString, c_char, c_uint};
use rustc_ast::expand::typetree::FncTree;
#[cfg(llvm_enzyme)]
use {
crate::attributes,
rustc_ast::expand::typetree::TypeTree as RustTypeTree,
std::ffi::{CString, c_char, c_uint},
};
use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree};
use crate::attributes;
use crate::llvm::{self, Value};
/// Converts a Rust TypeTree to Enzyme's internal TypeTree format
@@ -50,15 +53,6 @@ fn to_enzyme_typetree(
enzyme_tt
}
#[cfg(not(llvm_enzyme))]
fn to_enzyme_typetree(
_rust_typetree: RustTypeTree,
_data_layout: &str,
_llcx: &llvm::Context,
) -> ! {
unimplemented!("TypeTree conversion not available without llvm_enzyme support")
}
// Attaches TypeTree information to LLVM function as enzyme_type attributes.
#[cfg(llvm_enzyme)]
pub(crate) fn add_tt<'ll>(

View File

@@ -738,6 +738,7 @@ fn copy_to_temporary_if_more_aligned<'ll, 'tcx>(
src_align,
bx.const_u32(layout.layout.size().bytes() as u32),
MemFlags::empty(),
None,
);
tmp
} else {