add autodiff batching backend

This commit is contained in:
Manuel Drehwald
2025-04-04 14:24:23 -04:00
parent e0c8ead880
commit b7c63a973f
6 changed files with 196 additions and 44 deletions

View File

@@ -8,6 +8,7 @@ use std::str;
use rustc_abi::{HasDataLayout, Size, TargetDataLayout, VariantIdx};
use rustc_codegen_ssa::back::versioned_llvm_target;
use rustc_codegen_ssa::base::{wants_msvc_seh, wants_wasm_eh};
use rustc_codegen_ssa::common::TypeKind;
use rustc_codegen_ssa::errors as ssa_errors;
use rustc_codegen_ssa::traits::*;
use rustc_data_structures::base_n::{ALPHANUMERIC_ONLY, ToBaseN};
@@ -38,7 +39,7 @@ use crate::debuginfo::metadata::apply_vcall_visibility_metadata;
use crate::llvm::Metadata;
use crate::type_::Type;
use crate::value::Value;
use crate::{attributes, coverageinfo, debuginfo, llvm, llvm_util};
use crate::{attributes, common, coverageinfo, debuginfo, llvm, llvm_util};
/// `TyCtxt` (and related cache datastructures) can't be move between threads.
/// However, there are various cx related functions which we want to be available to the builder and
@@ -643,7 +644,18 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
llvm::set_section(g, c"llvm.metadata");
}
}
impl<'ll> SimpleCx<'ll> {
pub(crate) fn get_return_type(&self, ty: &'ll Type) -> &'ll Type {
assert_eq!(self.type_kind(ty), TypeKind::Function);
unsafe { llvm::LLVMGetReturnType(ty) }
}
pub(crate) fn get_type_of_global(&self, val: &'ll Value) -> &'ll Type {
unsafe { llvm::LLVMGlobalGetValueType(val) }
}
pub(crate) fn val_ty(&self, v: &'ll Value) -> &'ll Type {
common::val_ty(v)
}
}
impl<'ll> SimpleCx<'ll> {
pub(crate) fn new(
llmod: &'ll llvm::Module,
@@ -660,6 +672,13 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
llvm::LLVMMetadataAsValue(self.llcx(), metadata)
}
// FIXME(autodiff): We should split `ConstCodegenMethods` to pull the reusable parts
// onto a trait that is also implemented for GenericCx.
pub(crate) fn get_const_i64(&self, n: u64) -> &'ll Value {
let ty = unsafe { llvm::LLVMInt64TypeInContext(self.llcx()) };
unsafe { llvm::LLVMConstInt(ty, n, llvm::False) }
}
pub(crate) fn get_function(&self, name: &str) -> Option<&'ll Value> {
let name = SmallCStr::new(name);
unsafe { llvm::LLVMGetNamedFunction((**self).borrow().llmod, name.as_ptr()) }