fix LooseTypes flag and PrintMod behaviour, add debug helper

This commit is contained in:
Manuel Drehwald
2025-04-12 00:50:41 -04:00
parent e643f59f6d
commit 75f86e6e2e
6 changed files with 68 additions and 21 deletions

View File

@@ -584,12 +584,10 @@ fn thin_lto(
}
}
fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<ModuleLlvm>) {
fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
for &val in ad {
// We intentionally don't use a wildcard, to not forget handling anything new.
match val {
config::AutoDiff::PrintModBefore => {
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
}
config::AutoDiff::PrintPerf => {
llvm::set_print_perf(true);
}
@@ -603,17 +601,23 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<
llvm::set_inline(true);
}
config::AutoDiff::LooseTypes => {
llvm::set_loose_types(false);
llvm::set_loose_types(true);
}
config::AutoDiff::PrintSteps => {
llvm::set_print(true);
}
// We handle this below
// We handle this in the PassWrapper.cpp
config::AutoDiff::PrintPasses => {}
// We handle this in the PassWrapper.cpp
config::AutoDiff::PrintModBefore => {}
// We handle this in the PassWrapper.cpp
config::AutoDiff::PrintModAfter => {}
// We handle this below
// We handle this in the PassWrapper.cpp
config::AutoDiff::PrintModFinal => {}
// This is required and already checked
config::AutoDiff::Enable => {}
// We handle this below
config::AutoDiff::NoPostopt => {}
}
}
// This helps with handling enums for now.
@@ -647,27 +651,27 @@ pub(crate) fn run_pass_manager(
// We then run the llvm_optimize function a second time, to optimize the code which we generated
// in the enzyme differentiation pass.
let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable);
let stage =
if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD };
let stage = if thin {
write::AutodiffStage::PreAD
} else {
if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD }
};
if enable_ad {
enable_autodiff_settings(&config.autodiff, module);
enable_autodiff_settings(&config.autodiff);
}
unsafe {
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
}
if cfg!(llvm_enzyme) && enable_ad {
// This is the post-autodiff IR, mainly used for testing and educational purposes.
if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
}
if cfg!(llvm_enzyme) && enable_ad && !thin {
let opt_stage = llvm::OptStage::FatLTO;
let stage = write::AutodiffStage::PostAD;
unsafe {
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
if !config.autodiff.contains(&config::AutoDiff::NoPostopt) {
unsafe {
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
}
}
// This is the final IR, so people should be able to inspect the optimized autodiff output,

View File

@@ -565,6 +565,9 @@ pub(crate) unsafe fn llvm_optimize(
let consider_ad = cfg!(llvm_enzyme) && config.autodiff.contains(&config::AutoDiff::Enable);
let run_enzyme = autodiff_stage == AutodiffStage::DuringAD;
let print_before_enzyme = config.autodiff.contains(&config::AutoDiff::PrintModBefore);
let print_after_enzyme = config.autodiff.contains(&config::AutoDiff::PrintModAfter);
let print_passes = config.autodiff.contains(&config::AutoDiff::PrintPasses);
let unroll_loops;
let vectorize_slp;
let vectorize_loop;
@@ -663,6 +666,9 @@ pub(crate) unsafe fn llvm_optimize(
config.no_builtins,
config.emit_lifetime_markers,
run_enzyme,
print_before_enzyme,
print_after_enzyme,
print_passes,
sanitizer_options.as_ref(),
pgo_gen_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()),
pgo_use_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()),

View File

@@ -2454,6 +2454,9 @@ unsafe extern "C" {
DisableSimplifyLibCalls: bool,
EmitLifetimeMarkers: bool,
RunEnzyme: bool,
PrintBeforeEnzyme: bool,
PrintAfterEnzyme: bool,
PrintPasses: bool,
SanitizerOptions: Option<&SanitizerOptions>,
PGOGenPath: *const c_char,
PGOUsePath: *const c_char,

View File

@@ -14,6 +14,7 @@
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Verifier.h"
#include "llvm/IRPrinter/IRPrintingPasses.h"
#include "llvm/LTO/LTO.h"
#include "llvm/MC/MCSubtargetInfo.h"
#include "llvm/MC/TargetRegistry.h"
@@ -703,7 +704,8 @@ extern "C" LLVMRustResult LLVMRustOptimize(
bool LintIR, LLVMRustThinLTOBuffer **ThinLTOBufferRef, bool EmitThinLTO,
bool EmitThinLTOSummary, bool MergeFunctions, bool UnrollLoops,
bool SLPVectorize, bool LoopVectorize, bool DisableSimplifyLibCalls,
bool EmitLifetimeMarkers, bool RunEnzyme,
bool EmitLifetimeMarkers, bool RunEnzyme, bool PrintBeforeEnzyme,
bool PrintAfterEnzyme, bool PrintPasses,
LLVMRustSanitizerOptions *SanitizerOptions, const char *PGOGenPath,
const char *PGOUsePath, bool InstrumentCoverage,
const char *InstrProfileOutput, const char *PGOSampleUsePath,
@@ -1048,14 +1050,38 @@ extern "C" LLVMRustResult LLVMRustOptimize(
// now load "-enzyme" pass:
#ifdef ENZYME
if (RunEnzyme) {
registerEnzymeAndPassPipeline(PB, true);
if (PrintBeforeEnzyme) {
// Handle the Rust flag `-Zautodiff=PrintModBefore`.
std::string Banner = "Module before EnzymeNewPM";
MPM.addPass(PrintModulePass(outs(), Banner, true, false));
}
registerEnzymeAndPassPipeline(PB, false);
if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) {
std::string ErrMsg = toString(std::move(Err));
LLVMRustSetLastError(ErrMsg.c_str());
return LLVMRustResult::Failure;
}
if (PrintAfterEnzyme) {
// Handle the Rust flag `-Zautodiff=PrintModAfter`.
std::string Banner = "Module after EnzymeNewPM";
MPM.addPass(PrintModulePass(outs(), Banner, true, false));
}
}
#endif
if (PrintPasses) {
// Print all passes from the PM:
std::string Pipeline;
raw_string_ostream SOS(Pipeline);
MPM.printPipeline(SOS, [&PIC](StringRef ClassName) {
auto PassName = PIC.getPassNameForClassName(ClassName);
return PassName.empty() ? ClassName : PassName;
});
outs() << Pipeline;
outs() << "\n";
}
// Upgrade all calls to old intrinsics first.
for (Module::iterator I = TheModule->begin(), E = TheModule->end(); I != E;)

View File

@@ -244,6 +244,10 @@ pub enum AutoDiff {
/// Print the module after running autodiff and optimizations.
PrintModFinal,
/// Print all passes scheduled by LLVM
PrintPasses,
/// Disable extra opt run after running autodiff
NoPostopt,
/// Enzyme's loose type debug helper (can cause incorrect gradients!!)
/// Usable in cases where Enzyme errors with `can not deduce type of X`.
LooseTypes,

View File

@@ -711,7 +711,7 @@ mod desc {
pub(crate) const parse_list: &str = "a space-separated list of strings";
pub(crate) const parse_list_with_polarity: &str =
"a comma-separated list of strings, with elements beginning with + or -";
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `LooseTypes`, `Inline`";
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`";
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
pub(crate) const parse_number: &str = "a number";
@@ -1360,6 +1360,8 @@ pub mod parse {
"PrintModBefore" => AutoDiff::PrintModBefore,
"PrintModAfter" => AutoDiff::PrintModAfter,
"PrintModFinal" => AutoDiff::PrintModFinal,
"NoPostopt" => AutoDiff::NoPostopt,
"PrintPasses" => AutoDiff::PrintPasses,
"LooseTypes" => AutoDiff::LooseTypes,
"Inline" => AutoDiff::Inline,
_ => {
@@ -2095,6 +2097,8 @@ options! {
`=PrintModBefore`
`=PrintModAfter`
`=PrintModFinal`
`=PrintPasses`,
`=NoPostopt`
`=LooseTypes`
`=Inline`
Multiple options can be combined with commas."),