upstream rustc_codegen_ssa/rustc_middle changes for enzyme/autodiff
This commit is contained in:
@@ -7,6 +7,7 @@ use std::sync::mpsc::{Receiver, Sender, channel};
|
||||
use std::{fs, io, mem, str, thread};
|
||||
|
||||
use rustc_ast::attr;
|
||||
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
|
||||
use rustc_data_structures::fx::{FxHashMap, FxIndexMap};
|
||||
use rustc_data_structures::jobserver::{self, Acquired};
|
||||
use rustc_data_structures::memmap::Mmap;
|
||||
@@ -40,7 +41,7 @@ use tracing::debug;
|
||||
use super::link::{self, ensure_removed};
|
||||
use super::lto::{self, SerializedModule};
|
||||
use super::symbol_export::symbol_name_for_instance_in_crate;
|
||||
use crate::errors::ErrorCreatingRemarkDir;
|
||||
use crate::errors::{AutodiffWithoutLto, ErrorCreatingRemarkDir};
|
||||
use crate::traits::*;
|
||||
use crate::{
|
||||
CachedModuleCodegen, CodegenResults, CompiledModule, CrateInfo, ModuleCodegen, ModuleKind,
|
||||
@@ -118,6 +119,7 @@ pub struct ModuleConfig {
|
||||
pub merge_functions: bool,
|
||||
pub emit_lifetime_markers: bool,
|
||||
pub llvm_plugins: Vec<String>,
|
||||
pub autodiff: Vec<config::AutoDiff>,
|
||||
}
|
||||
|
||||
impl ModuleConfig {
|
||||
@@ -266,6 +268,7 @@ impl ModuleConfig {
|
||||
|
||||
emit_lifetime_markers: sess.emit_lifetime_markers(),
|
||||
llvm_plugins: if_regular!(sess.opts.unstable_opts.llvm_plugins.clone(), vec![]),
|
||||
autodiff: if_regular!(sess.opts.unstable_opts.autodiff.clone(), vec![]),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -389,6 +392,7 @@ impl<B: WriteBackendMethods> CodegenContext<B> {
|
||||
|
||||
fn generate_lto_work<B: ExtraBackendMethods>(
|
||||
cgcx: &CodegenContext<B>,
|
||||
autodiff: Vec<AutoDiffItem>,
|
||||
needs_fat_lto: Vec<FatLtoInput<B>>,
|
||||
needs_thin_lto: Vec<(String, B::ThinBuffer)>,
|
||||
import_only_modules: Vec<(SerializedModule<B::ModuleBuffer>, WorkProduct)>,
|
||||
@@ -397,11 +401,19 @@ fn generate_lto_work<B: ExtraBackendMethods>(
|
||||
|
||||
if !needs_fat_lto.is_empty() {
|
||||
assert!(needs_thin_lto.is_empty());
|
||||
let module =
|
||||
let mut module =
|
||||
B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise());
|
||||
if cgcx.lto == Lto::Fat {
|
||||
let config = cgcx.config(ModuleKind::Regular);
|
||||
module = unsafe { module.autodiff(cgcx, autodiff, config).unwrap() };
|
||||
}
|
||||
// We are adding a single work item, so the cost doesn't matter.
|
||||
vec![(WorkItem::LTO(module), 0)]
|
||||
} else {
|
||||
if !autodiff.is_empty() {
|
||||
let dcx = cgcx.create_dcx();
|
||||
dcx.handle().emit_fatal(AutodiffWithoutLto {});
|
||||
}
|
||||
assert!(needs_fat_lto.is_empty());
|
||||
let (lto_modules, copy_jobs) = B::run_thin_lto(cgcx, needs_thin_lto, import_only_modules)
|
||||
.unwrap_or_else(|e| e.raise());
|
||||
@@ -1021,6 +1033,9 @@ pub(crate) enum Message<B: WriteBackendMethods> {
|
||||
/// Sent from a backend worker thread.
|
||||
WorkItem { result: Result<WorkItemResult<B>, Option<WorkerFatalError>>, worker_id: usize },
|
||||
|
||||
/// A vector containing all the AutoDiff tasks that we have to pass to Enzyme.
|
||||
AddAutoDiffItems(Vec<AutoDiffItem>),
|
||||
|
||||
/// The frontend has finished generating something (backend IR or a
|
||||
/// post-LTO artifact) for a codegen unit, and it should be passed to the
|
||||
/// backend. Sent from the main thread.
|
||||
@@ -1348,6 +1363,7 @@ fn start_executing_work<B: ExtraBackendMethods>(
|
||||
|
||||
// This is where we collect codegen units that have gone all the way
|
||||
// through codegen and LLVM.
|
||||
let mut autodiff_items = Vec::new();
|
||||
let mut compiled_modules = vec![];
|
||||
let mut compiled_allocator_module = None;
|
||||
let mut needs_link = Vec::new();
|
||||
@@ -1459,9 +1475,13 @@ fn start_executing_work<B: ExtraBackendMethods>(
|
||||
let needs_thin_lto = mem::take(&mut needs_thin_lto);
|
||||
let import_only_modules = mem::take(&mut lto_import_only_modules);
|
||||
|
||||
for (work, cost) in
|
||||
generate_lto_work(&cgcx, needs_fat_lto, needs_thin_lto, import_only_modules)
|
||||
{
|
||||
for (work, cost) in generate_lto_work(
|
||||
&cgcx,
|
||||
autodiff_items.clone(),
|
||||
needs_fat_lto,
|
||||
needs_thin_lto,
|
||||
import_only_modules,
|
||||
) {
|
||||
let insertion_index = work_items
|
||||
.binary_search_by_key(&cost, |&(_, cost)| cost)
|
||||
.unwrap_or_else(|e| e);
|
||||
@@ -1596,6 +1616,10 @@ fn start_executing_work<B: ExtraBackendMethods>(
|
||||
main_thread_state = MainThreadState::Idle;
|
||||
}
|
||||
|
||||
Message::AddAutoDiffItems(mut items) => {
|
||||
autodiff_items.append(&mut items);
|
||||
}
|
||||
|
||||
Message::CodegenComplete => {
|
||||
if codegen_state != Aborted {
|
||||
codegen_state = Completed;
|
||||
@@ -2070,6 +2094,10 @@ impl<B: ExtraBackendMethods> OngoingCodegen<B> {
|
||||
drop(self.coordinator.sender.send(Box::new(Message::CodegenComplete::<B>)));
|
||||
}
|
||||
|
||||
pub(crate) fn submit_autodiff_items(&self, items: Vec<AutoDiffItem>) {
|
||||
drop(self.coordinator.sender.send(Box::new(Message::<B>::AddAutoDiffItems(items))));
|
||||
}
|
||||
|
||||
pub(crate) fn check_for_errors(&self, sess: &Session) {
|
||||
self.shared_emitter_main.check(sess, false);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user