move second opt run to lto phase and cleanup code

This commit is contained in:
Manuel Drehwald
2025-02-10 01:35:22 -05:00
parent 21d096184e
commit 1221cff551
7 changed files with 75 additions and 54 deletions

View File

@@ -4,10 +4,9 @@ use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivit
use rustc_codegen_ssa::ModuleCodegen;
use rustc_codegen_ssa::back::write::ModuleConfig;
use rustc_errors::FatalError;
use rustc_session::config::Lto;
use tracing::{debug, trace};
use crate::back::write::{llvm_err, llvm_optimize};
use crate::back::write::llvm_err;
use crate::builder::SBuilder;
use crate::context::SimpleCx;
use crate::declare::declare_simple_fn;
@@ -153,7 +152,7 @@ fn generate_enzyme_call<'ll>(
_ => {}
}
trace!("matching autodiff arguments");
debug!("matching autodiff arguments");
// We now handle the issue that Rust level arguments not always match the llvm-ir level
// arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
// llvm-ir level. The number of activities matches the number of Rust level arguments, so we
@@ -222,7 +221,10 @@ fn generate_enzyme_call<'ll>(
// A duplicated pointer will have the following two outer_fn arguments:
// (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
// (..., metadata! enzyme_dup, ptr, ptr, ...).
if matches!(diff_activity, DiffActivity::Duplicated | DiffActivity::DuplicatedOnly) {
if matches!(
diff_activity,
DiffActivity::Duplicated | DiffActivity::DuplicatedOnly
) {
assert!(
llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer
);
@@ -282,7 +284,7 @@ pub(crate) fn differentiate<'ll>(
module: &'ll ModuleCodegen<ModuleLlvm>,
cgcx: &CodegenContext<LlvmCodegenBackend>,
diff_items: Vec<AutoDiffItem>,
config: &ModuleConfig,
_config: &ModuleConfig,
) -> Result<(), FatalError> {
for item in &diff_items {
trace!("{}", item);
@@ -317,29 +319,6 @@ pub(crate) fn differentiate<'ll>(
// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
if let Some(opt_level) = config.opt_level {
let opt_stage = match cgcx.lto {
Lto::Fat => llvm::OptStage::PreLinkFatLTO,
Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO,
_ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO,
_ => llvm::OptStage::PreLinkNoLTO,
};
// This is our second opt call, so now we run all opts,
// to make sure we get the best performance.
let skip_size_increasing_opts = false;
trace!("running Module Optimization after differentiation");
unsafe {
llvm_optimize(
cgcx,
diag_handler.handle(),
module,
config,
opt_level,
opt_stage,
skip_size_increasing_opts,
)?
};
}
trace!("done with differentiate()");
Ok(())