[Mlir-commits] [mlir] 20245ed - [mlir][transform] Add `apply_cse` option to `transform.apply_patterns` op

Matthias Springer llvmlistbot at llvm.org
Fri Jul 21 06:19:15 PDT 2023


Author: Matthias Springer
Date: 2023-07-21T15:13:56+02:00
New Revision: 20245ed4dea067d281e5d091badf7bcffbb1445b

URL: https://github.com/llvm/llvm-project/commit/20245ed4dea067d281e5d091badf7bcffbb1445b
DIFF: https://github.com/llvm/llvm-project/commit/20245ed4dea067d281e5d091badf7bcffbb1445b.diff

LOG: [mlir][transform] Add `apply_cse` option to `transform.apply_patterns` op

Applying the canonicalizer and CSE in an interleaved fashion is useful after bufferization (and maybe other transforms) to fold away self copies.

Differential Revision: https://reviews.llvm.org/D155933

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Transform/test-pattern-application.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 37312e8f7420ab..5af2649ae519fc 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -172,6 +172,10 @@ def ApplyPatternsOp : TransformDialectOp<"apply_patterns",
     in which patterns are applied is unspecified; i.e., the ordering of ops in
     the region of this op is irrelevant.
 
+    If `apple_cse` is set, the greedy pattern rewrite is interleaved with
+    common subexpression elimination (CSE): both are repeated until a fixpoint
+    is reached.
+
     This transform only reads the target handle and modifies the payload. If a
     pattern erases or replaces a tracked op, the mapping is updated accordingly.
 
@@ -188,7 +192,7 @@ def ApplyPatternsOp : TransformDialectOp<"apply_patterns",
   }];
 
   let arguments = (ins
-    TransformHandleTypeInterface:$target);
+    TransformHandleTypeInterface:$target, UnitAttr:$apply_cse);
   let results = (outs);
   let regions = (region MaxSizedRegion<1>:$patterns);
 

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 20d572d994c7bc..c9ecec7659ccba 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -317,32 +317,52 @@ DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
   GreedyRewriteConfig config;
   config.listener =
       static_cast<RewriterBase::Listener *>(rewriter.getListener());
+  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+
+  // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
+  // was requested, apply the greedy pattern rewrite only once. (The greedy
+  // pattern rewrite driver already iterates to a fixpoint internally.)
+  bool cseChanged = false;
+  // One or two iterations should be sufficient. Stop iterating after a certain
+  // threshold to make debugging easier.
+  static const int64_t kNumMaxIterations = 50;
+  int64_t iteration = 0;
+  do {
+    LogicalResult result = failure();
+    if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
+      // Op is isolated from above. Apply patterns and also perform region
+      // simplification.
+      result = applyPatternsAndFoldGreedily(target, frozenPatterns, config);
+    } else {
+      // Manually gather list of ops because the other
+      // GreedyPatternRewriteDriver overloads only accepts ops that are isolated
+      // from above. This way, patterns can be applied to ops that are not
+      // isolated from above. Regions are not being simplified. Furthermore,
+      // only a single greedy rewrite iteration is performed.
+      SmallVector<Operation *> ops;
+      target->walk([&](Operation *nestedOp) {
+        if (target != nestedOp)
+          ops.push_back(nestedOp);
+      });
+      result = applyOpPatternsAndFold(ops, frozenPatterns, config);
+    }
 
-  LogicalResult result = failure();
-  if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
-    // Op is isolated from above. Apply patterns and also perform region
-    // simplification.
-    result = applyPatternsAndFoldGreedily(target, std::move(patterns), config);
-  } else {
-    // Manually gather list of ops because the other GreedyPatternRewriteDriver
-    // overloads only accepts ops that are isolated from above. This way,
-    // patterns can be applied to ops that are not isolated from above. Regions
-    // are not being simplified. Furthermore, only a single greedy rewrite
-    // iteration is performed.
-    SmallVector<Operation *> ops;
-    target->walk([&](Operation *nestedOp) {
-      if (target != nestedOp)
-        ops.push_back(nestedOp);
-    });
-    result = applyOpPatternsAndFold(ops, std::move(patterns), config);
-  }
+    // A failure typically indicates that the pattern application did not
+    // converge.
+    if (failed(result)) {
+      return emitSilenceableFailure(target)
+             << "greedy pattern application failed";
+    }
 
-  // A failure typically indicates that the pattern application did not
-  // converge.
-  if (failed(result)) {
-    return emitSilenceableFailure(target)
-           << "greedy pattern application failed";
-  }
+    if (getApplyCse()) {
+      DominanceInfo domInfo;
+      mlir::eliminateCommonSubExpressions(rewriter, domInfo, target,
+                                          &cseChanged);
+    }
+  } while (cseChanged && ++iteration < kNumMaxIterations);
+
+  if (iteration == kNumMaxIterations)
+    return emitDefiniteFailure() << "fixpoint iteration did not converge";
 
   return DiagnosedSilenceableFailure::success();
 }

diff  --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
index 992f78623a825a..a9a5e43cc06774 100644
--- a/mlir/test/Dialect/Transform/test-pattern-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -210,3 +210,24 @@ module {
     }
   }
 }
+
+// -----
+
+// CHECK-LABEL: func @canonicalization_and_cse(
+//   CHECK-NOT:   memref.subview
+//   CHECK-NOT:   memref.copy
+func.func @canonicalization_and_cse(%m: memref<5xf32>) {
+  %c2 = arith.constant 2 : index
+  %s0 = memref.subview %m[1] [2] [1] : memref<5xf32> to memref<2xf32, strided<[1], offset: 1>>
+  %s1 = memref.subview %m[1] [%c2] [1] : memref<5xf32> to memref<?xf32, strided<[1], offset: 1>>
+  memref.copy %s0, %s1 : memref<2xf32, strided<[1], offset: 1>> to memref<?xf32, strided<[1], offset: 1>>
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %1 {
+    transform.apply_patterns.canonicalization
+  } {apply_cse} : !transform.any_op
+}


        


More information about the Mlir-commits mailing list