Rollup merge of #142809 - KMJ-007:ad-type-analysis-flag, r=ZuseZ4

Add PrintTAFn flag for targeted type analysis printing

## Summary
This PR adds a new `PrintTAFn` flag to the `-Z autodiff` option that allows printing type analysis information for a specific function, rather than all functions.

## Changes

### New Flag
- Added `PrintTAFn=<function_name>` option to `-Z autodiff`
- Usage: `-Z autodiff=Enable,PrintTAFn=my_function_name`

### Implementation Details
- **Rust side**: Added `PrintTAFn(String)` variant to `AutoDiff` enum
- **Parser**: Updated `parse_autodiff` to handle `PrintTAFn=<function_name>` syntax with proper error handling
- **FFI**: Added `set_print_type_fun` function to interface with Enzyme's `FunctionToAnalyze` command line option
- **Documentation**: Updated help text and documentation for the new flag

### Files Modified
- `compiler/rustc_session/src/config.rs`: Added `PrintTAFn(String)` variant
- `compiler/rustc_session/src/options.rs`: Updated parser and help text (now shows `PrintTAFn` in the list)
- `compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs`: Added FFI function and static variable
- `compiler/rustc_codegen_llvm/src/back/lto.rs`: Added handling for new flag
- `src/doc/rustc-dev-guide/src/autodiff/flags.md`: Updated documentation
- `src/doc/unstable-book/src/compiler-flags/autodiff.md`: Updated documentation

## Testing
The flag can be tested with:
```bash
rustc +enzyme -Z autodiff=Enable,PrintTAFn=square test.rs
```

This will print type analysis information only for the function named "square" instead of all functions.

## Error Handling
The parser includes proper error handling:
- Missing argument: `PrintTAFn` without `=<function_name>` will show an error
- Unknown options: Invalid autodiff options will be reported

r? ```@ZuseZ4```
This commit is contained in:
Jana Dönszelmann
2025-06-25 22:14:55 +02:00
committed by GitHub
7 changed files with 43 additions and 5 deletions

View File

@@ -57,14 +57,19 @@ pub(crate) use self::Enzyme_AD::*;
#[cfg(llvm_enzyme)]
pub(crate) mod Enzyme_AD {
use std::ffi::{CString, c_char};
use libc::c_void;
unsafe extern "C" {
pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8);
pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char);
}
unsafe extern "C" {
static mut EnzymePrintPerf: c_void;
static mut EnzymePrintActivity: c_void;
static mut EnzymePrintType: c_void;
static mut EnzymeFunctionToAnalyze: c_void;
static mut EnzymePrint: c_void;
static mut EnzymeStrictAliasing: c_void;
static mut looseTypeAnalysis: c_void;
@@ -86,6 +91,15 @@ pub(crate) mod Enzyme_AD {
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8);
}
}
pub(crate) fn set_print_type_fun(fun_name: &str) {
let c_fun_name = CString::new(fun_name).unwrap();
unsafe {
EnzymeSetCLString(
std::ptr::addr_of_mut!(EnzymeFunctionToAnalyze),
c_fun_name.as_ptr() as *const c_char,
);
}
}
pub(crate) fn set_print(print: bool) {
unsafe {
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8);
@@ -132,6 +146,9 @@ pub(crate) mod Fallback_AD {
pub(crate) fn set_print_type(print: bool) {
unimplemented!()
}
pub(crate) fn set_print_type_fun(fun_name: &str) {
unimplemented!()
}
pub(crate) fn set_print(print: bool) {
unimplemented!()
}