added typetree support for memcpy
This commit is contained in:
@@ -246,6 +246,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
|
||||
scratch_align,
|
||||
bx.const_usize(copy_bytes),
|
||||
MemFlags::empty(),
|
||||
None,
|
||||
);
|
||||
bx.lifetime_end(llscratch, scratch_size);
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ use std::borrow::{Borrow, Cow};
|
||||
use std::ops::Deref;
|
||||
use std::{iter, ptr};
|
||||
|
||||
use rustc_ast::expand::typetree::FncTree;
|
||||
pub(crate) mod autodiff;
|
||||
pub(crate) mod gpu_offload;
|
||||
|
||||
@@ -1107,11 +1108,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
|
||||
src_align: Align,
|
||||
size: &'ll Value,
|
||||
flags: MemFlags,
|
||||
tt: Option<FncTree>,
|
||||
) {
|
||||
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
|
||||
let size = self.intcast(size, self.type_isize(), false);
|
||||
let is_volatile = flags.contains(MemFlags::VOLATILE);
|
||||
unsafe {
|
||||
let memcpy = unsafe {
|
||||
llvm::LLVMRustBuildMemCpy(
|
||||
self.llbuilder,
|
||||
dst,
|
||||
@@ -1120,7 +1122,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
|
||||
src_align.bytes() as c_uint,
|
||||
size,
|
||||
is_volatile,
|
||||
);
|
||||
)
|
||||
};
|
||||
|
||||
// TypeTree metadata for memcpy is especially important: when Enzyme encounters
|
||||
// a memcpy during autodiff, it needs to know the structure of the data being
|
||||
// copied to properly track derivatives. For example, copying an array of floats
|
||||
// vs. copying a struct with mixed types requires different derivative handling.
|
||||
// The TypeTree tells Enzyme exactly what memory layout to expect.
|
||||
if let Some(tt) = tt {
|
||||
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ pub(crate) enum CConcreteType {
|
||||
DT_Half = 3,
|
||||
DT_Float = 4,
|
||||
DT_Double = 5,
|
||||
// FIXME(KMJ-007): handle f128 using long double here(https://github.com/EnzymeAD/Enzyme/issues/1600)
|
||||
DT_Unknown = 6,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
use std::ffi::{CString, c_char, c_uint};
|
||||
use rustc_ast::expand::typetree::FncTree;
|
||||
#[cfg(llvm_enzyme)]
|
||||
use {
|
||||
crate::attributes,
|
||||
rustc_ast::expand::typetree::TypeTree as RustTypeTree,
|
||||
std::ffi::{CString, c_char, c_uint},
|
||||
};
|
||||
|
||||
use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree};
|
||||
|
||||
use crate::attributes;
|
||||
use crate::llvm::{self, Value};
|
||||
|
||||
/// Converts a Rust TypeTree to Enzyme's internal TypeTree format
|
||||
@@ -50,15 +53,6 @@ fn to_enzyme_typetree(
|
||||
enzyme_tt
|
||||
}
|
||||
|
||||
#[cfg(not(llvm_enzyme))]
|
||||
fn to_enzyme_typetree(
|
||||
_rust_typetree: RustTypeTree,
|
||||
_data_layout: &str,
|
||||
_llcx: &llvm::Context,
|
||||
) -> ! {
|
||||
unimplemented!("TypeTree conversion not available without llvm_enzyme support")
|
||||
}
|
||||
|
||||
// Attaches TypeTree information to LLVM function as enzyme_type attributes.
|
||||
#[cfg(llvm_enzyme)]
|
||||
pub(crate) fn add_tt<'ll>(
|
||||
|
||||
@@ -738,6 +738,7 @@ fn copy_to_temporary_if_more_aligned<'ll, 'tcx>(
|
||||
src_align,
|
||||
bx.const_u32(layout.layout.size().bytes() as u32),
|
||||
MemFlags::empty(),
|
||||
None,
|
||||
);
|
||||
tmp
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user