Move the macro's input function list to a new module shared

This will enable us to `include!` the file to access these types in
`libm-test`, rather than somehow reproducing the types as part of the
macro. Ideally `libm-test` would just `use` the types from `libm-macros`
but proc macro crates cannot currently export anything else.

This also adjusts naming to closer match the scheme described in
`libm_test::op`.
This commit is contained in:
Trevor Gross
2024-12-22 11:22:02 +00:00
parent 49aa452fd9
commit 1069346b6d
3 changed files with 320 additions and 263 deletions

View File

@@ -5,7 +5,7 @@ use quote::quote;
use syn::spanned::Spanned;
use syn::{Fields, ItemEnum, Variant};
use crate::{ALL_FUNCTIONS_FLAT, base_name};
use crate::{ALL_OPERATIONS, base_name};
/// Implement `#[function_enum]`, see documentation in `lib.rs`.
pub fn function_enum(
@@ -33,7 +33,7 @@ pub fn function_enum(
let mut as_str_arms = Vec::new();
let mut base_arms = Vec::new();
for func in ALL_FUNCTIONS_FLAT.iter() {
for func in ALL_OPERATIONS.iter() {
let fn_name = func.name;
let ident = Ident::new(&fn_name.to_upper_camel_case(), Span::call_site());
let bname_ident = Ident::new(&base_name(fn_name).to_upper_camel_case(), Span::call_site());
@@ -85,8 +85,7 @@ pub fn base_name_enum(
return Err(syn::Error::new(sp.span(), "no attributes expected"));
}
let mut base_names: Vec<_> =
ALL_FUNCTIONS_FLAT.iter().map(|func| base_name(func.name)).collect();
let mut base_names: Vec<_> = ALL_OPERATIONS.iter().map(|func| base_name(func.name)).collect();
base_names.sort_unstable();
base_names.dedup();

View File

@@ -1,270 +1,18 @@
mod enums;
mod parse;
use std::sync::LazyLock;
mod shared;
use parse::{Invocation, StructuredInput};
use proc_macro as pm;
use proc_macro2::{self as pm2, Span};
use quote::{ToTokens, quote};
pub(crate) use shared::{ALL_OPERATIONS, FloatTy, MathOpInfo, Ty};
use syn::spanned::Spanned;
use syn::visit_mut::VisitMut;
use syn::{Ident, ItemEnum};
const ALL_FUNCTIONS: &[(Ty, Signature, Option<Signature>, &[&str])] = &[
(
// `fn(f32) -> f32`
Ty::F32,
Signature { args: &[Ty::F32], returns: &[Ty::F32] },
None,
&[
"acosf", "acoshf", "asinf", "asinhf", "atanf", "atanhf", "cbrtf", "ceilf", "cosf",
"coshf", "erff", "exp10f", "exp2f", "expf", "expm1f", "fabsf", "floorf", "j0f", "j1f",
"lgammaf", "log10f", "log1pf", "log2f", "logf", "rintf", "roundf", "sinf", "sinhf",
"sqrtf", "tanf", "tanhf", "tgammaf", "truncf",
],
),
(
// `(f64) -> f64`
Ty::F64,
Signature { args: &[Ty::F64], returns: &[Ty::F64] },
None,
&[
"acos", "acosh", "asin", "asinh", "atan", "atanh", "cbrt", "ceil", "cos", "cosh",
"erf", "exp10", "exp2", "exp", "expm1", "fabs", "floor", "j0", "j1", "lgamma", "log10",
"log1p", "log2", "log", "rint", "round", "sin", "sinh", "sqrt", "tan", "tanh",
"tgamma", "trunc",
],
),
(
// `(f32, f32) -> f32`
Ty::F32,
Signature { args: &[Ty::F32, Ty::F32], returns: &[Ty::F32] },
None,
&[
"atan2f",
"copysignf",
"fdimf",
"fmaxf",
"fminf",
"fmodf",
"hypotf",
"nextafterf",
"powf",
"remainderf",
],
),
(
// `(f64, f64) -> f64`
Ty::F64,
Signature { args: &[Ty::F64, Ty::F64], returns: &[Ty::F64] },
None,
&[
"atan2",
"copysign",
"fdim",
"fmax",
"fmin",
"fmod",
"hypot",
"nextafter",
"pow",
"remainder",
],
),
(
// `(f32, f32, f32) -> f32`
Ty::F32,
Signature { args: &[Ty::F32, Ty::F32, Ty::F32], returns: &[Ty::F32] },
None,
&["fmaf"],
),
(
// `(f64, f64, f64) -> f64`
Ty::F64,
Signature { args: &[Ty::F64, Ty::F64, Ty::F64], returns: &[Ty::F64] },
None,
&["fma"],
),
(
// `(f32) -> i32`
Ty::F32,
Signature { args: &[Ty::F32], returns: &[Ty::I32] },
None,
&["ilogbf"],
),
(
// `(f64) -> i32`
Ty::F64,
Signature { args: &[Ty::F64], returns: &[Ty::I32] },
None,
&["ilogb"],
),
(
// `(i32, f32) -> f32`
Ty::F32,
Signature { args: &[Ty::I32, Ty::F32], returns: &[Ty::F32] },
None,
&["jnf"],
),
(
// `(i32, f64) -> f64`
Ty::F64,
Signature { args: &[Ty::I32, Ty::F64], returns: &[Ty::F64] },
None,
&["jn"],
),
(
// `(f32, i32) -> f32`
Ty::F32,
Signature { args: &[Ty::F32, Ty::I32], returns: &[Ty::F32] },
None,
&["scalbnf", "ldexpf"],
),
(
// `(f64, i64) -> f64`
Ty::F64,
Signature { args: &[Ty::F64, Ty::I32], returns: &[Ty::F64] },
None,
&["scalbn", "ldexp"],
),
(
// `(f32, &mut f32) -> f32` as `(f32) -> (f32, f32)`
Ty::F32,
Signature { args: &[Ty::F32], returns: &[Ty::F32, Ty::F32] },
Some(Signature { args: &[Ty::F32, Ty::MutF32], returns: &[Ty::F32] }),
&["modff"],
),
(
// `(f64, &mut f64) -> f64` as `(f64) -> (f64, f64)`
Ty::F64,
Signature { args: &[Ty::F64], returns: &[Ty::F64, Ty::F64] },
Some(Signature { args: &[Ty::F64, Ty::MutF64], returns: &[Ty::F64] }),
&["modf"],
),
(
// `(f32, &mut c_int) -> f32` as `(f32) -> (f32, i32)`
Ty::F32,
Signature { args: &[Ty::F32], returns: &[Ty::F32, Ty::I32] },
Some(Signature { args: &[Ty::F32, Ty::MutCInt], returns: &[Ty::F32] }),
&["frexpf", "lgammaf_r"],
),
(
// `(f64, &mut c_int) -> f64` as `(f64) -> (f64, i32)`
Ty::F64,
Signature { args: &[Ty::F64], returns: &[Ty::F64, Ty::I32] },
Some(Signature { args: &[Ty::F64, Ty::MutCInt], returns: &[Ty::F64] }),
&["frexp", "lgamma_r"],
),
(
// `(f32, f32, &mut c_int) -> f32` as `(f32, f32) -> (f32, i32)`
Ty::F32,
Signature { args: &[Ty::F32, Ty::F32], returns: &[Ty::F32, Ty::I32] },
Some(Signature { args: &[Ty::F32, Ty::F32, Ty::MutCInt], returns: &[Ty::F32] }),
&["remquof"],
),
(
// `(f64, f64, &mut c_int) -> f64` as `(f64, f64) -> (f64, i32)`
Ty::F64,
Signature { args: &[Ty::F64, Ty::F64], returns: &[Ty::F64, Ty::I32] },
Some(Signature { args: &[Ty::F64, Ty::F64, Ty::MutCInt], returns: &[Ty::F64] }),
&["remquo"],
),
(
// `(f32, &mut f32, &mut f32)` as `(f32) -> (f32, f32)`
Ty::F32,
Signature { args: &[Ty::F32], returns: &[Ty::F32, Ty::F32] },
Some(Signature { args: &[Ty::F32, Ty::MutF32, Ty::MutF32], returns: &[] }),
&["sincosf"],
),
(
// `(f64, &mut f64, &mut f64)` as `(f64) -> (f64, f64)`
Ty::F64,
Signature { args: &[Ty::F64], returns: &[Ty::F64, Ty::F64] },
Some(Signature { args: &[Ty::F64, Ty::MutF64, Ty::MutF64], returns: &[] }),
&["sincos"],
),
];
const KNOWN_TYPES: &[&str] = &["FTy", "CFn", "CArgs", "CRet", "RustFn", "RustArgs", "RustRet"];
/// A type used in a function signature.
#[allow(dead_code)]
#[derive(Debug, Clone, Copy)]
enum Ty {
F16,
F32,
F64,
F128,
I32,
CInt,
MutF16,
MutF32,
MutF64,
MutF128,
MutI32,
MutCInt,
}
impl ToTokens for Ty {
fn to_tokens(&self, tokens: &mut pm2::TokenStream) {
let ts = match self {
Ty::F16 => quote! { f16 },
Ty::F32 => quote! { f32 },
Ty::F64 => quote! { f64 },
Ty::F128 => quote! { f128 },
Ty::I32 => quote! { i32 },
Ty::CInt => quote! { ::core::ffi::c_int },
Ty::MutF16 => quote! { &'a mut f16 },
Ty::MutF32 => quote! { &'a mut f32 },
Ty::MutF64 => quote! { &'a mut f64 },
Ty::MutF128 => quote! { &'a mut f128 },
Ty::MutI32 => quote! { &'a mut i32 },
Ty::MutCInt => quote! { &'a mut core::ffi::c_int },
};
tokens.extend(ts);
}
}
/// Representation of e.g. `(f32, f32) -> f32`
#[derive(Debug, Clone)]
struct Signature {
args: &'static [Ty],
returns: &'static [Ty],
}
/// Combined information about a function implementation.
#[derive(Debug, Clone)]
struct FunctionInfo {
name: &'static str,
base_fty: Ty,
/// Function signature for C implementations
c_sig: Signature,
/// Function signature for Rust implementations
rust_sig: Signature,
}
/// A flat representation of `ALL_FUNCTIONS`.
static ALL_FUNCTIONS_FLAT: LazyLock<Vec<FunctionInfo>> = LazyLock::new(|| {
let mut ret = Vec::new();
for (base_fty, rust_sig, c_sig, names) in ALL_FUNCTIONS {
for name in *names {
let api = FunctionInfo {
name,
base_fty: *base_fty,
rust_sig: rust_sig.clone(),
c_sig: c_sig.clone().unwrap_or_else(|| rust_sig.clone()),
};
ret.push(api);
}
}
ret.sort_by_key(|item| item.name);
ret
});
/// Populate an enum with a variant representing function. Names are in upper camel case.
///
/// Applied to an empty enum. Expects one attribute `#[function_enum(BaseName)]` that provides
@@ -382,7 +130,7 @@ pub fn for_each_function(tokens: pm::TokenStream) -> pm::TokenStream {
/// Check for any input that is structurally correct but has other problems.
///
/// Returns the list of function names that we should expand for.
fn validate(input: &mut StructuredInput) -> syn::Result<Vec<&'static FunctionInfo>> {
fn validate(input: &mut StructuredInput) -> syn::Result<Vec<&'static MathOpInfo>> {
// Collect lists of all functions that are provied as macro inputs in various fields (only,
// skip, attributes).
let attr_mentions = input
@@ -398,7 +146,7 @@ fn validate(input: &mut StructuredInput) -> syn::Result<Vec<&'static FunctionInf
// Make sure that every function mentioned is a real function
for mentioned in all_mentioned_fns {
if !ALL_FUNCTIONS_FLAT.iter().any(|func| mentioned == func.name) {
if !ALL_OPERATIONS.iter().any(|func| mentioned == func.name) {
let e = syn::Error::new(
mentioned.span(),
format!("unrecognized function name `{mentioned}`"),
@@ -417,7 +165,7 @@ fn validate(input: &mut StructuredInput) -> syn::Result<Vec<&'static FunctionInf
// Construct a list of what we intend to expand
let mut fn_list = Vec::new();
for func in ALL_FUNCTIONS_FLAT.iter() {
for func in ALL_OPERATIONS.iter() {
let fn_name = func.name;
// If we have an `only` list and it does _not_ contain this function name, skip it
if input.only.as_ref().is_some_and(|only| !only.iter().any(|o| o == fn_name)) {
@@ -498,7 +246,7 @@ fn validate(input: &mut StructuredInput) -> syn::Result<Vec<&'static FunctionInf
}
/// Expand our structured macro input into invocations of the callback macro.
fn expand(input: StructuredInput, fn_list: &[&FunctionInfo]) -> syn::Result<pm2::TokenStream> {
fn expand(input: StructuredInput, fn_list: &[&MathOpInfo]) -> syn::Result<pm2::TokenStream> {
let mut out = pm2::TokenStream::new();
let default_ident = Ident::new("_", Span::call_site());
let callback = input.callback;
@@ -545,7 +293,7 @@ fn expand(input: StructuredInput, fn_list: &[&FunctionInfo]) -> syn::Result<pm2:
None => pm2::TokenStream::new(),
};
let base_fty = func.base_fty;
let base_fty = func.float_ty;
let c_args = &func.c_sig.args;
let c_ret = &func.c_sig.returns;
let rust_args = &func.rust_sig.args;
@@ -648,3 +396,36 @@ fn base_name(name: &str) -> &str {
.unwrap_or(name),
}
}
impl ToTokens for Ty {
fn to_tokens(&self, tokens: &mut pm2::TokenStream) {
let ts = match self {
Ty::F16 => quote! { f16 },
Ty::F32 => quote! { f32 },
Ty::F64 => quote! { f64 },
Ty::F128 => quote! { f128 },
Ty::I32 => quote! { i32 },
Ty::CInt => quote! { ::core::ffi::c_int },
Ty::MutF16 => quote! { &'a mut f16 },
Ty::MutF32 => quote! { &'a mut f32 },
Ty::MutF64 => quote! { &'a mut f64 },
Ty::MutF128 => quote! { &'a mut f128 },
Ty::MutI32 => quote! { &'a mut i32 },
Ty::MutCInt => quote! { &'a mut core::ffi::c_int },
};
tokens.extend(ts);
}
}
impl ToTokens for FloatTy {
fn to_tokens(&self, tokens: &mut pm2::TokenStream) {
let ts = match self {
FloatTy::F16 => quote! { f16 },
FloatTy::F32 => quote! { f32 },
FloatTy::F64 => quote! { f64 },
FloatTy::F128 => quote! { f128 },
};
tokens.extend(ts);
}
}

View File

@@ -0,0 +1,277 @@
/* List of all functions that is shared between `libm-macros` and `libm-test`. */
use std::fmt;
use std::sync::LazyLock;
const ALL_OPERATIONS_NESTED: &[(FloatTy, Signature, Option<Signature>, &[&str])] = &[
(
// `fn(f32) -> f32`
FloatTy::F32,
Signature { args: &[Ty::F32], returns: &[Ty::F32] },
None,
&[
"acosf", "acoshf", "asinf", "asinhf", "atanf", "atanhf", "cbrtf", "ceilf", "cosf",
"coshf", "erff", "exp10f", "exp2f", "expf", "expm1f", "fabsf", "floorf", "j0f", "j1f",
"lgammaf", "log10f", "log1pf", "log2f", "logf", "rintf", "roundf", "sinf", "sinhf",
"sqrtf", "tanf", "tanhf", "tgammaf", "truncf",
],
),
(
// `(f64) -> f64`
FloatTy::F64,
Signature { args: &[Ty::F64], returns: &[Ty::F64] },
None,
&[
"acos", "acosh", "asin", "asinh", "atan", "atanh", "cbrt", "ceil", "cos", "cosh",
"erf", "exp10", "exp2", "exp", "expm1", "fabs", "floor", "j0", "j1", "lgamma", "log10",
"log1p", "log2", "log", "rint", "round", "sin", "sinh", "sqrt", "tan", "tanh",
"tgamma", "trunc",
],
),
(
// `(f32, f32) -> f32`
FloatTy::F32,
Signature { args: &[Ty::F32, Ty::F32], returns: &[Ty::F32] },
None,
&[
"atan2f",
"copysignf",
"fdimf",
"fmaxf",
"fminf",
"fmodf",
"hypotf",
"nextafterf",
"powf",
"remainderf",
],
),
(
// `(f64, f64) -> f64`
FloatTy::F64,
Signature { args: &[Ty::F64, Ty::F64], returns: &[Ty::F64] },
None,
&[
"atan2",
"copysign",
"fdim",
"fmax",
"fmin",
"fmod",
"hypot",
"nextafter",
"pow",
"remainder",
],
),
(
// `(f32, f32, f32) -> f32`
FloatTy::F32,
Signature { args: &[Ty::F32, Ty::F32, Ty::F32], returns: &[Ty::F32] },
None,
&["fmaf"],
),
(
// `(f64, f64, f64) -> f64`
FloatTy::F64,
Signature { args: &[Ty::F64, Ty::F64, Ty::F64], returns: &[Ty::F64] },
None,
&["fma"],
),
(
// `(f32) -> i32`
FloatTy::F32,
Signature { args: &[Ty::F32], returns: &[Ty::I32] },
None,
&["ilogbf"],
),
(
// `(f64) -> i32`
FloatTy::F64,
Signature { args: &[Ty::F64], returns: &[Ty::I32] },
None,
&["ilogb"],
),
(
// `(i32, f32) -> f32`
FloatTy::F32,
Signature { args: &[Ty::I32, Ty::F32], returns: &[Ty::F32] },
None,
&["jnf"],
),
(
// `(i32, f64) -> f64`
FloatTy::F64,
Signature { args: &[Ty::I32, Ty::F64], returns: &[Ty::F64] },
None,
&["jn"],
),
(
// `(f32, i32) -> f32`
FloatTy::F32,
Signature { args: &[Ty::F32, Ty::I32], returns: &[Ty::F32] },
None,
&["scalbnf", "ldexpf"],
),
(
// `(f64, i64) -> f64`
FloatTy::F64,
Signature { args: &[Ty::F64, Ty::I32], returns: &[Ty::F64] },
None,
&["scalbn", "ldexp"],
),
(
// `(f32, &mut f32) -> f32` as `(f32) -> (f32, f32)`
FloatTy::F32,
Signature { args: &[Ty::F32], returns: &[Ty::F32, Ty::F32] },
Some(Signature { args: &[Ty::F32, Ty::MutF32], returns: &[Ty::F32] }),
&["modff"],
),
(
// `(f64, &mut f64) -> f64` as `(f64) -> (f64, f64)`
FloatTy::F64,
Signature { args: &[Ty::F64], returns: &[Ty::F64, Ty::F64] },
Some(Signature { args: &[Ty::F64, Ty::MutF64], returns: &[Ty::F64] }),
&["modf"],
),
(
// `(f32, &mut c_int) -> f32` as `(f32) -> (f32, i32)`
FloatTy::F32,
Signature { args: &[Ty::F32], returns: &[Ty::F32, Ty::I32] },
Some(Signature { args: &[Ty::F32, Ty::MutCInt], returns: &[Ty::F32] }),
&["frexpf", "lgammaf_r"],
),
(
// `(f64, &mut c_int) -> f64` as `(f64) -> (f64, i32)`
FloatTy::F64,
Signature { args: &[Ty::F64], returns: &[Ty::F64, Ty::I32] },
Some(Signature { args: &[Ty::F64, Ty::MutCInt], returns: &[Ty::F64] }),
&["frexp", "lgamma_r"],
),
(
// `(f32, f32, &mut c_int) -> f32` as `(f32, f32) -> (f32, i32)`
FloatTy::F32,
Signature { args: &[Ty::F32, Ty::F32], returns: &[Ty::F32, Ty::I32] },
Some(Signature { args: &[Ty::F32, Ty::F32, Ty::MutCInt], returns: &[Ty::F32] }),
&["remquof"],
),
(
// `(f64, f64, &mut c_int) -> f64` as `(f64, f64) -> (f64, i32)`
FloatTy::F64,
Signature { args: &[Ty::F64, Ty::F64], returns: &[Ty::F64, Ty::I32] },
Some(Signature { args: &[Ty::F64, Ty::F64, Ty::MutCInt], returns: &[Ty::F64] }),
&["remquo"],
),
(
// `(f32, &mut f32, &mut f32)` as `(f32) -> (f32, f32)`
FloatTy::F32,
Signature { args: &[Ty::F32], returns: &[Ty::F32, Ty::F32] },
Some(Signature { args: &[Ty::F32, Ty::MutF32, Ty::MutF32], returns: &[] }),
&["sincosf"],
),
(
// `(f64, &mut f64, &mut f64)` as `(f64) -> (f64, f64)`
FloatTy::F64,
Signature { args: &[Ty::F64], returns: &[Ty::F64, Ty::F64] },
Some(Signature { args: &[Ty::F64, Ty::MutF64, Ty::MutF64], returns: &[] }),
&["sincos"],
),
];
/// A type used in a function signature.
#[allow(dead_code)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum Ty {
F16,
F32,
F64,
F128,
I32,
CInt,
MutF16,
MutF32,
MutF64,
MutF128,
MutI32,
MutCInt,
}
/// A subset of [`Ty`] representing only floats.
#[allow(dead_code)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum FloatTy {
F16,
F32,
F64,
F128,
}
impl fmt::Display for Ty {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
Ty::F16 => "f16",
Ty::F32 => "f32",
Ty::F64 => "f64",
Ty::F128 => "f128",
Ty::I32 => "i32",
Ty::CInt => "::core::ffi::c_int",
Ty::MutF16 => "&mut f16",
Ty::MutF32 => "&mut f32",
Ty::MutF64 => "&mut f64",
Ty::MutF128 => "&mut f128",
Ty::MutI32 => "&mut i32",
Ty::MutCInt => "&mut ::core::ffi::c_int",
};
f.write_str(s)
}
}
impl fmt::Display for FloatTy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
FloatTy::F16 => "f16",
FloatTy::F32 => "f32",
FloatTy::F64 => "f64",
FloatTy::F128 => "f128",
};
f.write_str(s)
}
}
/// Representation of e.g. `(f32, f32) -> f32`
#[derive(Debug, Clone)]
pub struct Signature {
pub args: &'static [Ty],
pub returns: &'static [Ty],
}
/// Combined information about a function implementation.
#[derive(Debug, Clone)]
pub struct MathOpInfo {
pub name: &'static str,
pub float_ty: FloatTy,
/// Function signature for C implementations
pub c_sig: Signature,
/// Function signature for Rust implementations
pub rust_sig: Signature,
}
/// A flat representation of `ALL_FUNCTIONS`.
pub static ALL_OPERATIONS: LazyLock<Vec<MathOpInfo>> = LazyLock::new(|| {
let mut ret = Vec::new();
for (base_fty, rust_sig, c_sig, names) in ALL_OPERATIONS_NESTED {
for name in *names {
let api = MathOpInfo {
name,
float_ty: *base_fty,
rust_sig: rust_sig.clone(),
c_sig: c_sig.clone().unwrap_or_else(|| rust_sig.clone()),
};
ret.push(api);
}
}
ret.sort_by_key(|item| item.name);
ret
});