Remove inlining for autodiff handling

This commit is contained in:
Marcelo Domínguez
2025-08-14 15:29:37 +00:00
parent 250d77e5d7
commit c9c1c17128
2 changed files with 21 additions and 19 deletions

View File

@@ -192,7 +192,6 @@ mod llvm_enzyme {
/// which becomes expanded to:
/// ```
/// #[rustc_autodiff]
/// #[inline(never)]
/// fn sin(x: &Box<f32>) -> f32 {
/// f32::sin(**x)
/// }
@@ -371,7 +370,7 @@ mod llvm_enzyme {
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
let inline_never = outer_normal_attr(&inline_never_attr, new_id, span);
// We're avoid duplicating the attributes `#[rustc_autodiff]` and `#[inline(never)]`.
// We're avoid duplicating the attribute `#[rustc_autodiff]`.
fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {
match (attr, item) {
(ast::AttrKind::Normal(a), ast::AttrKind::Normal(b)) => {
@@ -384,14 +383,16 @@ mod llvm_enzyme {
}
}
let mut has_inline_never = false;
// Don't add it multiple times:
let orig_annotatable: Annotatable = match item {
Annotatable::Item(ref mut iitem) => {
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
iitem.attrs.push(attr);
}
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
iitem.attrs.push(inline_never.clone());
if iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
has_inline_never = true;
}
Annotatable::Item(iitem.clone())
}
@@ -399,8 +400,8 @@ mod llvm_enzyme {
if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
assoc_item.attrs.push(attr);
}
if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
assoc_item.attrs.push(inline_never.clone());
if assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
has_inline_never = true;
}
Annotatable::AssocItem(assoc_item.clone(), i)
}
@@ -410,9 +411,8 @@ mod llvm_enzyme {
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
iitem.attrs.push(attr);
}
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind))
{
iitem.attrs.push(inline_never.clone());
if iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
has_inline_never = true;
}
}
_ => unreachable!("stmt kind checked previously"),
@@ -433,11 +433,19 @@ mod llvm_enzyme {
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
// If the source function has the `#[inline(never)]` attribute, we'll also add it to the diff function
let mut d_attrs = thin_vec![d_attr];
if has_inline_never {
d_attrs.push(inline_never);
}
let d_annotatable = match &item {
Annotatable::AssocItem(_, _) => {
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(d_fn);
let d_fn = Box::new(ast::AssocItem {
attrs: thin_vec![d_attr],
attrs: d_attrs,
id: ast::DUMMY_NODE_ID,
span,
vis,
@@ -447,13 +455,13 @@ mod llvm_enzyme {
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
}
Annotatable::Item(_) => {
let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(d_fn));
let mut d_fn = ecx.item(span, d_attrs, ItemKind::Fn(d_fn));
d_fn.vis = vis;
Annotatable::Item(d_fn)
}
Annotatable::Stmt(_) => {
let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(d_fn));
let mut d_fn = ecx.item(span, d_attrs, ItemKind::Fn(d_fn));
d_fn.vis = vis;
Annotatable::Stmt(Box::new(ast::Stmt {