Add TypeTree metadata attachment for autodiff
- Add F128 support to TypeTree Kind enum - Implement TypeTree FFI bindings and conversion functions - Add typetree.rs module for metadata attachment to LLVM functions - Integrate TypeTree generation with autodiff intrinsic pipeline - Support scalar types: f32, f64, integers, f16, f128 - Attach enzyme_type attributes as LLVM string metadata for Enzyme Signed-off-by: Karan Janthe <karanjanthe@gmail.com>
This commit is contained in:
@@ -3,9 +3,35 @@
|
||||
use libc::{c_char, c_uint};
|
||||
|
||||
use super::MetadataKindId;
|
||||
use super::ffi::{AttributeKind, BasicBlock, Metadata, Module, Type, Value};
|
||||
use super::ffi::{AttributeKind, BasicBlock, Context, Metadata, Module, Type, Value};
|
||||
use crate::llvm::{Bool, Builder};
|
||||
|
||||
// TypeTree types
|
||||
pub(crate) type CTypeTreeRef = *mut EnzymeTypeTree;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub(crate) struct EnzymeTypeTree {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
|
||||
#[repr(u32)]
|
||||
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
|
||||
#[allow(non_camel_case_types)]
|
||||
pub(crate) enum CConcreteType {
|
||||
DT_Anything = 0,
|
||||
DT_Integer = 1,
|
||||
DT_Pointer = 2,
|
||||
DT_Half = 3,
|
||||
DT_Float = 4,
|
||||
DT_Double = 5,
|
||||
DT_Unknown = 6,
|
||||
}
|
||||
|
||||
pub(crate) struct TypeTree {
|
||||
pub(crate) inner: CTypeTreeRef,
|
||||
}
|
||||
|
||||
#[link(name = "llvm-wrapper", kind = "static")]
|
||||
unsafe extern "C" {
|
||||
// Enzyme
|
||||
@@ -68,10 +94,33 @@ pub(crate) mod Enzyme_AD {
|
||||
|
||||
use libc::c_void;
|
||||
|
||||
use super::{CConcreteType, CTypeTreeRef, Context};
|
||||
|
||||
unsafe extern "C" {
|
||||
pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8);
|
||||
pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char);
|
||||
}
|
||||
|
||||
// TypeTree functions
|
||||
unsafe extern "C" {
|
||||
pub(crate) fn EnzymeNewTypeTree() -> CTypeTreeRef;
|
||||
pub(crate) fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef;
|
||||
pub(crate) fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef;
|
||||
pub(crate) fn EnzymeFreeTypeTree(CTT: CTypeTreeRef);
|
||||
pub(crate) fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool;
|
||||
pub(crate) fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64);
|
||||
pub(crate) fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef);
|
||||
pub(crate) fn EnzymeTypeTreeShiftIndiciesEq(
|
||||
arg1: CTypeTreeRef,
|
||||
data_layout: *const c_char,
|
||||
offset: i64,
|
||||
max_size: i64,
|
||||
add_offset: u64,
|
||||
);
|
||||
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
|
||||
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
static mut EnzymePrintPerf: c_void;
|
||||
static mut EnzymePrintActivity: c_void;
|
||||
@@ -141,6 +190,57 @@ pub(crate) use self::Fallback_AD::*;
|
||||
pub(crate) mod Fallback_AD {
|
||||
#![allow(unused_variables)]
|
||||
|
||||
use libc::c_char;
|
||||
|
||||
use super::{CConcreteType, CTypeTreeRef, Context};
|
||||
|
||||
// TypeTree function fallbacks
|
||||
pub(crate) unsafe fn EnzymeNewTypeTree() -> CTypeTreeRef {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeFreeTypeTree(CTT: CTypeTreeRef) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeTypeTreeShiftIndiciesEq(
|
||||
arg1: CTypeTreeRef,
|
||||
data_layout: *const c_char,
|
||||
offset: i64,
|
||||
max_size: i64,
|
||||
add_offset: u64,
|
||||
) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeTypeTreeToStringFree(arg1: *const c_char) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) fn set_inline(val: bool) {
|
||||
unimplemented!()
|
||||
}
|
||||
@@ -169,3 +269,83 @@ pub(crate) mod Fallback_AD {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
impl TypeTree {
|
||||
pub(crate) fn new() -> TypeTree {
|
||||
let inner = unsafe { EnzymeNewTypeTree() };
|
||||
TypeTree { inner }
|
||||
}
|
||||
|
||||
pub(crate) fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree {
|
||||
let inner = unsafe { EnzymeNewTypeTreeCT(t, ctx) };
|
||||
TypeTree { inner }
|
||||
}
|
||||
|
||||
pub(crate) fn merge(self, other: Self) -> Self {
|
||||
unsafe {
|
||||
EnzymeMergeTypeTree(self.inner, other.inner);
|
||||
}
|
||||
drop(other);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub(crate) fn shift(
|
||||
self,
|
||||
layout: &str,
|
||||
offset: isize,
|
||||
max_size: isize,
|
||||
add_offset: usize,
|
||||
) -> Self {
|
||||
let layout = std::ffi::CString::new(layout).unwrap();
|
||||
|
||||
unsafe {
|
||||
EnzymeTypeTreeShiftIndiciesEq(
|
||||
self.inner,
|
||||
layout.as_ptr(),
|
||||
offset as i64,
|
||||
max_size as i64,
|
||||
add_offset as u64,
|
||||
);
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for TypeTree {
|
||||
fn clone(&self) -> Self {
|
||||
let inner = unsafe { EnzymeNewTypeTreeTR(self.inner) };
|
||||
TypeTree { inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TypeTree {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let ptr = unsafe { EnzymeTypeTreeToString(self.inner) };
|
||||
let cstr = unsafe { std::ffi::CStr::from_ptr(ptr) };
|
||||
match cstr.to_str() {
|
||||
Ok(x) => write!(f, "{}", x)?,
|
||||
Err(err) => write!(f, "could not parse: {}", err)?,
|
||||
}
|
||||
|
||||
// delete C string pointer
|
||||
unsafe {
|
||||
EnzymeTypeTreeToStringFree(ptr);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for TypeTree {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
<Self as std::fmt::Display>::fmt(self, f)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TypeTree {
|
||||
fn drop(&mut self) {
|
||||
unsafe { EnzymeFreeTypeTree(self.inner) }
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user