[llvm] [InstCombine] Remove AllOnes fallbacks in getMaskedTypeForICmpPair() (PR #104941)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 20 07:20:39 PDT 2024


https://github.com/nikic created https://github.com/llvm/llvm-project/pull/104941

getMaskedTypeForICmpPair() tries to model non-and operands as x & -1. However, this can end up confusing the matching logic, by picking the -1 operand as the "common" operand, resulting in a successful, but useless, match. This is what causes commutation failures for some of the optimizations driven by this function.

Fix this by removing this -1 fallback entirely. We don't seem to have any test coverage that demonstrates why it would be needed.

>From 516bf1a8828434faf9608dd61d6133f934d264f4 Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Tue, 20 Aug 2024 16:07:07 +0200
Subject: [PATCH] [InstCombine] Remove AllOnes fallbacks in
 getMaskedTypeForICmpPair()

getMaskedTypeForICmpPair() tries to model non-and operands as
x & -1. However, this can end up confusing the matching logic,
by picking the -1 operand as the "common" operand, resulting in
a successful, but useless, match. This is what causes commutation
failures for some of the optimizations driven by this function.

Fix this by removing this -1 fallback entirely. We don't seem to
have any test coverage that demonstrates why it would be needed.
---
 .../InstCombine/InstCombineAndOrXor.cpp       | 83 ++++++++-----------
 .../test/Transforms/InstCombine/bit-checks.ll | 20 ++---
 2 files changed, 41 insertions(+), 62 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 2bba83b5cde3c7..e7f21a42105add 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -211,23 +211,18 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
   // above.
   Value *L1 = LHS->getOperand(0);
   Value *L2 = LHS->getOperand(1);
-  Value *L11, *L12, *L21, *L22;
+  Value *L11 = nullptr, *L12 = nullptr, *L21 = nullptr, *L22 = nullptr;
   // Check whether the icmp can be decomposed into a bit test.
   if (decomposeBitTestICmp(L1, L2, PredL, L11, L12, L2)) {
     L21 = L22 = L1 = nullptr;
   } else {
     // Look for ANDs in the LHS icmp.
-    if (!match(L1, m_And(m_Value(L11), m_Value(L12)))) {
-      // Any icmp can be viewed as being trivially masked; if it allows us to
-      // remove one, it's worth it.
-      L11 = L1;
-      L12 = Constant::getAllOnesValue(L1->getType());
-    }
+    match(L1, m_And(m_Value(L11), m_Value(L12)));
+    match(L2, m_And(m_Value(L21), m_Value(L22)));
 
-    if (!match(L2, m_And(m_Value(L21), m_Value(L22)))) {
-      L21 = L2;
-      L22 = Constant::getAllOnesValue(L2->getType());
-    }
+    // Check that at least one and was found.
+    if (!L11 && !L21)
+      return std::nullopt;
   }
 
   // Bail if LHS was a icmp that can't be decomposed into an equality.
@@ -252,54 +247,42 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
     R1 = nullptr;
     Ok = true;
   } else {
-    if (!match(R1, m_And(m_Value(R11), m_Value(R12)))) {
-      // As before, model no mask as a trivial mask if it'll let us do an
-      // optimization.
-      R11 = R1;
-      R12 = Constant::getAllOnesValue(R1->getType());
+    if (match(R1, m_And(m_Value(R11), m_Value(R12)))) {
+      if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
+        A = R11;
+        D = R12;
+        E = R2;
+        Ok = true;
+      } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
+        A = R12;
+        D = R11;
+        E = R2;
+        Ok = true;
+      }
     }
 
-    if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
-      A = R11;
-      D = R12;
-      E = R2;
-      Ok = true;
-    } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
-      A = R12;
-      D = R11;
-      E = R2;
-      Ok = true;
+    if (match(R2, m_And(m_Value(R11), m_Value(R12)))) {
+      if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
+        A = R11;
+        D = R12;
+        E = R1;
+        Ok = true;
+      } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
+        A = R12;
+        D = R11;
+        E = R1;
+        Ok = true;
+      }
     }
+
+    if (!Ok)
+      return std::nullopt;
   }
 
   // Bail if RHS was a icmp that can't be decomposed into an equality.
   if (!ICmpInst::isEquality(PredR))
     return std::nullopt;
 
-  // Look for ANDs on the right side of the RHS icmp.
-  if (!Ok) {
-    if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) {
-      R11 = R2;
-      R12 = Constant::getAllOnesValue(R2->getType());
-    }
-
-    if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
-      A = R11;
-      D = R12;
-      E = R1;
-      Ok = true;
-    } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
-      A = R12;
-      D = R11;
-      E = R1;
-      Ok = true;
-    } else {
-      return std::nullopt;
-    }
-
-    assert(Ok && "Failed to find AND on the right side of the RHS icmp.");
-  }
-
   if (L11 == A) {
     B = L12;
     C = L2;
diff --git a/llvm/test/Transforms/InstCombine/bit-checks.ll b/llvm/test/Transforms/InstCombine/bit-checks.ll
index c7e1fbb8945493..906e57e0979635 100644
--- a/llvm/test/Transforms/InstCombine/bit-checks.ll
+++ b/llvm/test/Transforms/InstCombine/bit-checks.ll
@@ -809,12 +809,10 @@ define i32 @main7a_logical(i32 %argc, i32 %argc2, i32 %argc3) {
 define i32 @main7b(i32 %argc, i32 %argc2, i32 %argc3x) {
 ; CHECK-LABEL: @main7b(
 ; CHECK-NEXT:    [[ARGC3:%.*]] = mul i32 [[ARGC3X:%.*]], 42
-; CHECK-NEXT:    [[AND1:%.*]] = and i32 [[ARGC:%.*]], [[ARGC2:%.*]]
-; CHECK-NEXT:    [[TOBOOL:%.*]] = icmp ne i32 [[AND1]], [[ARGC2]]
-; CHECK-NEXT:    [[AND2:%.*]] = and i32 [[ARGC3]], [[ARGC]]
-; CHECK-NEXT:    [[TOBOOL3:%.*]] = icmp ne i32 [[ARGC3]], [[AND2]]
-; CHECK-NEXT:    [[AND_COND_NOT:%.*]] = or i1 [[TOBOOL]], [[TOBOOL3]]
-; CHECK-NEXT:    [[STOREMERGE:%.*]] = zext i1 [[AND_COND_NOT]] to i32
+; CHECK-NEXT:    [[TMP1:%.*]] = or i32 [[ARGC3]], [[ARGC2:%.*]]
+; CHECK-NEXT:    [[TMP2:%.*]] = and i32 [[TMP1]], [[ARGC:%.*]]
+; CHECK-NEXT:    [[AND_COND:%.*]] = icmp ne i32 [[TMP2]], [[TMP1]]
+; CHECK-NEXT:    [[STOREMERGE:%.*]] = zext i1 [[AND_COND]] to i32
 ; CHECK-NEXT:    ret i32 [[STOREMERGE]]
 ;
   %argc3 = mul i32 %argc3x, 42 ; thwart complexity-based canonicalization
@@ -850,12 +848,10 @@ define i32 @main7b_logical(i32 %argc, i32 %argc2, i32 %argc3) {
 define i32 @main7c(i32 %argc, i32 %argc2, i32 %argc3x) {
 ; CHECK-LABEL: @main7c(
 ; CHECK-NEXT:    [[ARGC3:%.*]] = mul i32 [[ARGC3X:%.*]], 42
-; CHECK-NEXT:    [[AND1:%.*]] = and i32 [[ARGC2:%.*]], [[ARGC:%.*]]
-; CHECK-NEXT:    [[TOBOOL:%.*]] = icmp ne i32 [[AND1]], [[ARGC2]]
-; CHECK-NEXT:    [[AND2:%.*]] = and i32 [[ARGC3]], [[ARGC]]
-; CHECK-NEXT:    [[TOBOOL3:%.*]] = icmp ne i32 [[ARGC3]], [[AND2]]
-; CHECK-NEXT:    [[AND_COND_NOT:%.*]] = or i1 [[TOBOOL]], [[TOBOOL3]]
-; CHECK-NEXT:    [[STOREMERGE:%.*]] = zext i1 [[AND_COND_NOT]] to i32
+; CHECK-NEXT:    [[TMP1:%.*]] = or i32 [[ARGC3]], [[ARGC2:%.*]]
+; CHECK-NEXT:    [[TMP2:%.*]] = and i32 [[TMP1]], [[ARGC:%.*]]
+; CHECK-NEXT:    [[AND_COND:%.*]] = icmp ne i32 [[TMP2]], [[TMP1]]
+; CHECK-NEXT:    [[STOREMERGE:%.*]] = zext i1 [[AND_COND]] to i32
 ; CHECK-NEXT:    ret i32 [[STOREMERGE]]
 ;
   %argc3 = mul i32 %argc3x, 42 ; thwart complexity-based canonicalization



More information about the llvm-commits mailing list