Rollup merge of #144197 - KMJ-007:type-tree, r=ZuseZ4

TypeTree support in autodiff

# TypeTrees for Autodiff

## What are TypeTrees?
Memory layout descriptors for Enzyme. Tell Enzyme exactly how types are structured in memory so it can compute derivatives efficiently.

## Structure
```rust
TypeTree(Vec<Type>)

Type {
    offset: isize,  // byte offset (-1 = everywhere)
    size: usize,    // size in bytes
    kind: Kind,     // Float, Integer, Pointer, etc.
    child: TypeTree // nested structure
}
```

## Example: `fn compute(x: &f32, data: &[f32]) -> f32`

**Input 0: `x: &f32`**
```rust
TypeTree(vec![Type {
    offset: -1, size: 8, kind: Pointer,
    child: TypeTree(vec![Type {
        offset: -1, size: 4, kind: Float,
        child: TypeTree::new()
    }])
}])
```

**Input 1: `data: &[f32]`**
```rust
TypeTree(vec![Type {
    offset: -1, size: 8, kind: Pointer,
    child: TypeTree(vec![Type {
        offset: -1, size: 4, kind: Float,  // -1 = all elements
        child: TypeTree::new()
    }])
}])
```

**Output: `f32`**
```rust
TypeTree(vec![Type {
    offset: -1, size: 4, kind: Float,
    child: TypeTree::new()
}])
```

## Why Needed?
- Enzyme can't deduce complex type layouts from LLVM IR
- Prevents slow memory pattern analysis
- Enables correct derivative computation for nested structures
- Tells Enzyme which bytes are differentiable vs metadata

## What Enzyme Does With This Information:

Without TypeTrees (current state):
```llvm
; Enzyme sees generic LLVM IR:
define float ``@distance(ptr*`` %p1, ptr* %p2) {
; Has to guess what these pointers point to
; Slow analysis of all memory operations
; May miss optimization opportunities
}
```

With TypeTrees (our implementation):
```llvm
define "enzyme_type"="{[]:Float@float}" float ``@distance(``
    ptr "enzyme_type"="{[]:Pointer}" %p1,
    ptr "enzyme_type"="{[]:Pointer}" %p2
) {
; Enzyme knows exact type layout
; Can generate efficient derivative code directly
}
```

# TypeTrees - Offset and -1 Explained

## Type Structure

```rust
Type {
    offset: isize, // WHERE this type starts
    size: usize,   // HOW BIG this type is
    kind: Kind,    // WHAT KIND of data (Float, Int, Pointer)
    child: TypeTree // WHAT'S INSIDE (for pointers/containers)
}
```

## Offset Values

### Regular Offset (0, 4, 8, etc.)
**Specific byte position within a structure**

```rust
struct Point {
    x: f32, // offset 0, size 4
    y: f32, // offset 4, size 4
    id: i32, // offset 8, size 4
}
```

TypeTree for `&Point` (internal representation):
```rust
TypeTree(vec![
    Type { offset: 0, size: 4, kind: Float },   // x at byte 0
    Type { offset: 4, size: 4, kind: Float },   // y at byte 4
    Type { offset: 8, size: 4, kind: Integer }  // id at byte 8
])
```

Generates LLVM:
```llvm
"enzyme_type"="{[]:Float@float}"
```

### Offset -1 (Special: "Everywhere")
**Means "this pattern repeats for ALL elements"**

#### Example 1: Array `[f32; 100]`
```rust
TypeTree(vec![Type {
    offset: -1, // ALL positions
    size: 4,    // each f32 is 4 bytes
    kind: Float, // every element is float
}])
```

Instead of listing 100 separate Types with offsets `0,4,8,12...396`

#### Example 2: Slice `&[i32]`
```rust
// Pointer to slice data
TypeTree(vec![Type {
    offset: -1, size: 8, kind: Pointer,
    child: TypeTree(vec![Type {
        offset: -1, // ALL slice elements
        size: 4,    // each i32 is 4 bytes
        kind: Integer
    }])
}])
```

#### Example 3: Mixed Structure
```rust
struct Container {
    header: i64,        // offset 0
    data: [f32; 1000],  // offset 8, but elements use -1
}
```

```rust
TypeTree(vec![
    Type { offset: 0, size: 8, kind: Integer }, // header
    Type { offset: 8, size: 4000, kind: Pointer,
        child: TypeTree(vec![Type {
            offset: -1, size: 4, kind: Float // ALL array elements
        }])
    }
])
```
This commit is contained in:
Matthias Krüger
2025-09-28 18:13:11 +02:00
committed by GitHub
68 changed files with 1250 additions and 14 deletions

View File

@@ -3,9 +3,36 @@
use libc::{c_char, c_uint};
use super::MetadataKindId;
use super::ffi::{AttributeKind, BasicBlock, Metadata, Module, Type, Value};
use super::ffi::{AttributeKind, BasicBlock, Context, Metadata, Module, Type, Value};
use crate::llvm::{Bool, Builder};
// TypeTree types
pub(crate) type CTypeTreeRef = *mut EnzymeTypeTree;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub(crate) struct EnzymeTypeTree {
_unused: [u8; 0],
}
#[repr(u32)]
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
#[allow(non_camel_case_types)]
pub(crate) enum CConcreteType {
DT_Anything = 0,
DT_Integer = 1,
DT_Pointer = 2,
DT_Half = 3,
DT_Float = 4,
DT_Double = 5,
DT_Unknown = 6,
DT_FP128 = 9,
}
pub(crate) struct TypeTree {
pub(crate) inner: CTypeTreeRef,
}
#[link(name = "llvm-wrapper", kind = "static")]
unsafe extern "C" {
// Enzyme
@@ -68,10 +95,40 @@ pub(crate) mod Enzyme_AD {
use libc::c_void;
use super::{CConcreteType, CTypeTreeRef, Context};
unsafe extern "C" {
pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8);
pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char);
}
// TypeTree functions
unsafe extern "C" {
pub(crate) fn EnzymeNewTypeTree() -> CTypeTreeRef;
pub(crate) fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef;
pub(crate) fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef;
pub(crate) fn EnzymeFreeTypeTree(CTT: CTypeTreeRef);
pub(crate) fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool;
pub(crate) fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64);
pub(crate) fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef);
pub(crate) fn EnzymeTypeTreeShiftIndiciesEq(
arg1: CTypeTreeRef,
data_layout: *const c_char,
offset: i64,
max_size: i64,
add_offset: u64,
);
pub(crate) fn EnzymeTypeTreeInsertEq(
CTT: CTypeTreeRef,
indices: *const i64,
len: usize,
ct: CConcreteType,
ctx: &Context,
);
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
}
unsafe extern "C" {
static mut EnzymePrintPerf: c_void;
static mut EnzymePrintActivity: c_void;
@@ -141,6 +198,67 @@ pub(crate) use self::Fallback_AD::*;
pub(crate) mod Fallback_AD {
#![allow(unused_variables)]
use libc::c_char;
use super::{CConcreteType, CTypeTreeRef, Context};
// TypeTree function fallbacks
pub(crate) unsafe fn EnzymeNewTypeTree() -> CTypeTreeRef {
unimplemented!()
}
pub(crate) unsafe fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef {
unimplemented!()
}
pub(crate) unsafe fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef {
unimplemented!()
}
pub(crate) unsafe fn EnzymeFreeTypeTree(CTT: CTypeTreeRef) {
unimplemented!()
}
pub(crate) unsafe fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool {
unimplemented!()
}
pub(crate) unsafe fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64) {
unimplemented!()
}
pub(crate) unsafe fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef) {
unimplemented!()
}
pub(crate) unsafe fn EnzymeTypeTreeShiftIndiciesEq(
arg1: CTypeTreeRef,
data_layout: *const c_char,
offset: i64,
max_size: i64,
add_offset: u64,
) {
unimplemented!()
}
pub(crate) unsafe fn EnzymeTypeTreeInsertEq(
CTT: CTypeTreeRef,
indices: *const i64,
len: usize,
ct: CConcreteType,
ctx: &Context,
) {
unimplemented!()
}
pub(crate) unsafe fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char {
unimplemented!()
}
pub(crate) unsafe fn EnzymeTypeTreeToStringFree(arg1: *const c_char) {
unimplemented!()
}
pub(crate) fn set_inline(val: bool) {
unimplemented!()
}
@@ -169,3 +287,89 @@ pub(crate) mod Fallback_AD {
unimplemented!()
}
}
impl TypeTree {
pub(crate) fn new() -> TypeTree {
let inner = unsafe { EnzymeNewTypeTree() };
TypeTree { inner }
}
pub(crate) fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree {
let inner = unsafe { EnzymeNewTypeTreeCT(t, ctx) };
TypeTree { inner }
}
pub(crate) fn merge(self, other: Self) -> Self {
unsafe {
EnzymeMergeTypeTree(self.inner, other.inner);
}
drop(other);
self
}
#[must_use]
pub(crate) fn shift(
self,
layout: &str,
offset: isize,
max_size: isize,
add_offset: usize,
) -> Self {
let layout = std::ffi::CString::new(layout).unwrap();
unsafe {
EnzymeTypeTreeShiftIndiciesEq(
self.inner,
layout.as_ptr(),
offset as i64,
max_size as i64,
add_offset as u64,
);
}
self
}
pub(crate) fn insert(&mut self, indices: &[i64], ct: CConcreteType, ctx: &Context) {
unsafe {
EnzymeTypeTreeInsertEq(self.inner, indices.as_ptr(), indices.len(), ct, ctx);
}
}
}
impl Clone for TypeTree {
fn clone(&self) -> Self {
let inner = unsafe { EnzymeNewTypeTreeTR(self.inner) };
TypeTree { inner }
}
}
impl std::fmt::Display for TypeTree {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let ptr = unsafe { EnzymeTypeTreeToString(self.inner) };
let cstr = unsafe { std::ffi::CStr::from_ptr(ptr) };
match cstr.to_str() {
Ok(x) => write!(f, "{}", x)?,
Err(err) => write!(f, "could not parse: {}", err)?,
}
// delete C string pointer
unsafe {
EnzymeTypeTreeToStringFree(ptr);
}
Ok(())
}
}
impl std::fmt::Debug for TypeTree {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
<Self as std::fmt::Display>::fmt(self, f)
}
}
impl Drop for TypeTree {
fn drop(&mut self) {
unsafe { EnzymeFreeTypeTree(self.inner) }
}
}