[Mlir-commits] [mlir] 9144fed - [mlir] Add option for a cleanup pattern set to SCF tiling helper (#109554)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 4 11:42:58 PDT 2024


Author: Quinn Dawkins
Date: 2024-10-04T14:42:55-04:00
New Revision: 9144fed31b59089f4e3e5fedf7eb87d2695ef843

URL: https://github.com/llvm/llvm-project/commit/9144fed31b59089f4e3e5fedf7eb87d2695ef843
DIFF: https://github.com/llvm/llvm-project/commit/9144fed31b59089f4e3e5fedf7eb87d2695ef843.diff

LOG: [mlir] Add option for a cleanup pattern set to SCF tiling helper (#109554)

The SCF helper for tiling an operation implementing the TilingInterface
and greedily fusing consumers requires an uninterrupted chain of
operations implementing the tiling interface to succeed. There can be
cases with intermediate ops that don't implement the interface but have
producers that could be fused if various canonicalization/simplification
patterns could run in between fusion steps.

This adds an option to SCFTileAndFuseOptions for a pattern set to run
between fusion steps to the ops that result from fusion/tiling. Removed
and newly inserted slices are tracked for continued fusion applications.

See this RFC for more discussion:

https://discourse.llvm.org/t/rfc-split-fusion-portions-of-the-tilinginterface-into-a-new-interface/81155

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
    mlir/test/Dialect/Linalg/transform-op-fuse.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index a997502c34299c..f9036cf96e9a1d 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -295,18 +295,23 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
   let description = [{
     Tiles the operations pointed to by the target handle and fuses their
     producers greedily using the options provided as attributes.
+
+    If `apply_cleanup` is true then slice canonicalization is applied between
+    fusion steps.
   }];
 
   let arguments =
     (ins TransformHandleTypeInterface:$target,
          DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
-         DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange);
+         DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange,
+         DefaultValuedAttr<BoolAttr, "false">:$apply_cleanup);
   let results = (outs TransformHandleTypeInterface:$transformed,
                       Variadic<TransformHandleTypeInterface>:$loops);
 
   let assemblyFormat = [{
     $target ($tile_sizes^)? (`interchange` $tile_interchange^)?
-    attr-dict `:` functional-type(operands, results)
+    (`apply_cleanup` `=` $apply_cleanup^)? attr-dict
+    `:` functional-type(operands, results)
   }];
   let hasVerifier = 1;
 }

diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 77c812cde71533..9f5f9f3fca97ad 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -15,6 +15,7 @@
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
 
 #include <deque>
 
@@ -153,6 +154,11 @@ struct SCFTileAndFuseOptions {
     fusionControlFn = controlFn;
     return *this;
   }
+
+  /// An optional set of rewrite patterns to apply to the results of tiling
+  /// before fusion. This will track deleted and newly inserted
+  /// `tensor.extract_slice` ops and update the worklist.
+  std::optional<FrozenRewritePatternSet> cleanupPatterns = std::nullopt;
 };
 
 /// Fuse the producer of the source of `candidateSliceOp` by computing the

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 0b9223013a0f1b..8e7621754f76bf 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -562,6 +562,15 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
   tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
   scf::SCFTileAndFuseOptions tileAndFuseOptions;
   tileAndFuseOptions.tilingOptions = tilingOptions;
+
+  if (getApplyCleanup()) {
+    MLIRContext *context = rewriter.getContext();
+    RewritePatternSet patterns(context);
+    tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
+    tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
+    tileAndFuseOptions.cleanupPatterns = std::move(patterns);
+  }
+
   LogicalResult result = applyTilingToAll(
       rewriter, getOperation(), state.getPayloadOps(getTarget()),
       tileSizes.size() - llvm::count(tileSizes, 0), transformResults,

diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 50cfd29e6bf907..e2feb10b314540 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -24,6 +24,8 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include <optional>
@@ -1315,6 +1317,104 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
   return generatedSlices;
 }
 
+namespace {
+
+//===----------------------------------------------------------------------===//
+// SliceTrackingListener
+//===----------------------------------------------------------------------===//
+
+/// This class is a listener for tracking the insertion and removal of
+/// `tensor.extract_slice` ops in a worklist. This can be used in a greedy
+/// fusion algorithm to apply cleanup patterns in between fusion steps.
+class SliceTrackingListener : public RewriterBase::Listener {
+public:
+  explicit SliceTrackingListener(
+      std::optional<FrozenRewritePatternSet> patterns);
+  SliceTrackingListener() = default;
+
+  /// Adds the given list of operations to the worklist, and if present, applies
+  /// the list of `patterns` to the newly added operations. This only processes
+  /// the given operations and any newly inserted ones by the pattern set.
+  LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
+
+  /// Add to the new operation worklist if it is an extract_slice.
+  void notifyOperationInserted(Operation *op,
+                               OpBuilder::InsertPoint previous) override;
+
+  /// Shared helper for operation removal from the worklist.
+  void removeOp(Operation *op);
+
+  /// Remove the operation from the worklist.
+  void notifyOperationErased(Operation *op) override;
+
+  /// Remove the operation from the worklist.
+  void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
+
+  /// The worklist for this transformation keeps track of the slices to visit
+  /// next for fusion.
+  std::deque<tensor::ExtractSliceOp> worklist;
+
+private:
+  /// Optional pattern set to apply when adding new operations to the worklist.
+  std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
+};
+
+SliceTrackingListener::SliceTrackingListener(
+    std::optional<FrozenRewritePatternSet> p) {
+  patterns = std::move(p);
+}
+
+LogicalResult
+SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
+  for (Operation *op : ops) {
+    if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
+      worklist.push_back(slice);
+  }
+
+  if (!patterns)
+    return success();
+
+  GreedyRewriteConfig config;
+  config.listener = this;
+  config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
+  return applyOpPatternsAndFold(ops, patterns.value(), config);
+}
+
+void SliceTrackingListener::notifyOperationInserted(
+    Operation *op, OpBuilder::InsertPoint previous) {
+  auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
+  if (!slice)
+    return;
+  worklist.push_back(slice);
+}
+
+// Scan the worklist for the given op and remove it if present. The expectation
+// is for the worklist to be small and for removal to be relatively rare.
+void SliceTrackingListener::removeOp(Operation *op) {
+  if (!isa<tensor::ExtractSliceOp>(op))
+    return;
+  auto iter = worklist.begin();
+  while (iter != worklist.end()) {
+    if (*iter == op)
+      break;
+    iter++;
+  }
+  if (iter == worklist.end())
+    return;
+
+  worklist.erase(iter);
+}
+
+void SliceTrackingListener::notifyOperationErased(Operation *op) {
+  removeOp(op);
+}
+
+void SliceTrackingListener::notifyOperationReplaced(Operation *op,
+                                                    ValueRange replacement) {
+  removeOp(op);
+}
+} // namespace
+
 /// Implementation of tile consumer and fuse producer greedily.
 FailureOr<scf::SCFTileAndFuseResult>
 mlir::scf::tileConsumerAndFuseProducersUsingSCF(
@@ -1370,33 +1470,32 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
     tensor::ExtractSliceOp candidateSlice;
     SCFTileAndFuseOptions::ControlFnResult controlFnResult;
   };
-  std::deque<WorklistItem> worklist;
-  auto addCandidateSlices = [&worklist, &options,
-                             &loops](ArrayRef<Operation *> candidates) {
-    for (auto candidate : candidates) {
-      auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(candidate);
-      if (!sliceOp || sliceOp.use_empty())
-        continue;
 
-      auto [fusableProducer, destinationInitArg] =
-          getUntiledProducerFromSliceSource(&sliceOp.getSourceMutable(), loops);
-      if (!fusableProducer)
-        continue;
-      std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
-          options.fusionControlFn(sliceOp, fusableProducer,
-                                  destinationInitArg.has_value());
-      if (!controlFnResult)
-        continue;
-      worklist.emplace_back(WorklistItem{sliceOp, controlFnResult.value()});
-    }
-  };
+  SliceTrackingListener sliceTracker =
+      SliceTrackingListener(options.cleanupPatterns);
 
-  addCandidateSlices(tilingResult->generatedSlices);
+  if (failed(
+          sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
+    return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
+  }
   OpBuilder::InsertionGuard g(rewriter);
-  while (!worklist.empty()) {
-    // Traverse the slices in BFS fashion.
-    WorklistItem worklistItem = worklist.front();
-    worklist.pop_front();
+  while (!sliceTracker.worklist.empty()) {
+    auto candidateSlice = sliceTracker.worklist.front();
+    sliceTracker.worklist.pop_front();
+
+    auto [fusableProducer, destinationInitArg] =
+        getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(),
+                                          loops);
+    if (!fusableProducer)
+      continue;
+
+    std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
+        options.fusionControlFn(candidateSlice, fusableProducer,
+                                destinationInitArg.has_value());
+    if (!controlFnResult)
+      continue;
+
+    WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
 
     // The operands of the fused producer might themselved be slices of
     // values produced by operations that implement the `TilingInterface`.
@@ -1407,6 +1506,8 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
     if (!fusedResult)
       continue;
 
+    SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;
+
     if (worklistItem.controlFnResult.yieldProducerReplacement) {
       // Reconstruct and yield all opResult of fusableProducerOp by default. The
       // caller can specific which one to yield by designating optional argument
@@ -1421,7 +1522,7 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
             fusableProducerOp, "failed to replacement value for this "
                                "operation from within the tiled loop");
       }
-      addCandidateSlices(newSlices.value());
+      worklistCandidates.append(newSlices.value());
       for (auto [index, result] :
            llvm::enumerate(fusableProducerOp->getResults())) {
         origValToResultNumber[result] = loops.front()->getNumResults() -
@@ -1429,12 +1530,15 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
                                         index;
       }
     }
-    addCandidateSlices(fusedResult->generatedSlices);
     if (Operation *tiledAndFusedOp =
             fusedResult->tiledAndFusedProducer.getDefiningOp()) {
       fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
       tiledAndFusedOps.insert(tiledAndFusedOp);
     }
+
+    if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
+      return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
+    }
   }
 
   DenseMap<Value, Value> replacements;

diff  --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
index 3a023deb1132f3..ac1ca9319d3354 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
@@ -178,3 +178,103 @@ module attributes {transform.with_named_sequence} {
       transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: func.func @fuse_through_slice
+func.func @fuse_through_slice(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+
+  //     CHECK: %[[RES:.*]] = scf.for
+  //     CHECK:     scf.for
+  //     CHECK:       linalg.elemwise_unary
+  //     CHECK:       linalg.elemwise_binary
+  //     CHECK: return %[[RES]]
+  %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
+                             outs(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
+  %dim1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
+  %1 = tensor.extract_slice %0 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+  %2 = linalg.elemwise_binary ins(%1, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+                             outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %2 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true}
+      : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fuse_through_slice_and_cast_chain
+func.func @fuse_through_slice_and_cast_chain(%arg0: tensor<100x100xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+
+  //     CHECK: %[[RES:.*]] = scf.for
+  //     CHECK:     scf.for
+  //     CHECK:       linalg.elemwise_unary
+  //     CHECK:       linalg.elemwise_binary
+  //     CHECK: return %[[RES]]
+  %0 = linalg.elemwise_unary ins(%arg0 : tensor<100x100xf32>)
+                             outs(%arg0: tensor<100x100xf32>) -> tensor<100x100xf32>
+  %1 = tensor.cast %0 : tensor<100x100xf32> to tensor<100x?xf32>
+  %2 = tensor.extract_slice %1 [1, 1] [98, 98] [1, 1] : tensor<100x?xf32> to tensor<98x98xf32>
+  %3 = tensor.cast %2 : tensor<98x98xf32> to tensor<?x?xf32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
+  %dim1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
+  %4 = tensor.extract_slice %3 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+  %5 = linalg.elemwise_binary ins(%4, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+                             outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %5 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true}
+      : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fuse_unrelated_slice
+func.func @fuse_unrelated_slices(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<10x10xf32>) {
+
+  //     CHECK: %[[SLICE1:.+]] = tensor.extract_slice
+  //     CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[SLICE1]]
+  //     CHECK: %[[RES:.*]] = scf.for
+  //     CHECK:     scf.for
+  //     CHECK:       linalg.elemwise_unary
+  //     CHECK:       linalg.elemwise_binary
+  //     CHECK: return %[[RES]], %[[SLICE2]]
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
+  %dim1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
+  %slice1 = tensor.extract_slice %arg0 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+  %slice2 = tensor.extract_slice %slice1 [1, 1] [10, 10] [1, 1] : tensor<?x?xf32> to tensor<10x10xf32>
+  %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
+                             outs(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>
+  %1 = tensor.extract_slice %0 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+  %2 = linalg.elemwise_binary ins(%1, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+                             outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %2, %slice2 : tensor<?x?xf32>, tensor<10x10xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true}
+      : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
+    transform.yield
+  }
+}


        


More information about the Mlir-commits mailing list