feat: propagate generics to generated function

This commit is contained in:
HaeNoe
2025-04-19 19:17:22 +02:00
parent 16c1c54a29
commit 56a0c7dfea

View File

@@ -73,10 +73,10 @@ mod llvm_enzyme {
}
// Get information about the function the macro is applied to
fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident)> {
fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident, Generics)> {
match &iitem.kind {
ItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
Some((iitem.vis.clone(), sig.clone(), ident.clone()))
ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
}
_ => None,
}
@@ -210,16 +210,18 @@ mod llvm_enzyme {
}
let dcx = ecx.sess.dcx();
// first get information about the annotable item:
let Some((vis, sig, primal)) = (match &item {
// first get information about the annotable item: visibility, signature, name and generic
// parameters.
// these will be used to generate the differentiated version of the function
let Some((vis, sig, primal, generics)) = (match &item {
Annotatable::Item(iitem) => extract_item_info(iitem),
Annotatable::Stmt(stmt) => match &stmt.kind {
ast::StmtKind::Item(iitem) => extract_item_info(iitem),
_ => None,
},
Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind {
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
Some((assoc_item.vis.clone(), sig.clone(), ident.clone()))
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
}
_ => None,
},
@@ -310,7 +312,7 @@ mod llvm_enzyme {
defaultness: ast::Defaultness::Final,
sig: d_sig,
ident: first_ident(&meta_item_vec[0]),
generics: Generics::default(),
generics,
contract: None,
body: Some(d_body),
define_opaque: None,