[llvm] dac0f7e - [VPlan] Add general recipe matcher, replace handwritten ones (NFC)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 21 16:46:52 PDT 2024


Author: Florian Hahn
Date: 2024-10-21T16:46:45-07:00
New Revision: dac0f7e83ebcaed451d5d0bcc3a4c7f949f0c26c

URL: https://github.com/llvm/llvm-project/commit/dac0f7e83ebcaed451d5d0bcc3a4c7f949f0c26c
DIFF: https://github.com/llvm/llvm-project/commit/dac0f7e83ebcaed451d5d0bcc3a4c7f949f0c26c.diff

LOG: [VPlan] Add general recipe matcher, replace handwritten ones (NFC)

The new matcher is more flexible and can be used to build matchers for
additional recipe types without unnecessary duplication.

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
index ed0bb13d9425f6..1b05afd6b117a5 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
@@ -120,7 +120,12 @@ template <unsigned Opcode, typename RecipeTy>
 struct MatchRecipeAndOpcode<Opcode, RecipeTy> {
   static bool match(const VPRecipeBase *R) {
     auto *DefR = dyn_cast<RecipeTy>(R);
-    return DefR && DefR->getOpcode() == Opcode;
+    // Check for recipes that do not have opcodes.
+    if constexpr (std::is_same<RecipeTy, VPScalarIVStepsRecipe>::value ||
+                  std::is_same<RecipeTy, VPCanonicalIVPHIRecipe>::value)
+      return DefR;
+    else
+      return DefR && DefR->getOpcode() == Opcode;
   }
 };
 
@@ -131,13 +136,34 @@ struct MatchRecipeAndOpcode<Opcode, RecipeTy, RecipeTys...> {
            MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R);
   }
 };
+template <typename TupleTy, typename Fn, std::size_t... Is>
+bool CheckTupleElements(const TupleTy &Ops, Fn P, std::index_sequence<Is...>) {
+  return (P(std::get<Is>(Ops), Is) && ...);
+}
+
+/// Helper to check if predicate \p P holds on all tuple elements in \p Ops
+template <typename TupleTy, typename Fn>
+bool all_of_tuple_elements(const TupleTy &Ops, Fn P) {
+  return CheckTupleElements(
+      Ops, P, std::make_index_sequence<std::tuple_size<TupleTy>::value>{});
+}
 } // namespace detail
 
-template <typename Op0_t, unsigned Opcode, typename... RecipeTys>
-struct UnaryRecipe_match {
-  Op0_t Op0;
+template <typename Ops_t, unsigned Opcode, bool Commutative,
+          typename... RecipeTys>
+struct Recipe_match {
+  Ops_t Ops;
 
-  UnaryRecipe_match(Op0_t Op0) : Op0(Op0) {}
+  Recipe_match() : Ops() {
+    static_assert(std::tuple_size<Ops_t>::value == 0 &&
+                  "constructor can only be used with zero operands");
+  }
+  Recipe_match(Ops_t Ops) : Ops(Ops) {}
+  template <typename A_t, typename B_t>
+  Recipe_match(A_t A, B_t B) : Ops({A, B}) {
+    static_assert(std::tuple_size<Ops_t>::value == 2 &&
+                  "constructor can only be used for binary matcher");
+  }
 
   bool match(const VPValue *V) const {
     auto *DefR = V->getDefiningRecipe();
@@ -151,12 +177,25 @@ struct UnaryRecipe_match {
   bool match(const VPRecipeBase *R) const {
     if (!detail::MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R))
       return false;
-    assert(R->getNumOperands() == 1 &&
-           "recipe with matched opcode does not have 1 operands");
-    return Op0.match(R->getOperand(0));
+    assert(R->getNumOperands() == std::tuple_size<Ops_t>::value &&
+           "recipe with matched opcode the expected number of operands");
+
+    if (detail::all_of_tuple_elements(Ops, [R](auto Op, unsigned Idx) {
+          return Op.match(R->getOperand(Idx));
+        }))
+      return true;
+
+    return Commutative &&
+           detail::all_of_tuple_elements(Ops, [R](auto Op, unsigned Idx) {
+             return Op.match(R->getOperand(R->getNumOperands() - Idx - 1));
+           });
   }
 };
 
+template <typename Op0_t, unsigned Opcode, typename... RecipeTys>
+using UnaryRecipe_match =
+    Recipe_match<std::tuple<Op0_t>, Opcode, false, RecipeTys...>;
+
 template <typename Op0_t, unsigned Opcode>
 using UnaryVPInstruction_match =
     UnaryRecipe_match<Op0_t, Opcode, VPInstruction>;
@@ -168,32 +207,8 @@ using AllUnaryRecipe_match =
 
 template <typename Op0_t, typename Op1_t, unsigned Opcode, bool Commutative,
           typename... RecipeTys>
-struct BinaryRecipe_match {
-  Op0_t Op0;
-  Op1_t Op1;
-
-  BinaryRecipe_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {}
-
-  bool match(const VPValue *V) const {
-    auto *DefR = V->getDefiningRecipe();
-    return DefR && match(DefR);
-  }
-
-  bool match(const VPSingleDefRecipe *R) const {
-    return match(static_cast<const VPRecipeBase *>(R));
-  }
-
-  bool match(const VPRecipeBase *R) const {
-    if (!detail::MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R))
-      return false;
-    assert(R->getNumOperands() == 2 &&
-           "recipe with matched opcode does not have 2 operands");
-    if (Op0.match(R->getOperand(0)) && Op1.match(R->getOperand(1)))
-      return true;
-    return Commutative && Op0.match(R->getOperand(1)) &&
-           Op1.match(R->getOperand(0));
-  }
-};
+using BinaryRecipe_match =
+    Recipe_match<std::tuple<Op0_t, Op1_t>, Opcode, Commutative, RecipeTys...>;
 
 template <typename Op0_t, typename Op1_t, unsigned Opcode>
 using BinaryVPInstruction_match =
@@ -313,40 +328,16 @@ m_LogicalAnd(const Op0_t &Op0, const Op1_t &Op1) {
   return m_VPInstruction<VPInstruction::LogicalAnd, Op0_t, Op1_t>(Op0, Op1);
 }
 
-struct VPCanonicalIVPHI_match {
-  bool match(const VPValue *V) const {
-    auto *DefR = V->getDefiningRecipe();
-    return DefR && match(DefR);
-  }
-
-  bool match(const VPRecipeBase *R) const {
-    return isa<VPCanonicalIVPHIRecipe>(R);
-  }
-};
+using VPCanonicalIVPHI_match =
+    Recipe_match<std::tuple<>, 0, false, VPCanonicalIVPHIRecipe>;
 
 inline VPCanonicalIVPHI_match m_CanonicalIV() {
   return VPCanonicalIVPHI_match();
 }
 
-template <typename Op0_t, typename Op1_t> struct VPScalarIVSteps_match {
-  Op0_t Op0;
-  Op1_t Op1;
-
-  VPScalarIVSteps_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {}
-
-  bool match(const VPValue *V) const {
-    auto *DefR = V->getDefiningRecipe();
-    return DefR && match(DefR);
-  }
-
-  bool match(const VPRecipeBase *R) const {
-    if (!isa<VPScalarIVStepsRecipe>(R))
-      return false;
-    assert(R->getNumOperands() == 2 &&
-           "VPScalarIVSteps must have exactly 2 operands");
-    return Op0.match(R->getOperand(0)) && Op1.match(R->getOperand(1));
-  }
-};
+template <typename Op0_t, typename Op1_t>
+using VPScalarIVSteps_match =
+    Recipe_match<std::tuple<Op0_t, Op1_t>, 0, false, VPScalarIVStepsRecipe>;
 
 template <typename Op0_t, typename Op1_t>
 inline VPScalarIVSteps_match<Op0_t, Op1_t> m_ScalarIVSteps(const Op0_t &Op0,


        


More information about the llvm-commits mailing list