[Mlir-commits] [mlir] 4b48063 - [mlir][vector][transform] Register vector dialect patterns
Matthias Springer
llvmlistbot at llvm.org
Fri Jun 2 07:00:08 PDT 2023
Author: Matthias Springer
Date: 2023-06-02T15:59:56+02:00
New Revision: 4b48063b521bfcc9835269c729de79459d93229e
URL: https://github.com/llvm/llvm-project/commit/4b48063b521bfcc9835269c729de79459d93229e
DIFF: https://github.com/llvm/llvm-project/commit/4b48063b521bfcc9835269c729de79459d93229e.diff
LOG: [mlir][vector][transform] Register vector dialect patterns
Differential Revision: https://reviews.llvm.org/D151983
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
index 3e3461bb14f6e..e738baf15c8f9 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
@@ -161,11 +161,20 @@ class PatternRegistry : public TransformDialectData<PatternRegistry> {
/// A function that populates a `RewritePatternSet`.
using PopulatePatternsFn = std::function<void(RewritePatternSet &)>;
+ /// A function that populates a `RewritePatternSet` with a specified benefit.
+ using PopulatePatternsWithBenefitFn =
+ std::function<void(RewritePatternSet &, PatternBenefit)>;
/// Registers patterns with the specified identifier. The identifier should
/// be prefixed with the dialect to which the patterns belong.
void registerPatterns(StringRef identifier, PopulatePatternsFn &&fn);
+ /// Registers patterns with the specified identifier. The identifier should
+ /// be prefixed with the dialect to which the patterns belong. The pattern
+ /// benefit is currently ignored.
+ void registerPatterns(StringRef identifier,
+ PopulatePatternsWithBenefitFn &&fn);
+
protected:
friend class ApplyPatternsOp;
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index c076a8cab89ea..c8326f8fd9b06 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -223,6 +223,15 @@ void transform::PatternRegistry::registerPatterns(StringRef identifier,
patterns.try_emplace(attr, std::move(fn));
}
+void transform::PatternRegistry::registerPatterns(
+ StringRef identifier, PopulatePatternsWithBenefitFn &&fn) {
+ StringAttr attr = builder.getStringAttr(identifier);
+ assert(!patterns.contains(attr) && "patterns identifier is already in use");
+ patterns.try_emplace(attr, [f = move(fn)](RewritePatternSet &patternSet) {
+ f(patternSet, /*benefit=*/1);
+ });
+}
+
void transform::PatternRegistry::populatePatterns(
StringAttr identifier, RewritePatternSet &patternSet) const {
auto it = patterns.find(identifier);
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 9c7184de88119..44caaec2d1910 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -188,6 +189,34 @@ class VectorTransformDialectExtension
#define GET_OP_LIST
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
>();
+
+ addDialectDataInitializer<transform::PatternRegistry>(
+ [&](transform::PatternRegistry ®istry) {
+ registry.registerPatterns("vector.outer_product_lowering",
+ populateVectorOuterProductLoweringPatterns);
+ registry.registerPatterns("vector.broadcast_lowering",
+ populateVectorBroadcastLoweringPatterns);
+ registry.registerPatterns("vector.mask_op_lowering",
+ populateVectorMaskOpLoweringPatterns);
+ registry.registerPatterns("vector.shape_cast_lowering",
+ populateVectorShapeCastLoweringPatterns);
+ registry.registerPatterns(
+ "vector.transfer_lowering",
+ [&](RewritePatternSet &set, PatternBenefit benefit) {
+ return populateVectorTransferLoweringPatterns(
+ set, /*maxTransferRank=*/std::nullopt, benefit);
+ });
+ registry.registerPatterns(
+ "vector.transfer_permutation_map_lowering",
+ populateVectorTransferPermutationMapLoweringPatterns);
+ registry.registerPatterns("vector.scan_lowering",
+ populateVectorScanLoweringPatterns);
+ registry.registerPatterns("vector.vector_gather_lowering",
+ populateVectorGatherLoweringPatterns);
+ registry.registerPatterns(
+ "vector.mask_lowering_for_side_effecting_ops",
+ populateVectorMaskLoweringPatternsForSideEffectingOps);
+ });
}
};
} // namespace
More information about the Mlir-commits
mailing list