[Mlir-commits] [mlir] 20245ed - [mlir][transform] Add `apply_cse` option to `transform.apply_patterns` op
Matthias Springer
llvmlistbot at llvm.org
Fri Jul 21 06:19:15 PDT 2023
Author: Matthias Springer
Date: 2023-07-21T15:13:56+02:00
New Revision: 20245ed4dea067d281e5d091badf7bcffbb1445b
URL: https://github.com/llvm/llvm-project/commit/20245ed4dea067d281e5d091badf7bcffbb1445b
DIFF: https://github.com/llvm/llvm-project/commit/20245ed4dea067d281e5d091badf7bcffbb1445b.diff
LOG: [mlir][transform] Add `apply_cse` option to `transform.apply_patterns` op
Applying the canonicalizer and CSE in an interleaved fashion is useful after bufferization (and maybe other transforms) to fold away self copies.
Differential Revision: https://reviews.llvm.org/D155933
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/test-pattern-application.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 37312e8f7420ab..5af2649ae519fc 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -172,6 +172,10 @@ def ApplyPatternsOp : TransformDialectOp<"apply_patterns",
in which patterns are applied is unspecified; i.e., the ordering of ops in
the region of this op is irrelevant.
+ If `apple_cse` is set, the greedy pattern rewrite is interleaved with
+ common subexpression elimination (CSE): both are repeated until a fixpoint
+ is reached.
+
This transform only reads the target handle and modifies the payload. If a
pattern erases or replaces a tracked op, the mapping is updated accordingly.
@@ -188,7 +192,7 @@ def ApplyPatternsOp : TransformDialectOp<"apply_patterns",
}];
let arguments = (ins
- TransformHandleTypeInterface:$target);
+ TransformHandleTypeInterface:$target, UnitAttr:$apply_cse);
let results = (outs);
let regions = (region MaxSizedRegion<1>:$patterns);
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 20d572d994c7bc..c9ecec7659ccba 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -317,32 +317,52 @@ DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
GreedyRewriteConfig config;
config.listener =
static_cast<RewriterBase::Listener *>(rewriter.getListener());
+ FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+
+ // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
+ // was requested, apply the greedy pattern rewrite only once. (The greedy
+ // pattern rewrite driver already iterates to a fixpoint internally.)
+ bool cseChanged = false;
+ // One or two iterations should be sufficient. Stop iterating after a certain
+ // threshold to make debugging easier.
+ static const int64_t kNumMaxIterations = 50;
+ int64_t iteration = 0;
+ do {
+ LogicalResult result = failure();
+ if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
+ // Op is isolated from above. Apply patterns and also perform region
+ // simplification.
+ result = applyPatternsAndFoldGreedily(target, frozenPatterns, config);
+ } else {
+ // Manually gather list of ops because the other
+ // GreedyPatternRewriteDriver overloads only accepts ops that are isolated
+ // from above. This way, patterns can be applied to ops that are not
+ // isolated from above. Regions are not being simplified. Furthermore,
+ // only a single greedy rewrite iteration is performed.
+ SmallVector<Operation *> ops;
+ target->walk([&](Operation *nestedOp) {
+ if (target != nestedOp)
+ ops.push_back(nestedOp);
+ });
+ result = applyOpPatternsAndFold(ops, frozenPatterns, config);
+ }
- LogicalResult result = failure();
- if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
- // Op is isolated from above. Apply patterns and also perform region
- // simplification.
- result = applyPatternsAndFoldGreedily(target, std::move(patterns), config);
- } else {
- // Manually gather list of ops because the other GreedyPatternRewriteDriver
- // overloads only accepts ops that are isolated from above. This way,
- // patterns can be applied to ops that are not isolated from above. Regions
- // are not being simplified. Furthermore, only a single greedy rewrite
- // iteration is performed.
- SmallVector<Operation *> ops;
- target->walk([&](Operation *nestedOp) {
- if (target != nestedOp)
- ops.push_back(nestedOp);
- });
- result = applyOpPatternsAndFold(ops, std::move(patterns), config);
- }
+ // A failure typically indicates that the pattern application did not
+ // converge.
+ if (failed(result)) {
+ return emitSilenceableFailure(target)
+ << "greedy pattern application failed";
+ }
- // A failure typically indicates that the pattern application did not
- // converge.
- if (failed(result)) {
- return emitSilenceableFailure(target)
- << "greedy pattern application failed";
- }
+ if (getApplyCse()) {
+ DominanceInfo domInfo;
+ mlir::eliminateCommonSubExpressions(rewriter, domInfo, target,
+ &cseChanged);
+ }
+ } while (cseChanged && ++iteration < kNumMaxIterations);
+
+ if (iteration == kNumMaxIterations)
+ return emitDefiniteFailure() << "fixpoint iteration did not converge";
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
index 992f78623a825a..a9a5e43cc06774 100644
--- a/mlir/test/Dialect/Transform/test-pattern-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -210,3 +210,24 @@ module {
}
}
}
+
+// -----
+
+// CHECK-LABEL: func @canonicalization_and_cse(
+// CHECK-NOT: memref.subview
+// CHECK-NOT: memref.copy
+func.func @canonicalization_and_cse(%m: memref<5xf32>) {
+ %c2 = arith.constant 2 : index
+ %s0 = memref.subview %m[1] [2] [1] : memref<5xf32> to memref<2xf32, strided<[1], offset: 1>>
+ %s1 = memref.subview %m[1] [%c2] [1] : memref<5xf32> to memref<?xf32, strided<[1], offset: 1>>
+ memref.copy %s0, %s1 : memref<2xf32, strided<[1], offset: 1>> to memref<?xf32, strided<[1], offset: 1>>
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %1 {
+ transform.apply_patterns.canonicalization
+ } {apply_cse} : !transform.any_op
+}
More information about the Mlir-commits
mailing list