[llvm] 2435dcd - [VPlan] Add initial pattern match implementation for VPInstruction. (#80563)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Mar 3 13:49:02 PST 2024
Author: Florian Hahn
Date: 2024-03-03T21:48:58Z
New Revision: 2435dcd83a32c63aef91c82cb19b08604ba96b64
URL: https://github.com/llvm/llvm-project/commit/2435dcd83a32c63aef91c82cb19b08604ba96b64
DIFF: https://github.com/llvm/llvm-project/commit/2435dcd83a32c63aef91c82cb19b08604ba96b64.diff
LOG: [VPlan] Add initial pattern match implementation for VPInstruction. (#80563)
Add an initial version of a pattern match for VPValues and recipes,
starting with VPInstruction.
PR: https://github.com/llvm/llvm-project/pull/80563
Added:
llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
Modified:
llvm/lib/Transforms/Vectorize/VPlan.cpp
llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index 4aeab6fc619988..9768e4b7aa0a8a 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -19,6 +19,7 @@
#include "VPlan.h"
#include "VPlanCFG.h"
#include "VPlanDominatorTree.h"
+#include "VPlanPatternMatch.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
@@ -46,6 +47,7 @@
#include <vector>
using namespace llvm;
+using namespace llvm::VPlanPatternMatch;
namespace llvm {
extern cl::opt<bool> EnableVPlanNativePath;
@@ -569,11 +571,9 @@ static bool hasConditionalTerminator(const VPBasicBlock *VPBB) {
}
const VPRecipeBase *R = &VPBB->back();
- auto *VPI = dyn_cast<VPInstruction>(R);
- bool IsCondBranch =
- isa<VPBranchOnMaskRecipe>(R) ||
- (VPI && (VPI->getOpcode() == VPInstruction::BranchOnCond ||
- VPI->getOpcode() == VPInstruction::BranchOnCount));
+ bool IsCondBranch = isa<VPBranchOnMaskRecipe>(R) ||
+ match(R, m_BranchOnCond(m_VPValue())) ||
+ match(R, m_BranchOnCount(m_VPValue(), m_VPValue()));
(void)IsCondBranch;
if (VPBB->getNumSuccessors() >= 2 ||
diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
new file mode 100644
index 00000000000000..b90c588b607564
--- /dev/null
+++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
@@ -0,0 +1,136 @@
+//===- VPlanPatternMatch.h - Match on VPValues and recipes ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides a simple and efficient mechanism for performing general
+// tree-based pattern matches on the VPlan values and recipes, based on
+// LLVM's IR pattern matchers.
+//
+// Currently it provides generic matchers for unary and binary VPInstructions,
+// and specialized matchers like m_Not, m_ActiveLaneMask, m_BranchOnCond,
+// m_BranchOnCount to match specific VPInstructions.
+// TODO: Add missing matchers for additional opcodes and recipes as needed.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H
+#define LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H
+
+#include "VPlan.h"
+
+namespace llvm {
+namespace VPlanPatternMatch {
+
+template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) {
+ return const_cast<Pattern &>(P).match(V);
+}
+
+template <typename Class> struct class_match {
+ template <typename ITy> bool match(ITy *V) { return isa<Class>(V); }
+};
+
+/// Match an arbitrary VPValue and ignore it.
+inline class_match<VPValue> m_VPValue() { return class_match<VPValue>(); }
+
+template <typename Class> struct bind_ty {
+ Class *&VR;
+
+ bind_ty(Class *&V) : VR(V) {}
+
+ template <typename ITy> bool match(ITy *V) {
+ if (auto *CV = dyn_cast<Class>(V)) {
+ VR = CV;
+ return true;
+ }
+ return false;
+ }
+};
+
+/// Match a VPValue, capturing it if we match.
+inline bind_ty<VPValue> m_VPValue(VPValue *&V) { return V; }
+
+template <typename Op0_t, unsigned Opcode> struct UnaryVPInstruction_match {
+ Op0_t Op0;
+
+ UnaryVPInstruction_match(Op0_t Op0) : Op0(Op0) {}
+
+ bool match(const VPValue *V) {
+ auto *DefR = V->getDefiningRecipe();
+ return DefR && match(DefR);
+ }
+
+ bool match(const VPRecipeBase *R) {
+ auto *DefR = dyn_cast<VPInstruction>(R);
+ if (!DefR || DefR->getOpcode() != Opcode)
+ return false;
+ assert(DefR->getNumOperands() == 1 &&
+ "recipe with matched opcode does not have 1 operands");
+ return Op0.match(DefR->getOperand(0));
+ }
+};
+
+template <typename Op0_t, typename Op1_t, unsigned Opcode>
+struct BinaryVPInstruction_match {
+ Op0_t Op0;
+ Op1_t Op1;
+
+ BinaryVPInstruction_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {}
+
+ bool match(const VPValue *V) {
+ auto *DefR = V->getDefiningRecipe();
+ return DefR && match(DefR);
+ }
+
+ bool match(const VPRecipeBase *R) {
+ auto *DefR = dyn_cast<VPInstruction>(R);
+ if (!DefR || DefR->getOpcode() != Opcode)
+ return false;
+ assert(DefR->getNumOperands() == 2 &&
+ "recipe with matched opcode does not have 2 operands");
+ return Op0.match(DefR->getOperand(0)) && Op1.match(DefR->getOperand(1));
+ }
+};
+
+template <unsigned Opcode, typename Op0_t>
+inline UnaryVPInstruction_match<Op0_t, Opcode>
+m_VPInstruction(const Op0_t &Op0) {
+ return UnaryVPInstruction_match<Op0_t, Opcode>(Op0);
+}
+
+template <unsigned Opcode, typename Op0_t, typename Op1_t>
+inline BinaryVPInstruction_match<Op0_t, Op1_t, Opcode>
+m_VPInstruction(const Op0_t &Op0, const Op1_t &Op1) {
+ return BinaryVPInstruction_match<Op0_t, Op1_t, Opcode>(Op0, Op1);
+}
+
+template <typename Op0_t>
+inline UnaryVPInstruction_match<Op0_t, VPInstruction::Not>
+m_Not(const Op0_t &Op0) {
+ return m_VPInstruction<VPInstruction::Not>(Op0);
+}
+
+template <typename Op0_t>
+inline UnaryVPInstruction_match<Op0_t, VPInstruction::BranchOnCond>
+m_BranchOnCond(const Op0_t &Op0) {
+ return m_VPInstruction<VPInstruction::BranchOnCond>(Op0);
+}
+
+template <typename Op0_t, typename Op1_t>
+inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::ActiveLaneMask>
+m_ActiveLaneMask(const Op0_t &Op0, const Op1_t &Op1) {
+ return m_VPInstruction<VPInstruction::ActiveLaneMask>(Op0, Op1);
+}
+
+template <typename Op0_t, typename Op1_t>
+inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::BranchOnCount>
+m_BranchOnCount(const Op0_t &Op0, const Op1_t &Op1) {
+ return m_VPInstruction<VPInstruction::BranchOnCount>(Op0, Op1);
+}
+} // namespace VPlanPatternMatch
+} // namespace llvm
+
+#endif
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 9c3f35112b592f..9d6deb802e2090 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -16,6 +16,7 @@
#include "VPlanAnalysis.h"
#include "VPlanCFG.h"
#include "VPlanDominatorTree.h"
+#include "VPlanPatternMatch.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
@@ -26,8 +27,6 @@
using namespace llvm;
-using namespace llvm::PatternMatch;
-
void VPlanTransforms::VPInstructionsToVPRecipes(
VPlanPtr &Plan,
function_ref<const InductionDescriptor *(PHINode *)>
@@ -486,6 +485,7 @@ static void removeDeadRecipes(VPlan &Plan) {
[](VPValue *V) { return V->getNumUsers(); }))
continue;
+ using namespace llvm::PatternMatch;
// Having side effects keeps R alive, but do remove conditional assume
// instructions as their conditions may be flattened.
auto *RepR = dyn_cast<VPReplicateRecipe>(&R);
@@ -595,15 +595,6 @@ static void removeRedundantExpandSCEVRecipes(VPlan &Plan) {
}
}
-static bool canSimplifyBranchOnCond(VPInstruction *Term) {
- VPInstruction *Not = dyn_cast<VPInstruction>(Term->getOperand(0));
- if (!Not || Not->getOpcode() != VPInstruction::Not)
- return false;
-
- VPInstruction *ALM = dyn_cast<VPInstruction>(Not->getOperand(0));
- return ALM && ALM->getOpcode() == VPInstruction::ActiveLaneMask;
-}
-
void VPlanTransforms::optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF,
unsigned BestUF,
PredicatedScalarEvolution &PSE) {
@@ -611,15 +602,16 @@ void VPlanTransforms::optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF,
assert(Plan.hasUF(BestUF) && "BestUF is not available in Plan");
VPBasicBlock *ExitingVPBB =
Plan.getVectorLoopRegion()->getExitingBasicBlock();
- auto *Term = dyn_cast<VPInstruction>(&ExitingVPBB->back());
+ auto *Term = &ExitingVPBB->back();
// Try to simplify the branch condition if TC <= VF * UF when preparing to
// execute the plan for the main vector loop. We only do this if the
// terminator is:
// 1. BranchOnCount, or
// 2. BranchOnCond where the input is Not(ActiveLaneMask).
- if (!Term || (Term->getOpcode() != VPInstruction::BranchOnCount &&
- (Term->getOpcode() != VPInstruction::BranchOnCond ||
- !canSimplifyBranchOnCond(Term))))
+ using namespace llvm::VPlanPatternMatch;
+ if (!match(Term, m_BranchOnCount(m_VPValue(), m_VPValue())) &&
+ !match(Term,
+ m_BranchOnCond(m_Not(m_ActiveLaneMask(m_VPValue(), m_VPValue())))))
return;
Type *IdxTy =
More information about the llvm-commits
mailing list