autodiff: typetree recursive depth query from enzyme with fallback
Signed-off-by: Karan Janthe <karanjanthe@gmail.com>
This commit is contained in:
@@ -127,7 +127,6 @@ pub(crate) mod Enzyme_AD {
|
||||
);
|
||||
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
|
||||
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
|
||||
pub(crate) fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use rustc_ast::expand::typetree::FncTree;
|
||||
#[cfg(llvm_enzyme)]
|
||||
#[cfg(feature = "llvm_enzyme")]
|
||||
use {
|
||||
crate::attributes,
|
||||
rustc_ast::expand::typetree::TypeTree as RustTypeTree,
|
||||
@@ -8,7 +8,7 @@ use {
|
||||
|
||||
use crate::llvm::{self, Value};
|
||||
|
||||
#[cfg(llvm_enzyme)]
|
||||
#[cfg(feature = "llvm_enzyme")]
|
||||
fn to_enzyme_typetree(
|
||||
rust_typetree: RustTypeTree,
|
||||
_data_layout: &str,
|
||||
@@ -18,7 +18,7 @@ fn to_enzyme_typetree(
|
||||
process_typetree_recursive(&mut enzyme_tt, &rust_typetree, &[], llcx);
|
||||
enzyme_tt
|
||||
}
|
||||
#[cfg(llvm_enzyme)]
|
||||
#[cfg(feature = "llvm_enzyme")]
|
||||
fn process_typetree_recursive(
|
||||
enzyme_tt: &mut llvm::TypeTree,
|
||||
rust_typetree: &RustTypeTree,
|
||||
@@ -56,7 +56,7 @@ fn process_typetree_recursive(
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(llvm_enzyme)]
|
||||
#[cfg(feature = "llvm_enzyme")]
|
||||
pub(crate) fn add_tt<'ll>(
|
||||
llmod: &'ll llvm::Module,
|
||||
llcx: &'ll llvm::Context,
|
||||
@@ -111,7 +111,7 @@ pub(crate) fn add_tt<'ll>(
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(llvm_enzyme))]
|
||||
#[cfg(not(feature = "llvm_enzyme"))]
|
||||
pub(crate) fn add_tt<'ll>(
|
||||
_llmod: &'ll llvm::Module,
|
||||
_llcx: &'ll llvm::Context,
|
||||
|
||||
@@ -1847,3 +1847,15 @@ extern "C" void LLVMRustSetNoSanitizeHWAddress(LLVMValueRef Global) {
|
||||
MD.NoHWAddress = true;
|
||||
GV.setSanitizerMetadata(MD);
|
||||
}
|
||||
|
||||
#ifdef ENZYME
|
||||
extern "C" {
|
||||
extern llvm::cl::opt<unsigned> EnzymeMaxTypeDepth;
|
||||
}
|
||||
|
||||
extern "C" size_t LLVMRustEnzymeGetMaxTypeDepth() { return EnzymeMaxTypeDepth; }
|
||||
#else
|
||||
extern "C" size_t LLVMRustEnzymeGetMaxTypeDepth() {
|
||||
return 6; // Default fallback depth
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -63,7 +63,7 @@ pub use rustc_type_ir::solve::SizedTraitKind;
|
||||
pub use rustc_type_ir::*;
|
||||
#[allow(hidden_glob_reexports, unused_imports)]
|
||||
use rustc_type_ir::{InferCtxtLike, Interner};
|
||||
use tracing::{debug, instrument};
|
||||
use tracing::{debug, instrument, trace};
|
||||
pub use vtable::*;
|
||||
use {rustc_ast as ast, rustc_hir as hir};
|
||||
|
||||
@@ -2256,6 +2256,10 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
|
||||
typetree_from_ty_inner(tcx, ty, 0, &mut visited)
|
||||
}
|
||||
|
||||
/// Maximum recursion depth for TypeTree generation to prevent stack overflow
|
||||
/// from pathological deeply nested types. Combined with cycle detection.
|
||||
const MAX_TYPETREE_DEPTH: usize = 6;
|
||||
|
||||
/// Internal recursive function for TypeTree generation with cycle detection and depth limiting.
|
||||
fn typetree_from_ty_inner<'tcx>(
|
||||
tcx: TyCtxt<'tcx>,
|
||||
@@ -2263,19 +2267,8 @@ fn typetree_from_ty_inner<'tcx>(
|
||||
depth: usize,
|
||||
visited: &mut Vec<Ty<'tcx>>,
|
||||
) -> TypeTree {
|
||||
#[cfg(llvm_enzyme)]
|
||||
{
|
||||
unsafe extern "C" {
|
||||
fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint;
|
||||
}
|
||||
let max_depth = unsafe { EnzymeGetMaxTypeDepth() } as usize;
|
||||
if depth > max_depth {
|
||||
return TypeTree::new();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(llvm_enzyme))]
|
||||
if depth > 6 {
|
||||
if depth >= MAX_TYPETREE_DEPTH {
|
||||
trace!("typetree depth limit {} reached for type: {}", MAX_TYPETREE_DEPTH, ty);
|
||||
return TypeTree::new();
|
||||
}
|
||||
|
||||
|
||||
Submodule src/llvm-project updated: 2a22c80143...333793696b
Submodule src/tools/cargo updated: a4bd03c92d...966f94733b
Reference in New Issue
Block a user