Auto merge of #118420 - compiler-errors:async-gen, r=eholk

Introduce support for `async gen` blocks

I'm delighted to demonstrate that `async gen` block are not very difficult to support. They're simply coroutines that yield `Poll<Option<T>>` and return `()`.

**This PR is WIP and in draft mode for now** -- I'm mostly putting it up to show folks that it's possible. This PR needs a lang-team experiment associated with it or possible an RFC, since I don't think it falls under the jurisdiction of the `gen` RFC that was recently authored by oli (https://github.com/rust-lang/rfcs/pull/3513, https://github.com/rust-lang/rust/issues/117078).

### Technical note on the pre-generator-transform yield type:

The reason that the underlying coroutines yield `Poll<Option<T>>` and not `Poll<T>` (which would make more sense, IMO, for the pre-transformed coroutine), is because the `TransformVisitor` that is used to turn coroutines into built-in state machine functions would have to destructure and reconstruct the latter into the former, which requires at least inserting a new basic block (for a `switchInt` terminator, to match on the `Poll` discriminant).

This does mean that the desugaring (at the `rustc_ast_lowering` level) of `async gen` blocks is a bit more involved. However, since we already need to intercept both `.await` and `yield` operators, I don't consider it much of a technical burden.

r? `@ghost`
This commit is contained in:
bors
2023-12-08 19:13:57 +00:00
61 changed files with 1120 additions and 357 deletions

View File

@@ -14,6 +14,7 @@ use rustc_ast::*;
use rustc_data_structures::stack::ensure_sufficient_stack;
use rustc_hir as hir;
use rustc_hir::def::{DefKind, Res};
use rustc_middle::span_bug;
use rustc_session::errors::report_lit_error;
use rustc_span::source_map::{respan, Spanned};
use rustc_span::symbol::{kw, sym, Ident, Symbol};
@@ -196,22 +197,19 @@ impl<'hir> LoweringContext<'_, 'hir> {
binder,
capture_clause,
constness,
coro_kind,
coroutine_kind,
movability,
fn_decl,
body,
fn_decl_span,
fn_arg_span,
}) => match coro_kind {
Some(
CoroutineKind::Async { closure_id, .. }
| CoroutineKind::Gen { closure_id, .. },
) => self.lower_expr_async_closure(
}) => match coroutine_kind {
Some(coroutine_kind) => self.lower_expr_coroutine_closure(
binder,
*capture_clause,
e.id,
hir_id,
*closure_id,
*coroutine_kind,
fn_decl,
body,
*fn_decl_span,
@@ -325,6 +323,15 @@ impl<'hir> LoweringContext<'_, 'hir> {
hir::CoroutineSource::Block,
|this| this.with_new_scopes(e.span, |this| this.lower_block_expr(block)),
),
ExprKind::Gen(capture_clause, block, GenBlockKind::AsyncGen) => self
.make_async_gen_expr(
*capture_clause,
e.id,
None,
e.span,
hir::CoroutineSource::Block,
|this| this.with_new_scopes(e.span, |this| this.lower_block_expr(block)),
),
ExprKind::Yield(opt_expr) => self.lower_expr_yield(e.span, opt_expr.as_deref()),
ExprKind::Err => hir::ExprKind::Err(
self.tcx.sess.span_delayed_bug(e.span, "lowered ExprKind::Err"),
@@ -736,6 +743,87 @@ impl<'hir> LoweringContext<'_, 'hir> {
}))
}
/// Lower a `async gen` construct to a generator that implements `AsyncIterator`.
///
/// This results in:
///
/// ```text
/// static move? |_task_context| -> () {
/// <body>
/// }
/// ```
pub(super) fn make_async_gen_expr(
&mut self,
capture_clause: CaptureBy,
closure_node_id: NodeId,
_yield_ty: Option<hir::FnRetTy<'hir>>,
span: Span,
async_coroutine_source: hir::CoroutineSource,
body: impl FnOnce(&mut Self) -> hir::Expr<'hir>,
) -> hir::ExprKind<'hir> {
let output = hir::FnRetTy::DefaultReturn(self.lower_span(span));
// Resume argument type: `ResumeTy`
let unstable_span = self.mark_span_with_reason(
DesugaringKind::Async,
span,
Some(self.allow_gen_future.clone()),
);
let resume_ty = hir::QPath::LangItem(hir::LangItem::ResumeTy, unstable_span);
let input_ty = hir::Ty {
hir_id: self.next_id(),
kind: hir::TyKind::Path(resume_ty),
span: unstable_span,
};
// The closure/coroutine `FnDecl` takes a single (resume) argument of type `input_ty`.
let fn_decl = self.arena.alloc(hir::FnDecl {
inputs: arena_vec![self; input_ty],
output,
c_variadic: false,
implicit_self: hir::ImplicitSelfKind::None,
lifetime_elision_allowed: false,
});
// Lower the argument pattern/ident. The ident is used again in the `.await` lowering.
let (pat, task_context_hid) = self.pat_ident_binding_mode(
span,
Ident::with_dummy_span(sym::_task_context),
hir::BindingAnnotation::MUT,
);
let param = hir::Param {
hir_id: self.next_id(),
pat,
ty_span: self.lower_span(span),
span: self.lower_span(span),
};
let params = arena_vec![self; param];
let body = self.lower_body(move |this| {
this.coroutine_kind = Some(hir::CoroutineKind::AsyncGen(async_coroutine_source));
let old_ctx = this.task_context;
this.task_context = Some(task_context_hid);
let res = body(this);
this.task_context = old_ctx;
(params, res)
});
// `static |_task_context| -> <ret_ty> { body }`:
hir::ExprKind::Closure(self.arena.alloc(hir::Closure {
def_id: self.local_def_id(closure_node_id),
binder: hir::ClosureBinder::Default,
capture_clause,
bound_generic_params: &[],
fn_decl,
body,
fn_decl_span: self.lower_span(span),
fn_arg_span: None,
movability: Some(hir::Movability::Static),
constness: hir::Constness::NotConst,
}))
}
/// Forwards a possible `#[track_caller]` annotation from `outer_hir_id` to
/// `inner_hir_id` in case the `async_fn_track_caller` feature is enabled.
pub(super) fn maybe_forward_track_caller(
@@ -785,15 +873,18 @@ impl<'hir> LoweringContext<'_, 'hir> {
/// ```
fn lower_expr_await(&mut self, await_kw_span: Span, expr: &Expr) -> hir::ExprKind<'hir> {
let full_span = expr.span.to(await_kw_span);
match self.coroutine_kind {
Some(hir::CoroutineKind::Async(_)) => {}
let is_async_gen = match self.coroutine_kind {
Some(hir::CoroutineKind::Async(_)) => false,
Some(hir::CoroutineKind::AsyncGen(_)) => true,
Some(hir::CoroutineKind::Coroutine) | Some(hir::CoroutineKind::Gen(_)) | None => {
return hir::ExprKind::Err(self.tcx.sess.emit_err(AwaitOnlyInAsyncFnAndBlocks {
await_kw_span,
item_span: self.current_item,
}));
}
}
};
let span = self.mark_span_with_reason(DesugaringKind::Await, await_kw_span, None);
let gen_future_span = self.mark_span_with_reason(
DesugaringKind::Await,
@@ -882,12 +973,19 @@ impl<'hir> LoweringContext<'_, 'hir> {
self.stmt_expr(span, match_expr)
};
// task_context = yield ();
// Depending on `async` of `async gen`:
// async - task_context = yield ();
// async gen - task_context = yield ASYNC_GEN_PENDING;
let yield_stmt = {
let unit = self.expr_unit(span);
let yielded = if is_async_gen {
self.arena.alloc(self.expr_lang_item_path(span, hir::LangItem::AsyncGenPending))
} else {
self.expr_unit(span)
};
let yield_expr = self.expr(
span,
hir::ExprKind::Yield(unit, hir::YieldSource::Await { expr: Some(expr_hir_id) }),
hir::ExprKind::Yield(yielded, hir::YieldSource::Await { expr: Some(expr_hir_id) }),
);
let yield_expr = self.arena.alloc(yield_expr);
@@ -997,7 +1095,11 @@ impl<'hir> LoweringContext<'_, 'hir> {
}
Some(movability)
}
Some(hir::CoroutineKind::Gen(_)) | Some(hir::CoroutineKind::Async(_)) => {
Some(
hir::CoroutineKind::Gen(_)
| hir::CoroutineKind::Async(_)
| hir::CoroutineKind::AsyncGen(_),
) => {
panic!("non-`async`/`gen` closure body turned `async`/`gen` during lowering");
}
None => {
@@ -1024,18 +1126,22 @@ impl<'hir> LoweringContext<'_, 'hir> {
(binder, params)
}
fn lower_expr_async_closure(
fn lower_expr_coroutine_closure(
&mut self,
binder: &ClosureBinder,
capture_clause: CaptureBy,
closure_id: NodeId,
closure_hir_id: hir::HirId,
inner_closure_id: NodeId,
coroutine_kind: CoroutineKind,
decl: &FnDecl,
body: &Expr,
fn_decl_span: Span,
fn_arg_span: Span,
) -> hir::ExprKind<'hir> {
let CoroutineKind::Async { closure_id: inner_closure_id, .. } = coroutine_kind else {
span_bug!(fn_decl_span, "`async gen` and `gen` closures are not supported, yet");
};
if let &ClosureBinder::For { span, .. } = binder {
self.tcx.sess.emit_err(NotSupportedForLifetimeBinderAsyncClosure { span });
}
@@ -1504,8 +1610,9 @@ impl<'hir> LoweringContext<'_, 'hir> {
}
fn lower_expr_yield(&mut self, span: Span, opt_expr: Option<&Expr>) -> hir::ExprKind<'hir> {
match self.coroutine_kind {
Some(hir::CoroutineKind::Gen(_)) => {}
let is_async_gen = match self.coroutine_kind {
Some(hir::CoroutineKind::Gen(_)) => false,
Some(hir::CoroutineKind::AsyncGen(_)) => true,
Some(hir::CoroutineKind::Async(_)) => {
return hir::ExprKind::Err(
self.tcx.sess.emit_err(AsyncCoroutinesNotSupported { span }),
@@ -1521,14 +1628,24 @@ impl<'hir> LoweringContext<'_, 'hir> {
)
.emit();
}
self.coroutine_kind = Some(hir::CoroutineKind::Coroutine)
self.coroutine_kind = Some(hir::CoroutineKind::Coroutine);
false
}
}
};
let expr =
let mut yielded =
opt_expr.as_ref().map(|x| self.lower_expr(x)).unwrap_or_else(|| self.expr_unit(span));
hir::ExprKind::Yield(expr, hir::YieldSource::Yield)
if is_async_gen {
// yield async_gen_ready($expr);
yielded = self.expr_call_lang_item_fn(
span,
hir::LangItem::AsyncGenReady,
std::slice::from_ref(yielded),
);
}
hir::ExprKind::Yield(yielded, hir::YieldSource::Yield)
}
/// Desugar `ExprForLoop` from: `[opt_ident]: for <pat> in <head> <body>` into: