Implement async gen blocks

This commit is contained in:
Michael Goulet
2023-11-28 18:18:19 +00:00
parent a0cbc168c9
commit 96bb542a31
32 changed files with 563 additions and 54 deletions

View File

@@ -324,6 +324,15 @@ impl<'hir> LoweringContext<'_, 'hir> {
hir::CoroutineSource::Block,
|this| this.with_new_scopes(e.span, |this| this.lower_block_expr(block)),
),
ExprKind::Gen(capture_clause, block, GenBlockKind::AsyncGen) => self
.make_async_gen_expr(
*capture_clause,
e.id,
None,
e.span,
hir::CoroutineSource::Block,
|this| this.with_new_scopes(e.span, |this| this.lower_block_expr(block)),
),
ExprKind::Yield(opt_expr) => self.lower_expr_yield(e.span, opt_expr.as_deref()),
ExprKind::Err => hir::ExprKind::Err(
self.tcx.sess.span_delayed_bug(e.span, "lowered ExprKind::Err"),
@@ -706,6 +715,87 @@ impl<'hir> LoweringContext<'_, 'hir> {
}))
}
/// Lower a `async gen` construct to a generator that implements `AsyncIterator`.
///
/// This results in:
///
/// ```text
/// static move? |_task_context| -> () {
/// <body>
/// }
/// ```
pub(super) fn make_async_gen_expr(
&mut self,
capture_clause: CaptureBy,
closure_node_id: NodeId,
_yield_ty: Option<hir::FnRetTy<'hir>>,
span: Span,
async_coroutine_source: hir::CoroutineSource,
body: impl FnOnce(&mut Self) -> hir::Expr<'hir>,
) -> hir::ExprKind<'hir> {
let output = hir::FnRetTy::DefaultReturn(self.lower_span(span));
// Resume argument type: `ResumeTy`
let unstable_span = self.mark_span_with_reason(
DesugaringKind::Async,
span,
Some(self.allow_gen_future.clone()),
);
let resume_ty = hir::QPath::LangItem(hir::LangItem::ResumeTy, unstable_span);
let input_ty = hir::Ty {
hir_id: self.next_id(),
kind: hir::TyKind::Path(resume_ty),
span: unstable_span,
};
// The closure/coroutine `FnDecl` takes a single (resume) argument of type `input_ty`.
let fn_decl = self.arena.alloc(hir::FnDecl {
inputs: arena_vec![self; input_ty],
output,
c_variadic: false,
implicit_self: hir::ImplicitSelfKind::None,
lifetime_elision_allowed: false,
});
// Lower the argument pattern/ident. The ident is used again in the `.await` lowering.
let (pat, task_context_hid) = self.pat_ident_binding_mode(
span,
Ident::with_dummy_span(sym::_task_context),
hir::BindingAnnotation::MUT,
);
let param = hir::Param {
hir_id: self.next_id(),
pat,
ty_span: self.lower_span(span),
span: self.lower_span(span),
};
let params = arena_vec![self; param];
let body = self.lower_body(move |this| {
this.coroutine_kind = Some(hir::CoroutineKind::AsyncGen(async_coroutine_source));
let old_ctx = this.task_context;
this.task_context = Some(task_context_hid);
let res = body(this);
this.task_context = old_ctx;
(params, res)
});
// `static |_task_context| -> <ret_ty> { body }`:
hir::ExprKind::Closure(self.arena.alloc(hir::Closure {
def_id: self.local_def_id(closure_node_id),
binder: hir::ClosureBinder::Default,
capture_clause,
bound_generic_params: &[],
fn_decl,
body,
fn_decl_span: self.lower_span(span),
fn_arg_span: None,
movability: Some(hir::Movability::Static),
constness: hir::Constness::NotConst,
}))
}
/// Forwards a possible `#[track_caller]` annotation from `outer_hir_id` to
/// `inner_hir_id` in case the `async_fn_track_caller` feature is enabled.
pub(super) fn maybe_forward_track_caller(
@@ -755,15 +845,18 @@ impl<'hir> LoweringContext<'_, 'hir> {
/// ```
fn lower_expr_await(&mut self, await_kw_span: Span, expr: &Expr) -> hir::ExprKind<'hir> {
let full_span = expr.span.to(await_kw_span);
match self.coroutine_kind {
Some(hir::CoroutineKind::Async(_)) => {}
let is_async_gen = match self.coroutine_kind {
Some(hir::CoroutineKind::Async(_)) => false,
Some(hir::CoroutineKind::AsyncGen(_)) => true,
Some(hir::CoroutineKind::Coroutine) | Some(hir::CoroutineKind::Gen(_)) | None => {
return hir::ExprKind::Err(self.tcx.sess.emit_err(AwaitOnlyInAsyncFnAndBlocks {
await_kw_span,
item_span: self.current_item,
}));
}
}
};
let span = self.mark_span_with_reason(DesugaringKind::Await, await_kw_span, None);
let gen_future_span = self.mark_span_with_reason(
DesugaringKind::Await,
@@ -852,12 +945,19 @@ impl<'hir> LoweringContext<'_, 'hir> {
self.stmt_expr(span, match_expr)
};
// task_context = yield ();
// Depending on `async` of `async gen`:
// async - task_context = yield ();
// async gen - task_context = yield ASYNC_GEN_PENDING;
let yield_stmt = {
let unit = self.expr_unit(span);
let yielded = if is_async_gen {
self.arena.alloc(self.expr_lang_item_path(span, hir::LangItem::AsyncGenPending))
} else {
self.expr_unit(span)
};
let yield_expr = self.expr(
span,
hir::ExprKind::Yield(unit, hir::YieldSource::Await { expr: Some(expr_hir_id) }),
hir::ExprKind::Yield(yielded, hir::YieldSource::Await { expr: Some(expr_hir_id) }),
);
let yield_expr = self.arena.alloc(yield_expr);
@@ -967,7 +1067,11 @@ impl<'hir> LoweringContext<'_, 'hir> {
}
Some(movability)
}
Some(hir::CoroutineKind::Gen(_)) | Some(hir::CoroutineKind::Async(_)) => {
Some(
hir::CoroutineKind::Gen(_)
| hir::CoroutineKind::Async(_)
| hir::CoroutineKind::AsyncGen(_),
) => {
panic!("non-`async`/`gen` closure body turned `async`/`gen` during lowering");
}
None => {
@@ -1474,8 +1578,9 @@ impl<'hir> LoweringContext<'_, 'hir> {
}
fn lower_expr_yield(&mut self, span: Span, opt_expr: Option<&Expr>) -> hir::ExprKind<'hir> {
match self.coroutine_kind {
Some(hir::CoroutineKind::Gen(_)) => {}
let is_async_gen = match self.coroutine_kind {
Some(hir::CoroutineKind::Gen(_)) => false,
Some(hir::CoroutineKind::AsyncGen(_)) => true,
Some(hir::CoroutineKind::Async(_)) => {
return hir::ExprKind::Err(
self.tcx.sess.emit_err(AsyncCoroutinesNotSupported { span }),
@@ -1491,14 +1596,24 @@ impl<'hir> LoweringContext<'_, 'hir> {
)
.emit();
}
self.coroutine_kind = Some(hir::CoroutineKind::Coroutine)
self.coroutine_kind = Some(hir::CoroutineKind::Coroutine);
false
}
}
};
let expr =
let mut yielded =
opt_expr.as_ref().map(|x| self.lower_expr(x)).unwrap_or_else(|| self.expr_unit(span));
hir::ExprKind::Yield(expr, hir::YieldSource::Yield)
if is_async_gen {
// yield async_gen_ready($expr);
yielded = self.expr_call_lang_item_fn(
span,
hir::LangItem::AsyncGenReady,
std::slice::from_ref(yielded),
);
}
hir::ExprKind::Yield(yielded, hir::YieldSource::Yield)
}
/// Desugar `ExprForLoop` from: `[opt_ident]: for <pat> in <head> <body>` into: