added typetree support for memcpy
This commit is contained in:
@@ -1383,6 +1383,7 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> {
|
|||||||
_src_align: Align,
|
_src_align: Align,
|
||||||
size: RValue<'gcc>,
|
size: RValue<'gcc>,
|
||||||
flags: MemFlags,
|
flags: MemFlags,
|
||||||
|
_tt: Option<rustc_ast::expand::typetree::FncTree>, // Autodiff TypeTrees are LLVM-only, ignored in GCC backend
|
||||||
) {
|
) {
|
||||||
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
|
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
|
||||||
let size = self.intcast(size, self.type_size_t(), false);
|
let size = self.intcast(size, self.type_size_t(), false);
|
||||||
|
|||||||
@@ -771,6 +771,7 @@ impl<'gcc, 'tcx> ArgAbiExt<'gcc, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
|
|||||||
scratch_align,
|
scratch_align,
|
||||||
bx.const_usize(self.layout.size.bytes()),
|
bx.const_usize(self.layout.size.bytes()),
|
||||||
MemFlags::empty(),
|
MemFlags::empty(),
|
||||||
|
None,
|
||||||
);
|
);
|
||||||
|
|
||||||
bx.lifetime_end(scratch, scratch_size);
|
bx.lifetime_end(scratch, scratch_size);
|
||||||
|
|||||||
@@ -246,6 +246,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
|
|||||||
scratch_align,
|
scratch_align,
|
||||||
bx.const_usize(copy_bytes),
|
bx.const_usize(copy_bytes),
|
||||||
MemFlags::empty(),
|
MemFlags::empty(),
|
||||||
|
None,
|
||||||
);
|
);
|
||||||
bx.lifetime_end(llscratch, scratch_size);
|
bx.lifetime_end(llscratch, scratch_size);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ use std::borrow::{Borrow, Cow};
|
|||||||
use std::ops::Deref;
|
use std::ops::Deref;
|
||||||
use std::{iter, ptr};
|
use std::{iter, ptr};
|
||||||
|
|
||||||
|
use rustc_ast::expand::typetree::FncTree;
|
||||||
pub(crate) mod autodiff;
|
pub(crate) mod autodiff;
|
||||||
pub(crate) mod gpu_offload;
|
pub(crate) mod gpu_offload;
|
||||||
|
|
||||||
@@ -1107,11 +1108,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
|
|||||||
src_align: Align,
|
src_align: Align,
|
||||||
size: &'ll Value,
|
size: &'ll Value,
|
||||||
flags: MemFlags,
|
flags: MemFlags,
|
||||||
|
tt: Option<FncTree>,
|
||||||
) {
|
) {
|
||||||
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
|
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
|
||||||
let size = self.intcast(size, self.type_isize(), false);
|
let size = self.intcast(size, self.type_isize(), false);
|
||||||
let is_volatile = flags.contains(MemFlags::VOLATILE);
|
let is_volatile = flags.contains(MemFlags::VOLATILE);
|
||||||
unsafe {
|
let memcpy = unsafe {
|
||||||
llvm::LLVMRustBuildMemCpy(
|
llvm::LLVMRustBuildMemCpy(
|
||||||
self.llbuilder,
|
self.llbuilder,
|
||||||
dst,
|
dst,
|
||||||
@@ -1120,7 +1122,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
|
|||||||
src_align.bytes() as c_uint,
|
src_align.bytes() as c_uint,
|
||||||
size,
|
size,
|
||||||
is_volatile,
|
is_volatile,
|
||||||
);
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
// TypeTree metadata for memcpy is especially important: when Enzyme encounters
|
||||||
|
// a memcpy during autodiff, it needs to know the structure of the data being
|
||||||
|
// copied to properly track derivatives. For example, copying an array of floats
|
||||||
|
// vs. copying a struct with mixed types requires different derivative handling.
|
||||||
|
// The TypeTree tells Enzyme exactly what memory layout to expect.
|
||||||
|
if let Some(tt) = tt {
|
||||||
|
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ pub(crate) enum CConcreteType {
|
|||||||
DT_Half = 3,
|
DT_Half = 3,
|
||||||
DT_Float = 4,
|
DT_Float = 4,
|
||||||
DT_Double = 5,
|
DT_Double = 5,
|
||||||
|
// FIXME(KMJ-007): handle f128 using long double here(https://github.com/EnzymeAD/Enzyme/issues/1600)
|
||||||
DT_Unknown = 6,
|
DT_Unknown = 6,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
use std::ffi::{CString, c_char, c_uint};
|
use rustc_ast::expand::typetree::FncTree;
|
||||||
|
#[cfg(llvm_enzyme)]
|
||||||
|
use {
|
||||||
|
crate::attributes,
|
||||||
|
rustc_ast::expand::typetree::TypeTree as RustTypeTree,
|
||||||
|
std::ffi::{CString, c_char, c_uint},
|
||||||
|
};
|
||||||
|
|
||||||
use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree};
|
|
||||||
|
|
||||||
use crate::attributes;
|
|
||||||
use crate::llvm::{self, Value};
|
use crate::llvm::{self, Value};
|
||||||
|
|
||||||
/// Converts a Rust TypeTree to Enzyme's internal TypeTree format
|
/// Converts a Rust TypeTree to Enzyme's internal TypeTree format
|
||||||
@@ -50,15 +53,6 @@ fn to_enzyme_typetree(
|
|||||||
enzyme_tt
|
enzyme_tt
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(llvm_enzyme))]
|
|
||||||
fn to_enzyme_typetree(
|
|
||||||
_rust_typetree: RustTypeTree,
|
|
||||||
_data_layout: &str,
|
|
||||||
_llcx: &llvm::Context,
|
|
||||||
) -> ! {
|
|
||||||
unimplemented!("TypeTree conversion not available without llvm_enzyme support")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attaches TypeTree information to LLVM function as enzyme_type attributes.
|
// Attaches TypeTree information to LLVM function as enzyme_type attributes.
|
||||||
#[cfg(llvm_enzyme)]
|
#[cfg(llvm_enzyme)]
|
||||||
pub(crate) fn add_tt<'ll>(
|
pub(crate) fn add_tt<'ll>(
|
||||||
|
|||||||
@@ -738,6 +738,7 @@ fn copy_to_temporary_if_more_aligned<'ll, 'tcx>(
|
|||||||
src_align,
|
src_align,
|
||||||
bx.const_u32(layout.layout.size().bytes() as u32),
|
bx.const_u32(layout.layout.size().bytes() as u32),
|
||||||
MemFlags::empty(),
|
MemFlags::empty(),
|
||||||
|
None,
|
||||||
);
|
);
|
||||||
tmp
|
tmp
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -1626,6 +1626,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
|
|||||||
align,
|
align,
|
||||||
bx.const_usize(copy_bytes),
|
bx.const_usize(copy_bytes),
|
||||||
MemFlags::empty(),
|
MemFlags::empty(),
|
||||||
|
None,
|
||||||
);
|
);
|
||||||
// ...and then load it with the ABI type.
|
// ...and then load it with the ABI type.
|
||||||
llval = load_cast(bx, cast, llscratch, scratch_align);
|
llval = load_cast(bx, cast, llscratch, scratch_align);
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
|
|||||||
if allow_overlap {
|
if allow_overlap {
|
||||||
bx.memmove(dst, align, src, align, size, flags);
|
bx.memmove(dst, align, src, align, size, flags);
|
||||||
} else {
|
} else {
|
||||||
bx.memcpy(dst, align, src, align, size, flags);
|
bx.memcpy(dst, align, src, align, size, flags, None);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
|
|||||||
let align = pointee_layout.align;
|
let align = pointee_layout.align;
|
||||||
let dst = dst_val.immediate();
|
let dst = dst_val.immediate();
|
||||||
let src = src_val.immediate();
|
let src = src_val.immediate();
|
||||||
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty());
|
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty(), None);
|
||||||
}
|
}
|
||||||
mir::StatementKind::FakeRead(..)
|
mir::StatementKind::FakeRead(..)
|
||||||
| mir::StatementKind::Retag { .. }
|
| mir::StatementKind::Retag { .. }
|
||||||
|
|||||||
@@ -451,6 +451,7 @@ pub trait BuilderMethods<'a, 'tcx>:
|
|||||||
src_align: Align,
|
src_align: Align,
|
||||||
size: Self::Value,
|
size: Self::Value,
|
||||||
flags: MemFlags,
|
flags: MemFlags,
|
||||||
|
tt: Option<rustc_ast::expand::typetree::FncTree>,
|
||||||
);
|
);
|
||||||
fn memmove(
|
fn memmove(
|
||||||
&mut self,
|
&mut self,
|
||||||
@@ -507,7 +508,7 @@ pub trait BuilderMethods<'a, 'tcx>:
|
|||||||
temp.val.store_with_flags(self, dst.with_type(layout), flags);
|
temp.val.store_with_flags(self, dst.with_type(layout), flags);
|
||||||
} else if !layout.is_zst() {
|
} else if !layout.is_zst() {
|
||||||
let bytes = self.const_usize(layout.size.bytes());
|
let bytes = self.const_usize(layout.size.bytes());
|
||||||
self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags);
|
self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, None);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -765,7 +765,6 @@ fn test_unstable_options_tracking_hash() {
|
|||||||
tracked!(allow_features, Some(vec![String::from("lang_items")]));
|
tracked!(allow_features, Some(vec![String::from("lang_items")]));
|
||||||
tracked!(always_encode_mir, true);
|
tracked!(always_encode_mir, true);
|
||||||
tracked!(assume_incomplete_release, true);
|
tracked!(assume_incomplete_release, true);
|
||||||
tracked!(autodiff, vec![AutoDiff::Enable]);
|
|
||||||
tracked!(autodiff, vec![AutoDiff::Enable, AutoDiff::NoTT]);
|
tracked!(autodiff, vec![AutoDiff::Enable, AutoDiff::NoTT]);
|
||||||
tracked!(binary_dep_depinfo, true);
|
tracked!(binary_dep_depinfo, true);
|
||||||
tracked!(box_noalias, false);
|
tracked!(box_noalias, false);
|
||||||
|
|||||||
@@ -2280,12 +2280,12 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
|
|||||||
let child = typetree_from_ty(tcx, inner_ty);
|
let child = typetree_from_ty(tcx, inner_ty);
|
||||||
return TypeTree(vec![Type {
|
return TypeTree(vec![Type {
|
||||||
offset: -1,
|
offset: -1,
|
||||||
size: 8, // TODO(KMJ-007): Get actual pointer size from target
|
size: tcx.data_layout.pointer_size().bytes_usize(),
|
||||||
kind: Kind::Pointer,
|
kind: Kind::Pointer,
|
||||||
child,
|
child,
|
||||||
}]);
|
}]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(KMJ-007): Handle arrays, slices, structs, and other complex types
|
// FIXME(KMJ-007): Handle arrays, slices, structs, and other complex types
|
||||||
TypeTree::new()
|
TypeTree::new()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,4 +30,4 @@ fn main() {
|
|||||||
let output_ = d_simple(&x, &mut df_dx, 1.0);
|
let output_ = d_simple(&x, &mut df_dx, 1.0);
|
||||||
assert_eq!(output, output_);
|
assert_eq!(output, output_);
|
||||||
assert_eq!(2.0, df_dx);
|
assert_eq!(2.0, df_dx);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,8 @@
|
|||||||
|
; Check that enzyme_type attributes are present in the LLVM IR function definition
|
||||||
|
; This verifies our TypeTree system correctly attaches metadata for Enzyme
|
||||||
|
|
||||||
|
CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_memcpy({{.*}}"enzyme_type"="{[]:Pointer}"
|
||||||
|
|
||||||
|
; Check that llvm.memcpy exists (either call or declare)
|
||||||
|
CHECK: {{(call|declare).*}}@llvm.memcpy
|
||||||
|
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
CHECK: force_memcpy
|
||||||
|
|
||||||
|
CHECK: @llvm.memcpy.p0.p0.i64
|
||||||
|
|
||||||
|
CHECK: test_memcpy - {[-1]:Float@double} |{[-1]:Pointer}:{}
|
||||||
|
|
||||||
|
CHECK-DAG: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Float@double, [-1,24]:Float@double}
|
||||||
|
|
||||||
|
CHECK-DAG: load double{{.*}}: {[-1]:Float@double}
|
||||||
|
|
||||||
|
CHECK-DAG: fmul double{{.*}}: {[-1]:Float@double}
|
||||||
|
|
||||||
|
CHECK-DAG: fadd double{{.*}}: {[-1]:Float@double}
|
||||||
36
tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.rs
Normal file
36
tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.rs
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
#![feature(autodiff)]
|
||||||
|
|
||||||
|
use std::autodiff::autodiff_reverse;
|
||||||
|
use std::ptr;
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
|
fn force_memcpy(src: *const f64, dst: *mut f64, count: usize) {
|
||||||
|
unsafe {
|
||||||
|
ptr::copy_nonoverlapping(src, dst, count);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[autodiff_reverse(d_test_memcpy, Duplicated, Active)]
|
||||||
|
#[no_mangle]
|
||||||
|
fn test_memcpy(input: &[f64; 128]) -> f64 {
|
||||||
|
let mut local_data = [0.0f64; 128];
|
||||||
|
|
||||||
|
// Use a separate function to prevent inlining and optimization
|
||||||
|
force_memcpy(input.as_ptr(), local_data.as_mut_ptr(), 128);
|
||||||
|
|
||||||
|
// Sum only first few elements to keep the computation simple
|
||||||
|
local_data[0] * local_data[0]
|
||||||
|
+ local_data[1] * local_data[1]
|
||||||
|
+ local_data[2] * local_data[2]
|
||||||
|
+ local_data[3] * local_data[3]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let input = [1.0; 128];
|
||||||
|
let mut d_input = [0.0; 128];
|
||||||
|
let result = test_memcpy(&input);
|
||||||
|
let result_d = d_test_memcpy(&input, &mut d_input, 1.0);
|
||||||
|
|
||||||
|
assert_eq!(result, result_d);
|
||||||
|
println!("Memcpy test passed: result = {}", result);
|
||||||
|
}
|
||||||
39
tests/run-make/autodiff/type-trees/memcpy-typetree/rmake.rs
Normal file
39
tests/run-make/autodiff/type-trees/memcpy-typetree/rmake.rs
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
//@ needs-enzyme
|
||||||
|
//@ ignore-cross-compile
|
||||||
|
|
||||||
|
use run_make_support::{llvm_filecheck, rfs, rustc};
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
// First, compile to LLVM IR to check for enzyme_type attributes
|
||||||
|
let _ir_output = rustc()
|
||||||
|
.input("memcpy.rs")
|
||||||
|
.arg("-Zautodiff=Enable")
|
||||||
|
.arg("-Zautodiff=NoPostopt")
|
||||||
|
.opt_level("0")
|
||||||
|
.arg("--emit=llvm-ir")
|
||||||
|
.arg("-o")
|
||||||
|
.arg("main.ll")
|
||||||
|
.run();
|
||||||
|
|
||||||
|
// Then compile with TypeTree analysis output for the existing checks
|
||||||
|
let output = rustc()
|
||||||
|
.input("memcpy.rs")
|
||||||
|
.arg("-Zautodiff=Enable,PrintTAFn=test_memcpy")
|
||||||
|
.arg("-Zautodiff=NoPostopt")
|
||||||
|
.opt_level("3")
|
||||||
|
.arg("-Clto=fat")
|
||||||
|
.arg("-g")
|
||||||
|
.run();
|
||||||
|
|
||||||
|
let stdout = output.stdout_utf8();
|
||||||
|
let stderr = output.stderr_utf8();
|
||||||
|
let ir_content = rfs::read_to_string("main.ll");
|
||||||
|
|
||||||
|
rfs::write("memcpy.stdout", &stdout);
|
||||||
|
rfs::write("memcpy.stderr", &stderr);
|
||||||
|
rfs::write("main.ir", &ir_content);
|
||||||
|
|
||||||
|
llvm_filecheck().patterns("memcpy.check").stdin_buf(stdout).run();
|
||||||
|
|
||||||
|
llvm_filecheck().patterns("memcpy-ir.check").stdin_buf(ir_content).run();
|
||||||
|
}
|
||||||
@@ -23,14 +23,8 @@ fn main() {
|
|||||||
.run();
|
.run();
|
||||||
|
|
||||||
// Verify NoTT version does NOT have enzyme_type attributes
|
// Verify NoTT version does NOT have enzyme_type attributes
|
||||||
llvm_filecheck()
|
llvm_filecheck().patterns("nott.check").stdin_buf(rfs::read("nott.ll")).run();
|
||||||
.patterns("nott.check")
|
|
||||||
.stdin_buf(rfs::read("nott.ll"))
|
|
||||||
.run();
|
|
||||||
|
|
||||||
// Verify TypeTree version DOES have enzyme_type attributes
|
// Verify TypeTree version DOES have enzyme_type attributes
|
||||||
llvm_filecheck()
|
llvm_filecheck().patterns("with_tt.check").stdin_buf(rfs::read("with_tt.ll")).run();
|
||||||
.patterns("with_tt.check")
|
}
|
||||||
.stdin_buf(rfs::read("with_tt.ll"))
|
|
||||||
.run();
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -12,4 +12,4 @@ fn main() {
|
|||||||
let x = 2.0;
|
let x = 2.0;
|
||||||
let mut dx = 0.0;
|
let mut dx = 0.0;
|
||||||
let _result = d_square(&x, &mut dx, 1.0);
|
let _result = d_square(&x, &mut dx, 1.0);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,4 +16,4 @@ fn main() {
|
|||||||
let x = 2.0;
|
let x = 2.0;
|
||||||
let mut dx = 0.0;
|
let mut dx = 0.0;
|
||||||
let result = d_square(&x, &mut dx, 1.0);
|
let result = d_square(&x, &mut dx, 1.0);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user