Add a procedural macro for expanding all function signatures
Introduce `libm_test::for_each_function`. which macro takes a callback macro and invokes it once per function signature. This provides an easier way of registering various tests and benchmarks without duplicating the function names and signatures each time.
This commit is contained in:
541
library/compiler-builtins/libm/crates/libm-macros/src/lib.rs
Normal file
541
library/compiler-builtins/libm/crates/libm-macros/src/lib.rs
Normal file
@@ -0,0 +1,541 @@
|
||||
mod parse;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
use parse::{Invocation, StructuredInput};
|
||||
use proc_macro as pm;
|
||||
use proc_macro2::{self as pm2, Span};
|
||||
use quote::{ToTokens, quote};
|
||||
use syn::Ident;
|
||||
use syn::visit_mut::VisitMut;
|
||||
|
||||
const ALL_FUNCTIONS: &[(Signature, Option<Signature>, &[&str])] = &[
|
||||
(
|
||||
// `fn(f32) -> f32`
|
||||
Signature { args: &[Ty::F32], returns: &[Ty::F32] },
|
||||
None,
|
||||
&[
|
||||
"acosf", "acoshf", "asinf", "asinhf", "atanf", "atanhf", "cbrtf", "ceilf", "cosf",
|
||||
"coshf", "erff", "exp10f", "exp2f", "expf", "expm1f", "fabsf", "floorf", "j0f", "j1f",
|
||||
"lgammaf", "log10f", "log1pf", "log2f", "logf", "rintf", "roundf", "sinf", "sinhf",
|
||||
"sqrtf", "tanf", "tanhf", "tgammaf", "truncf",
|
||||
],
|
||||
),
|
||||
(
|
||||
// `(f64) -> f64`
|
||||
Signature { args: &[Ty::F64], returns: &[Ty::F64] },
|
||||
None,
|
||||
&[
|
||||
"acos", "acosh", "asin", "asinh", "atan", "atanh", "cbrt", "ceil", "cos", "cosh",
|
||||
"erf", "exp10", "exp2", "exp", "expm1", "fabs", "floor", "j0", "j1", "lgamma", "log10",
|
||||
"log1p", "log2", "log", "rint", "round", "sin", "sinh", "sqrt", "tan", "tanh",
|
||||
"tgamma", "trunc",
|
||||
],
|
||||
),
|
||||
(
|
||||
// `(f32, f32) -> f32`
|
||||
Signature { args: &[Ty::F32, Ty::F32], returns: &[Ty::F32] },
|
||||
None,
|
||||
&[
|
||||
"atan2f",
|
||||
"copysignf",
|
||||
"fdimf",
|
||||
"fmaxf",
|
||||
"fminf",
|
||||
"fmodf",
|
||||
"hypotf",
|
||||
"nextafterf",
|
||||
"powf",
|
||||
"remainderf",
|
||||
],
|
||||
),
|
||||
(
|
||||
// `(f64, f64) -> f64`
|
||||
Signature { args: &[Ty::F64, Ty::F64], returns: &[Ty::F64] },
|
||||
None,
|
||||
&[
|
||||
"atan2",
|
||||
"copysign",
|
||||
"fdim",
|
||||
"fmax",
|
||||
"fmin",
|
||||
"fmod",
|
||||
"hypot",
|
||||
"nextafter",
|
||||
"pow",
|
||||
"remainder",
|
||||
],
|
||||
),
|
||||
(
|
||||
// `(f32, f32, f32) -> f32`
|
||||
Signature { args: &[Ty::F32, Ty::F32, Ty::F32], returns: &[Ty::F32] },
|
||||
None,
|
||||
&["fmaf"],
|
||||
),
|
||||
(
|
||||
// `(f64, f64, f64) -> f64`
|
||||
Signature { args: &[Ty::F64, Ty::F64, Ty::F64], returns: &[Ty::F64] },
|
||||
None,
|
||||
&["fma"],
|
||||
),
|
||||
(
|
||||
// `(f32) -> i32`
|
||||
Signature { args: &[Ty::F32], returns: &[Ty::I32] },
|
||||
None,
|
||||
&["ilogbf"],
|
||||
),
|
||||
(
|
||||
// `(f64) -> i32`
|
||||
Signature { args: &[Ty::F64], returns: &[Ty::I32] },
|
||||
None,
|
||||
&["ilogb"],
|
||||
),
|
||||
(
|
||||
// `(i32, f32) -> f32`
|
||||
Signature { args: &[Ty::I32, Ty::F32], returns: &[Ty::F32] },
|
||||
None,
|
||||
&["jnf"],
|
||||
),
|
||||
(
|
||||
// `(i32, f64) -> f64`
|
||||
Signature { args: &[Ty::I32, Ty::F64], returns: &[Ty::F64] },
|
||||
None,
|
||||
&["jn"],
|
||||
),
|
||||
(
|
||||
// `(f32, i32) -> f32`
|
||||
Signature { args: &[Ty::F32, Ty::I32], returns: &[Ty::F32] },
|
||||
None,
|
||||
&["scalbnf", "ldexpf"],
|
||||
),
|
||||
(
|
||||
// `(f64, i64) -> f64`
|
||||
Signature { args: &[Ty::F64, Ty::I32], returns: &[Ty::F64] },
|
||||
None,
|
||||
&["scalbn", "ldexp"],
|
||||
),
|
||||
(
|
||||
// `(f32, &mut f32) -> f32` as `(f32) -> (f32, f32)`
|
||||
Signature { args: &[Ty::F32], returns: &[Ty::F32, Ty::F32] },
|
||||
Some(Signature { args: &[Ty::F32, Ty::MutF32], returns: &[Ty::F32] }),
|
||||
&["modff"],
|
||||
),
|
||||
(
|
||||
// `(f64, &mut f64) -> f64` as `(f64) -> (f64, f64)`
|
||||
Signature { args: &[Ty::F64], returns: &[Ty::F64, Ty::F64] },
|
||||
Some(Signature { args: &[Ty::F64, Ty::MutF64], returns: &[Ty::F64] }),
|
||||
&["modf"],
|
||||
),
|
||||
(
|
||||
// `(f32, &mut c_int) -> f32` as `(f32) -> (f32, i32)`
|
||||
Signature { args: &[Ty::F32], returns: &[Ty::F32, Ty::I32] },
|
||||
Some(Signature { args: &[Ty::F32, Ty::MutCInt], returns: &[Ty::F32] }),
|
||||
&["frexpf", "lgammaf_r"],
|
||||
),
|
||||
(
|
||||
// `(f64, &mut c_int) -> f64` as `(f64) -> (f64, i32)`
|
||||
Signature { args: &[Ty::F64], returns: &[Ty::F64, Ty::I32] },
|
||||
Some(Signature { args: &[Ty::F64, Ty::MutCInt], returns: &[Ty::F64] }),
|
||||
&["frexp", "lgamma_r"],
|
||||
),
|
||||
(
|
||||
// `(f32, f32, &mut c_int) -> f32` as `(f32, f32) -> (f32, i32)`
|
||||
Signature { args: &[Ty::F32, Ty::F32], returns: &[Ty::F32, Ty::I32] },
|
||||
Some(Signature { args: &[Ty::F32, Ty::F32, Ty::MutCInt], returns: &[Ty::F32] }),
|
||||
&["remquof"],
|
||||
),
|
||||
(
|
||||
// `(f64, f64, &mut c_int) -> f64` as `(f64, f64) -> (f64, i32)`
|
||||
Signature { args: &[Ty::F64, Ty::F64], returns: &[Ty::F64, Ty::I32] },
|
||||
Some(Signature { args: &[Ty::F64, Ty::F64, Ty::MutCInt], returns: &[Ty::F64] }),
|
||||
&["remquo"],
|
||||
),
|
||||
(
|
||||
// `(f32, &mut f32, &mut f32)` as `(f32) -> (f32, f32)`
|
||||
Signature { args: &[Ty::F32], returns: &[Ty::F32, Ty::F32] },
|
||||
Some(Signature { args: &[Ty::F32, Ty::MutF32, Ty::MutF32], returns: &[] }),
|
||||
&["sincosf"],
|
||||
),
|
||||
(
|
||||
// `(f64, &mut f64, &mut f64)` as `(f64) -> (f64, f64)`
|
||||
Signature { args: &[Ty::F64], returns: &[Ty::F64, Ty::F64] },
|
||||
Some(Signature { args: &[Ty::F64, Ty::MutF64, Ty::MutF64], returns: &[] }),
|
||||
&["sincos"],
|
||||
),
|
||||
];
|
||||
|
||||
/// A type used in a function signature.
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum Ty {
|
||||
F16,
|
||||
F32,
|
||||
F64,
|
||||
F128,
|
||||
I32,
|
||||
CInt,
|
||||
MutF16,
|
||||
MutF32,
|
||||
MutF64,
|
||||
MutF128,
|
||||
MutI32,
|
||||
MutCInt,
|
||||
}
|
||||
|
||||
impl ToTokens for Ty {
|
||||
fn to_tokens(&self, tokens: &mut pm2::TokenStream) {
|
||||
let ts = match self {
|
||||
Ty::F16 => quote! { f16 },
|
||||
Ty::F32 => quote! { f32 },
|
||||
Ty::F64 => quote! { f64 },
|
||||
Ty::F128 => quote! { f128 },
|
||||
Ty::I32 => quote! { i32 },
|
||||
Ty::CInt => quote! { ::core::ffi::c_int },
|
||||
Ty::MutF16 => quote! { &mut f16 },
|
||||
Ty::MutF32 => quote! { &mut f32 },
|
||||
Ty::MutF64 => quote! { &mut f64 },
|
||||
Ty::MutF128 => quote! { &mut f128 },
|
||||
Ty::MutI32 => quote! { &mut i32 },
|
||||
Ty::MutCInt => quote! { &mut core::ffi::c_int },
|
||||
};
|
||||
|
||||
tokens.extend(ts);
|
||||
}
|
||||
}
|
||||
|
||||
/// Representation of e.g. `(f32, f32) -> f32`
|
||||
#[derive(Debug, Clone)]
|
||||
struct Signature {
|
||||
args: &'static [Ty],
|
||||
returns: &'static [Ty],
|
||||
}
|
||||
|
||||
/// Combined information about a function implementation.
|
||||
#[derive(Debug, Clone)]
|
||||
struct FunctionInfo {
|
||||
name: &'static str,
|
||||
/// Function signature for C implementations
|
||||
c_sig: Signature,
|
||||
/// Function signature for Rust implementations
|
||||
rust_sig: Signature,
|
||||
}
|
||||
|
||||
/// A flat representation of `ALL_FUNCTIONS`.
|
||||
static ALL_FUNCTIONS_FLAT: LazyLock<Vec<FunctionInfo>> = LazyLock::new(|| {
|
||||
let mut ret = Vec::new();
|
||||
|
||||
for (rust_sig, c_sig, names) in ALL_FUNCTIONS {
|
||||
for name in *names {
|
||||
let api = FunctionInfo {
|
||||
name,
|
||||
rust_sig: rust_sig.clone(),
|
||||
c_sig: c_sig.clone().unwrap_or_else(|| rust_sig.clone()),
|
||||
};
|
||||
ret.push(api);
|
||||
}
|
||||
}
|
||||
|
||||
ret.sort_by_key(|item| item.name);
|
||||
ret
|
||||
});
|
||||
|
||||
/// Do something for each function present in this crate.
|
||||
///
|
||||
/// Takes a callback macro and invokes it multiple times, once for each function that
|
||||
/// this crate exports. This makes it easy to create generic tests, benchmarks, or other checks
|
||||
/// and apply it to each symbol.
|
||||
///
|
||||
/// Additionally, the `extra` and `fn_extra` patterns can make use of magic identifiers:
|
||||
///
|
||||
/// - `MACRO_FN_NAME`: gets replaced with the name of the function on that invocation.
|
||||
/// - `MACRO_FN_NAME_NORMALIZED`: similar to the above, but removes sufixes so e.g. `sinf` becomes
|
||||
/// `sin`, `cosf128` becomes `cos`, etc.
|
||||
///
|
||||
/// Invoke as:
|
||||
///
|
||||
/// ```
|
||||
/// // Macro that is invoked once per function
|
||||
/// macro_rules! callback_macro {
|
||||
/// (
|
||||
/// // Name of that function
|
||||
/// fn_name: $fn_name:ident,
|
||||
/// // Function signature of the C version (e.g. `fn(f32, &mut f32) -> f32`)
|
||||
/// CFn: $CFn:ty,
|
||||
/// // A tuple representing the C version's arguments (e.g. `(f32, &mut f32)`)
|
||||
/// CArgs: $CArgs:ty,
|
||||
/// // The C version's return type (e.g. `f32`)
|
||||
/// CRet: $CRet:ty,
|
||||
/// // Function signature of the Rust version (e.g. `fn(f32) -> (f32, f32)`)
|
||||
/// RustFn: $RustFn:ty,
|
||||
/// // A tuple representing the Rust version's arguments (e.g. `(f32,)`)
|
||||
/// RustArgs: $RustArgs:ty,
|
||||
/// // The Rust version's return type (e.g. `(f32, f32)`)
|
||||
/// RustRet: $RustRet:ty,
|
||||
/// // Attributes for the current function, if any
|
||||
/// attrs: [$($meta:meta)*]
|
||||
/// // Extra tokens passed directly (if any)
|
||||
/// extra: [$extra:ident],
|
||||
/// // Extra function-tokens passed directly (if any)
|
||||
/// fn_extra: $fn_extra:expr,
|
||||
/// ) => { };
|
||||
/// }
|
||||
///
|
||||
/// libm_macros::for_each_function! {
|
||||
/// // The macro to invoke as a callback
|
||||
/// callback: callback_macro,
|
||||
/// // Functions to skip, i.e. `callback` shouldn't be called at all for these.
|
||||
/// //
|
||||
/// // This is an optional field.
|
||||
/// skip: [sin, cos],
|
||||
/// // Attributes passed as `attrs` for specific functions. For example, here the invocation
|
||||
/// // with `sinf` and that with `cosf` will both get `meta1` and `meta2`, but no others will.
|
||||
/// //
|
||||
/// // This is an optional field.
|
||||
/// attributes: [
|
||||
/// #[meta1]
|
||||
/// #[meta2]
|
||||
/// [sinf, cosf],
|
||||
/// ],
|
||||
/// // Any tokens that should be passed directly to all invocations of the callback. This can
|
||||
/// // be used to pass local variables or other things the macro needs access to.
|
||||
/// //
|
||||
/// // This is an optional field.
|
||||
/// extra: [foo],
|
||||
/// // Similar to `extra`, but allow providing a pattern for only specific functions. Uses
|
||||
/// // a simplified match-like syntax.
|
||||
/// fn_extra: match MACRO_FN_NAME {
|
||||
/// hypot | hypotf => |x| x.hypot(),
|
||||
/// _ => |x| x,
|
||||
/// },
|
||||
/// }
|
||||
/// ```
|
||||
#[proc_macro]
|
||||
pub fn for_each_function(tokens: pm::TokenStream) -> pm::TokenStream {
|
||||
let input = syn::parse_macro_input!(tokens as Invocation);
|
||||
|
||||
let res = StructuredInput::from_fields(input)
|
||||
.and_then(|s_in| validate(&s_in).map(|fn_list| (s_in, fn_list)))
|
||||
.and_then(|(s_in, fn_list)| expand(s_in, &fn_list));
|
||||
|
||||
match res {
|
||||
Ok(ts) => ts.into(),
|
||||
Err(e) => e.into_compile_error().into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check for any input that is structurally correct but has other problems.
|
||||
///
|
||||
/// Returns the list of function names that we should expand for.
|
||||
fn validate(input: &StructuredInput) -> syn::Result<Vec<&'static FunctionInfo>> {
|
||||
// Collect lists of all functions that are provied as macro inputs in various fields (only,
|
||||
// skip, attributes).
|
||||
let attr_mentions = input
|
||||
.attributes
|
||||
.iter()
|
||||
.flat_map(|map_list| map_list.iter())
|
||||
.flat_map(|attr_map| attr_map.names.iter());
|
||||
let only_mentions = input.only.iter().flat_map(|only_list| only_list.iter());
|
||||
let fn_extra_mentions =
|
||||
input.fn_extra.iter().flat_map(|v| v.keys()).filter(|name| *name != "_");
|
||||
let all_mentioned_fns =
|
||||
input.skip.iter().chain(only_mentions).chain(attr_mentions).chain(fn_extra_mentions);
|
||||
|
||||
// Make sure that every function mentioned is a real function
|
||||
for mentioned in all_mentioned_fns {
|
||||
if !ALL_FUNCTIONS_FLAT.iter().any(|func| mentioned == func.name) {
|
||||
let e = syn::Error::new(
|
||||
mentioned.span(),
|
||||
format!("unrecognized function name `{mentioned}`"),
|
||||
);
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
|
||||
if !input.skip.is_empty() && input.only.is_some() {
|
||||
let e = syn::Error::new(
|
||||
input.only_span.unwrap(),
|
||||
format!("only one of `skip` or `only` may be specified"),
|
||||
);
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
// Construct a list of what we intend to expand
|
||||
let mut fn_list = Vec::new();
|
||||
for func in ALL_FUNCTIONS_FLAT.iter() {
|
||||
let fn_name = func.name;
|
||||
// If we have an `only` list and it does _not_ contain this function name, skip it
|
||||
if input.only.as_ref().is_some_and(|only| !only.iter().any(|o| o == fn_name)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// If there is a `skip` list that contains this function name, skip it
|
||||
if input.skip.iter().any(|s| s == fn_name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Run everything else
|
||||
fn_list.push(func);
|
||||
}
|
||||
|
||||
if let Some(map) = &input.fn_extra {
|
||||
if !map.keys().any(|key| key == "_") {
|
||||
// No default provided; make sure every expected function is covered
|
||||
let mut fns_not_covered = Vec::new();
|
||||
for func in &fn_list {
|
||||
if !map.keys().any(|key| key == func.name) {
|
||||
// `name` was not mentioned in the `match` statement
|
||||
fns_not_covered.push(func);
|
||||
}
|
||||
}
|
||||
|
||||
if !fns_not_covered.is_empty() {
|
||||
let e = syn::Error::new(
|
||||
input.fn_extra_span.unwrap(),
|
||||
format!(
|
||||
"`fn_extra`: no default `_` pattern specified and the following \
|
||||
patterns are not covered: {fns_not_covered:#?}"
|
||||
),
|
||||
);
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(fn_list)
|
||||
}
|
||||
|
||||
/// Expand our structured macro input into invocations of the callback macro.
|
||||
fn expand(input: StructuredInput, fn_list: &[&FunctionInfo]) -> syn::Result<pm2::TokenStream> {
|
||||
let mut out = pm2::TokenStream::new();
|
||||
let default_ident = Ident::new("_", Span::call_site());
|
||||
let callback = input.callback;
|
||||
|
||||
for func in fn_list {
|
||||
let fn_name = Ident::new(func.name, Span::call_site());
|
||||
|
||||
// Prepare attributes in an `attrs: ...` field
|
||||
let meta_field = match &input.attributes {
|
||||
Some(attrs) => {
|
||||
let meta = attrs
|
||||
.iter()
|
||||
.filter(|map| map.names.contains(&fn_name))
|
||||
.flat_map(|map| &map.meta);
|
||||
quote! { attrs: [ #( #meta )* ] }
|
||||
}
|
||||
None => pm2::TokenStream::new(),
|
||||
};
|
||||
|
||||
// Prepare extra in an `extra: ...` field, running the replacer
|
||||
let extra_field = match input.extra.clone() {
|
||||
Some(mut extra) => {
|
||||
let mut v = MacroReplace::new(func.name);
|
||||
v.visit_expr_mut(&mut extra);
|
||||
v.finish()?;
|
||||
|
||||
quote! { extra: #extra, }
|
||||
}
|
||||
None => pm2::TokenStream::new(),
|
||||
};
|
||||
|
||||
// Prepare function-specific extra in a `fn_extra: ...` field, running the replacer
|
||||
let fn_extra_field = match input.fn_extra {
|
||||
Some(ref map) => {
|
||||
let mut fn_extra =
|
||||
map.get(&fn_name).or_else(|| map.get(&default_ident)).unwrap().clone();
|
||||
|
||||
let mut v = MacroReplace::new(func.name);
|
||||
v.visit_expr_mut(&mut fn_extra);
|
||||
v.finish()?;
|
||||
|
||||
quote! { fn_extra: #fn_extra, }
|
||||
}
|
||||
None => pm2::TokenStream::new(),
|
||||
};
|
||||
|
||||
let c_args = &func.c_sig.args;
|
||||
let c_ret = &func.c_sig.returns;
|
||||
let rust_args = &func.rust_sig.args;
|
||||
let rust_ret = &func.rust_sig.returns;
|
||||
|
||||
let new = quote! {
|
||||
#callback! {
|
||||
fn_name: #fn_name,
|
||||
CFn: fn( #(#c_args),* ,) -> ( #(#c_ret),* ),
|
||||
CArgs: ( #(#c_args),* ,),
|
||||
CRet: ( #(#c_ret),* ),
|
||||
RustFn: fn( #(#rust_args),* ,) -> ( #(#rust_ret),* ),
|
||||
RustArgs: ( #(#rust_args),* ,),
|
||||
RustRet: ( #(#rust_ret),* ),
|
||||
#meta_field
|
||||
#extra_field
|
||||
#fn_extra_field
|
||||
}
|
||||
};
|
||||
|
||||
out.extend(new);
|
||||
}
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
/// Visitor to replace "magic" identifiers that we allow: `MACRO_FN_NAME` and
|
||||
/// `MACRO_FN_NAME_NORMALIZED`.
|
||||
struct MacroReplace {
|
||||
fn_name: &'static str,
|
||||
/// Remove the trailing `f` or `f128` to make
|
||||
norm_name: String,
|
||||
error: Option<syn::Error>,
|
||||
}
|
||||
|
||||
impl MacroReplace {
|
||||
fn new(name: &'static str) -> Self {
|
||||
// Keep this in sync with `libm_test::canonical_name`
|
||||
let known_mappings = &[
|
||||
("erff", "erf"),
|
||||
("erf", "erf"),
|
||||
("lgammaf_r", "lgamma_r"),
|
||||
("modff", "modf"),
|
||||
("modf", "modf"),
|
||||
];
|
||||
|
||||
let norm_name = match known_mappings.iter().find(|known| known.0 == name) {
|
||||
Some(found) => found.1,
|
||||
None => name
|
||||
.strip_suffix("f")
|
||||
.or_else(|| name.strip_suffix("f16"))
|
||||
.or_else(|| name.strip_suffix("f128"))
|
||||
.unwrap_or(name),
|
||||
};
|
||||
|
||||
Self { fn_name: name, norm_name: norm_name.to_owned(), error: None }
|
||||
}
|
||||
|
||||
fn finish(self) -> syn::Result<()> {
|
||||
match self.error {
|
||||
Some(e) => Err(e),
|
||||
None => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_ident_inner(&mut self, i: &mut Ident) {
|
||||
let s = i.to_string();
|
||||
if !s.starts_with("MACRO") || self.error.is_some() {
|
||||
return;
|
||||
}
|
||||
|
||||
match s.as_str() {
|
||||
"MACRO_FN_NAME" => *i = Ident::new(self.fn_name, i.span()),
|
||||
"MACRO_FN_NAME_NORMALIZED" => *i = Ident::new(&self.norm_name, i.span()),
|
||||
_ => {
|
||||
self.error =
|
||||
Some(syn::Error::new(i.span(), format!("unrecognized meta expression `{s}`")));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VisitMut for MacroReplace {
|
||||
fn visit_ident_mut(&mut self, i: &mut Ident) {
|
||||
self.visit_ident_inner(i);
|
||||
syn::visit_mut::visit_ident_mut(self, i);
|
||||
}
|
||||
}
|
||||
236
library/compiler-builtins/libm/crates/libm-macros/src/parse.rs
Normal file
236
library/compiler-builtins/libm/crates/libm-macros/src/parse.rs
Normal file
@@ -0,0 +1,236 @@
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use proc_macro2::Span;
|
||||
use quote::ToTokens;
|
||||
use syn::parse::{Parse, ParseStream, Parser};
|
||||
use syn::punctuated::Punctuated;
|
||||
use syn::spanned::Spanned;
|
||||
use syn::token::Comma;
|
||||
use syn::{Arm, Attribute, Expr, ExprMatch, Ident, Meta, Token, bracketed};
|
||||
|
||||
/// The input to our macro; just a list of `field: value` items.
|
||||
#[derive(Debug)]
|
||||
pub struct Invocation {
|
||||
fields: Punctuated<Mapping, Comma>,
|
||||
}
|
||||
|
||||
impl Parse for Invocation {
|
||||
fn parse(input: ParseStream) -> syn::Result<Self> {
|
||||
Ok(Self { fields: input.parse_terminated(Mapping::parse, Token![,])? })
|
||||
}
|
||||
}
|
||||
|
||||
/// A `key: expression` mapping with nothing else. Basically a simplified `syn::Field`.
|
||||
#[derive(Debug)]
|
||||
struct Mapping {
|
||||
name: Ident,
|
||||
_sep: Token![:],
|
||||
expr: Expr,
|
||||
}
|
||||
|
||||
impl Parse for Mapping {
|
||||
fn parse(input: ParseStream) -> syn::Result<Self> {
|
||||
Ok(Self { name: input.parse()?, _sep: input.parse()?, expr: input.parse()? })
|
||||
}
|
||||
}
|
||||
|
||||
/// The input provided to our proc macro, after parsing into the form we expect.
|
||||
#[derive(Debug)]
|
||||
pub struct StructuredInput {
|
||||
/// Macro to invoke once per function
|
||||
pub callback: Ident,
|
||||
/// Skip these functions
|
||||
pub skip: Vec<Ident>,
|
||||
/// Invoke only for these functions
|
||||
pub only: Option<Vec<Ident>>,
|
||||
/// Attributes that get applied to specific functions
|
||||
pub attributes: Option<Vec<AttributeMap>>,
|
||||
/// Extra expressions to pass to all invocations of the macro
|
||||
pub extra: Option<Expr>,
|
||||
/// Per-function extra expressions to pass to the macro
|
||||
pub fn_extra: Option<BTreeMap<Ident, Expr>>,
|
||||
// For diagnostics
|
||||
pub only_span: Option<Span>,
|
||||
pub fn_extra_span: Option<Span>,
|
||||
}
|
||||
|
||||
impl StructuredInput {
|
||||
pub fn from_fields(input: Invocation) -> syn::Result<Self> {
|
||||
let mut map: Vec<_> = input.fields.into_iter().collect();
|
||||
let cb_expr = expect_field(&mut map, "callback")?;
|
||||
let skip_expr = expect_field(&mut map, "skip").ok();
|
||||
let only_expr = expect_field(&mut map, "only").ok();
|
||||
let attr_expr = expect_field(&mut map, "attributes").ok();
|
||||
let extra = expect_field(&mut map, "extra").ok();
|
||||
let fn_extra = expect_field(&mut map, "fn_extra").ok();
|
||||
|
||||
if !map.is_empty() {
|
||||
Err(syn::Error::new(
|
||||
map.first().unwrap().name.span(),
|
||||
format!("unexpected fields {map:?}"),
|
||||
))?;
|
||||
}
|
||||
|
||||
let skip = match skip_expr {
|
||||
Some(expr) => Parser::parse2(parse_ident_array, expr.into_token_stream())?,
|
||||
None => Vec::new(),
|
||||
};
|
||||
|
||||
let only_span = only_expr.as_ref().map(|expr| expr.span());
|
||||
let only = match only_expr {
|
||||
Some(expr) => Some(Parser::parse2(parse_ident_array, expr.into_token_stream())?),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let attributes = match attr_expr {
|
||||
Some(expr) => {
|
||||
let mut attributes = Vec::new();
|
||||
let attr_exprs = Parser::parse2(parse_expr_array, expr.into_token_stream())?;
|
||||
|
||||
for attr in attr_exprs {
|
||||
attributes.push(syn::parse2(attr.into_token_stream())?);
|
||||
}
|
||||
Some(attributes)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
let fn_extra_span = fn_extra.as_ref().map(|expr| expr.span());
|
||||
let fn_extra = match fn_extra {
|
||||
Some(expr) => Some(extract_fn_extra_field(expr)?),
|
||||
None => None,
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
callback: expect_ident(cb_expr)?,
|
||||
skip,
|
||||
only,
|
||||
only_span,
|
||||
attributes,
|
||||
extra,
|
||||
fn_extra,
|
||||
fn_extra_span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_fn_extra_field(expr: Expr) -> syn::Result<BTreeMap<Ident, Expr>> {
|
||||
let Expr::Match(mexpr) = expr else {
|
||||
let e = syn::Error::new(expr.span(), "`fn_extra` expects a match expression");
|
||||
return Err(e);
|
||||
};
|
||||
|
||||
let ExprMatch { attrs, match_token: _, expr, brace_token: _, arms } = mexpr;
|
||||
|
||||
expect_empty_attrs(&attrs)?;
|
||||
|
||||
let match_on = expect_ident(*expr)?;
|
||||
if match_on != "MACRO_FN_NAME" {
|
||||
let e = syn::Error::new(match_on.span(), "only allowed to match on `MACRO_FN_NAME`");
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
let mut res = BTreeMap::new();
|
||||
|
||||
for arm in arms {
|
||||
let Arm { attrs, pat, guard, fat_arrow_token: _, body, comma: _ } = arm;
|
||||
|
||||
expect_empty_attrs(&attrs)?;
|
||||
|
||||
let keys = match pat {
|
||||
syn::Pat::Wild(w) => vec![Ident::new("_", w.span())],
|
||||
_ => Parser::parse2(parse_ident_pat, pat.into_token_stream())?,
|
||||
};
|
||||
|
||||
if let Some(guard) = guard {
|
||||
let e = syn::Error::new(guard.0.span(), "no guards allowed in this position");
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
for key in keys {
|
||||
let inserted = res.insert(key.clone(), *body.clone());
|
||||
if inserted.is_some() {
|
||||
let e = syn::Error::new(key.span(), format!("key `{key}` specified twice"));
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
fn expect_empty_attrs(attrs: &[Attribute]) -> syn::Result<()> {
|
||||
if attrs.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let e =
|
||||
syn::Error::new(attrs.first().unwrap().span(), "no attributes allowed in this position");
|
||||
Err(e)
|
||||
}
|
||||
|
||||
/// Extract a named field from a map, raising an error if it doesn't exist.
|
||||
fn expect_field(v: &mut Vec<Mapping>, name: &str) -> syn::Result<Expr> {
|
||||
let pos = v.iter().position(|v| v.name == name).ok_or_else(|| {
|
||||
syn::Error::new(Span::call_site(), format!("missing expected field `{name}`"))
|
||||
})?;
|
||||
|
||||
Ok(v.remove(pos).expr)
|
||||
}
|
||||
|
||||
/// Coerce an expression into a simple identifier.
|
||||
fn expect_ident(expr: Expr) -> syn::Result<Ident> {
|
||||
syn::parse2(expr.into_token_stream())
|
||||
}
|
||||
|
||||
/// Parse an array of expressions.
|
||||
fn parse_expr_array(input: ParseStream) -> syn::Result<Vec<Expr>> {
|
||||
let content;
|
||||
let _ = bracketed!(content in input);
|
||||
let fields = content.parse_terminated(Expr::parse, Token![,])?;
|
||||
Ok(fields.into_iter().collect())
|
||||
}
|
||||
|
||||
/// Parse an array of idents, e.g. `[foo, bar, baz]`.
|
||||
fn parse_ident_array(input: ParseStream) -> syn::Result<Vec<Ident>> {
|
||||
let content;
|
||||
let _ = bracketed!(content in input);
|
||||
let fields = content.parse_terminated(Ident::parse, Token![,])?;
|
||||
Ok(fields.into_iter().collect())
|
||||
}
|
||||
|
||||
/// Parse an pattern of idents, specifically `(foo | bar | baz)`.
|
||||
fn parse_ident_pat(input: ParseStream) -> syn::Result<Vec<Ident>> {
|
||||
if !input.peek2(Token![|]) {
|
||||
return Ok(vec![input.parse()?]);
|
||||
}
|
||||
|
||||
let fields = Punctuated::<Ident, Token![|]>::parse_separated_nonempty(input)?;
|
||||
Ok(fields.into_iter().collect())
|
||||
}
|
||||
|
||||
/// A mapping of attributes to identifiers (just a simplified `Expr`).
|
||||
///
|
||||
/// Expressed as:
|
||||
///
|
||||
/// ```ignore
|
||||
/// #[meta1]
|
||||
/// #[meta2]
|
||||
/// [foo, bar, baz]
|
||||
/// ```
|
||||
#[derive(Debug)]
|
||||
pub struct AttributeMap {
|
||||
pub meta: Vec<Meta>,
|
||||
pub names: Vec<Ident>,
|
||||
}
|
||||
|
||||
impl Parse for AttributeMap {
|
||||
fn parse(input: ParseStream) -> syn::Result<Self> {
|
||||
let attrs = input.call(Attribute::parse_outer)?;
|
||||
|
||||
Ok(Self {
|
||||
meta: attrs.into_iter().map(|a| a.meta).collect(),
|
||||
names: parse_ident_array(input)?,
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user