[llvm] [GlobalISel] Add and use a m_GAddLike pattern matcher. NFC (PR #125435)

via llvm-commits llvm-commits at lists.llvm.org
Sun Feb 2 13:39:29 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: David Green (davemgreen)

<details>
<summary>Changes</summary>

This adds a Flags parameter to the BinaryOp_match, allowing it to detect different flags like Disjoint. A m_GDisjointOr is added to detect Or's with disjoint flags, and G_AddLike is then either a m_GADD or m_GDisjointOr.

The rest is trying to allow matching `const MachineInstr&`, as opposed to non-const references.

---
Full diff: https://github.com/llvm/llvm-project/pull/125435.diff


2 Files Affected:

- (modified) llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h (+52-11) 
- (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.td (+1-3) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
index 78a92c86b91e4c8..edc2d24a2f6de80 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
@@ -33,6 +33,12 @@ template <typename Pattern>
   return P.match(MRI, &MI);
 }
 
+template <typename Pattern>
+[[nodiscard]] bool mi_match(const MachineInstr &MI,
+                            const MachineRegisterInfo &MRI, Pattern &&P) {
+  return P.match(MRI, &MI);
+}
+
 // TODO: Extend for N use.
 template <typename SubPatternT> struct OneUse_match {
   SubPatternT SubPat;
@@ -337,6 +343,21 @@ template <> struct bind_helper<MachineInstr *> {
   }
 };
 
+template <> struct bind_helper<const MachineInstr *> {
+  static bool bind(const MachineRegisterInfo &MRI, const MachineInstr *&MI,
+                   Register Reg) {
+    MI = MRI.getVRegDef(Reg);
+    if (MI)
+      return true;
+    return false;
+  }
+  static bool bind(const MachineRegisterInfo &MRI, const MachineInstr *&MI,
+                   const MachineInstr *Inst) {
+    MI = Inst;
+    return MI;
+  }
+};
+
 template <> struct bind_helper<LLT> {
   static bool bind(const MachineRegisterInfo &MRI, LLT &Ty, Register Reg) {
     Ty = MRI.getType(Reg);
@@ -368,6 +389,9 @@ template <typename Class> struct bind_ty {
 
 inline bind_ty<Register> m_Reg(Register &R) { return R; }
 inline bind_ty<MachineInstr *> m_MInstr(MachineInstr *&MI) { return MI; }
+inline bind_ty<const MachineInstr *> m_MInstr(const MachineInstr *&MI) {
+  return MI;
+}
 inline bind_ty<LLT> m_Type(LLT &Ty) { return Ty; }
 inline bind_ty<CmpInst::Predicate> m_Pred(CmpInst::Predicate &P) { return P; }
 inline operand_type_match m_Pred() { return operand_type_match(); }
@@ -418,7 +442,7 @@ inline bind_ty<const ConstantFP *> m_GFCst(const ConstantFP *&C) { return C; }
 
 // General helper for all the binary generic MI such as G_ADD/G_SUB etc
 template <typename LHS_P, typename RHS_P, unsigned Opcode,
-          bool Commutable = false>
+          bool Commutable = false, unsigned Flags = MachineInstr::NoFlags>
 struct BinaryOp_match {
   LHS_P L;
   RHS_P R;
@@ -426,18 +450,22 @@ struct BinaryOp_match {
   BinaryOp_match(const LHS_P &LHS, const RHS_P &RHS) : L(LHS), R(RHS) {}
   template <typename OpTy>
   bool match(const MachineRegisterInfo &MRI, OpTy &&Op) {
-    MachineInstr *TmpMI;
+    const MachineInstr *TmpMI;
     if (mi_match(Op, MRI, m_MInstr(TmpMI))) {
       if (TmpMI->getOpcode() == Opcode && TmpMI->getNumOperands() == 3) {
-        return (L.match(MRI, TmpMI->getOperand(1).getReg()) &&
-                R.match(MRI, TmpMI->getOperand(2).getReg())) ||
-               // NOTE: When trying the alternative operand ordering
-               // with a commutative operation, it is imperative to always run
-               // the LHS sub-pattern  (i.e. `L`) before the RHS sub-pattern
-               // (i.e. `R`). Otherwsie, m_DeferredReg/Type will not work as
-               // expected.
-               (Commutable && (L.match(MRI, TmpMI->getOperand(2).getReg()) &&
-                               R.match(MRI, TmpMI->getOperand(1).getReg())));
+        if (!(L.match(MRI, TmpMI->getOperand(1).getReg()) &&
+              R.match(MRI, TmpMI->getOperand(2).getReg())) &&
+            // NOTE: When trying the alternative operand ordering
+            // with a commutative operation, it is imperative to always run
+            // the LHS sub-pattern  (i.e. `L`) before the RHS sub-pattern
+            // (i.e. `R`). Otherwsie, m_DeferredReg/Type will not work as
+            // expected.
+            !(Commutable && (L.match(MRI, TmpMI->getOperand(2).getReg()) &&
+                             R.match(MRI, TmpMI->getOperand(1).getReg()))))
+          return false;
+        if (Flags == MachineInstr::NoFlags)
+          return true;
+        return (TmpMI->getFlags() & Flags) == Flags;
       }
     }
     return false;
@@ -559,6 +587,19 @@ inline BinaryOp_match<LHS, RHS, TargetOpcode::G_OR, true> m_GOr(const LHS &L,
   return BinaryOp_match<LHS, RHS, TargetOpcode::G_OR, true>(L, R);
 }
 
+template <typename LHS, typename RHS>
+inline BinaryOp_match<LHS, RHS, TargetOpcode::G_OR, true,
+                      MachineInstr::Disjoint>
+m_GDisjointOr(const LHS &L, const RHS &R) {
+  return BinaryOp_match<LHS, RHS, TargetOpcode::G_OR, true,
+                        MachineInstr::Disjoint>(L, R);
+}
+
+template <typename LHS, typename RHS>
+inline auto m_GAddLike(const LHS &L, const RHS &R) {
+  return m_any_of(m_GAdd(L, R), m_GDisjointOr(L, R));
+}
+
 template <typename LHS, typename RHS>
 inline BinaryOp_match<LHS, RHS, TargetOpcode::G_SHL, false>
 m_GShl(const LHS &L, const RHS &R) {
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 3c57ba414b2bf07..9b36665d539b96c 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -1027,9 +1027,7 @@ def add_and_or_is_add : PatFrags<(ops node:$lhs, node:$rhs),
    return CurDAG->isADDLike(SDValue(N,0));
 }]> {
   let GISelPredicateCode = [{
-     return MI.getOpcode() == TargetOpcode::G_ADD ||
-            (MI.getOpcode() == TargetOpcode::G_OR &&
-             MI.getFlag(MachineInstr::MIFlag::Disjoint));
+     return mi_match(MI, MRI, m_GAddLike(m_Reg(), m_Reg()));
   }];
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/125435


More information about the llvm-commits mailing list