Renumber locals after state transform.

This commit is contained in:
Camille Gillot
2025-10-09 14:12:26 +00:00
parent 1ad657ed5f
commit 6d800ae35b
3 changed files with 75 additions and 65 deletions

View File

@@ -68,7 +68,7 @@ use rustc_hir::lang_items::LangItem;
use rustc_hir::{CoroutineDesugaring, CoroutineKind};
use rustc_index::bit_set::{BitMatrix, DenseBitSet, GrowableBitSet};
use rustc_index::{Idx, IndexVec};
use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor};
use rustc_middle::mir::*;
use rustc_middle::ty::util::Discr;
use rustc_middle::ty::{
@@ -110,6 +110,8 @@ impl<'tcx> MutVisitor<'tcx> for RenameLocalVisitor<'tcx> {
fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
if *local == self.from {
*local = self.to;
} else if *local == self.to {
*local = self.from;
}
}
@@ -206,8 +208,8 @@ struct TransformVisitor<'tcx> {
// The set of locals that have no `StorageLive`/`StorageDead` annotations.
always_live_locals: DenseBitSet<Local>,
// The original RETURN_PLACE local
old_ret_local: Local,
// New local we just create to hold the `CoroutineState` value.
new_ret_local: Local,
old_yield_ty: Ty<'tcx>,
@@ -344,9 +346,10 @@ impl<'tcx> TransformVisitor<'tcx> {
}
};
// Assign to `new_ret_local`, which will be replaced by `RETURN_PLACE` later.
statements.push(Statement::new(
source_info,
StatementKind::Assign(Box::new((Place::return_place(), rvalue))),
StatementKind::Assign(Box::new((self.new_ret_local.into(), rvalue))),
));
}
@@ -388,6 +391,20 @@ impl<'tcx> TransformVisitor<'tcx> {
);
(assign, temp)
}
/// Swaps all references of `old_local` and `new_local`.
#[tracing::instrument(level = "trace", skip(self, body))]
fn replace_local(&mut self, old_local: Local, new_local: Local, body: &mut Body<'tcx>) {
body.local_decls.swap(old_local, new_local);
let mut visitor = RenameLocalVisitor { from: old_local, to: new_local, tcx: self.tcx };
visitor.visit_body(body);
for suspension in &mut self.suspension_points {
let ctxt = PlaceContext::MutatingUse(MutatingUseContext::Yield);
let location = Location { block: START_BLOCK, statement_index: 0 };
visitor.visit_place(&mut suspension.resume_arg, ctxt, location);
}
}
}
impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
@@ -419,6 +436,16 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
self.super_statement(stmt, location);
}
#[tracing::instrument(level = "trace", skip(self, term), ret)]
fn visit_terminator(&mut self, term: &mut Terminator<'tcx>, location: Location) {
if let TerminatorKind::Return = term.kind {
// `visit_basic_block_data` introduces `Return` terminators which read `RETURN_PLACE`.
// But this `RETURN_PLACE` is already remapped, so we should not touch it again.
return;
}
self.super_terminator(term, location);
}
#[tracing::instrument(level = "trace", skip(self, data), ret)]
fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) {
match data.terminator().kind {
@@ -426,7 +453,7 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
let source_info = data.terminator().source_info;
// We must assign the value first in case it gets declared dead below
self.make_state(
Operand::Move(Place::from(self.old_ret_local)),
Operand::Move(Place::return_place()),
source_info,
true,
&mut data.statements,
@@ -521,27 +548,6 @@ fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body
.visit_body(body);
}
/// Allocates a new local and replaces all references of `local` with it. Returns the new local.
///
/// `local` will be changed to a new local decl with type `ty`.
///
/// Note that the new local will be uninitialized. It is the caller's responsibility to assign some
/// valid value to it before its first use.
fn replace_local<'tcx>(
local: Local,
ty: Ty<'tcx>,
body: &mut Body<'tcx>,
tcx: TyCtxt<'tcx>,
) -> Local {
let new_decl = LocalDecl::new(ty, body.span);
let new_local = body.local_decls.push(new_decl);
body.local_decls.swap(local, new_local);
RenameLocalVisitor { from: local, to: new_local, tcx }.visit_body(body);
new_local
}
/// Transforms the `body` of the coroutine applying the following transforms:
///
/// - Eliminates all the `get_context` calls that async lowering created.
@@ -1511,10 +1517,6 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
}
};
// We rename RETURN_PLACE which has type mir.return_ty to old_ret_local
// RETURN_PLACE then is a fresh unused local with type ret_ty.
let old_ret_local = replace_local(RETURN_PLACE, new_ret_ty, body, tcx);
// We need to insert clean drop for unresumed state and perform drop elaboration
// (finally in open_drop_for_tuple) before async drop expansion.
// Async drops, produced by this drop elaboration, will be expanded,
@@ -1561,6 +1563,11 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
let can_return = can_return(tcx, body, body.typing_env(tcx));
// We rename RETURN_PLACE which has type mir.return_ty to new_ret_local
// RETURN_PLACE then is a fresh unused local with type ret_ty.
let new_ret_local = body.local_decls.push(LocalDecl::new(new_ret_ty, body.span));
tracing::trace!(?new_ret_local);
// Run the transformation which converts Places from Local to coroutine struct
// accesses for locals in `remap`.
// It also rewrites `return x` and `yield y` as writing a new coroutine state and returning
@@ -1573,13 +1580,16 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
storage_liveness,
always_live_locals,
suspension_points: Vec::new(),
old_ret_local,
discr_ty,
new_ret_local,
old_ret_ty,
old_yield_ty,
};
transform.visit_body(body);
// Swap the actual `RETURN_PLACE` and the provisional `new_ret_local`.
transform.replace_local(RETURN_PLACE, new_ret_local, body);
// MIR parameters are not explicitly assigned-to when entering the MIR body.
// If we want to save their values inside the coroutine state, we need to do so explicitly.
let source_info = SourceInfo::outermost(body.span);