[Mlir-commits] [mlir] 4b48063 - [mlir][vector][transform] Register vector dialect patterns

Matthias Springer llvmlistbot at llvm.org
Fri Jun 2 07:00:08 PDT 2023


Author: Matthias Springer
Date: 2023-06-02T15:59:56+02:00
New Revision: 4b48063b521bfcc9835269c729de79459d93229e

URL: https://github.com/llvm/llvm-project/commit/4b48063b521bfcc9835269c729de79459d93229e
DIFF: https://github.com/llvm/llvm-project/commit/4b48063b521bfcc9835269c729de79459d93229e.diff

LOG: [mlir][vector][transform] Register vector dialect patterns

Differential Revision: https://reviews.llvm.org/D151983

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
index 3e3461bb14f6e..e738baf15c8f9 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
@@ -161,11 +161,20 @@ class PatternRegistry : public TransformDialectData<PatternRegistry> {
 
   /// A function that populates a `RewritePatternSet`.
   using PopulatePatternsFn = std::function<void(RewritePatternSet &)>;
+  /// A function that populates a `RewritePatternSet` with a specified benefit.
+  using PopulatePatternsWithBenefitFn =
+      std::function<void(RewritePatternSet &, PatternBenefit)>;
 
   /// Registers patterns with the specified identifier. The identifier should
   /// be prefixed with the dialect to which the patterns belong.
   void registerPatterns(StringRef identifier, PopulatePatternsFn &&fn);
 
+  /// Registers patterns with the specified identifier. The identifier should
+  /// be prefixed with the dialect to which the patterns belong. The pattern
+  /// benefit is currently ignored.
+  void registerPatterns(StringRef identifier,
+                        PopulatePatternsWithBenefitFn &&fn);
+
 protected:
   friend class ApplyPatternsOp;
 

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index c076a8cab89ea..c8326f8fd9b06 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -223,6 +223,15 @@ void transform::PatternRegistry::registerPatterns(StringRef identifier,
   patterns.try_emplace(attr, std::move(fn));
 }
 
+void transform::PatternRegistry::registerPatterns(
+    StringRef identifier, PopulatePatternsWithBenefitFn &&fn) {
+  StringAttr attr = builder.getStringAttr(identifier);
+  assert(!patterns.contains(attr) && "patterns identifier is already in use");
+  patterns.try_emplace(attr, [f = move(fn)](RewritePatternSet &patternSet) {
+    f(patternSet, /*benefit=*/1);
+  });
+}
+
 void transform::PatternRegistry::populatePatterns(
     StringAttr identifier, RewritePatternSet &patternSet) const {
   auto it = patterns.find(identifier);

diff  --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 9c7184de88119..44caaec2d1910 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -188,6 +189,34 @@ class VectorTransformDialectExtension
 #define GET_OP_LIST
 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
         >();
+
+    addDialectDataInitializer<transform::PatternRegistry>(
+        [&](transform::PatternRegistry &registry) {
+          registry.registerPatterns("vector.outer_product_lowering",
+                                    populateVectorOuterProductLoweringPatterns);
+          registry.registerPatterns("vector.broadcast_lowering",
+                                    populateVectorBroadcastLoweringPatterns);
+          registry.registerPatterns("vector.mask_op_lowering",
+                                    populateVectorMaskOpLoweringPatterns);
+          registry.registerPatterns("vector.shape_cast_lowering",
+                                    populateVectorShapeCastLoweringPatterns);
+          registry.registerPatterns(
+              "vector.transfer_lowering",
+              [&](RewritePatternSet &set, PatternBenefit benefit) {
+                return populateVectorTransferLoweringPatterns(
+                    set, /*maxTransferRank=*/std::nullopt, benefit);
+              });
+          registry.registerPatterns(
+              "vector.transfer_permutation_map_lowering",
+              populateVectorTransferPermutationMapLoweringPatterns);
+          registry.registerPatterns("vector.scan_lowering",
+                                    populateVectorScanLoweringPatterns);
+          registry.registerPatterns("vector.vector_gather_lowering",
+                                    populateVectorGatherLoweringPatterns);
+          registry.registerPatterns(
+              "vector.mask_lowering_for_side_effecting_ops",
+              populateVectorMaskLoweringPatternsForSideEffectingOps);
+        });
   }
 };
 } // namespace


        


More information about the Mlir-commits mailing list