[Mlir-commits] [mlir] 97f9198 - [mlir][TilingInterface] NFC Refactor of tile and fuse using `TilingInterface`.

Mahesh Ravishankar llvmlistbot at llvm.org
Wed Sep 28 13:26:00 PDT 2022


Author: Mahesh Ravishankar
Date: 2022-09-28T20:25:33Z
New Revision: 97f919820b075fe49393405bf0ea990cf820ffeb

URL: https://github.com/llvm/llvm-project/commit/97f919820b075fe49393405bf0ea990cf820ffeb
DIFF: https://github.com/llvm/llvm-project/commit/97f919820b075fe49393405bf0ea990cf820ffeb.diff

LOG: [mlir][TilingInterface] NFC Refactor of tile and fuse using `TilingInterface`.

This patch refactors the tiling and tile + fuse implementation using
`TilingInterface`. Primarily, it exposes the functionality as simple
utility functions instead of as a Pattern to allow calling it from a
pattern as it is done in the test today or from within the transform
dialect (in the future). This is a step towards deprecating similar
methods in Linalg dialect.

- The utility methods do not erase the root operations.
- The return value provides the values to use for replacements.

Differential Revision: https://reviews.llvm.org/D134144

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
    mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
    mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index a56b6b44e4657..1c374d62425d1 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -60,38 +60,48 @@ struct SCFTilingOptions {
   }
 };
 
+/// Transformation information returned after tiling.
 struct SCFTilingResult {
+  /// The tiled operation generated.
   Operation *tiledOp;
+  /// The `scf.for` operations that iterate over the tiles.
   SmallVector<scf::ForOp> loops;
+  /// Values to use as replacements for the untiled op. Is the same size as the
+  /// number of results of the untiled op.
+  SmallVector<Value> replacements;
 };
 
-/// Pattern to tile an op that implements the `TilingInterface` using
+/// Method to tile an op that implements the `TilingInterface` using
 /// `scf.for` for iterating over the tiles.
-struct TileUsingSCFForOp : public OpInterfaceRewritePattern<TilingInterface> {
-  /// Construct a generic pattern applied to all TilingInterface ops.
-  TileUsingSCFForOp(MLIRContext *context, SCFTilingOptions options,
-                    PatternBenefit benefit = 1);
-
-  /// Construct a generic pattern applied to `opName`.
-  TileUsingSCFForOp(StringRef opName, MLIRContext *context,
-                    SCFTilingOptions options, PatternBenefit benefit = 1);
-
-  /// `matchAndRewrite` implementation that returns the significant transformed
-  /// pieces of IR.
-  FailureOr<SCFTilingResult>
-  returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const;
-
-  LogicalResult matchAndRewrite(TilingInterface op,
-                                PatternRewriter &rewriter) const override {
-    return returningMatchAndRewrite(op, rewriter);
+FailureOr<SCFTilingResult> tileUsingSCFForOp(RewriterBase &rewriter,
+                                             TilingInterface op,
+                                             SCFTilingOptions options);
+
+/// Options used to control tile + fuse.
+struct SCFTileAndFuseOptions {
+  /// The tiling options used to control the tiling of the consumer.
+  SCFTilingOptions tilingOptions;
+  SCFTileAndFuseOptions &setTilingOptions(SCFTilingOptions options) {
+    tilingOptions = options;
+    return *this;
   }
+};
 
-private:
-  /// Options to control tiling;
-  SCFTilingOptions options;
+/// Transformation information returned after tile and fuse.
+struct SCFTileAndFuseResult {
+  /// List of untiled operations that were fused with the tiled consumer.
+  llvm::SetVector<Operation *> fusedProducers;
+  /// List of tiled and fused operations generated. The first one in this list
+  /// is guaranteed to be the tiled operations generated during tiling of the
+  /// generated operation.
+  llvm::SetVector<Operation *> tiledAndFusedOps;
+  /// The `scf.for` operations that iterate over the tiles.
+  SmallVector<scf::ForOp> loops;
+  /// The replacement values to use for the tiled and fused operations.
+  llvm::DenseMap<Value, Value> replacements;
 };
 
-/// Pattern to tile and fuse a sequence of operations, by tiling the consumer
+/// Method to tile and fuse a sequence of operations, by tiling the consumer
 /// and fusing its producers. Note that this assumes that it is valid to
 /// tile+fuse the producer into the innermost tiled loop. Its up to the caller
 /// to ensure that the tile sizes provided make this fusion valid.
@@ -99,64 +109,32 @@ struct TileUsingSCFForOp : public OpInterfaceRewritePattern<TilingInterface> {
 /// For example, for the following sequence
 ///
 /// ```mlir
-/// %0 = linalg.fill ...
-/// %1 = linalg.matmul ... outs(%0 : ...) ...
+/// %0 =
+/// %1 = linalg.fill ... outs(%0 : ... )
+/// %2 = linalg.matmul ... outs(%1 : ...) ...
 /// ```
 ///
 /// it is legal to fuse the fill with the matmul only if the matmul is tiled
 /// along the parallel dimensions and not the reduction dimension, i.e. the tile
-/// size for the reduction dimension should be 0.
-struct SCFTileAndFuseResult {
-  SmallVector<Operation *> tiledAndFusedOps;
-  SmallVector<scf::ForOp> loops;
-};
-struct TileConsumerAndFuseProducersUsingSCFForOp
-    : public OpInterfaceRewritePattern<TilingInterface> {
-
-  /// Construct a generic pattern applied to all TilingInterface ops.
-  TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context,
-                                            SCFTilingOptions options,
-                                            PatternBenefit benefit = 1);
-
-  /// Construct a generic pattern applied to `opName`.
-  TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName,
-                                            MLIRContext *context,
-                                            SCFTilingOptions options,
-                                            PatternBenefit benefit = 1);
-
-  /// `matchAndRewrite` implementation that returns the significant transformed
-  /// pieces of IR.
-  FailureOr<SCFTileAndFuseResult>
-  returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const;
-
-  LogicalResult matchAndRewrite(TilingInterface op,
-                                PatternRewriter &rewriter) const override {
-    return returningMatchAndRewrite(op, rewriter);
-  }
-
-private:
-  /// This pattern uses the tiling pattern. Instead of using inheritance, use
-  /// the patterns as private object that is instantiated at the same time as
-  /// this pattern.
-  TileUsingSCFForOp tilingPattern;
-};
-
-/// Pattern to lower operations that implement the `TilingInterface` to
-/// loops/scalar IR using `scf.for`.
-struct LowerToLoopsUsingSCFForOp
-    : public OpInterfaceRewritePattern<TilingInterface> {
-  using OpInterfaceRewritePattern<TilingInterface>::OpInterfaceRewritePattern;
-
-  /// `matchAndRewrite` implementation that returns the significant transformed
-  /// pieces of IR.
-  FailureOr<SmallVector<scf::ForOp>>
-  returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const;
-
-  LogicalResult matchAndRewrite(TilingInterface op,
-                                PatternRewriter &rewriter) const override {
-    return returningMatchAndRewrite(op, rewriter);
-  }
-};
+/// size for the reduction dimension should be 0. The resulting fused
+/// transformation is
+///
+/// ```mlir
+/// %1 = scf.for ... iter_args(%arg0 = %0)
+///   %2 = tensor.extract_slice %arg0
+///   %3 = linalg.fill .. outs(%2 : ... )
+///   %4 = linalg.matmul .. outs(%3 : ...)
+/// }
+/// ```
+FailureOr<SCFTileAndFuseResult>
+tileConsumerAndFuseProducerGreedilyUsingSCFForOp(RewriterBase &rewriter,
+                                                 TilingInterface consumer,
+                                                 SCFTileAndFuseOptions options);
+
+/// Method to lower an `op` that implements the `TilingInterface` to
+/// loops/scalars.
+FailureOr<SmallVector<scf::ForOp>>
+lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
 
 } // namespace scf
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 0c6ba3d195da5..5342be5cfdc65 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -87,7 +87,7 @@ static bool isPermutation(ArrayRef<unsigned> interchange) {
 }
 
 //===----------------------------------------------------------------------===//
-// TileUsingSCFForOp pattern implementation.
+// tileUsingSCFForOp implementation.
 //===----------------------------------------------------------------------===//
 
 // Check if `stride` evenly divides the trip count `size - offset`.
@@ -167,7 +167,65 @@ generateTileLoopNest(OpBuilder &builder, Location loc,
   return loops;
 }
 
-/// If the tiled operation is in destination passing style, update the
+/// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`,
+/// construct the destructive update pattern that inserts the yielded
+/// value into a destination tensor provided by `initValue` at offset
+/// `tileOffsets` and size `tileSizes`. For example,
+///
+/// ```mlir
+/// scf.for %iv0 = ... {
+///   %0 = tiled_op
+/// }
+/// ```
+///
+/// is transformed to
+///
+/// ```mlir
+/// scf.for %iv0 = ... iter_args(%arg = %0) {
+///   %1 = tensor.extract_slice %arg
+///   %2 = tiled_op
+///   %3 = tensor.insert_slice %2 into %arg
+///   scf.yield %3
+/// }
+/// ```
+/// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`.
+static FailureOr<SmallVector<Value>>
+yieldTiledValues(RewriterBase &rewriter, ValueRange initValues,
+                 ValueRange yieldedValues,
+                 ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
+                 ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
+                 MutableArrayRef<scf::ForOp> loops) {
+  NewYieldValueFn yieldValueFn =
+      [&](OpBuilder &b, Location loc,
+          ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
+    SmallVector<Value> inserts;
+    for (auto yieldedValue : llvm::enumerate(yieldedValues)) {
+      ArrayRef<OpFoldResult> tileOffsets =
+          tileOffsetsList[yieldedValue.index()];
+      ArrayRef<OpFoldResult> tileSizes = tileSizesList[yieldedValue.index()];
+      SmallVector<OpFoldResult> tileStrides(tileOffsets.size(),
+                                            b.getIndexAttr(1));
+      Value insert = b.create<tensor::InsertSliceOp>(
+          loc, yieldedValue.value(), newBBArgs[yieldedValue.index()],
+          tileOffsets, tileSizes, tileStrides);
+      inserts.push_back(insert);
+    }
+    return inserts;
+  };
+
+  SmallVector<scf::ForOp> newLoops =
+      replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn,
+                                   /*replaceIterOperandsUsesInLoop =*/false);
+  for (const auto &loop : llvm::enumerate(loops)) {
+    rewriter.eraseOp(loop.value());
+    loops[loop.index()] = newLoops[loop.index()];
+  }
+  return llvm::to_vector(llvm::map_range(
+      loops.front().getResults().take_back(yieldedValues.size()),
+      [](OpResult r) -> Value { return r; }));
+}
+
+/// If the tiled operation is destination passing style, update the
 /// slice of the destination used (which refers to the untiled destination)
 /// to use the corresponding region argument of the innermost loop.
 ///
@@ -191,8 +249,6 @@ generateTileLoopNest(OpBuilder &builder, Location loc,
 ///   scf.yield %3
 /// }
 /// ```
-/// TODO: This can be made much cleaner when `DestinationStyleOp` interface is
-/// available generally.
 static void
 updateDestinationOperandsForTiledOp(OpBuilder &builder,
                                     ValueRange tiledOpDestinationValues,
@@ -205,22 +261,11 @@ updateDestinationOperandsForTiledOp(OpBuilder &builder,
   }
 }
 
-scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context,
-                                          scf::SCFTilingOptions options,
-                                          PatternBenefit benefit)
-    : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
-      options(std::move(options)) {}
-
-scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName,
-                                          MLIRContext *context,
-                                          scf::SCFTilingOptions options,
-                                          PatternBenefit benefit)
-    : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
-      options(std::move(options)) {}
-
+/// Implementation of tiling transformation of `op` that implements the
+/// `TilingInterface` using `scf.for` to iterate over the tiles.
 FailureOr<scf::SCFTilingResult>
-scf::TileUsingSCFForOp::returningMatchAndRewrite(
-    TilingInterface op, PatternRewriter &rewriter) const {
+mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
+                             scf::SCFTilingOptions options) {
   OpBuilder::InsertionGuard guard(rewriter);
   rewriter.setInsertionPointAfter(op);
 
@@ -282,132 +327,86 @@ scf::TileUsingSCFForOp::returningMatchAndRewrite(
       offsets = applyPermutationToVector(offsets, inversePermutation);
       sizes = applyPermutationToVector(sizes, inversePermutation);
     }
+  }
 
-    LLVM_DEBUG({
-      if (!tilingResult.loops.empty()) {
-        llvm::errs() << "LoopNest shell :\n";
-        tilingResult.loops.front().dump();
-        llvm::errs() << "\n";
-      }
-    });
-
-    // 4. Generate the tiled implementation within the inner most loop.
-    if (!tilingResult.loops.empty())
-      rewriter.setInsertionPoint(
-          tilingResult.loops.back().getBody()->getTerminator());
-    SmallVector<Operation *> tiledImplementation =
-        op.getTiledImplementation(rewriter, offsets, sizes);
-    if (tiledImplementation.size() != 1) {
-      return rewriter.notifyMatchFailure(
-          op, "expected tiled implementation to return a single op");
+  LLVM_DEBUG({
+    if (!tilingResult.loops.empty()) {
+      llvm::dbgs() << "LoopNest shell :\n";
+      tilingResult.loops.front().dump();
+      llvm::dbgs() << "\n";
     }
-    tilingResult.tiledOp = tiledImplementation[0];
-
-    LLVM_DEBUG({
-      if (!tilingResult.loops.empty()) {
-        llvm::errs() << "After tiled implementation :\n";
-        tilingResult.loops.front().dump();
-        llvm::errs() << "\n";
-      }
-    });
+  });
+
+  // 4. Generate the tiled implementation within the inner most loop.
+  if (!tilingResult.loops.empty())
+    rewriter.setInsertionPoint(
+        tilingResult.loops.back().getBody()->getTerminator());
+  SmallVector<Operation *> tiledImplementation =
+      op.getTiledImplementation(rewriter, offsets, sizes);
+  if (tiledImplementation.size() != 1) {
+    return rewriter.notifyMatchFailure(
+        op, "expected tiled implementation to return a single op");
   }
-
+  tilingResult.tiledOp = tiledImplementation[0];
   if (op->getNumResults() == 0) {
-    rewriter.eraseOp(op);
+    // nothing more to do.
     return tilingResult;
   }
 
-  // 5. If the original operations has results, modify the loop nest to yield
-  // the replacement values.
+  // If loops are empty, the tiled op is used as the replacement for the untiled
+  // op.
   if (tilingResult.loops.empty()) {
-    // 5a. If there were no loops, the tiled implementation results are the
-    // replacements.
-    rewriter.replaceOp(op, tilingResult.tiledOp->getResults());
+    tilingResult.replacements = llvm::to_vector(
+        llvm::map_range(tiledImplementation[0]->getResults(),
+                        [](OpResult result) -> Value { return result; }));
     return tilingResult;
   }
 
-  // 6. Yield the results of the tiled operation from the loop nest as
-  //    replacements for the original untiled ops.
-  if (tilingResult.tiledOp->getNumResults() != op->getNumResults()) {
-    return rewriter.notifyMatchFailure(
-        tilingResult.tiledOp,
-        "expected tiled op to have as many results as the untiled operation");
+  // 5. Yield all the results of the tiled operation. The surrounding loop
+  //    nest is modified to insert a destructive update pattern to yield
+  //    from the loop nest values to replace the untiled op with.
+  unsigned numResults = op->getNumResults();
+  SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults),
+      resultSizesList(numResults);
+  for (auto result : llvm::enumerate(op->getResults())) {
+    if (failed(op.getResultTilePosition(rewriter, result.index(), offsets,
+                                        sizes,
+                                        resultOffsetsList[result.index()],
+                                        resultSizesList[result.index()]))) {
+      return rewriter.notifyMatchFailure(
+          op, "failed to get slice of result produced");
+    }
   }
 
-  // `scf.for` with tensor semantics requires the loop nest to yield the
-  // replacement values using destructive updates. Use the `TilingInterface`
-  // to get the position of the result tiles and use that to generate the
-  // destructive update pattern, i.e.,
-  //
-  // ```mlir
-  // scf.for %iv0 = ... {
-  //   %0 = tiled_op
-  // }
-  // ```
-  //
-  // is transformed to
-  //
-  // ```mlir
-  // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. {
-  //   %0 = tiled_op
-  //   %1 = tensor.insert_slice %0 into %arg[..] [..] [..]
-  //   scf.yield %1
-  // }
-  // ```
-  NewYieldValueFn yieldValueFn =
-      [&](OpBuilder &b, Location loc,
-          ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
-    SmallVector<Value> yieldedValues;
-    Attribute one = b.getIndexAttr(1);
-    for (auto resultNum : llvm::seq<unsigned>(0, op->getNumResults())) {
-      SmallVector<OpFoldResult> resultTileOffsets, resultTileSizes;
-      if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes,
-                                          resultTileOffsets,
-                                          resultTileSizes))) {
-        op.emitOpError("unable to get position of result ")
-            << resultNum << " of the tiled implementation";
-        return {};
-      }
-      SmallVector<OpFoldResult> resultTileStrides(resultTileOffsets.size(),
-                                                  one);
-      Value yieldedValue = b.create<tensor::InsertSliceOp>(
-          op->getLoc(), tilingResult.tiledOp->getResult(resultNum),
-          newBBArgs[resultNum], resultTileOffsets, resultTileSizes,
-          resultTileStrides);
-      yieldedValues.push_back(yieldedValue);
-    }
-    return yieldedValues;
-  };
-  SmallVector<scf::ForOp> newLoops = replaceLoopNestWithNewYields(
-      rewriter, tilingResult.loops, op.getDestinationOperands(rewriter),
-      yieldValueFn, /*replaceIterOperandsUsesInLoops =*/false);
-  for (const auto &loop : llvm::enumerate(tilingResult.loops)) {
-    rewriter.eraseOp(loop.value());
-    tilingResult.loops[loop.index()] = newLoops[loop.index()];
+  FailureOr<SmallVector<Value>> replacementOr =
+      yieldTiledValues(rewriter, op.getDestinationOperands(rewriter),
+                       tilingResult.tiledOp->getResults(), resultOffsetsList,
+                       resultSizesList, tilingResult.loops);
+  if (failed(replacementOr))
+    return rewriter.notifyMatchFailure(op, "failed to yield replacement");
+  if (auto tiledInterfaceOp = dyn_cast<TilingInterface>(tilingResult.tiledOp)) {
+    auto innerMostLoop = tilingResult.loops.back();
+    updateDestinationOperandsForTiledOp(
+        rewriter, tiledInterfaceOp.getDestinationOperands(rewriter),
+        innerMostLoop.getRegionIterArgs());
   }
-  rewriter.replaceOp(op, tilingResult.loops.front().getResults());
+
+  tilingResult.replacements = replacementOr.value();
+
+  LLVM_DEBUG({
+    if (!tilingResult.loops.empty()) {
+      llvm::dbgs() << "After tiled implementation :\n";
+      tilingResult.loops.front().dump();
+      llvm::dbgs() << "\n";
+    }
+  });
   return tilingResult;
 }
 
 //===----------------------------------------------------------------------===//
-// TileConsumerAndFuseProducersUsingSCFForOp pattern implementation.
+// tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
 //===----------------------------------------------------------------------===//
 
-scf::TileConsumerAndFuseProducersUsingSCFForOp::
-    TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context,
-                                              scf::SCFTilingOptions options,
-                                              PatternBenefit benefit)
-    : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
-      tilingPattern(context, std::move(options)) {}
-
-scf::TileConsumerAndFuseProducersUsingSCFForOp::
-    TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName,
-                                              MLIRContext *context,
-                                              scf::SCFTilingOptions options,
-                                              PatternBenefit benefit)
-    : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
-      tilingPattern(context, std::move(options)) {}
-
 /// Return the untiled producer whose slice is used in a tiled consumer. The
 /// method traverses the tile loop nest (`loops`) if needed, and returns the
 /// `iter_args` of the outer most that is encountered. Traversing the iter_args
@@ -430,28 +429,41 @@ getUntiledProducerFromSliceSource(OpOperand *source,
   return {source->get().dyn_cast<OpResult>(), destinationIterArg};
 }
 
+/// Implementation of tile consumer and fuse producer greedily.
 FailureOr<scf::SCFTileAndFuseResult>
-scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
-    TilingInterface op, PatternRewriter &rewriter) const {
+mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
+    RewriterBase &rewriter, TilingInterface consumer,
+    scf::SCFTileAndFuseOptions options) {
   // This transformation is only valid for ops that return values (i.e. not
   // valid to use with operations that have memref operands).
-  if (!op->getNumResults()) {
+  if (!consumer->getNumResults()) {
     return rewriter.notifyMatchFailure(
-        op, "invalid pattern for op with no results");
+        consumer, "invalid pattern for op with no results");
   }
 
   // 1. First tile the consumer.
-  SCFTileAndFuseResult tileAndFuseResult;
+  scf::SCFTileAndFuseResult tileAndFuseResult;
+  llvm::SmallDenseMap<Value, unsigned> yieldedValueToResultNumber;
   {
-    FailureOr<SCFTilingResult> tilingResult =
-        tilingPattern.returningMatchAndRewrite(op, rewriter);
-    if (failed(tilingResult)) {
-      return failure();
-    }
-    tileAndFuseResult.tiledAndFusedOps.push_back(tilingResult->tiledOp);
+    FailureOr<scf::SCFTilingResult> tilingResult =
+        tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
+    if (failed(tilingResult))
+      return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
+    tileAndFuseResult.tiledAndFusedOps.insert(tilingResult->tiledOp);
     tileAndFuseResult.loops = std::move(tilingResult->loops);
+    for (auto result : llvm::enumerate(
+             llvm::zip(consumer->getResults(), tilingResult->replacements))) {
+      tileAndFuseResult.replacements[std::get<0>(result.value())] =
+          std::get<1>(result.value());
+      yieldedValueToResultNumber[tilingResult->tiledOp->getResult(
+          result.index())] = result.index();
+    }
   }
 
+  // If there are no loops generated, fusion is immaterial.
+  if (tileAndFuseResult.loops.empty())
+    return tileAndFuseResult;
+
   // 2. Typically, the operands of the tiled operation are slices of the
   //    operands of the untiled operation. These are expressed in IR using
   //    `tensor.extract_slice` operations with source being the operands of the
@@ -495,7 +507,7 @@ scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
     //     values produced by operations that implement the `TilingInterface`.
     //     Add these operations to the worklist.
     Operation *fusedProducer = fusedProducerValue->getDefiningOp();
-    tileAndFuseResult.tiledAndFusedOps.push_back(fusedProducer);
+    tileAndFuseResult.tiledAndFusedOps.insert(fusedProducer);
     addCandidateSlices(fusedProducer, candidates);
 
     // 2e. If the slice is for a destination operand, for example,
@@ -577,20 +589,19 @@ scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
 }
 
 //===----------------------------------------------------------------------===//
-// LowerToLoopsUsingSCFForOp
+// lowerToLoopsUsingSCFForOp implementation.
 //===----------------------------------------------------------------------===//
 
 FailureOr<SmallVector<scf::ForOp>>
-scf::LowerToLoopsUsingSCFForOp::returningMatchAndRewrite(
-    TilingInterface op, PatternRewriter &rewriter) const {
-  SmallVector<Range> domain = op.getIterationDomain(rewriter);
-
+mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
+                                     TilingInterface op) {
   // TODO: Handle cases where the op has results if needed.
   if (op->getNumResults() > 0) {
     return rewriter.notifyMatchFailure(
         op, "unable to lower to loops operations with return values");
   }
 
+  SmallVector<Range> domain = op.getIterationDomain(rewriter);
   SmallVector<Value> ivs;
   SmallVector<scf::ForOp> loops;
   Location loc = op.getLoc();
@@ -610,6 +621,5 @@ scf::LowerToLoopsUsingSCFForOp::returningMatchAndRewrite(
   if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
     return failure();
   }
-  rewriter.eraseOp(op);
   return loops;
 }

diff  --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index b20cd640c3881..edb7ba3729173 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -36,38 +36,46 @@ namespace {
 /// Pattern for testing `TileUsingSCFForOp` pattern (that tiles operations using
 /// the `TilingInterface` with `scf.for` ops for iterating over the tiles) while
 /// using a `filter` to avoid recursive application.
-struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp {
-  TestTileUsingSCFForOpWithFilter(MLIRContext *context,
-                                  scf::SCFTilingOptions options,
-                                  linalg::LinalgTransformationFilter filter =
-                                      linalg::LinalgTransformationFilter(),
-                                  PatternBenefit benefit = 1)
-      : scf::TileUsingSCFForOp(context, std::move(options), benefit),
-        filter(std::move(filter)) {}
+struct TestTileUsingSCFForOp
+    : public OpInterfaceRewritePattern<TilingInterface> {
+  TestTileUsingSCFForOp(MLIRContext *context, scf::SCFTilingOptions options,
+                        linalg::LinalgTransformationFilter filter =
+                            linalg::LinalgTransformationFilter(),
+                        PatternBenefit benefit = 1)
+      : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+        options(std::move(options)), filter(std::move(filter)) {}
 
   /// Construct a generic pattern applied to `opName`.
-  TestTileUsingSCFForOpWithFilter(StringRef opName, MLIRContext *context,
-                                  scf::SCFTilingOptions options,
-                                  linalg::LinalgTransformationFilter filter =
-                                      linalg::LinalgTransformationFilter(),
-                                  PatternBenefit benefit = 1)
-      : scf::TileUsingSCFForOp(context, std::move(options), benefit),
-        filter(std::move(filter)) {}
+  TestTileUsingSCFForOp(StringRef opName, MLIRContext *context,
+                        scf::SCFTilingOptions options,
+                        linalg::LinalgTransformationFilter filter =
+                            linalg::LinalgTransformationFilter(),
+                        PatternBenefit benefit = 1)
+      : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+        options(std::move(options)), filter(std::move(filter)) {}
 
   LogicalResult matchAndRewrite(TilingInterface op,
                                 PatternRewriter &rewriter) const override {
     if (failed(filter.checkAndNotify(rewriter, op)))
       return failure();
 
-    auto tilingResult = returningMatchAndRewrite(op, rewriter);
-    if (failed(tilingResult)) {
-      return failure();
+    FailureOr<scf::SCFTilingResult> tilingResult =
+        scf::tileUsingSCFForOp(rewriter, op, options);
+    if (failed(tilingResult))
+      return rewriter.notifyMatchFailure(op, "failed to tile operation");
+
+    if (op->getNumResults()) {
+      rewriter.replaceOp(op, tilingResult->replacements);
+    } else {
+      rewriter.eraseOp(op);
     }
+
     filter.replaceLinalgTransformationFilter(rewriter, tilingResult->tiledOp);
     return success();
   }
 
 private:
+  scf::SCFTilingOptions options;
   linalg::LinalgTransformationFilter filter;
 };
 
@@ -75,45 +83,74 @@ struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp {
 /// (that tiles and fuses operations using the `TilingInterface` with `scf.for`
 /// ops for iterating over the tiles) while using a `filter` to avoid recursive
 /// application.
-struct TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter
-    : public scf::TileConsumerAndFuseProducersUsingSCFForOp {
-  TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter(
-      MLIRContext *context, scf::SCFTilingOptions options,
+struct TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp
+    : public OpInterfaceRewritePattern<TilingInterface> {
+  TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp(
+      MLIRContext *context, scf::SCFTileAndFuseOptions options,
       linalg::LinalgTransformationFilter filter =
           linalg::LinalgTransformationFilter(),
       PatternBenefit benefit = 1)
-      : scf::TileConsumerAndFuseProducersUsingSCFForOp(
-            context, std::move(options), benefit),
-        filter(std::move(filter)) {}
+      : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+        options(std::move(options)), filter(std::move(filter)) {}
 
   /// Construct a generic pattern applied to `opName`.
-  TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter(
-      StringRef opName, MLIRContext *context, scf::SCFTilingOptions options,
+  TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp(
+      StringRef opName, MLIRContext *context,
+      scf::SCFTileAndFuseOptions options,
       linalg::LinalgTransformationFilter filter =
           linalg::LinalgTransformationFilter(),
       PatternBenefit benefit = 1)
-      : scf::TileConsumerAndFuseProducersUsingSCFForOp(
-            context, std::move(options), benefit),
-        filter(std::move(filter)) {}
+      : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+        options(std::move(options)), filter(std::move(filter)) {}
 
   LogicalResult matchAndRewrite(TilingInterface op,
                                 PatternRewriter &rewriter) const override {
     if (failed(filter.checkAndNotify(rewriter, op)))
       return failure();
 
-    auto tileAndFuseResult = returningMatchAndRewrite(op, rewriter);
+    FailureOr<scf::SCFTileAndFuseResult> tileAndFuseResult =
+        scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(rewriter, op,
+                                                              options);
     if (failed(tileAndFuseResult)) {
       return failure();
     }
+    // Replace the tiled op with replacements.
+    SmallVector<Value> replacements(op->getNumResults());
+    for (auto result : llvm::enumerate(op->getResults())) {
+      replacements[result.index()] =
+          tileAndFuseResult->replacements.lookup(result.value());
+    }
+    rewriter.replaceOp(op, replacements);
+
     filter.replaceLinalgTransformationFilter(
         rewriter, tileAndFuseResult->tiledAndFusedOps.front());
     return success();
   }
 
 private:
+  scf::SCFTileAndFuseOptions options;
   linalg::LinalgTransformationFilter filter;
 };
 
+/// Pattern to lower operations that implement the `TilingInterface` to
+/// loops/scalar IR using `scf.for`.
+struct LowerToLoopsUsingSCFForOp
+    : public OpInterfaceRewritePattern<TilingInterface> {
+  using OpInterfaceRewritePattern<TilingInterface>::OpInterfaceRewritePattern;
+
+  /// `matchAndRewrite` implementation that returns the significant transformed
+  /// pieces of IR.
+  LogicalResult matchAndRewrite(TilingInterface op,
+                                PatternRewriter &rewriter) const override {
+    FailureOr<SmallVector<scf::ForOp>> loops =
+        scf::lowerToLoopsUsingSCFForOp(rewriter, op);
+    if (failed(loops))
+      return rewriter.notifyMatchFailure(op, "failed to lower to loops");
+    rewriter.eraseOp(op);
+    return loops;
+  }
+};
+
 /// Test pass for testing the use of `TilingInterface`.
 struct TestTilingInterfacePass
     : public PassWrapper<TestTilingInterfacePass, OperationPass<func::FuncOp>> {
@@ -158,72 +195,78 @@ struct TestTilingInterfacePass
 };
 } // namespace
 
-template <class Pattern>
-static void
-addPatternForTiling(MLIRContext *context, RewritePatternSet &patterns,
-                    StringRef filterName, ArrayRef<int64_t> tileSizes,
-                    ArrayRef<unsigned> interchange = {}) {
+static void addPatternForTiling(MLIRContext *context,
+                                RewritePatternSet &patterns,
+                                StringRef filterName,
+                                ArrayRef<int64_t> tileSizes,
+                                ArrayRef<unsigned> interchange = {}) {
   scf::SCFTilingOptions tilingOptions;
   tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
   linalg::LinalgTransformationFilter filter(
       StringAttr::get(context, filterName), StringAttr::get(context, "tiled"));
-  patterns.add<Pattern>(context, tilingOptions, filter);
+  patterns.add<TestTileUsingSCFForOp>(context, tilingOptions, filter);
+}
+
+static void addPatternForTileAndFuse(MLIRContext *context,
+                                     RewritePatternSet &patterns,
+                                     StringRef filterName,
+                                     ArrayRef<int64_t> tileSizes,
+                                     ArrayRef<unsigned> interchange = {}) {
+  scf::SCFTileAndFuseOptions tileAndFuseOptions;
+  tileAndFuseOptions.tilingOptions.setTileSizes(tileSizes).setInterchange(
+      interchange);
+  linalg::LinalgTransformationFilter filter(
+      StringAttr::get(context, filterName), StringAttr::get(context, "tiled"));
+  patterns.add<TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp>(
+      context, tileAndFuseOptions, filter);
 }
 
 void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
                                               RewritePatternSet &patterns) {
   if (testTiling) {
     // 1. Tiling M and N dims of `linalg.matmul` on tensors.
-    addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
-        context, patterns, "simple_gemm", {10, 20});
+    addPatternForTiling(context, patterns, "simple_gemm", {10, 20});
     // 2. Tiling M, N and K of `linalg.matmul` on buffers.
-    addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
-        context, patterns, "simple_gemm_memref", {10, 20, 30});
+    addPatternForTiling(context, patterns, "simple_gemm_memref", {10, 20, 30});
     // 3. Tiling 3D parallel generic op which implements a transpose
-    addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
-        context, patterns, "parallel_generic_transpose", {10, 0, 20});
+    addPatternForTiling(context, patterns, "parallel_generic_transpose",
+                        {10, 0, 20});
     // 4. Tiling 2D conv op.
-    addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
-        context, patterns, "simple_conv", {0, 0, 0, 0, 10, 20, 30});
+    addPatternForTiling(context, patterns, "simple_conv",
+                        {0, 0, 0, 0, 10, 20, 30});
     // 5. Tiling a simple op with `linalg.index` inside.
-    addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
-        context, patterns, "indexed_semantics", {10, 20});
+    addPatternForTiling(context, patterns, "indexed_semantics", {10, 20});
     // 6. Tiling + interchange of an operation
-    addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
-        context, patterns, "gemm_interchange", {10, 20, 30}, {1, 2, 0});
+    addPatternForTiling(context, patterns, "gemm_interchange", {10, 20, 30},
+                        {1, 2, 0});
     // 7. Tiling for 2D pad tensor operations.
-    addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
-        context, patterns, "pad_2dtiling", {2, 3});
+    addPatternForTiling(context, patterns, "pad_2dtiling", {2, 3});
     // 8. Tiling inner dimension of 2d pad tensor operations.
-    addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
-        context, patterns, "pad_inner_tiling", {0, 3});
+    addPatternForTiling(context, patterns, "pad_inner_tiling", {0, 3});
     // 9. Tiling inner dimension of 2d pad tensor operations.
-    addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
-        context, patterns, "pad_outer_tiling", {2, 3});
+    addPatternForTiling(context, patterns, "pad_outer_tiling", {2, 3});
 
     return;
   }
   if (testTileConsumerAndFuseProducer) {
-    // 1. Tile and fuse of gemm with bias-add operation.
-    addPatternForTiling<
-        TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
-        context, patterns, "fusion", {10, 20});
-    addPatternForTiling<
-        TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
-        context, patterns, "gemm_fusion", {10});
-    addPatternForTiling<
-        TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
-        context, patterns, "gemm_interchange_fusion", {10, 20}, {1, 0});
-    addPatternForTiling<
-        TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
-        context, patterns, "gemm_plus_gemm_fusion", {10, 20});
-    addPatternForTiling<
-        TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
-        context, patterns, "gemm_sequence_fusion", {10});
+    // 1. Tile and fuse of gemm with fill producer and bias-add consumer.
+    addPatternForTileAndFuse(context, patterns, "fusion", {10, 20});
+    // 2. Tile and fuse sequence of GEMMs, by fusing only along M.
+    addPatternForTileAndFuse(context, patterns, "gemm_fusion", {10});
+    // 3. Tile and fuse gemm with consumer + interchange of tiled loops.
+    addPatternForTileAndFuse(context, patterns, "gemm_interchange_fusion",
+                             {10, 20}, {1, 0});
+    // 4. Tile and fuse matmul + transpose(matmul). Will introduce redundant
+    // computations.
+    addPatternForTileAndFuse(context, patterns, "gemm_plus_gemm_fusion",
+                             {10, 20});
+    // 5. Tile and fuse a sequence of GEMMs by tiling and fusing only along M
+    // dimension.
+    addPatternForTileAndFuse(context, patterns, "gemm_sequence_fusion", {10});
     return;
   }
   if (testLoweringToScalar) {
-    patterns.add<scf::LowerToLoopsUsingSCFForOp>(context);
+    patterns.add<LowerToLoopsUsingSCFForOp>(context);
   }
 }
 


        


More information about the Mlir-commits mailing list