[llvm] 5197520 - [PatternMatch] add matchers for commutative logical and/or

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 7 07:40:55 PDT 2021


Author: Sanjay Patel
Date: 2021-10-07T10:37:34-04:00
New Revision: 519752062c6056adb99a5cff070852c9c698fd0b

URL: https://github.com/llvm/llvm-project/commit/519752062c6056adb99a5cff070852c9c698fd0b
DIFF: https://github.com/llvm/llvm-project/commit/519752062c6056adb99a5cff070852c9c698fd0b.diff

LOG: [PatternMatch] add matchers for commutative logical and/or

We need these to add folds with the same structure as
regular commuted logic ops.

Added: 
    

Modified: 
    llvm/include/llvm/IR/PatternMatch.h
    llvm/unittests/IR/PatternMatch.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index f7946310e505e..fb5736fbfdd11 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -2456,7 +2456,7 @@ inline VScaleVal_match m_VScale(const DataLayout &DL) {
   return VScaleVal_match(DL);
 }
 
-template <typename LHS, typename RHS, unsigned Opcode>
+template <typename LHS, typename RHS, unsigned Opcode, bool Commutable = false>
 struct LogicalOp_match {
   LHS L;
   RHS R;
@@ -2464,27 +2464,32 @@ struct LogicalOp_match {
   LogicalOp_match(const LHS &L, const RHS &R) : L(L), R(R) {}
 
   template <typename T> bool match(T *V) {
-    if (auto *I = dyn_cast<Instruction>(V)) {
-      if (!I->getType()->isIntOrIntVectorTy(1))
-        return false;
+    auto *I = dyn_cast<Instruction>(V);
+    if (!I || !I->getType()->isIntOrIntVectorTy(1))
+      return false;
 
-      if (I->getOpcode() == Opcode && L.match(I->getOperand(0)) &&
-          R.match(I->getOperand(1)))
-        return true;
+    if (I->getOpcode() == Opcode) {
+      auto *Op0 = I->getOperand(0);
+      auto *Op1 = I->getOperand(1);
+      return (L.match(Op0) && R.match(Op1)) ||
+             (Commutable && L.match(Op1) && R.match(Op0));
+    }
 
-      if (auto *SI = dyn_cast<SelectInst>(I)) {
-        if (Opcode == Instruction::And) {
-          if (const auto *C = dyn_cast<Constant>(SI->getFalseValue()))
-            if (C->isNullValue() && L.match(SI->getCondition()) &&
-                R.match(SI->getTrueValue()))
-              return true;
-        } else {
-          assert(Opcode == Instruction::Or);
-          if (const auto *C = dyn_cast<Constant>(SI->getTrueValue()))
-            if (C->isOneValue() && L.match(SI->getCondition()) &&
-                R.match(SI->getFalseValue()))
-              return true;
-        }
+    if (auto *Select = dyn_cast<SelectInst>(I)) {
+      auto *Cond = Select->getCondition();
+      auto *TVal = Select->getTrueValue();
+      auto *FVal = Select->getFalseValue();
+      if (Opcode == Instruction::And) {
+        auto *C = dyn_cast<Constant>(FVal);
+        if (C && C->isNullValue())
+          return (L.match(Cond) && R.match(TVal)) ||
+                 (Commutable && L.match(TVal) && R.match(Cond));
+      } else {
+        assert(Opcode == Instruction::Or);
+        auto *C = dyn_cast<Constant>(TVal);
+        if (C && C->isOneValue())
+          return (L.match(Cond) && R.match(FVal)) ||
+                 (Commutable && L.match(FVal) && R.match(Cond));
       }
     }
 
@@ -2503,6 +2508,13 @@ m_LogicalAnd(const LHS &L, const RHS &R) {
 /// Matches L && R where L and R are arbitrary values.
 inline auto m_LogicalAnd() { return m_LogicalAnd(m_Value(), m_Value()); }
 
+/// Matches L && R with LHS and RHS in either order.
+template <typename LHS, typename RHS>
+inline LogicalOp_match<LHS, RHS, Instruction::And, true>
+m_c_LogicalAnd(const LHS &L, const RHS &R) {
+  return LogicalOp_match<LHS, RHS, Instruction::And, true>(L, R);
+}
+
 /// Matches L || R either in the form of L | R or L ? true : R.
 /// Note that the latter form is poison-blocking.
 template <typename LHS, typename RHS>
@@ -2512,8 +2524,13 @@ m_LogicalOr(const LHS &L, const RHS &R) {
 }
 
 /// Matches L || R where L and R are arbitrary values.
-inline auto m_LogicalOr() {
-  return m_LogicalOr(m_Value(), m_Value());
+inline auto m_LogicalOr() { return m_LogicalOr(m_Value(), m_Value()); }
+
+/// Matches L || R with LHS and RHS in either order.
+template <typename LHS, typename RHS>
+inline LogicalOp_match<LHS, RHS, Instruction::Or, true>
+m_c_LogicalOr(const LHS &L, const RHS &R) {
+  return LogicalOp_match<LHS, RHS, Instruction::Or, true>(L, R);
 }
 
 } // end namespace PatternMatch

diff  --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index 1b7aa7f29fb8e..598dcdff943f8 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -1635,6 +1635,78 @@ TEST_F(PatternMatchTest, InsertValue) {
   EXPECT_FALSE(match(IRB.getInt64(99), m_InsertValue<0>(m_Value(), m_Value())));
 }
 
+TEST_F(PatternMatchTest, LogicalSelects) {
+  Value *Alloca = IRB.CreateAlloca(IRB.getInt1Ty());
+  Value *X = IRB.CreateLoad(IRB.getInt1Ty(), Alloca);
+  Value *Y = IRB.CreateLoad(IRB.getInt1Ty(), Alloca);
+  Constant *T = IRB.getInt1(true);
+  Constant *F = IRB.getInt1(false);
+  Value *And = IRB.CreateSelect(X, Y, F);
+  Value *Or = IRB.CreateSelect(X, T, Y);
+
+  // Logical and:
+  // Check basic no-capture logic - opcode and constant must match.
+  EXPECT_TRUE(match(And, m_LogicalAnd(m_Value(), m_Value())));
+  EXPECT_TRUE(match(And, m_c_LogicalAnd(m_Value(), m_Value())));
+  EXPECT_FALSE(match(And, m_LogicalOr(m_Value(), m_Value())));
+  EXPECT_FALSE(match(And, m_c_LogicalOr(m_Value(), m_Value())));
+
+  // Check with captures.
+  EXPECT_TRUE(match(And, m_LogicalAnd(m_Specific(X), m_Value())));
+  EXPECT_TRUE(match(And, m_LogicalAnd(m_Value(), m_Specific(Y))));
+  EXPECT_TRUE(match(And, m_LogicalAnd(m_Specific(X), m_Specific(Y))));
+
+  EXPECT_FALSE(match(And, m_LogicalAnd(m_Specific(Y), m_Value())));
+  EXPECT_FALSE(match(And, m_LogicalAnd(m_Value(), m_Specific(X))));
+  EXPECT_FALSE(match(And, m_LogicalAnd(m_Specific(Y), m_Specific(X))));
+
+  EXPECT_FALSE(match(And, m_LogicalAnd(m_Specific(X), m_Specific(X))));
+  EXPECT_FALSE(match(And, m_LogicalAnd(m_Specific(Y), m_Specific(Y))));
+
+  // Check captures for commutative match.
+  EXPECT_TRUE(match(And, m_c_LogicalAnd(m_Specific(X), m_Value())));
+  EXPECT_TRUE(match(And, m_c_LogicalAnd(m_Value(), m_Specific(Y))));
+  EXPECT_TRUE(match(And, m_c_LogicalAnd(m_Specific(X), m_Specific(Y))));
+
+  EXPECT_TRUE(match(And, m_c_LogicalAnd(m_Specific(Y), m_Value())));
+  EXPECT_TRUE(match(And, m_c_LogicalAnd(m_Value(), m_Specific(X))));
+  EXPECT_TRUE(match(And, m_c_LogicalAnd(m_Specific(Y), m_Specific(X))));
+
+  EXPECT_FALSE(match(And, m_c_LogicalAnd(m_Specific(X), m_Specific(X))));
+  EXPECT_FALSE(match(And, m_c_LogicalAnd(m_Specific(Y), m_Specific(Y))));
+
+  // Logical or:
+  // Check basic no-capture logic - opcode and constant must match.
+  EXPECT_TRUE(match(Or, m_LogicalOr(m_Value(), m_Value())));
+  EXPECT_TRUE(match(Or, m_c_LogicalOr(m_Value(), m_Value())));
+  EXPECT_FALSE(match(Or, m_LogicalAnd(m_Value(), m_Value())));
+  EXPECT_FALSE(match(Or, m_c_LogicalAnd(m_Value(), m_Value())));
+
+  // Check with captures.
+  EXPECT_TRUE(match(Or, m_LogicalOr(m_Specific(X), m_Value())));
+  EXPECT_TRUE(match(Or, m_LogicalOr(m_Value(), m_Specific(Y))));
+  EXPECT_TRUE(match(Or, m_LogicalOr(m_Specific(X), m_Specific(Y))));
+
+  EXPECT_FALSE(match(Or, m_LogicalOr(m_Specific(Y), m_Value())));
+  EXPECT_FALSE(match(Or, m_LogicalOr(m_Value(), m_Specific(X))));
+  EXPECT_FALSE(match(Or, m_LogicalOr(m_Specific(Y), m_Specific(X))));
+
+  EXPECT_FALSE(match(Or, m_LogicalOr(m_Specific(X), m_Specific(X))));
+  EXPECT_FALSE(match(Or, m_LogicalOr(m_Specific(Y), m_Specific(Y))));
+
+  // Check captures for commutative match.
+  EXPECT_TRUE(match(Or, m_c_LogicalOr(m_Specific(X), m_Value())));
+  EXPECT_TRUE(match(Or, m_c_LogicalOr(m_Value(), m_Specific(Y))));
+  EXPECT_TRUE(match(Or, m_c_LogicalOr(m_Specific(X), m_Specific(Y))));
+
+  EXPECT_TRUE(match(Or, m_c_LogicalOr(m_Specific(Y), m_Value())));
+  EXPECT_TRUE(match(Or, m_c_LogicalOr(m_Value(), m_Specific(X))));
+  EXPECT_TRUE(match(Or, m_c_LogicalOr(m_Specific(Y), m_Specific(X))));
+
+  EXPECT_FALSE(match(Or, m_c_LogicalOr(m_Specific(X), m_Specific(X))));
+  EXPECT_FALSE(match(Or, m_c_LogicalOr(m_Specific(Y), m_Specific(Y))));
+}
+
 TEST_F(PatternMatchTest, VScale) {
   DataLayout DL = M->getDataLayout();
 


        


More information about the llvm-commits mailing list