feat: propagate generics to generated function

This commit is contained in:
HaeNoe
2025-04-19 19:17:22 +02:00
parent 16c1c54a29
commit 56a0c7dfea

View File

@@ -73,10 +73,10 @@ mod llvm_enzyme {
} }
// Get information about the function the macro is applied to // Get information about the function the macro is applied to
fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident)> { fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident, Generics)> {
match &iitem.kind { match &iitem.kind {
ItemKind::Fn(box ast::Fn { sig, ident, .. }) => { ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
Some((iitem.vis.clone(), sig.clone(), ident.clone())) Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
} }
_ => None, _ => None,
} }
@@ -210,16 +210,18 @@ mod llvm_enzyme {
} }
let dcx = ecx.sess.dcx(); let dcx = ecx.sess.dcx();
// first get information about the annotable item: // first get information about the annotable item: visibility, signature, name and generic
let Some((vis, sig, primal)) = (match &item { // parameters.
// these will be used to generate the differentiated version of the function
let Some((vis, sig, primal, generics)) = (match &item {
Annotatable::Item(iitem) => extract_item_info(iitem), Annotatable::Item(iitem) => extract_item_info(iitem),
Annotatable::Stmt(stmt) => match &stmt.kind { Annotatable::Stmt(stmt) => match &stmt.kind {
ast::StmtKind::Item(iitem) => extract_item_info(iitem), ast::StmtKind::Item(iitem) => extract_item_info(iitem),
_ => None, _ => None,
}, },
Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind { Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind {
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => { ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
Some((assoc_item.vis.clone(), sig.clone(), ident.clone())) Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
} }
_ => None, _ => None,
}, },
@@ -310,7 +312,7 @@ mod llvm_enzyme {
defaultness: ast::Defaultness::Final, defaultness: ast::Defaultness::Final,
sig: d_sig, sig: d_sig,
ident: first_ident(&meta_item_vec[0]), ident: first_ident(&meta_item_vec[0]),
generics: Generics::default(), generics,
contract: None, contract: None,
body: Some(d_body), body: Some(d_body),
define_opaque: None, define_opaque: None,