Teach typeck/borrowck/solvers how to deal with async closures

This commit is contained in:
Michael Goulet
2024-01-24 22:27:25 +00:00
parent c567eddec2
commit a82bae2172
35 changed files with 1221 additions and 66 deletions

View File

@@ -1833,10 +1833,28 @@ fn assemble_candidates_from_impls<'cx, 'tcx>(
lang_items.fn_trait(),
lang_items.fn_mut_trait(),
lang_items.fn_once_trait(),
lang_items.async_fn_trait(),
lang_items.async_fn_mut_trait(),
lang_items.async_fn_once_trait(),
].contains(&Some(trait_ref.def_id))
{
true
}else if lang_items.discriminant_kind_trait() == Some(trait_ref.def_id) {
} else if lang_items.async_fn_kind_helper() == Some(trait_ref.def_id) {
// FIXME(async_closures): Validity constraints here could be cleaned up.
if obligation.predicate.args.type_at(0).is_ty_var()
|| obligation.predicate.args.type_at(4).is_ty_var()
|| obligation.predicate.args.type_at(5).is_ty_var()
{
candidate_set.mark_ambiguous();
true
} else if obligation.predicate.args.type_at(0).to_opt_closure_kind().is_some()
&& obligation.predicate.args.type_at(1).to_opt_closure_kind().is_some()
{
true
} else {
false
}
} else if lang_items.discriminant_kind_trait() == Some(trait_ref.def_id) {
match self_ty.kind() {
ty::Bool
| ty::Char
@@ -2061,6 +2079,10 @@ fn confirm_select_candidate<'cx, 'tcx>(
} else {
confirm_fn_pointer_candidate(selcx, obligation, data)
}
} else if selcx.tcx().async_fn_trait_kind_from_def_id(trait_def_id).is_some() {
confirm_async_closure_candidate(selcx, obligation, data)
} else if lang_items.async_fn_kind_helper() == Some(trait_def_id) {
confirm_async_fn_kind_helper_candidate(selcx, obligation, data)
} else {
confirm_builtin_candidate(selcx, obligation, data)
}
@@ -2421,6 +2443,164 @@ fn confirm_callable_candidate<'cx, 'tcx>(
confirm_param_env_candidate(selcx, obligation, predicate, true)
}
fn confirm_async_closure_candidate<'cx, 'tcx>(
selcx: &mut SelectionContext<'cx, 'tcx>,
obligation: &ProjectionTyObligation<'tcx>,
mut nested: Vec<PredicateObligation<'tcx>>,
) -> Progress<'tcx> {
let self_ty = selcx.infcx.shallow_resolve(obligation.predicate.self_ty());
let ty::CoroutineClosure(def_id, args) = *self_ty.kind() else {
unreachable!(
"expected coroutine-closure self type for coroutine-closure candidate, found {self_ty}"
)
};
let args = args.as_coroutine_closure();
let kind_ty = args.kind_ty();
let tcx = selcx.tcx();
let goal_kind =
tcx.async_fn_trait_kind_from_def_id(obligation.predicate.trait_def_id(tcx)).unwrap();
let helper_trait_def_id = tcx.require_lang_item(LangItem::AsyncFnKindHelper, None);
nested.push(obligation.with(
tcx,
ty::TraitRef::new(
tcx,
helper_trait_def_id,
[kind_ty, Ty::from_closure_kind(tcx, goal_kind)],
),
));
let env_region = match goal_kind {
ty::ClosureKind::Fn | ty::ClosureKind::FnMut => obligation.predicate.args.region_at(2),
ty::ClosureKind::FnOnce => tcx.lifetimes.re_static,
};
// FIXME(async_closures): Make this into a lang item.
let upvars_projection_def_id =
tcx.associated_items(helper_trait_def_id).in_definition_order().next().unwrap().def_id;
// FIXME(async_closures): Confirmation is kind of a mess here. Ideally,
// we'd short-circuit when we know that the goal_kind >= closure_kind, and not
// register a nested predicate or create a new projection ty here. But I'm too
// lazy to make this more efficient atm, and we can always tweak it later,
// since all this does is make the solver do more work.
//
// The code duplication due to the different length args is kind of weird, too.
let poly_cache_entry = args.coroutine_closure_sig().map_bound(|sig| {
let (projection_ty, term) = match tcx.item_name(obligation.predicate.def_id) {
sym::CallOnceFuture => {
let tupled_upvars_ty = Ty::new_projection(
tcx,
upvars_projection_def_id,
[
ty::GenericArg::from(kind_ty),
Ty::from_closure_kind(tcx, goal_kind).into(),
env_region.into(),
sig.tupled_inputs_ty.into(),
args.tupled_upvars_ty().into(),
args.coroutine_captures_by_ref_ty().into(),
],
);
let coroutine_ty = sig.to_coroutine(
tcx,
args.parent_args(),
tcx.coroutine_for_closure(def_id),
tupled_upvars_ty,
);
(
ty::AliasTy::new(
tcx,
obligation.predicate.def_id,
[self_ty, sig.tupled_inputs_ty],
),
coroutine_ty.into(),
)
}
sym::CallMutFuture | sym::CallFuture => {
let tupled_upvars_ty = Ty::new_projection(
tcx,
upvars_projection_def_id,
[
ty::GenericArg::from(kind_ty),
Ty::from_closure_kind(tcx, goal_kind).into(),
env_region.into(),
sig.tupled_inputs_ty.into(),
args.tupled_upvars_ty().into(),
args.coroutine_captures_by_ref_ty().into(),
],
);
let coroutine_ty = sig.to_coroutine(
tcx,
args.parent_args(),
tcx.coroutine_for_closure(def_id),
tupled_upvars_ty,
);
(
ty::AliasTy::new(
tcx,
obligation.predicate.def_id,
[
ty::GenericArg::from(self_ty),
sig.tupled_inputs_ty.into(),
env_region.into(),
],
),
coroutine_ty.into(),
)
}
sym::Output => (
ty::AliasTy::new(tcx, obligation.predicate.def_id, [self_ty, sig.tupled_inputs_ty]),
sig.return_ty.into(),
),
name => bug!("no such associated type: {name}"),
};
ty::ProjectionPredicate { projection_ty, term }
});
confirm_param_env_candidate(selcx, obligation, poly_cache_entry, true)
.with_addl_obligations(nested)
}
fn confirm_async_fn_kind_helper_candidate<'cx, 'tcx>(
selcx: &mut SelectionContext<'cx, 'tcx>,
obligation: &ProjectionTyObligation<'tcx>,
nested: Vec<PredicateObligation<'tcx>>,
) -> Progress<'tcx> {
let [
// We already checked that the goal_kind >= closure_kind
_closure_kind_ty,
goal_kind_ty,
borrow_region,
tupled_inputs_ty,
tupled_upvars_ty,
coroutine_captures_by_ref_ty,
] = **obligation.predicate.args
else {
bug!();
};
let predicate = ty::ProjectionPredicate {
projection_ty: ty::AliasTy::new(
selcx.tcx(),
obligation.predicate.def_id,
obligation.predicate.args,
),
term: ty::CoroutineClosureSignature::tupled_upvars_by_closure_kind(
selcx.tcx(),
goal_kind_ty.expect_ty().to_opt_closure_kind().unwrap(),
tupled_inputs_ty.expect_ty(),
tupled_upvars_ty.expect_ty(),
coroutine_captures_by_ref_ty.expect_ty(),
borrow_region.expect_region(),
)
.into(),
};
confirm_param_env_candidate(selcx, obligation, ty::Binder::dummy(predicate), false)
.with_addl_obligations(nested)
}
fn confirm_param_env_candidate<'cx, 'tcx>(
selcx: &mut SelectionContext<'cx, 'tcx>,
obligation: &ProjectionTyObligation<'tcx>,

View File

@@ -117,9 +117,12 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
self.assemble_iterator_candidates(obligation, &mut candidates);
} else if lang_items.async_iterator_trait() == Some(def_id) {
self.assemble_async_iterator_candidates(obligation, &mut candidates);
} else if lang_items.async_fn_kind_helper() == Some(def_id) {
self.assemble_async_fn_kind_helper_candidates(obligation, &mut candidates);
}
self.assemble_closure_candidates(obligation, &mut candidates);
self.assemble_async_closure_candidates(obligation, &mut candidates);
self.assemble_fn_pointer_candidates(obligation, &mut candidates);
self.assemble_candidates_from_impls(obligation, &mut candidates);
self.assemble_candidates_from_object_ty(obligation, &mut candidates);
@@ -335,6 +338,49 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
}
}
fn assemble_async_closure_candidates(
&mut self,
obligation: &PolyTraitObligation<'tcx>,
candidates: &mut SelectionCandidateSet<'tcx>,
) {
let Some(goal_kind) =
self.tcx().async_fn_trait_kind_from_def_id(obligation.predicate.def_id())
else {
return;
};
match *obligation.self_ty().skip_binder().kind() {
ty::CoroutineClosure(_, args) => {
if let Some(closure_kind) =
args.as_coroutine_closure().kind_ty().to_opt_closure_kind()
&& !closure_kind.extends(goal_kind)
{
return;
}
candidates.vec.push(AsyncClosureCandidate);
}
ty::Infer(ty::TyVar(_)) => {
candidates.ambiguous = true;
}
_ => {}
}
}
fn assemble_async_fn_kind_helper_candidates(
&mut self,
obligation: &PolyTraitObligation<'tcx>,
candidates: &mut SelectionCandidateSet<'tcx>,
) {
if let Some(closure_kind) = obligation.self_ty().skip_binder().to_opt_closure_kind()
&& let Some(goal_kind) =
obligation.predicate.skip_binder().trait_ref.args.type_at(1).to_opt_closure_kind()
{
if closure_kind.extends(goal_kind) {
candidates.vec.push(AsyncFnKindHelperCandidate);
}
}
}
/// Implements one of the `Fn()` family for a fn pointer.
fn assemble_fn_pointer_candidates(
&mut self,

View File

@@ -83,6 +83,13 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
ImplSource::Builtin(BuiltinImplSource::Misc, vtable_closure)
}
AsyncClosureCandidate => {
let vtable_closure = self.confirm_async_closure_candidate(obligation)?;
ImplSource::Builtin(BuiltinImplSource::Misc, vtable_closure)
}
AsyncFnKindHelperCandidate => ImplSource::Builtin(BuiltinImplSource::Misc, vec![]),
CoroutineCandidate => {
let vtable_coroutine = self.confirm_coroutine_candidate(obligation)?;
ImplSource::Builtin(BuiltinImplSource::Misc, vtable_coroutine)
@@ -869,6 +876,49 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
Ok(nested)
}
#[instrument(skip(self), level = "debug")]
fn confirm_async_closure_candidate(
&mut self,
obligation: &PolyTraitObligation<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
// Okay to skip binder because the args on closure types never
// touch bound regions, they just capture the in-scope
// type/region parameters.
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
let ty::CoroutineClosure(closure_def_id, args) = *self_ty.kind() else {
bug!("async closure candidate for non-coroutine-closure {:?}", obligation);
};
let trait_ref = args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| {
ty::TraitRef::new(
self.tcx(),
obligation.predicate.def_id(),
[self_ty, sig.tupled_inputs_ty],
)
});
let mut nested = self.confirm_poly_trait_refs(obligation, trait_ref)?;
let goal_kind =
self.tcx().async_fn_trait_kind_from_def_id(obligation.predicate.def_id()).unwrap();
nested.push(obligation.with(
self.tcx(),
ty::TraitRef::from_lang_item(
self.tcx(),
LangItem::AsyncFnKindHelper,
obligation.cause.span,
[
args.as_coroutine_closure().kind_ty(),
Ty::from_closure_kind(self.tcx(), goal_kind),
],
),
));
debug!(?closure_def_id, ?trait_ref, ?nested, "confirm closure candidate obligations");
Ok(nested)
}
/// In the case of closure types and fn pointers,
/// we currently treat the input type parameters on the trait as
/// outputs. This means that when we have a match we have only

View File

@@ -1864,6 +1864,8 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
ImplCandidate(..)
| AutoImplCandidate
| ClosureCandidate { .. }
| AsyncClosureCandidate
| AsyncFnKindHelperCandidate
| CoroutineCandidate
| FutureCandidate
| IteratorCandidate
@@ -1894,6 +1896,8 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
ImplCandidate(_)
| AutoImplCandidate
| ClosureCandidate { .. }
| AsyncClosureCandidate
| AsyncFnKindHelperCandidate
| CoroutineCandidate
| FutureCandidate
| IteratorCandidate
@@ -1930,6 +1934,8 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
ImplCandidate(..)
| AutoImplCandidate
| ClosureCandidate { .. }
| AsyncClosureCandidate
| AsyncFnKindHelperCandidate
| CoroutineCandidate
| FutureCandidate
| IteratorCandidate
@@ -1946,6 +1952,8 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
ImplCandidate(..)
| AutoImplCandidate
| ClosureCandidate { .. }
| AsyncClosureCandidate
| AsyncFnKindHelperCandidate
| CoroutineCandidate
| FutureCandidate
| IteratorCandidate
@@ -2054,6 +2062,8 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
(
ImplCandidate(_)
| ClosureCandidate { .. }
| AsyncClosureCandidate
| AsyncFnKindHelperCandidate
| CoroutineCandidate
| FutureCandidate
| IteratorCandidate
@@ -2066,6 +2076,8 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
| TraitAliasCandidate,
ImplCandidate(_)
| ClosureCandidate { .. }
| AsyncClosureCandidate
| AsyncFnKindHelperCandidate
| CoroutineCandidate
| FutureCandidate
| IteratorCandidate