autodiff: fixed test to be more precise for type tree checking
This commit is contained in:
@@ -118,6 +118,13 @@ pub(crate) mod Enzyme_AD {
|
||||
max_size: i64,
|
||||
add_offset: u64,
|
||||
);
|
||||
pub(crate) fn EnzymeTypeTreeInsertEq(
|
||||
CTT: CTypeTreeRef,
|
||||
indices: *const i64,
|
||||
len: usize,
|
||||
ct: CConcreteType,
|
||||
ctx: &Context,
|
||||
);
|
||||
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
|
||||
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
|
||||
}
|
||||
@@ -234,6 +241,16 @@ pub(crate) mod Fallback_AD {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeTypeTreeInsertEq(
|
||||
CTT: CTypeTreeRef,
|
||||
indices: *const i64,
|
||||
len: usize,
|
||||
ct: CConcreteType,
|
||||
ctx: &Context,
|
||||
) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char {
|
||||
unimplemented!()
|
||||
}
|
||||
@@ -312,6 +329,12 @@ impl TypeTree {
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub(crate) fn insert(&mut self, indices: &[i64], ct: CConcreteType, ctx: &Context) {
|
||||
unsafe {
|
||||
EnzymeTypeTreeInsertEq(self.inner, indices.as_ptr(), indices.len(), ct, ctx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for TypeTree {
|
||||
|
||||
@@ -8,22 +8,24 @@ use {
|
||||
|
||||
use crate::llvm::{self, Value};
|
||||
|
||||
/// Converts a Rust TypeTree to Enzyme's internal TypeTree format
|
||||
///
|
||||
/// This function takes a Rust-side TypeTree (from rustc_ast::expand::typetree)
|
||||
/// and converts it to Enzyme's internal C++ TypeTree representation that
|
||||
/// Enzyme can understand during differentiation analysis.
|
||||
#[cfg(llvm_enzyme)]
|
||||
fn to_enzyme_typetree(
|
||||
rust_typetree: RustTypeTree,
|
||||
data_layout: &str,
|
||||
_data_layout: &str,
|
||||
llcx: &llvm::Context,
|
||||
) -> llvm::TypeTree {
|
||||
// Start with an empty TypeTree
|
||||
let mut enzyme_tt = llvm::TypeTree::new();
|
||||
|
||||
// Convert each Type in the Rust TypeTree to Enzyme format
|
||||
for rust_type in rust_typetree.0 {
|
||||
process_typetree_recursive(&mut enzyme_tt, &rust_typetree, &[], llcx);
|
||||
enzyme_tt
|
||||
}
|
||||
#[cfg(llvm_enzyme)]
|
||||
fn process_typetree_recursive(
|
||||
enzyme_tt: &mut llvm::TypeTree,
|
||||
rust_typetree: &RustTypeTree,
|
||||
parent_indices: &[i64],
|
||||
llcx: &llvm::Context,
|
||||
) {
|
||||
for rust_type in &rust_typetree.0 {
|
||||
let concrete_type = match rust_type.kind {
|
||||
rustc_ast::expand::typetree::Kind::Anything => llvm::CConcreteType::DT_Anything,
|
||||
rustc_ast::expand::typetree::Kind::Integer => llvm::CConcreteType::DT_Integer,
|
||||
@@ -35,25 +37,27 @@ fn to_enzyme_typetree(
|
||||
rustc_ast::expand::typetree::Kind::Unknown => llvm::CConcreteType::DT_Unknown,
|
||||
};
|
||||
|
||||
// Create a TypeTree for this specific type
|
||||
let type_tt = llvm::TypeTree::from_type(concrete_type, llcx);
|
||||
|
||||
// Apply offset if specified
|
||||
let type_tt = if rust_type.offset == -1 {
|
||||
type_tt // -1 means everywhere/no specific offset
|
||||
let mut indices = parent_indices.to_vec();
|
||||
if !parent_indices.is_empty() {
|
||||
if rust_type.offset == -1 {
|
||||
indices.push(-1);
|
||||
} else {
|
||||
indices.push(rust_type.offset as i64);
|
||||
}
|
||||
} else if rust_type.offset == -1 {
|
||||
indices.push(-1);
|
||||
} else {
|
||||
// Apply specific offset positioning
|
||||
type_tt.shift(data_layout, rust_type.offset, rust_type.size as isize, 0)
|
||||
};
|
||||
indices.push(rust_type.offset as i64);
|
||||
}
|
||||
|
||||
// Merge this type into the main TypeTree
|
||||
enzyme_tt = enzyme_tt.merge(type_tt);
|
||||
enzyme_tt.insert(&indices, concrete_type, llcx);
|
||||
|
||||
if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer && !rust_type.child.0.is_empty() {
|
||||
process_typetree_recursive(enzyme_tt, &rust_type.child, &indices, llcx);
|
||||
}
|
||||
}
|
||||
|
||||
enzyme_tt
|
||||
}
|
||||
|
||||
// Attaches TypeTree information to LLVM function as enzyme_type attributes.
|
||||
#[cfg(llvm_enzyme)]
|
||||
pub(crate) fn add_tt<'ll>(
|
||||
llmod: &'ll llvm::Module,
|
||||
@@ -64,28 +68,20 @@ pub(crate) fn add_tt<'ll>(
|
||||
let inputs = tt.args;
|
||||
let ret_tt: RustTypeTree = tt.ret;
|
||||
|
||||
// Get LLVM data layout string for TypeTree conversion
|
||||
let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) };
|
||||
let llvm_data_layout =
|
||||
std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes())
|
||||
.expect("got a non-UTF8 data-layout from LLVM");
|
||||
|
||||
// Attribute name that Enzyme recognizes for TypeTree information
|
||||
let attr_name = "enzyme_type";
|
||||
let c_attr_name = CString::new(attr_name).unwrap();
|
||||
|
||||
// Attach TypeTree attributes to each input parameter
|
||||
// Enzyme uses these to understand parameter memory layouts during differentiation
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
unsafe {
|
||||
// Convert Rust TypeTree to Enzyme's internal format
|
||||
let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx);
|
||||
|
||||
// Serialize TypeTree to string format that Enzyme can parse
|
||||
let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner);
|
||||
let c_str = std::ffi::CStr::from_ptr(c_str);
|
||||
|
||||
// Create LLVM string attribute with TypeTree information
|
||||
let attr = llvm::LLVMCreateStringAttribute(
|
||||
llcx,
|
||||
c_attr_name.as_ptr(),
|
||||
@@ -94,17 +90,11 @@ pub(crate) fn add_tt<'ll>(
|
||||
c_str.to_bytes().len() as c_uint,
|
||||
);
|
||||
|
||||
// Attach attribute to the specific function parameter
|
||||
// Note: ArgumentPlace uses 0-based indexing, but LLVM uses 1-based for arguments
|
||||
attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]);
|
||||
|
||||
// Free the C string to prevent memory leaks
|
||||
llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr());
|
||||
}
|
||||
}
|
||||
|
||||
// Attach TypeTree attribute to the return type
|
||||
// Enzyme needs this to understand how to handle return value derivatives
|
||||
unsafe {
|
||||
let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx);
|
||||
let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner);
|
||||
@@ -118,15 +108,11 @@ pub(crate) fn add_tt<'ll>(
|
||||
c_str.to_bytes().len() as c_uint,
|
||||
);
|
||||
|
||||
// Attach to function return type
|
||||
attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]);
|
||||
|
||||
// Free the C string
|
||||
llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr());
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback implementation when Enzyme is not available
|
||||
#[cfg(not(llvm_enzyme))]
|
||||
pub(crate) fn add_tt<'ll>(
|
||||
_llmod: &'ll llvm::Module,
|
||||
|
||||
Reference in New Issue
Block a user