Rollup merge of #144764 - scottmcm:tweak-impossible-discriminant-assume, r=WaffleLapkin

[codegen] assume the tag, not the relative discriminant

Address the issue mentioned in <https://github.com/llvm/llvm-project/issues/134024#issuecomment-3131782555> by changing discriminant calculation to `assume` on the originally-loaded `tag`, rather than on `cast(tag)-OFFSET`.

The previous way does make the *purpose* of the assume clearer, IMHO, since you see `assume(x != 4); if p { x } else { 4 }`, but doing it this way instead means that the `add`s optimize away in LLVM21, which is more important.  And this new way is still easily thought of as being like metadata on the load saying specifically which value is impossible.

Demo of the LLVM20 vs LLVM21 difference: <https://llvm.godbolt.org/z/n54x5Mq1T>

r? ``@nikic``
This commit is contained in:
Stuart Cook
2025-08-08 12:52:50 +10:00
committed by GitHub
3 changed files with 86 additions and 58 deletions

View File

@@ -498,6 +498,35 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64);
(is_niche, tagged_discr, 0)
} else {
// Thanks to parameter attributes and load metadata, LLVM already knows
// the general valid range of the tag. It's possible, though, for there
// to be an impossible value *in the middle*, which those ranges don't
// communicate, so it's worth an `assume` to let the optimizer know.
// Most importantly, this means when optimizing a variant test like
// `SELECT(is_niche, complex, CONST) == CONST` it's ok to simplify that
// to `!is_niche` because the `complex` part can't possibly match.
//
// This was previously asserted on `tagged_discr` below, where the
// impossible value is more obvious, but that caused an intermediate
// value to become multi-use and thus not optimize, so instead this
// assumes on the original input which is always multi-use. See
// <https://github.com/llvm/llvm-project/issues/134024#issuecomment-3131782555>
//
// FIXME: If we ever get range assume operand bundles in LLVM (so we
// don't need the `icmp`s in the instruction stream any more), it
// might be worth moving this back to being on the switch argument
// where it's more obviously applicable.
if niche_variants.contains(&untagged_variant)
&& bx.cx().sess().opts.optimize != OptLevel::No
{
let impossible = niche_start
.wrapping_add(u128::from(untagged_variant.as_u32()))
.wrapping_sub(u128::from(niche_variants.start().as_u32()));
let impossible = bx.cx().const_uint_big(tag_llty, impossible);
let ne = bx.icmp(IntPredicate::IntNE, tag, impossible);
bx.assume(ne);
}
// With multiple niched variants we'll have to actually compute
// the variant index from the stored tag.
//
@@ -588,20 +617,6 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
let untagged_variant_const =
bx.cx().const_uint(cast_to, u64::from(untagged_variant.as_u32()));
// Thanks to parameter attributes and load metadata, LLVM already knows
// the general valid range of the tag. It's possible, though, for there
// to be an impossible value *in the middle*, which those ranges don't
// communicate, so it's worth an `assume` to let the optimizer know.
// Most importantly, this means when optimizing a variant test like
// `SELECT(is_niche, complex, CONST) == CONST` it's ok to simplify that
// to `!is_niche` because the `complex` part can't possibly match.
if niche_variants.contains(&untagged_variant)
&& bx.cx().sess().opts.optimize != OptLevel::No
{
let ne = bx.icmp(IntPredicate::IntNE, tagged_discr, untagged_variant_const);
bx.assume(ne);
}
let discr = bx.select(is_niche, tagged_discr, untagged_variant_const);
// In principle we could insert assumes on the possible range of `discr`, but

View File

@@ -91,18 +91,23 @@ pub enum Mid<T> {
pub fn mid_bool_eq_discr(a: Mid<bool>, b: Mid<bool>) -> bool {
// CHECK-LABEL: @mid_bool_eq_discr(
// CHECK: %[[A_REL_DISCR:.+]] = add nsw i8 %a, -2
// CHECK: %[[A_IS_NICHE:.+]] = icmp samesign ugt i8 %a, 1
// CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i8 %[[A_REL_DISCR]], 1
// CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i8 %a, 3
// CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]])
// CHECK: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i8 %[[A_REL_DISCR]], i8 1
// LLVM20: %[[A_REL_DISCR:.+]] = add nsw i8 %a, -2
// CHECK: %[[A_IS_NICHE:.+]] = icmp samesign ugt i8 %a, 1
// LLVM20: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i8 %[[A_REL_DISCR]], i8 1
// CHECK: %[[B_REL_DISCR:.+]] = add nsw i8 %b, -2
// CHECK: %[[B_IS_NICHE:.+]] = icmp samesign ugt i8 %b, 1
// CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i8 %[[B_REL_DISCR]], 1
// CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i8 %b, 3
// CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]])
// CHECK: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i8 %[[B_REL_DISCR]], i8 1
// LLVM20: %[[B_REL_DISCR:.+]] = add nsw i8 %b, -2
// CHECK: %[[B_IS_NICHE:.+]] = icmp samesign ugt i8 %b, 1
// LLVM20: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i8 %[[B_REL_DISCR]], i8 1
// LLVM21: %[[A_MOD_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i8 %a, i8 3
// LLVM21: %[[B_MOD_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i8 %b, i8 3
// LLVM20: %[[R:.+]] = icmp eq i8 %[[A_DISCR]], %[[B_DISCR]]
// LLVM21: %[[R:.+]] = icmp eq i8 %[[A_MOD_DISCR]], %[[B_MOD_DISCR]]
// CHECK: ret i1 %[[R]]
discriminant_value(&a) == discriminant_value(&b)
}
@@ -111,19 +116,23 @@ pub fn mid_bool_eq_discr(a: Mid<bool>, b: Mid<bool>) -> bool {
pub fn mid_ord_eq_discr(a: Mid<Ordering>, b: Mid<Ordering>) -> bool {
// CHECK-LABEL: @mid_ord_eq_discr(
// CHECK: %[[A_REL_DISCR:.+]] = add nsw i8 %a, -2
// CHECK: %[[A_IS_NICHE:.+]] = icmp sgt i8 %a, 1
// CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i8 %[[A_REL_DISCR]], 1
// CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i8 %a, 3
// CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]])
// CHECK: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i8 %[[A_REL_DISCR]], i8 1
// LLVM20: %[[A_REL_DISCR:.+]] = add nsw i8 %a, -2
// CHECK: %[[A_IS_NICHE:.+]] = icmp sgt i8 %a, 1
// LLVM20: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i8 %[[A_REL_DISCR]], i8 1
// CHECK: %[[B_REL_DISCR:.+]] = add nsw i8 %b, -2
// CHECK: %[[B_IS_NICHE:.+]] = icmp sgt i8 %b, 1
// CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i8 %[[B_REL_DISCR]], 1
// CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i8 %b, 3
// CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]])
// CHECK: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i8 %[[B_REL_DISCR]], i8 1
// LLVM20: %[[B_REL_DISCR:.+]] = add nsw i8 %b, -2
// CHECK: %[[B_IS_NICHE:.+]] = icmp sgt i8 %b, 1
// LLVM20: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i8 %[[B_REL_DISCR]], i8 1
// CHECK: %[[R:.+]] = icmp eq i8 %[[A_DISCR]], %[[B_DISCR]]
// LLVM21: %[[A_MOD_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i8 %a, i8 3
// LLVM21: %[[B_MOD_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i8 %b, i8 3
// LLVM20: %[[R:.+]] = icmp eq i8 %[[A_DISCR]], %[[B_DISCR]]
// LLVM21: %[[R:.+]] = icmp eq i8 %[[A_MOD_DISCR]], %[[B_MOD_DISCR]]
// CHECK: ret i1 %[[R]]
discriminant_value(&a) == discriminant_value(&b)
}
@@ -140,16 +149,16 @@ pub fn mid_nz32_eq_discr(a: Mid<NonZero<u32>>, b: Mid<NonZero<u32>>) -> bool {
pub fn mid_ac_eq_discr(a: Mid<AC>, b: Mid<AC>) -> bool {
// CHECK-LABEL: @mid_ac_eq_discr(
// LLVM20: %[[A_REL_DISCR:.+]] = xor i8 %a, -128
// CHECK: %[[A_IS_NICHE:.+]] = icmp slt i8 %a, 0
// CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i8 %a, -127
// CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]])
// LLVM20: %[[A_REL_DISCR:.+]] = xor i8 %a, -128
// CHECK: %[[A_IS_NICHE:.+]] = icmp slt i8 %a, 0
// LLVM20: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i8 %[[A_REL_DISCR]], i8 1
// LLVM20: %[[B_REL_DISCR:.+]] = xor i8 %b, -128
// CHECK: %[[B_IS_NICHE:.+]] = icmp slt i8 %b, 0
// CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i8 %b, -127
// CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]])
// LLVM20: %[[B_REL_DISCR:.+]] = xor i8 %b, -128
// CHECK: %[[B_IS_NICHE:.+]] = icmp slt i8 %b, 0
// LLVM20: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i8 %[[B_REL_DISCR]], i8 1
// LLVM21: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i8 %a, i8 -127
@@ -166,21 +175,25 @@ pub fn mid_ac_eq_discr(a: Mid<AC>, b: Mid<AC>) -> bool {
pub fn mid_giant_eq_discr(a: Mid<Giant>, b: Mid<Giant>) -> bool {
// CHECK-LABEL: @mid_giant_eq_discr(
// CHECK: %[[A_TRUNC:.+]] = trunc nuw nsw i128 %a to i64
// CHECK: %[[A_REL_DISCR:.+]] = add nsw i64 %[[A_TRUNC]], -5
// CHECK: %[[A_IS_NICHE:.+]] = icmp samesign ugt i128 %a, 4
// CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i64 %[[A_REL_DISCR]], 1
// CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i128 %a, 6
// CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]])
// CHECK: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i64 %[[A_REL_DISCR]], i64 1
// CHECK: %[[A_TRUNC:.+]] = trunc nuw nsw i128 %a to i64
// LLVM20: %[[A_REL_DISCR:.+]] = add nsw i64 %[[A_TRUNC]], -5
// CHECK: %[[A_IS_NICHE:.+]] = icmp samesign ugt i128 %a, 4
// LLVM20: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i64 %[[A_REL_DISCR]], i64 1
// CHECK: %[[B_TRUNC:.+]] = trunc nuw nsw i128 %b to i64
// CHECK: %[[B_REL_DISCR:.+]] = add nsw i64 %[[B_TRUNC]], -5
// CHECK: %[[B_IS_NICHE:.+]] = icmp samesign ugt i128 %b, 4
// CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i64 %[[B_REL_DISCR]], 1
// CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i128 %b, 6
// CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]])
// CHECK: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i64 %[[B_REL_DISCR]], i64 1
// CHECK: %[[B_TRUNC:.+]] = trunc nuw nsw i128 %b to i64
// LLVM20: %[[B_REL_DISCR:.+]] = add nsw i64 %[[B_TRUNC]], -5
// CHECK: %[[B_IS_NICHE:.+]] = icmp samesign ugt i128 %b, 4
// LLVM20: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i64 %[[B_REL_DISCR]], i64 1
// CHECK: %[[R:.+]] = icmp eq i64 %[[A_DISCR]], %[[B_DISCR]]
// LLVM21: %[[A_MODIFIED_TAG:.+]] = select i1 %[[A_IS_NICHE]], i64 %[[A_TRUNC]], i64 6
// LLVM21: %[[B_MODIFIED_TAG:.+]] = select i1 %[[B_IS_NICHE]], i64 %[[B_TRUNC]], i64 6
// LLVM21: %[[R:.+]] = icmp eq i64 %[[A_MODIFIED_TAG]], %[[B_MODIFIED_TAG]]
// LLVM20: %[[R:.+]] = icmp eq i64 %[[A_DISCR]], %[[B_DISCR]]
// CHECK: ret i1 %[[R]]
discriminant_value(&a) == discriminant_value(&b)
}

View File

@@ -138,18 +138,18 @@ pub fn match3(e: Option<&u8>) -> i16 {
#[derive(PartialEq)]
pub enum MiddleNiche {
A,
B,
C(bool),
D,
E,
A, // tag 2
B, // tag 3
C(bool), // untagged
D, // tag 5
E, // tag 6
}
// CHECK-LABEL: define{{( dso_local)?}} noundef{{( range\(i8 -?[0-9]+, -?[0-9]+\))?}} i8 @match4(i8{{.+}}%0)
// CHECK-NEXT: start:
// CHECK-NEXT: %[[REL_VAR:.+]] = add{{( nsw)?}} i8 %0, -2
// CHECK-NEXT: %[[NOT_IMPOSSIBLE:.+]] = icmp ne i8 %[[REL_VAR]], 2
// CHECK-NEXT: %[[NOT_IMPOSSIBLE:.+]] = icmp ne i8 %0, 4
// CHECK-NEXT: call void @llvm.assume(i1 %[[NOT_IMPOSSIBLE]])
// CHECK-NEXT: %[[REL_VAR:.+]] = add{{( nsw)?}} i8 %0, -2
// CHECK-NEXT: %[[NOT_NICHE:.+]] = icmp{{( samesign)?}} ult i8 %0, 2
// CHECK-NEXT: %[[DISCR:.+]] = select i1 %[[NOT_NICHE]], i8 2, i8 %[[REL_VAR]]
// CHECK-NEXT: switch i8 %[[DISCR]]
@@ -443,19 +443,19 @@ pub enum HugeVariantIndex {
V255(Never),
V256(Never),
Possible257,
Bool258(bool),
Possible259,
Possible257, // tag 2
Bool258(bool), // untagged
Possible259, // tag 4
}
// CHECK-LABEL: define{{( dso_local)?}} noundef{{( range\(i8 [0-9]+, [0-9]+\))?}} i8 @match5(i8{{.+}}%0)
// CHECK-NEXT: start:
// CHECK-NEXT: %[[NOT_IMPOSSIBLE:.+]] = icmp ne i8 %0, 3
// CHECK-NEXT: call void @llvm.assume(i1 %[[NOT_IMPOSSIBLE]])
// CHECK-NEXT: %[[REL_VAR:.+]] = add{{( nsw)?}} i8 %0, -2
// CHECK-NEXT: %[[REL_VAR_WIDE:.+]] = zext i8 %[[REL_VAR]] to i64
// CHECK-NEXT: %[[IS_NICHE:.+]] = icmp{{( samesign)?}} ugt i8 %0, 1
// CHECK-NEXT: %[[NICHE_DISCR:.+]] = add nuw nsw i64 %[[REL_VAR_WIDE]], 257
// CHECK-NEXT: %[[NOT_IMPOSSIBLE:.+]] = icmp ne i64 %[[NICHE_DISCR]], 258
// CHECK-NEXT: call void @llvm.assume(i1 %[[NOT_IMPOSSIBLE]])
// CHECK-NEXT: %[[DISCR:.+]] = select i1 %[[IS_NICHE]], i64 %[[NICHE_DISCR]], i64 258
// CHECK-NEXT: switch i64 %[[DISCR]],
// CHECK-NEXT: i64 257,