added typetree support for memcpy
This commit is contained in:
@@ -1383,6 +1383,7 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> {
|
||||
_src_align: Align,
|
||||
size: RValue<'gcc>,
|
||||
flags: MemFlags,
|
||||
_tt: Option<rustc_ast::expand::typetree::FncTree>, // Autodiff TypeTrees are LLVM-only, ignored in GCC backend
|
||||
) {
|
||||
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
|
||||
let size = self.intcast(size, self.type_size_t(), false);
|
||||
|
||||
@@ -771,6 +771,7 @@ impl<'gcc, 'tcx> ArgAbiExt<'gcc, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
|
||||
scratch_align,
|
||||
bx.const_usize(self.layout.size.bytes()),
|
||||
MemFlags::empty(),
|
||||
None,
|
||||
);
|
||||
|
||||
bx.lifetime_end(scratch, scratch_size);
|
||||
|
||||
@@ -246,6 +246,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
|
||||
scratch_align,
|
||||
bx.const_usize(copy_bytes),
|
||||
MemFlags::empty(),
|
||||
None,
|
||||
);
|
||||
bx.lifetime_end(llscratch, scratch_size);
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ use std::borrow::{Borrow, Cow};
|
||||
use std::ops::Deref;
|
||||
use std::{iter, ptr};
|
||||
|
||||
use rustc_ast::expand::typetree::FncTree;
|
||||
pub(crate) mod autodiff;
|
||||
pub(crate) mod gpu_offload;
|
||||
|
||||
@@ -1107,11 +1108,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
|
||||
src_align: Align,
|
||||
size: &'ll Value,
|
||||
flags: MemFlags,
|
||||
tt: Option<FncTree>,
|
||||
) {
|
||||
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
|
||||
let size = self.intcast(size, self.type_isize(), false);
|
||||
let is_volatile = flags.contains(MemFlags::VOLATILE);
|
||||
unsafe {
|
||||
let memcpy = unsafe {
|
||||
llvm::LLVMRustBuildMemCpy(
|
||||
self.llbuilder,
|
||||
dst,
|
||||
@@ -1120,7 +1122,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
|
||||
src_align.bytes() as c_uint,
|
||||
size,
|
||||
is_volatile,
|
||||
);
|
||||
)
|
||||
};
|
||||
|
||||
// TypeTree metadata for memcpy is especially important: when Enzyme encounters
|
||||
// a memcpy during autodiff, it needs to know the structure of the data being
|
||||
// copied to properly track derivatives. For example, copying an array of floats
|
||||
// vs. copying a struct with mixed types requires different derivative handling.
|
||||
// The TypeTree tells Enzyme exactly what memory layout to expect.
|
||||
if let Some(tt) = tt {
|
||||
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ pub(crate) enum CConcreteType {
|
||||
DT_Half = 3,
|
||||
DT_Float = 4,
|
||||
DT_Double = 5,
|
||||
// FIXME(KMJ-007): handle f128 using long double here(https://github.com/EnzymeAD/Enzyme/issues/1600)
|
||||
DT_Unknown = 6,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
use std::ffi::{CString, c_char, c_uint};
|
||||
use rustc_ast::expand::typetree::FncTree;
|
||||
#[cfg(llvm_enzyme)]
|
||||
use {
|
||||
crate::attributes,
|
||||
rustc_ast::expand::typetree::TypeTree as RustTypeTree,
|
||||
std::ffi::{CString, c_char, c_uint},
|
||||
};
|
||||
|
||||
use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree};
|
||||
|
||||
use crate::attributes;
|
||||
use crate::llvm::{self, Value};
|
||||
|
||||
/// Converts a Rust TypeTree to Enzyme's internal TypeTree format
|
||||
@@ -50,15 +53,6 @@ fn to_enzyme_typetree(
|
||||
enzyme_tt
|
||||
}
|
||||
|
||||
#[cfg(not(llvm_enzyme))]
|
||||
fn to_enzyme_typetree(
|
||||
_rust_typetree: RustTypeTree,
|
||||
_data_layout: &str,
|
||||
_llcx: &llvm::Context,
|
||||
) -> ! {
|
||||
unimplemented!("TypeTree conversion not available without llvm_enzyme support")
|
||||
}
|
||||
|
||||
// Attaches TypeTree information to LLVM function as enzyme_type attributes.
|
||||
#[cfg(llvm_enzyme)]
|
||||
pub(crate) fn add_tt<'ll>(
|
||||
|
||||
@@ -738,6 +738,7 @@ fn copy_to_temporary_if_more_aligned<'ll, 'tcx>(
|
||||
src_align,
|
||||
bx.const_u32(layout.layout.size().bytes() as u32),
|
||||
MemFlags::empty(),
|
||||
None,
|
||||
);
|
||||
tmp
|
||||
} else {
|
||||
|
||||
@@ -1626,6 +1626,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
|
||||
align,
|
||||
bx.const_usize(copy_bytes),
|
||||
MemFlags::empty(),
|
||||
None,
|
||||
);
|
||||
// ...and then load it with the ABI type.
|
||||
llval = load_cast(bx, cast, llscratch, scratch_align);
|
||||
|
||||
@@ -30,7 +30,7 @@ fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
|
||||
if allow_overlap {
|
||||
bx.memmove(dst, align, src, align, size, flags);
|
||||
} else {
|
||||
bx.memcpy(dst, align, src, align, size, flags);
|
||||
bx.memcpy(dst, align, src, align, size, flags, None);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
|
||||
let align = pointee_layout.align;
|
||||
let dst = dst_val.immediate();
|
||||
let src = src_val.immediate();
|
||||
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty());
|
||||
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty(), None);
|
||||
}
|
||||
mir::StatementKind::FakeRead(..)
|
||||
| mir::StatementKind::Retag { .. }
|
||||
|
||||
@@ -451,6 +451,7 @@ pub trait BuilderMethods<'a, 'tcx>:
|
||||
src_align: Align,
|
||||
size: Self::Value,
|
||||
flags: MemFlags,
|
||||
tt: Option<rustc_ast::expand::typetree::FncTree>,
|
||||
);
|
||||
fn memmove(
|
||||
&mut self,
|
||||
@@ -507,7 +508,7 @@ pub trait BuilderMethods<'a, 'tcx>:
|
||||
temp.val.store_with_flags(self, dst.with_type(layout), flags);
|
||||
} else if !layout.is_zst() {
|
||||
let bytes = self.const_usize(layout.size.bytes());
|
||||
self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags);
|
||||
self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, None);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -765,7 +765,6 @@ fn test_unstable_options_tracking_hash() {
|
||||
tracked!(allow_features, Some(vec![String::from("lang_items")]));
|
||||
tracked!(always_encode_mir, true);
|
||||
tracked!(assume_incomplete_release, true);
|
||||
tracked!(autodiff, vec![AutoDiff::Enable]);
|
||||
tracked!(autodiff, vec![AutoDiff::Enable, AutoDiff::NoTT]);
|
||||
tracked!(binary_dep_depinfo, true);
|
||||
tracked!(box_noalias, false);
|
||||
|
||||
@@ -2280,12 +2280,12 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
|
||||
let child = typetree_from_ty(tcx, inner_ty);
|
||||
return TypeTree(vec![Type {
|
||||
offset: -1,
|
||||
size: 8, // TODO(KMJ-007): Get actual pointer size from target
|
||||
size: tcx.data_layout.pointer_size().bytes_usize(),
|
||||
kind: Kind::Pointer,
|
||||
child,
|
||||
}]);
|
||||
}
|
||||
|
||||
// TODO(KMJ-007): Handle arrays, slices, structs, and other complex types
|
||||
// FIXME(KMJ-007): Handle arrays, slices, structs, and other complex types
|
||||
TypeTree::new()
|
||||
}
|
||||
|
||||
@@ -30,4 +30,4 @@ fn main() {
|
||||
let output_ = d_simple(&x, &mut df_dx, 1.0);
|
||||
assert_eq!(output, output_);
|
||||
assert_eq!(2.0, df_dx);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
; Check that enzyme_type attributes are present in the LLVM IR function definition
|
||||
; This verifies our TypeTree system correctly attaches metadata for Enzyme
|
||||
|
||||
CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_memcpy({{.*}}"enzyme_type"="{[]:Pointer}"
|
||||
|
||||
; Check that llvm.memcpy exists (either call or declare)
|
||||
CHECK: {{(call|declare).*}}@llvm.memcpy
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
CHECK: force_memcpy
|
||||
|
||||
CHECK: @llvm.memcpy.p0.p0.i64
|
||||
|
||||
CHECK: test_memcpy - {[-1]:Float@double} |{[-1]:Pointer}:{}
|
||||
|
||||
CHECK-DAG: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Float@double, [-1,24]:Float@double}
|
||||
|
||||
CHECK-DAG: load double{{.*}}: {[-1]:Float@double}
|
||||
|
||||
CHECK-DAG: fmul double{{.*}}: {[-1]:Float@double}
|
||||
|
||||
CHECK-DAG: fadd double{{.*}}: {[-1]:Float@double}
|
||||
36
tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.rs
Normal file
36
tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
#![feature(autodiff)]
|
||||
|
||||
use std::autodiff::autodiff_reverse;
|
||||
use std::ptr;
|
||||
|
||||
#[inline(never)]
|
||||
fn force_memcpy(src: *const f64, dst: *mut f64, count: usize) {
|
||||
unsafe {
|
||||
ptr::copy_nonoverlapping(src, dst, count);
|
||||
}
|
||||
}
|
||||
|
||||
#[autodiff_reverse(d_test_memcpy, Duplicated, Active)]
|
||||
#[no_mangle]
|
||||
fn test_memcpy(input: &[f64; 128]) -> f64 {
|
||||
let mut local_data = [0.0f64; 128];
|
||||
|
||||
// Use a separate function to prevent inlining and optimization
|
||||
force_memcpy(input.as_ptr(), local_data.as_mut_ptr(), 128);
|
||||
|
||||
// Sum only first few elements to keep the computation simple
|
||||
local_data[0] * local_data[0]
|
||||
+ local_data[1] * local_data[1]
|
||||
+ local_data[2] * local_data[2]
|
||||
+ local_data[3] * local_data[3]
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let input = [1.0; 128];
|
||||
let mut d_input = [0.0; 128];
|
||||
let result = test_memcpy(&input);
|
||||
let result_d = d_test_memcpy(&input, &mut d_input, 1.0);
|
||||
|
||||
assert_eq!(result, result_d);
|
||||
println!("Memcpy test passed: result = {}", result);
|
||||
}
|
||||
39
tests/run-make/autodiff/type-trees/memcpy-typetree/rmake.rs
Normal file
39
tests/run-make/autodiff/type-trees/memcpy-typetree/rmake.rs
Normal file
@@ -0,0 +1,39 @@
|
||||
//@ needs-enzyme
|
||||
//@ ignore-cross-compile
|
||||
|
||||
use run_make_support::{llvm_filecheck, rfs, rustc};
|
||||
|
||||
fn main() {
|
||||
// First, compile to LLVM IR to check for enzyme_type attributes
|
||||
let _ir_output = rustc()
|
||||
.input("memcpy.rs")
|
||||
.arg("-Zautodiff=Enable")
|
||||
.arg("-Zautodiff=NoPostopt")
|
||||
.opt_level("0")
|
||||
.arg("--emit=llvm-ir")
|
||||
.arg("-o")
|
||||
.arg("main.ll")
|
||||
.run();
|
||||
|
||||
// Then compile with TypeTree analysis output for the existing checks
|
||||
let output = rustc()
|
||||
.input("memcpy.rs")
|
||||
.arg("-Zautodiff=Enable,PrintTAFn=test_memcpy")
|
||||
.arg("-Zautodiff=NoPostopt")
|
||||
.opt_level("3")
|
||||
.arg("-Clto=fat")
|
||||
.arg("-g")
|
||||
.run();
|
||||
|
||||
let stdout = output.stdout_utf8();
|
||||
let stderr = output.stderr_utf8();
|
||||
let ir_content = rfs::read_to_string("main.ll");
|
||||
|
||||
rfs::write("memcpy.stdout", &stdout);
|
||||
rfs::write("memcpy.stderr", &stderr);
|
||||
rfs::write("main.ir", &ir_content);
|
||||
|
||||
llvm_filecheck().patterns("memcpy.check").stdin_buf(stdout).run();
|
||||
|
||||
llvm_filecheck().patterns("memcpy-ir.check").stdin_buf(ir_content).run();
|
||||
}
|
||||
@@ -23,14 +23,8 @@ fn main() {
|
||||
.run();
|
||||
|
||||
// Verify NoTT version does NOT have enzyme_type attributes
|
||||
llvm_filecheck()
|
||||
.patterns("nott.check")
|
||||
.stdin_buf(rfs::read("nott.ll"))
|
||||
.run();
|
||||
|
||||
llvm_filecheck().patterns("nott.check").stdin_buf(rfs::read("nott.ll")).run();
|
||||
|
||||
// Verify TypeTree version DOES have enzyme_type attributes
|
||||
llvm_filecheck()
|
||||
.patterns("with_tt.check")
|
||||
.stdin_buf(rfs::read("with_tt.ll"))
|
||||
.run();
|
||||
}
|
||||
llvm_filecheck().patterns("with_tt.check").stdin_buf(rfs::read("with_tt.ll")).run();
|
||||
}
|
||||
|
||||
@@ -12,4 +12,4 @@ fn main() {
|
||||
let x = 2.0;
|
||||
let mut dx = 0.0;
|
||||
let _result = d_square(&x, &mut dx, 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,4 +16,4 @@ fn main() {
|
||||
let x = 2.0;
|
||||
let mut dx = 0.0;
|
||||
let result = d_square(&x, &mut dx, 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user