[Mlir-commits] [mlir] 60f06bc - [mlir][transform] ApplyPatternsOp: Register canonicalization patterns
Matthias Springer
llvmlistbot at llvm.org
Mon Jun 5 02:42:32 PDT 2023
Author: Matthias Springer
Date: 2023-06-05T11:37:43+02:00
New Revision: 60f06bc5bbd9bcb0d8e5e5b879da36fa7f210e84
URL: https://github.com/llvm/llvm-project/commit/60f06bc5bbd9bcb0d8e5e5b879da36fa7f210e84
DIFF: https://github.com/llvm/llvm-project/commit/60f06bc5bbd9bcb0d8e5e5b879da36fa7f210e84.diff
LOG: [mlir][transform] ApplyPatternsOp: Register canonicalization patterns
Also support replacing payload ops with ConstantLike ops in the TrackingListener, even if the replacement op does not have the same name. (Not supported for ops with multiple results, as this would require splitting the handle.)
Differential Revision: https://reviews.llvm.org/D152127
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/test-pattern-application.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 57a7bd33acfc5..b6740502d2bd1 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -137,7 +137,9 @@ def ApplyPatternsOp : TransformDialectOp<"apply_patterns",
Only patterns that were registered in the transform dialect's
`PatternRegistry` are available. Additional patterns can be registered as
- part of transform dialect extensions.
+ part of transform dialect extensions. "canonicalization" is a special set
+ of patterns that refers to all canonicalization patterns of all loaded
+ dialects.
This transform only reads the target handle and modifies the payload. If a
pattern erases or replaces a tracked op, the mapping is updated accordingly.
diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
index d0759941f1ad3..20bed31c34203 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
@@ -57,6 +57,16 @@ void transform::TransformDialect::initialize() {
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
>();
initializeTypes();
+
+ // Register all canonicalization patterns.
+ getOrCreateExtraData<transform::PatternRegistry>().registerPatterns(
+ "canonicalization", [](RewritePatternSet &patterns) {
+ MLIRContext *ctx = patterns.getContext();
+ for (Dialect *dialect : ctx->getLoadedDialects())
+ dialect->getCanonicalizationPatterns(patterns);
+ for (RegisteredOperationName op : ctx->getRegisteredOperations())
+ op.getCanonicalizationPatterns(patterns, ctx);
+ });
}
Type transform::TransformDialect::parseType(DialectAsmParser &parser) const {
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 49ec075f60700..987c8489703c5 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -89,6 +89,11 @@ transform::TrackingListener::findReplacementOp(Operation *op,
if (op->getName() == defOp->getName())
return defOp;
+ // Replacing an op with a constant-like equivalent is a common
+ // canonicalization.
+ if (defOp->hasTrait<OpTrait::ConstantLike>())
+ return defOp;
+
values.clear();
// Skip through ops that implement FindPayloadReplacementOpInterface.
diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
index 0df76d808f880..c51543e6be4c4 100644
--- a/mlir/test/Dialect/Transform/test-pattern-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -121,3 +121,23 @@ transform.sequence failures(propagate) {
transform.apply_patterns ["transform.test"] to %0 : !transform.any_op
transform.test_print_remark_at_operand %1, "op was deleted" : !transform.any_op
}
+
+// -----
+
+// CHECK-LABEL: func @canonicalization(
+// CHECK: %[[c5:.*]] = arith.constant 5 : index
+// CHECK: return %[[c5]]
+func.func @canonicalization(%t: tensor<5xf32>) -> index {
+ %c0 = arith.constant 0 : index
+ // expected-remark @below {{op was replaced}}
+ %dim = tensor.dim %t, %c0 : tensor<5xf32>
+ return %dim : index
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["tensor.dim"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns ["canonicalization"] to %1 : !transform.any_op
+ transform.test_print_remark_at_operand %0, "op was replaced" : !transform.any_op
+}
More information about the Mlir-commits
mailing list