[llvm] [MIPatternMatch] Add m_DeferredReg/Type (PR #121218)

via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 27 10:14:59 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-globalisel

Author: Min-Yih Hsu (mshockwave)

<details>
<summary>Changes</summary>

This pattern does the same thing as m_SpecificReg/Type except the value it matches against origniated from an earlier pattern in the same mi_match expression.

This patch also changes how commutative patterns are handled: in order to support m_DefferedReg/Type, we always have to run the LHS-pattern before the RHS one.

---
Full diff: https://github.com/llvm/llvm-project/pull/121218.diff


2 Files Affected:

- (modified) llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h (+48-4) 
- (modified) llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp (+30) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
index 47417f53b6e40a..450aaa17ee2665 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
@@ -372,6 +372,36 @@ inline bind_ty<LLT> m_Type(LLT &Ty) { return Ty; }
 inline bind_ty<CmpInst::Predicate> m_Pred(CmpInst::Predicate &P) { return P; }
 inline operand_type_match m_Pred() { return operand_type_match(); }
 
+template <typename BindTy> struct deferred_helper {
+  static bool match(const MachineRegisterInfo &MRI, BindTy &VR, BindTy &V) {
+    return VR == V;
+  }
+};
+
+template <> struct deferred_helper<LLT> {
+  static bool match(const MachineRegisterInfo &MRI, LLT VT, Register R) {
+    return VT == MRI.getType(R);
+  }
+};
+
+template <typename Class> struct deferred_ty {
+  Class &VR;
+
+  deferred_ty(Class &V) : VR(V) {}
+
+  template <typename ITy> bool match(const MachineRegisterInfo &MRI, ITy &&V) {
+    return deferred_helper<Class>::match(MRI, VR, V);
+  }
+};
+
+/// Similar to m_SpecificReg/Type, but the specific value to match originated
+/// from an earlier sub-pattern in the same mi_match expression. For example,
+/// we cannot match `(add X, X)` with `m_GAdd(m_Reg(X), m_SpecificReg(X))`
+/// because `X` is not initialized at the time it's passed to `m_SpecificReg`.
+/// Instead, we can use `m_GAdd(m_Reg(x), m_DeferredReg(X))`.
+inline deferred_ty<Register> m_DeferredReg(Register &R) { return R; }
+inline deferred_ty<LLT> m_DeferredType(LLT &Ty) { return Ty; }
+
 struct ImplicitDefMatch {
   bool match(const MachineRegisterInfo &MRI, Register Reg) {
     MachineInstr *TmpMI;
@@ -401,8 +431,13 @@ struct BinaryOp_match {
       if (TmpMI->getOpcode() == Opcode && TmpMI->getNumOperands() == 3) {
         return (L.match(MRI, TmpMI->getOperand(1).getReg()) &&
                 R.match(MRI, TmpMI->getOperand(2).getReg())) ||
-               (Commutable && (R.match(MRI, TmpMI->getOperand(1).getReg()) &&
-                               L.match(MRI, TmpMI->getOperand(2).getReg())));
+               // NOTE: When trying the alternative different operand ordering
+               // with a commutative operation, it is imperative to always run
+               // the LHS sub-pattern  (i.e. `L`) before the RHS sub-pattern
+               // (i.e. `R`). Otherwsie, m_DeferredReg/Type will not work as
+               // expected.
+               (Commutable && (L.match(MRI, TmpMI->getOperand(2).getReg()) &&
+                               R.match(MRI, TmpMI->getOperand(1).getReg())));
       }
     }
     return false;
@@ -426,8 +461,13 @@ struct BinaryOpc_match {
           TmpMI->getNumOperands() == 3) {
         return (L.match(MRI, TmpMI->getOperand(1).getReg()) &&
                 R.match(MRI, TmpMI->getOperand(2).getReg())) ||
-               (Commutable && (R.match(MRI, TmpMI->getOperand(1).getReg()) &&
-                               L.match(MRI, TmpMI->getOperand(2).getReg())));
+               // NOTE: When trying the alternative different operand ordering
+               // with a commutative operation, it is imperative to always run
+               // the LHS sub-pattern  (i.e. `L`) before the RHS sub-pattern
+               // (i.e. `R`). Otherwsie, m_DeferredReg/Type will not work as
+               // expected.
+               (Commutable && (L.match(MRI, TmpMI->getOperand(2).getReg()) &&
+                               R.match(MRI, TmpMI->getOperand(1).getReg())));
       }
     }
     return false;
@@ -674,6 +714,10 @@ struct CompareOp_match {
     Register RHS = TmpMI->getOperand(3).getReg();
     if (L.match(MRI, LHS) && R.match(MRI, RHS))
       return true;
+    // NOTE: When trying the alternative different operand ordering
+    // with a commutative operation, it is imperative to always run
+    // the LHS sub-pattern  (i.e. `L`) before the RHS sub-pattern
+    // (i.e. `R`). Otherwsie, m_DeferredReg/Type will not work as expected.
     if (Commutable && L.match(MRI, RHS) && R.match(MRI, LHS) &&
         P.match(MRI, CmpInst::getSwappedPredicate(TmpPred)))
       return true;
diff --git a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp
index fc76d4055722e4..40cd055c1c3f80 100644
--- a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp
@@ -920,6 +920,36 @@ TEST_F(AArch64GISelMITest, MatchSpecificReg) {
   EXPECT_TRUE(mi_match(Add.getReg(0), *MRI, m_GAdd(m_SpecificReg(Reg), m_Reg())));
 }
 
+TEST_F(AArch64GISelMITest, DeferredMatching) {
+  setUp();
+  if (!TM)
+    GTEST_SKIP();
+  auto s64 = LLT::scalar(64);
+  auto s32 = LLT::scalar(32);
+
+  auto Cst1 = B.buildConstant(s64, 42);
+  auto Cst2 = B.buildConstant(s64, 314);
+  auto Add = B.buildAdd(s64, Cst1, Cst2);
+  auto Sub = B.buildSub(s64, Add, Cst1);
+
+  auto TruncAdd = B.buildTrunc(s32, Add);
+  auto TruncSub = B.buildTrunc(s32, Sub);
+  auto NarrowAdd = B.buildAdd(s32, TruncAdd, TruncSub);
+
+  Register X;
+  EXPECT_TRUE(mi_match(Sub.getReg(0), *MRI,
+                       m_GSub(m_GAdd(m_Reg(X), m_Reg()), m_DeferredReg(X))));
+  LLT Ty;
+  EXPECT_TRUE(
+      mi_match(NarrowAdd.getReg(0), *MRI,
+               m_GAdd(m_GTrunc(m_Type(Ty)), m_GTrunc(m_DeferredType(Ty)))));
+
+  // Test commutative.
+  auto Add2 = B.buildAdd(s64, Sub, Cst1);
+  EXPECT_TRUE(mi_match(Add2.getReg(0), *MRI,
+                       m_GAdd(m_Reg(X), m_GSub(m_Reg(), m_DeferredReg(X)))));
+}
+
 } // namespace
 
 int main(int argc, char **argv) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/121218


More information about the llvm-commits mailing list