fix LooseTypes flag and PrintMod behaviour, add debug helper
This commit is contained in:
@@ -584,12 +584,10 @@ fn thin_lto(
|
||||
}
|
||||
}
|
||||
|
||||
fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<ModuleLlvm>) {
|
||||
fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
|
||||
for &val in ad {
|
||||
// We intentionally don't use a wildcard, to not forget handling anything new.
|
||||
match val {
|
||||
config::AutoDiff::PrintModBefore => {
|
||||
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
|
||||
}
|
||||
config::AutoDiff::PrintPerf => {
|
||||
llvm::set_print_perf(true);
|
||||
}
|
||||
@@ -603,17 +601,23 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<
|
||||
llvm::set_inline(true);
|
||||
}
|
||||
config::AutoDiff::LooseTypes => {
|
||||
llvm::set_loose_types(false);
|
||||
llvm::set_loose_types(true);
|
||||
}
|
||||
config::AutoDiff::PrintSteps => {
|
||||
llvm::set_print(true);
|
||||
}
|
||||
// We handle this below
|
||||
// We handle this in the PassWrapper.cpp
|
||||
config::AutoDiff::PrintPasses => {}
|
||||
// We handle this in the PassWrapper.cpp
|
||||
config::AutoDiff::PrintModBefore => {}
|
||||
// We handle this in the PassWrapper.cpp
|
||||
config::AutoDiff::PrintModAfter => {}
|
||||
// We handle this below
|
||||
// We handle this in the PassWrapper.cpp
|
||||
config::AutoDiff::PrintModFinal => {}
|
||||
// This is required and already checked
|
||||
config::AutoDiff::Enable => {}
|
||||
// We handle this below
|
||||
config::AutoDiff::NoPostopt => {}
|
||||
}
|
||||
}
|
||||
// This helps with handling enums for now.
|
||||
@@ -647,27 +651,27 @@ pub(crate) fn run_pass_manager(
|
||||
// We then run the llvm_optimize function a second time, to optimize the code which we generated
|
||||
// in the enzyme differentiation pass.
|
||||
let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable);
|
||||
let stage =
|
||||
if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD };
|
||||
let stage = if thin {
|
||||
write::AutodiffStage::PreAD
|
||||
} else {
|
||||
if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD }
|
||||
};
|
||||
|
||||
if enable_ad {
|
||||
enable_autodiff_settings(&config.autodiff, module);
|
||||
enable_autodiff_settings(&config.autodiff);
|
||||
}
|
||||
|
||||
unsafe {
|
||||
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
|
||||
}
|
||||
|
||||
if cfg!(llvm_enzyme) && enable_ad {
|
||||
// This is the post-autodiff IR, mainly used for testing and educational purposes.
|
||||
if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
|
||||
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
|
||||
}
|
||||
|
||||
if cfg!(llvm_enzyme) && enable_ad && !thin {
|
||||
let opt_stage = llvm::OptStage::FatLTO;
|
||||
let stage = write::AutodiffStage::PostAD;
|
||||
unsafe {
|
||||
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
|
||||
if !config.autodiff.contains(&config::AutoDiff::NoPostopt) {
|
||||
unsafe {
|
||||
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
|
||||
}
|
||||
}
|
||||
|
||||
// This is the final IR, so people should be able to inspect the optimized autodiff output,
|
||||
|
||||
@@ -565,6 +565,9 @@ pub(crate) unsafe fn llvm_optimize(
|
||||
|
||||
let consider_ad = cfg!(llvm_enzyme) && config.autodiff.contains(&config::AutoDiff::Enable);
|
||||
let run_enzyme = autodiff_stage == AutodiffStage::DuringAD;
|
||||
let print_before_enzyme = config.autodiff.contains(&config::AutoDiff::PrintModBefore);
|
||||
let print_after_enzyme = config.autodiff.contains(&config::AutoDiff::PrintModAfter);
|
||||
let print_passes = config.autodiff.contains(&config::AutoDiff::PrintPasses);
|
||||
let unroll_loops;
|
||||
let vectorize_slp;
|
||||
let vectorize_loop;
|
||||
@@ -663,6 +666,9 @@ pub(crate) unsafe fn llvm_optimize(
|
||||
config.no_builtins,
|
||||
config.emit_lifetime_markers,
|
||||
run_enzyme,
|
||||
print_before_enzyme,
|
||||
print_after_enzyme,
|
||||
print_passes,
|
||||
sanitizer_options.as_ref(),
|
||||
pgo_gen_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()),
|
||||
pgo_use_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()),
|
||||
|
||||
@@ -2454,6 +2454,9 @@ unsafe extern "C" {
|
||||
DisableSimplifyLibCalls: bool,
|
||||
EmitLifetimeMarkers: bool,
|
||||
RunEnzyme: bool,
|
||||
PrintBeforeEnzyme: bool,
|
||||
PrintAfterEnzyme: bool,
|
||||
PrintPasses: bool,
|
||||
SanitizerOptions: Option<&SanitizerOptions>,
|
||||
PGOGenPath: *const c_char,
|
||||
PGOUsePath: *const c_char,
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/PassManager.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/IRPrinter/IRPrintingPasses.h"
|
||||
#include "llvm/LTO/LTO.h"
|
||||
#include "llvm/MC/MCSubtargetInfo.h"
|
||||
#include "llvm/MC/TargetRegistry.h"
|
||||
@@ -703,7 +704,8 @@ extern "C" LLVMRustResult LLVMRustOptimize(
|
||||
bool LintIR, LLVMRustThinLTOBuffer **ThinLTOBufferRef, bool EmitThinLTO,
|
||||
bool EmitThinLTOSummary, bool MergeFunctions, bool UnrollLoops,
|
||||
bool SLPVectorize, bool LoopVectorize, bool DisableSimplifyLibCalls,
|
||||
bool EmitLifetimeMarkers, bool RunEnzyme,
|
||||
bool EmitLifetimeMarkers, bool RunEnzyme, bool PrintBeforeEnzyme,
|
||||
bool PrintAfterEnzyme, bool PrintPasses,
|
||||
LLVMRustSanitizerOptions *SanitizerOptions, const char *PGOGenPath,
|
||||
const char *PGOUsePath, bool InstrumentCoverage,
|
||||
const char *InstrProfileOutput, const char *PGOSampleUsePath,
|
||||
@@ -1048,14 +1050,38 @@ extern "C" LLVMRustResult LLVMRustOptimize(
|
||||
// now load "-enzyme" pass:
|
||||
#ifdef ENZYME
|
||||
if (RunEnzyme) {
|
||||
registerEnzymeAndPassPipeline(PB, true);
|
||||
|
||||
if (PrintBeforeEnzyme) {
|
||||
// Handle the Rust flag `-Zautodiff=PrintModBefore`.
|
||||
std::string Banner = "Module before EnzymeNewPM";
|
||||
MPM.addPass(PrintModulePass(outs(), Banner, true, false));
|
||||
}
|
||||
|
||||
registerEnzymeAndPassPipeline(PB, false);
|
||||
if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) {
|
||||
std::string ErrMsg = toString(std::move(Err));
|
||||
LLVMRustSetLastError(ErrMsg.c_str());
|
||||
return LLVMRustResult::Failure;
|
||||
}
|
||||
|
||||
if (PrintAfterEnzyme) {
|
||||
// Handle the Rust flag `-Zautodiff=PrintModAfter`.
|
||||
std::string Banner = "Module after EnzymeNewPM";
|
||||
MPM.addPass(PrintModulePass(outs(), Banner, true, false));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if (PrintPasses) {
|
||||
// Print all passes from the PM:
|
||||
std::string Pipeline;
|
||||
raw_string_ostream SOS(Pipeline);
|
||||
MPM.printPipeline(SOS, [&PIC](StringRef ClassName) {
|
||||
auto PassName = PIC.getPassNameForClassName(ClassName);
|
||||
return PassName.empty() ? ClassName : PassName;
|
||||
});
|
||||
outs() << Pipeline;
|
||||
outs() << "\n";
|
||||
}
|
||||
|
||||
// Upgrade all calls to old intrinsics first.
|
||||
for (Module::iterator I = TheModule->begin(), E = TheModule->end(); I != E;)
|
||||
|
||||
@@ -244,6 +244,10 @@ pub enum AutoDiff {
|
||||
/// Print the module after running autodiff and optimizations.
|
||||
PrintModFinal,
|
||||
|
||||
/// Print all passes scheduled by LLVM
|
||||
PrintPasses,
|
||||
/// Disable extra opt run after running autodiff
|
||||
NoPostopt,
|
||||
/// Enzyme's loose type debug helper (can cause incorrect gradients!!)
|
||||
/// Usable in cases where Enzyme errors with `can not deduce type of X`.
|
||||
LooseTypes,
|
||||
|
||||
@@ -711,7 +711,7 @@ mod desc {
|
||||
pub(crate) const parse_list: &str = "a space-separated list of strings";
|
||||
pub(crate) const parse_list_with_polarity: &str =
|
||||
"a comma-separated list of strings, with elements beginning with + or -";
|
||||
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `LooseTypes`, `Inline`";
|
||||
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`";
|
||||
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
|
||||
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
|
||||
pub(crate) const parse_number: &str = "a number";
|
||||
@@ -1360,6 +1360,8 @@ pub mod parse {
|
||||
"PrintModBefore" => AutoDiff::PrintModBefore,
|
||||
"PrintModAfter" => AutoDiff::PrintModAfter,
|
||||
"PrintModFinal" => AutoDiff::PrintModFinal,
|
||||
"NoPostopt" => AutoDiff::NoPostopt,
|
||||
"PrintPasses" => AutoDiff::PrintPasses,
|
||||
"LooseTypes" => AutoDiff::LooseTypes,
|
||||
"Inline" => AutoDiff::Inline,
|
||||
_ => {
|
||||
@@ -2095,6 +2097,8 @@ options! {
|
||||
`=PrintModBefore`
|
||||
`=PrintModAfter`
|
||||
`=PrintModFinal`
|
||||
`=PrintPasses`,
|
||||
`=NoPostopt`
|
||||
`=LooseTypes`
|
||||
`=Inline`
|
||||
Multiple options can be combined with commas."),
|
||||
|
||||
Reference in New Issue
Block a user