autodiff: typetree recursive depth query from enzyme with fallback

Signed-off-by: Karan Janthe <karanjanthe@gmail.com>
This commit is contained in:
Karan Janthe
2025-09-12 06:11:18 +00:00
parent 4520926bb5
commit 3ba5f19182
6 changed files with 26 additions and 22 deletions

View File

@@ -127,7 +127,6 @@ pub(crate) mod Enzyme_AD {
);
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
pub(crate) fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint;
}
unsafe extern "C" {

View File

@@ -1,5 +1,5 @@
use rustc_ast::expand::typetree::FncTree;
#[cfg(llvm_enzyme)]
#[cfg(feature = "llvm_enzyme")]
use {
crate::attributes,
rustc_ast::expand::typetree::TypeTree as RustTypeTree,
@@ -8,7 +8,7 @@ use {
use crate::llvm::{self, Value};
#[cfg(llvm_enzyme)]
#[cfg(feature = "llvm_enzyme")]
fn to_enzyme_typetree(
rust_typetree: RustTypeTree,
_data_layout: &str,
@@ -18,7 +18,7 @@ fn to_enzyme_typetree(
process_typetree_recursive(&mut enzyme_tt, &rust_typetree, &[], llcx);
enzyme_tt
}
#[cfg(llvm_enzyme)]
#[cfg(feature = "llvm_enzyme")]
fn process_typetree_recursive(
enzyme_tt: &mut llvm::TypeTree,
rust_typetree: &RustTypeTree,
@@ -56,7 +56,7 @@ fn process_typetree_recursive(
}
}
#[cfg(llvm_enzyme)]
#[cfg(feature = "llvm_enzyme")]
pub(crate) fn add_tt<'ll>(
llmod: &'ll llvm::Module,
llcx: &'ll llvm::Context,
@@ -111,7 +111,7 @@ pub(crate) fn add_tt<'ll>(
}
}
#[cfg(not(llvm_enzyme))]
#[cfg(not(feature = "llvm_enzyme"))]
pub(crate) fn add_tt<'ll>(
_llmod: &'ll llvm::Module,
_llcx: &'ll llvm::Context,

View File

@@ -1847,3 +1847,15 @@ extern "C" void LLVMRustSetNoSanitizeHWAddress(LLVMValueRef Global) {
MD.NoHWAddress = true;
GV.setSanitizerMetadata(MD);
}
#ifdef ENZYME
extern "C" {
extern llvm::cl::opt<unsigned> EnzymeMaxTypeDepth;
}
extern "C" size_t LLVMRustEnzymeGetMaxTypeDepth() { return EnzymeMaxTypeDepth; }
#else
extern "C" size_t LLVMRustEnzymeGetMaxTypeDepth() {
return 6; // Default fallback depth
}
#endif

View File

@@ -63,7 +63,7 @@ pub use rustc_type_ir::solve::SizedTraitKind;
pub use rustc_type_ir::*;
#[allow(hidden_glob_reexports, unused_imports)]
use rustc_type_ir::{InferCtxtLike, Interner};
use tracing::{debug, instrument};
use tracing::{debug, instrument, trace};
pub use vtable::*;
use {rustc_ast as ast, rustc_hir as hir};
@@ -2256,6 +2256,10 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
typetree_from_ty_inner(tcx, ty, 0, &mut visited)
}
/// Maximum recursion depth for TypeTree generation to prevent stack overflow
/// from pathological deeply nested types. Combined with cycle detection.
const MAX_TYPETREE_DEPTH: usize = 6;
/// Internal recursive function for TypeTree generation with cycle detection and depth limiting.
fn typetree_from_ty_inner<'tcx>(
tcx: TyCtxt<'tcx>,
@@ -2263,19 +2267,8 @@ fn typetree_from_ty_inner<'tcx>(
depth: usize,
visited: &mut Vec<Ty<'tcx>>,
) -> TypeTree {
#[cfg(llvm_enzyme)]
{
unsafe extern "C" {
fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint;
}
let max_depth = unsafe { EnzymeGetMaxTypeDepth() } as usize;
if depth > max_depth {
return TypeTree::new();
}
}
#[cfg(not(llvm_enzyme))]
if depth > 6 {
if depth >= MAX_TYPETREE_DEPTH {
trace!("typetree depth limit {} reached for type: {}", MAX_TYPETREE_DEPTH, ty);
return TypeTree::new();
}