[llvm] 5197520 - [PatternMatch] add matchers for commutative logical and/or
Sanjay Patel via llvm-commits
llvm-commits at lists.llvm.org
Thu Oct 7 07:40:55 PDT 2021
Author: Sanjay Patel
Date: 2021-10-07T10:37:34-04:00
New Revision: 519752062c6056adb99a5cff070852c9c698fd0b
URL: https://github.com/llvm/llvm-project/commit/519752062c6056adb99a5cff070852c9c698fd0b
DIFF: https://github.com/llvm/llvm-project/commit/519752062c6056adb99a5cff070852c9c698fd0b.diff
LOG: [PatternMatch] add matchers for commutative logical and/or
We need these to add folds with the same structure as
regular commuted logic ops.
Added:
Modified:
llvm/include/llvm/IR/PatternMatch.h
llvm/unittests/IR/PatternMatch.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index f7946310e505e..fb5736fbfdd11 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -2456,7 +2456,7 @@ inline VScaleVal_match m_VScale(const DataLayout &DL) {
return VScaleVal_match(DL);
}
-template <typename LHS, typename RHS, unsigned Opcode>
+template <typename LHS, typename RHS, unsigned Opcode, bool Commutable = false>
struct LogicalOp_match {
LHS L;
RHS R;
@@ -2464,27 +2464,32 @@ struct LogicalOp_match {
LogicalOp_match(const LHS &L, const RHS &R) : L(L), R(R) {}
template <typename T> bool match(T *V) {
- if (auto *I = dyn_cast<Instruction>(V)) {
- if (!I->getType()->isIntOrIntVectorTy(1))
- return false;
+ auto *I = dyn_cast<Instruction>(V);
+ if (!I || !I->getType()->isIntOrIntVectorTy(1))
+ return false;
- if (I->getOpcode() == Opcode && L.match(I->getOperand(0)) &&
- R.match(I->getOperand(1)))
- return true;
+ if (I->getOpcode() == Opcode) {
+ auto *Op0 = I->getOperand(0);
+ auto *Op1 = I->getOperand(1);
+ return (L.match(Op0) && R.match(Op1)) ||
+ (Commutable && L.match(Op1) && R.match(Op0));
+ }
- if (auto *SI = dyn_cast<SelectInst>(I)) {
- if (Opcode == Instruction::And) {
- if (const auto *C = dyn_cast<Constant>(SI->getFalseValue()))
- if (C->isNullValue() && L.match(SI->getCondition()) &&
- R.match(SI->getTrueValue()))
- return true;
- } else {
- assert(Opcode == Instruction::Or);
- if (const auto *C = dyn_cast<Constant>(SI->getTrueValue()))
- if (C->isOneValue() && L.match(SI->getCondition()) &&
- R.match(SI->getFalseValue()))
- return true;
- }
+ if (auto *Select = dyn_cast<SelectInst>(I)) {
+ auto *Cond = Select->getCondition();
+ auto *TVal = Select->getTrueValue();
+ auto *FVal = Select->getFalseValue();
+ if (Opcode == Instruction::And) {
+ auto *C = dyn_cast<Constant>(FVal);
+ if (C && C->isNullValue())
+ return (L.match(Cond) && R.match(TVal)) ||
+ (Commutable && L.match(TVal) && R.match(Cond));
+ } else {
+ assert(Opcode == Instruction::Or);
+ auto *C = dyn_cast<Constant>(TVal);
+ if (C && C->isOneValue())
+ return (L.match(Cond) && R.match(FVal)) ||
+ (Commutable && L.match(FVal) && R.match(Cond));
}
}
@@ -2503,6 +2508,13 @@ m_LogicalAnd(const LHS &L, const RHS &R) {
/// Matches L && R where L and R are arbitrary values.
inline auto m_LogicalAnd() { return m_LogicalAnd(m_Value(), m_Value()); }
+/// Matches L && R with LHS and RHS in either order.
+template <typename LHS, typename RHS>
+inline LogicalOp_match<LHS, RHS, Instruction::And, true>
+m_c_LogicalAnd(const LHS &L, const RHS &R) {
+ return LogicalOp_match<LHS, RHS, Instruction::And, true>(L, R);
+}
+
/// Matches L || R either in the form of L | R or L ? true : R.
/// Note that the latter form is poison-blocking.
template <typename LHS, typename RHS>
@@ -2512,8 +2524,13 @@ m_LogicalOr(const LHS &L, const RHS &R) {
}
/// Matches L || R where L and R are arbitrary values.
-inline auto m_LogicalOr() {
- return m_LogicalOr(m_Value(), m_Value());
+inline auto m_LogicalOr() { return m_LogicalOr(m_Value(), m_Value()); }
+
+/// Matches L || R with LHS and RHS in either order.
+template <typename LHS, typename RHS>
+inline LogicalOp_match<LHS, RHS, Instruction::Or, true>
+m_c_LogicalOr(const LHS &L, const RHS &R) {
+ return LogicalOp_match<LHS, RHS, Instruction::Or, true>(L, R);
}
} // end namespace PatternMatch
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index 1b7aa7f29fb8e..598dcdff943f8 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -1635,6 +1635,78 @@ TEST_F(PatternMatchTest, InsertValue) {
EXPECT_FALSE(match(IRB.getInt64(99), m_InsertValue<0>(m_Value(), m_Value())));
}
+TEST_F(PatternMatchTest, LogicalSelects) {
+ Value *Alloca = IRB.CreateAlloca(IRB.getInt1Ty());
+ Value *X = IRB.CreateLoad(IRB.getInt1Ty(), Alloca);
+ Value *Y = IRB.CreateLoad(IRB.getInt1Ty(), Alloca);
+ Constant *T = IRB.getInt1(true);
+ Constant *F = IRB.getInt1(false);
+ Value *And = IRB.CreateSelect(X, Y, F);
+ Value *Or = IRB.CreateSelect(X, T, Y);
+
+ // Logical and:
+ // Check basic no-capture logic - opcode and constant must match.
+ EXPECT_TRUE(match(And, m_LogicalAnd(m_Value(), m_Value())));
+ EXPECT_TRUE(match(And, m_c_LogicalAnd(m_Value(), m_Value())));
+ EXPECT_FALSE(match(And, m_LogicalOr(m_Value(), m_Value())));
+ EXPECT_FALSE(match(And, m_c_LogicalOr(m_Value(), m_Value())));
+
+ // Check with captures.
+ EXPECT_TRUE(match(And, m_LogicalAnd(m_Specific(X), m_Value())));
+ EXPECT_TRUE(match(And, m_LogicalAnd(m_Value(), m_Specific(Y))));
+ EXPECT_TRUE(match(And, m_LogicalAnd(m_Specific(X), m_Specific(Y))));
+
+ EXPECT_FALSE(match(And, m_LogicalAnd(m_Specific(Y), m_Value())));
+ EXPECT_FALSE(match(And, m_LogicalAnd(m_Value(), m_Specific(X))));
+ EXPECT_FALSE(match(And, m_LogicalAnd(m_Specific(Y), m_Specific(X))));
+
+ EXPECT_FALSE(match(And, m_LogicalAnd(m_Specific(X), m_Specific(X))));
+ EXPECT_FALSE(match(And, m_LogicalAnd(m_Specific(Y), m_Specific(Y))));
+
+ // Check captures for commutative match.
+ EXPECT_TRUE(match(And, m_c_LogicalAnd(m_Specific(X), m_Value())));
+ EXPECT_TRUE(match(And, m_c_LogicalAnd(m_Value(), m_Specific(Y))));
+ EXPECT_TRUE(match(And, m_c_LogicalAnd(m_Specific(X), m_Specific(Y))));
+
+ EXPECT_TRUE(match(And, m_c_LogicalAnd(m_Specific(Y), m_Value())));
+ EXPECT_TRUE(match(And, m_c_LogicalAnd(m_Value(), m_Specific(X))));
+ EXPECT_TRUE(match(And, m_c_LogicalAnd(m_Specific(Y), m_Specific(X))));
+
+ EXPECT_FALSE(match(And, m_c_LogicalAnd(m_Specific(X), m_Specific(X))));
+ EXPECT_FALSE(match(And, m_c_LogicalAnd(m_Specific(Y), m_Specific(Y))));
+
+ // Logical or:
+ // Check basic no-capture logic - opcode and constant must match.
+ EXPECT_TRUE(match(Or, m_LogicalOr(m_Value(), m_Value())));
+ EXPECT_TRUE(match(Or, m_c_LogicalOr(m_Value(), m_Value())));
+ EXPECT_FALSE(match(Or, m_LogicalAnd(m_Value(), m_Value())));
+ EXPECT_FALSE(match(Or, m_c_LogicalAnd(m_Value(), m_Value())));
+
+ // Check with captures.
+ EXPECT_TRUE(match(Or, m_LogicalOr(m_Specific(X), m_Value())));
+ EXPECT_TRUE(match(Or, m_LogicalOr(m_Value(), m_Specific(Y))));
+ EXPECT_TRUE(match(Or, m_LogicalOr(m_Specific(X), m_Specific(Y))));
+
+ EXPECT_FALSE(match(Or, m_LogicalOr(m_Specific(Y), m_Value())));
+ EXPECT_FALSE(match(Or, m_LogicalOr(m_Value(), m_Specific(X))));
+ EXPECT_FALSE(match(Or, m_LogicalOr(m_Specific(Y), m_Specific(X))));
+
+ EXPECT_FALSE(match(Or, m_LogicalOr(m_Specific(X), m_Specific(X))));
+ EXPECT_FALSE(match(Or, m_LogicalOr(m_Specific(Y), m_Specific(Y))));
+
+ // Check captures for commutative match.
+ EXPECT_TRUE(match(Or, m_c_LogicalOr(m_Specific(X), m_Value())));
+ EXPECT_TRUE(match(Or, m_c_LogicalOr(m_Value(), m_Specific(Y))));
+ EXPECT_TRUE(match(Or, m_c_LogicalOr(m_Specific(X), m_Specific(Y))));
+
+ EXPECT_TRUE(match(Or, m_c_LogicalOr(m_Specific(Y), m_Value())));
+ EXPECT_TRUE(match(Or, m_c_LogicalOr(m_Value(), m_Specific(X))));
+ EXPECT_TRUE(match(Or, m_c_LogicalOr(m_Specific(Y), m_Specific(X))));
+
+ EXPECT_FALSE(match(Or, m_c_LogicalOr(m_Specific(X), m_Specific(X))));
+ EXPECT_FALSE(match(Or, m_c_LogicalOr(m_Specific(Y), m_Specific(Y))));
+}
+
TEST_F(PatternMatchTest, VScale) {
DataLayout DL = M->getDataLayout();
More information about the llvm-commits
mailing list