[llvm] [VPlan] Add initial pattern match implementation for VPInstruction. (PR #80563)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Sun Mar 3 13:39:01 PST 2024


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/80563

>From e717be36f16a9169aa4e3f2c88e055705ffbf6b4 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Sat, 3 Feb 2024 20:06:38 +0000
Subject: [PATCH 1/4] [VPlan] Add initial pattern match implementation for
 VPInstruction.

Add an initial version of a pattern match for VPValues and recipes,
starting with VPInstruction.
---
 llvm/lib/Transforms/Vectorize/VPlan.cpp       |  10 +-
 .../Transforms/Vectorize/VPlanPatternMatch.h  | 130 ++++++++++++++++++
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  21 +--
 3 files changed, 142 insertions(+), 19 deletions(-)
 create mode 100644 llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index 2c0daa82afa59f..8300b8abd3edf7 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -19,6 +19,7 @@
 #include "VPlan.h"
 #include "VPlanCFG.h"
 #include "VPlanDominatorTree.h"
+#include "VPlanPatternMatch.h"
 #include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
@@ -46,6 +47,7 @@
 #include <vector>
 
 using namespace llvm;
+using namespace llvm::VPlanPatternMatch;
 
 namespace llvm {
 extern cl::opt<bool> EnableVPlanNativePath;
@@ -552,11 +554,9 @@ static bool hasConditionalTerminator(const VPBasicBlock *VPBB) {
   }
 
   const VPRecipeBase *R = &VPBB->back();
-  auto *VPI = dyn_cast<VPInstruction>(R);
-  bool IsCondBranch =
-      isa<VPBranchOnMaskRecipe>(R) ||
-      (VPI && (VPI->getOpcode() == VPInstruction::BranchOnCond ||
-               VPI->getOpcode() == VPInstruction::BranchOnCount));
+  bool IsCondBranch = isa<VPBranchOnMaskRecipe>(R) ||
+                      match(R, m_BranchOnCond(m_VPValue())) ||
+                      match(R, m_BranchOnCount(m_VPValue(), m_VPValue()));
   (void)IsCondBranch;
 
   if (VPBB->getNumSuccessors() >= 2 || VPBB->isExiting()) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
new file mode 100644
index 00000000000000..03ab21d860108d
--- /dev/null
+++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
@@ -0,0 +1,130 @@
+//===- VPlanPatternMatch.h - Match on VPValues and recipes ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides a simple and efficient mechanism for performing general
+// tree-based pattern matches on the VPlan values and recipes, based on
+// LLVM's IR pattern matchers.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H
+#define LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H
+
+#include "VPlan.h"
+
+namespace llvm {
+namespace VPlanPatternMatch {
+
+template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) {
+  return const_cast<Pattern &>(P).match(V);
+}
+
+template <typename Class> struct class_match {
+  template <typename ITy> bool match(ITy *V) { return isa<Class>(V); }
+};
+
+inline class_match<VPValue> m_VPValue() { return class_match<VPValue>(); }
+
+template <typename Class> struct bind_ty {
+  Class *&VR;
+
+  bind_ty(Class *&V) : VR(V) {}
+
+  template <typename ITy> bool match(ITy *V) {
+    if (auto *CV = dyn_cast<Class>(V)) {
+      VR = CV;
+      return true;
+    }
+    return false;
+  }
+};
+
+inline bind_ty<VPValue> m_VPValue(VPValue *&V) { return V; }
+
+template <typename Op0_t, unsigned Opcode> struct UnaryVPInstruction_match {
+  Op0_t Op0;
+
+  UnaryVPInstruction_match(Op0_t Op0) : Op0(Op0) {}
+
+  bool match(const VPValue *V) {
+    auto *DefR = V->getDefiningRecipe();
+    return DefR && match(DefR);
+  }
+
+  bool match(const VPRecipeBase *R) {
+    auto *DefR = dyn_cast<VPInstruction>(R);
+    if (!DefR)
+      return false;
+    assert((DefR->getOpcode() != Opcode || DefR->getNumOperands() == 1) &&
+           "matched recipe does not have 1 operands");
+    return DefR->getOpcode() == Opcode && Op0.match(DefR->getOperand(0));
+  }
+};
+
+template <typename Op0_t, typename Op1_t, unsigned Opcode>
+struct BinaryVPInstruction_match {
+  Op0_t Op0;
+  Op1_t Op1;
+
+  BinaryVPInstruction_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {}
+
+  bool match(const VPValue *V) {
+    auto *DefR = V->getDefiningRecipe();
+    return DefR && match(DefR);
+  }
+
+  bool match(const VPRecipeBase *R) {
+    auto *DefR = dyn_cast<VPInstruction>(R);
+    if (!DefR)
+      return false;
+    assert((DefR->getOpcode() != Opcode || DefR->getNumOperands() == 2) &&
+           "matched recipe does not have 2 operands");
+    return DefR->getOpcode() == Opcode && Op0.match(DefR->getOperand(0)) &&
+           Op1.match(DefR->getOperand(1));
+  }
+};
+
+template <unsigned Opcode, typename Op0_t>
+inline UnaryVPInstruction_match<Op0_t, Opcode>
+m_VPInstruction(const Op0_t &Op0) {
+  return UnaryVPInstruction_match<Op0_t, Opcode>(Op0);
+}
+
+template <unsigned Opcode, typename Op0_t, typename Op1_t>
+inline BinaryVPInstruction_match<Op0_t, Op1_t, Opcode>
+m_VPInstruction(const Op0_t &Op0, const Op1_t &Op1) {
+  return BinaryVPInstruction_match<Op0_t, Op1_t, Opcode>(Op0, Op1);
+}
+
+template <typename Op0_t>
+inline UnaryVPInstruction_match<Op0_t, VPInstruction::Not>
+m_Not(const Op0_t &Op0) {
+  return m_VPInstruction<VPInstruction::Not>(Op0);
+}
+
+template <typename Op0_t>
+inline UnaryVPInstruction_match<Op0_t, VPInstruction::BranchOnCond>
+m_BranchOnCond(const Op0_t &Op0) {
+  return m_VPInstruction<VPInstruction::BranchOnCond>(Op0);
+}
+
+template <typename Op0_t, typename Op1_t>
+inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::ActiveLaneMask>
+m_ActiveLaneMask(const Op0_t &Op0, const Op1_t &Op1) {
+  return m_VPInstruction<VPInstruction::ActiveLaneMask>(Op0, Op1);
+}
+
+template <typename Op0_t, typename Op1_t>
+inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::BranchOnCount>
+m_BranchOnCount(const Op0_t &Op0, const Op1_t &Op1) {
+  return m_VPInstruction<VPInstruction::BranchOnCount>(Op0, Op1);
+}
+} // namespace VPlanPatternMatch
+} // namespace llvm
+
+#endif
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 71f5285f90236b..af42f35a22b17d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -16,6 +16,7 @@
 #include "VPlanAnalysis.h"
 #include "VPlanCFG.h"
 #include "VPlanDominatorTree.h"
+#include "VPlanPatternMatch.h"
 #include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
@@ -26,7 +27,6 @@
 
 using namespace llvm;
 
-using namespace llvm::PatternMatch;
 
 void VPlanTransforms::VPInstructionsToVPRecipes(
     VPlanPtr &Plan,
@@ -475,6 +475,7 @@ void VPlanTransforms::removeDeadRecipes(VPlan &Plan) {
                  [](VPValue *V) { return V->getNumUsers(); }))
         continue;
 
+      using namespace llvm::PatternMatch;
       // Having side effects keeps R alive, but do remove conditional assume
       // instructions as their conditions may be flattened.
       auto *RepR = dyn_cast<VPReplicateRecipe>(&R);
@@ -577,15 +578,6 @@ void VPlanTransforms::removeRedundantExpandSCEVRecipes(VPlan &Plan) {
   }
 }
 
-static bool canSimplifyBranchOnCond(VPInstruction *Term) {
-  VPInstruction *Not = dyn_cast<VPInstruction>(Term->getOperand(0));
-  if (!Not || Not->getOpcode() != VPInstruction::Not)
-    return false;
-
-  VPInstruction *ALM = dyn_cast<VPInstruction>(Not->getOperand(0));
-  return ALM && ALM->getOpcode() == VPInstruction::ActiveLaneMask;
-}
-
 void VPlanTransforms::optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF,
                                          unsigned BestUF,
                                          PredicatedScalarEvolution &PSE) {
@@ -593,15 +585,16 @@ void VPlanTransforms::optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF,
   assert(Plan.hasUF(BestUF) && "BestUF is not available in Plan");
   VPBasicBlock *ExitingVPBB =
       Plan.getVectorLoopRegion()->getExitingBasicBlock();
-  auto *Term = dyn_cast<VPInstruction>(&ExitingVPBB->back());
+  auto *Term = &ExitingVPBB->back();
   // Try to simplify the branch condition if TC <= VF * UF when preparing to
   // execute the plan for the main vector loop. We only do this if the
   // terminator is:
   //  1. BranchOnCount, or
   //  2. BranchOnCond where the input is Not(ActiveLaneMask).
-  if (!Term || (Term->getOpcode() != VPInstruction::BranchOnCount &&
-                (Term->getOpcode() != VPInstruction::BranchOnCond ||
-                 !canSimplifyBranchOnCond(Term))))
+  using namespace llvm::VPlanPatternMatch;
+  if (!match(Term, m_BranchOnCount(m_VPValue(), m_VPValue())) &&
+      !match(Term,
+             m_BranchOnCond(m_ActiveLaneMask(m_Not(m_VPValue()), m_VPValue()))))
     return;
 
   Type *IdxTy =

>From 50d7d55e5f0b1fd1f3feb69f96efd9cb933094f8 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Sat, 3 Feb 2024 20:27:25 +0000
Subject: [PATCH 2/4] !fixup fix formatting.

---
 llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index af42f35a22b17d..b683900e427c4b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -27,7 +27,6 @@
 
 using namespace llvm;
 
-
 void VPlanTransforms::VPInstructionsToVPRecipes(
     VPlanPtr &Plan,
     function_ref<const InductionDescriptor *(PHINode *)>

>From 0c855b9576ea7f6eb9eebb2dab730fa0ed967ba8 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Mon, 26 Feb 2024 09:16:02 +0000
Subject: [PATCH 3/4] !fixup fix formatting

---
 llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
index 00fae679a04db5..55f08a1d9336ee 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
@@ -10,7 +10,9 @@
 // tree-based pattern matches on the VPlan values and recipes, based on
 // LLVM's IR pattern matchers.
 //
-// Currently it provides generic matchers for unary and binary VPInstructions, and specialized matchers like m_Not, m_ActiveLaneMask, m_BranchOnCond, m_BranchOnCount to match specific VPInstructions.
+// Currently it provides generic matchers for unary and binary VPInstructions,
+// and specialized matchers like m_Not, m_ActiveLaneMask, m_BranchOnCond,
+// m_BranchOnCount to match specific VPInstructions.
 // TODO: Add missing matchers for additional opcodes and recipes as needed.
 //
 //===----------------------------------------------------------------------===//

>From 6f3577685babbb23e8bb73299f917089a244c807 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Sun, 3 Mar 2024 21:38:20 +0000
Subject: [PATCH 4/4] !fixup address comments, thanks!

---
 llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h | 13 ++++++-------
 1 file changed, 6 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
index 55f08a1d9336ee..b90c588b607564 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
@@ -65,11 +65,11 @@ template <typename Op0_t, unsigned Opcode> struct UnaryVPInstruction_match {
 
   bool match(const VPRecipeBase *R) {
     auto *DefR = dyn_cast<VPInstruction>(R);
-    if (!DefR)
+    if (!DefR || DefR->getOpcode() != Opcode)
       return false;
-    assert((DefR->getOpcode() != Opcode || DefR->getNumOperands() == 1) &&
+    assert(DefR->getNumOperands() == 1 &&
            "recipe with matched opcode does not have 1 operands");
-    return DefR->getOpcode() == Opcode && Op0.match(DefR->getOperand(0));
+    return Op0.match(DefR->getOperand(0));
   }
 };
 
@@ -87,12 +87,11 @@ struct BinaryVPInstruction_match {
 
   bool match(const VPRecipeBase *R) {
     auto *DefR = dyn_cast<VPInstruction>(R);
-    if (!DefR)
+    if (!DefR || DefR->getOpcode() != Opcode)
       return false;
-    assert((DefR->getOpcode() != Opcode || DefR->getNumOperands() == 2) &&
+    assert(DefR->getNumOperands() == 2 &&
            "recipe with matched opcode does not have 2 operands");
-    return DefR->getOpcode() == Opcode && Op0.match(DefR->getOperand(0)) &&
-           Op1.match(DefR->getOperand(1));
+    return Op0.match(DefR->getOperand(0)) && Op1.match(DefR->getOperand(1));
   }
 };
 



More information about the llvm-commits mailing list