[Mlir-commits] [mlir] [mlir][TilingInterface] Avoid looking at operands for getting slices to continue tile + fuse. (PR #107882)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 9 09:06:45 PDT 2024


https://github.com/MaheshRavishankar created https://github.com/llvm/llvm-project/pull/107882

Current implementation of `scf::tileConsumerAndFuseProducerUsingSCF` looks at operands of tiled/tiled+fused operations to see if they are produced by `extract_slice` operations to populate the worklist used to continue fusion. This implicit assumption does not always work. Instead make the implementations of `getTiledImplementation` return the slices to use to continue fusion.

This is a breaking change

- To continue to get the same behavior of `scf::tileConsumerAndFuseProducerUsingSCF`, change all out-of-tree implementation of `TilingInterface::getTiledImplementation` to return the slices to continue fusion on. All in-tree implementations have been adapted to this.
- This change touches parts that required a simplification to the `ControlFn` in `scf::SCFTileAndFuseOptions`. It now returns a `std::optional<scf::SCFTileAndFuseOptions::ControlFnResult>` object that should be `std::nullopt` if fusion is not to be performed.

>From 850784b85d6da2e26764b648f62537815b2f37c5 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Sat, 7 Sep 2024 23:31:58 -0700
Subject: [PATCH] [mlir][TilingInterface] Avoid looking at operands for getting
 slices to continue tile + fuse.

Current implementation of `scf::tileConsumerAndFuseProducerUsingSCF`
looks at operands of tiled/tiled+fused operations to see if they are
produced by `extract_slice` operations to populate the worklist used
to continue fusion. This implicit assumption does not always
work. Instead make the implementations of `getTiledImplementation`
return the slices to use to continue fusion.

This is a breaking change

- To continue to get the same behavior of
  `scf::tileConsumerAndFuseProducerUsingSCF`, change all out-of-tree
  implementation of `TilingInterface::getTiledImplementation` to
  return the slices to continue fusion on. All in-tree implementations
  have been adapted to this.
- This change touches parts that required a simplification to the
  `ControlFn` in `scf::SCFTileAndFuseOptions`. It now returns a
  `std::optional<scf::SCFTileAndFuseOptions::ControlFnResult>` object
  that should be `std::nullopt` if fusion is not to be performed.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
 .../include/mlir/Dialect/Linalg/Utils/Utils.h | 11 +--
 .../SCF/Transforms/TileUsingInterface.h       | 28 ++++---
 .../include/mlir/Interfaces/TilingInterface.h |  7 +-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 81 ++++++++++++-------
 .../Linalg/Transforms/TilingInterfaceImpl.cpp | 22 +++--
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       | 20 ++---
 .../SCF/Transforms/TileUsingInterface.cpp     | 81 ++++++++++---------
 .../Tensor/IR/TensorTilingInterfaceImpl.cpp   | 71 ++++++++++------
 .../TestTilingInterfaceTransformOps.cpp       | 12 +--
 9 files changed, 205 insertions(+), 128 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 65a1a8b42e1495..f1df49ce3eaa36 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -178,11 +178,12 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
 /// at offsets `lbs` and with sizes `subShapeSizes`. `omitPartialTileCheck`
 /// controls whether to omit the partial/boundary tile condition check in
 /// cases where we statically know that it is unnecessary.
-Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
-                     ArrayRef<OpFoldResult> tileSizes, AffineMap map,
-                     ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
-                     ArrayRef<OpFoldResult> subShapeSizes,
-                     bool omitPartialTileCheck);
+Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
+                          ArrayRef<OpFoldResult> tileSizes, AffineMap map,
+                          ArrayRef<OpFoldResult> lbs,
+                          ArrayRef<OpFoldResult> ubs,
+                          ArrayRef<OpFoldResult> subShapeSizes,
+                          bool omitPartialTileCheck);
 
 /// Creates extract_slice/subview ops for all `valuesToTile` of the given
 /// `linalgOp` with `builder`, assuming `linalgOp` is being fused into a loop
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 1f21af6d6a29ac..ffd36878686751 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -106,6 +106,9 @@ struct SCFTilingResult {
   /// Values to use as replacements for the untiled op. Is the same size as the
   /// number of results of the untiled op.
   SmallVector<Value> replacements;
+  /// Slices generated after tiling that can be used for fusing with the tiled
+  /// producer.
+  SmallVector<Operation *> generatedSlices;
 };
 
 /// Method to tile an op that implements the `TilingInterface` using
@@ -129,18 +132,22 @@ struct SCFTileAndFuseOptions {
   /// 2) the producer value that is to be fused
   /// 3) a boolean value set to `true` if the fusion is from
   ///    a destination operand.
-  /// It retuns two booleans
-  /// - returns `true` if the fusion should be done through the candidate slice
-  /// - returns `true` if a replacement for the fused producer needs to be
-  ///   yielded from within the tiled loop. Note that it is valid to return
-  ///   `true` only if the slice fused is disjoint across all iterations of the
-  ///   tiled loop. It is up to the caller to ensure that this is true for the
-  ///   fused producers.
-  using ControlFnTy = std::function<std::tuple<bool, bool>(
+  /// The control function returns an `std::optiona<ControlFnResult>`.
+  /// If the return value is `std::nullopt`, that implies no fusion
+  /// is to be performed along that slice.
+  struct ControlFnResult {
+    /// Set to true if the loop nest has to return a replacement value
+    /// for the fused producer.
+    bool yieldProducerReplacement = false;
+  };
+  using ControlFnTy = std::function<std::optional<ControlFnResult>(
       tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
       bool isDestinationOperand)>;
-  ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult, bool) {
-    return std::make_tuple(true, false);
+  /// The default control function implements a greedy fusion with yielding
+  /// a replacement for any of the fused results.
+  ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult,
+                                   bool) -> std::optional<ControlFnResult> {
+    return ControlFnResult{};
   };
   SCFTileAndFuseOptions &setFusionControlFn(ControlFnTy controlFn) {
     fusionControlFn = controlFn;
@@ -156,6 +163,7 @@ struct SCFFuseProducerOfSliceResult {
   OpResult origProducer;       // Original untiled producer.
   Value tiledAndFusedProducer; // Tile and fused producer value.
   SmallVector<Operation *> tiledOps;
+  SmallVector<Operation *> generatedSlices;
 };
 std::optional<SCFFuseProducerOfSliceResult>
 tileAndFuseProducerOfSlice(RewriterBase &rewriter,
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h
index 2f51496d1b110a..b33aa1489c3116 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.h
+++ b/mlir/include/mlir/Interfaces/TilingInterface.h
@@ -25,12 +25,15 @@ namespace mlir {
 
 /// Container for result values of tiling.
 /// - `tiledOps` contains operations created by the tiling implementation that
-/// are returned to the caller for further transformations.
+///   are returned to the caller for further transformations.
 /// - `tiledValues` contains the tiled value corresponding to the result of the
-/// untiled operation.
+///   untiled operation.
+/// - `generatedSlices` contains the list of slices that are generated during
+///   tiling. These slices can be used for fusing producers.
 struct TilingResult {
   SmallVector<Operation *> tiledOps;
   SmallVector<Value> tiledValues;
+  SmallVector<Operation *> generatedSlices;
 };
 
 /// Container for the result of merge operation of tiling.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 76df3ecf2d2bd4..d423c94487d16c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -66,20 +66,20 @@ static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v,
 
 /// Returns a memref.subview or a tensor.extract_slice based on the type of the
 /// `source`.
-static Value getSlice(OpBuilder &b, Location loc, Value source,
-                      ArrayRef<OpFoldResult> offsets,
-                      ArrayRef<OpFoldResult> sizes,
-                      ArrayRef<OpFoldResult> strides) {
-  return TypeSwitch<Type, Value>(source.getType())
-      .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
+static Operation *getSlice(OpBuilder &b, Location loc, Value source,
+                           ArrayRef<OpFoldResult> offsets,
+                           ArrayRef<OpFoldResult> sizes,
+                           ArrayRef<OpFoldResult> strides) {
+  return TypeSwitch<Type, Operation *>(source.getType())
+      .Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
         return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
                                                 strides);
       })
-      .Case<MemRefType>([&](MemRefType type) -> Value {
+      .Case<MemRefType>([&](MemRefType type) -> Operation * {
         return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
                                            strides);
       })
-      .Default([&](Type t) { return nullptr; });
+      .Default([&](Type t) -> Operation * { return nullptr; });
 }
 
 //===----------------------------------------------------------------------===//
@@ -2603,10 +2603,18 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
   auto oneAttr = builder.getI64IntegerAttr(1);
   SmallVector<OpFoldResult> strides(rank, oneAttr);
   SmallVector<Value> tiledOperands;
-  tiledOperands.emplace_back(
-      getSlice(builder, getLoc(), getInput(), offsets, sizes, strides));
-  tiledOperands.emplace_back(
-      getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides));
+  Operation *inputSlice =
+      getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
+  if (!inputSlice) {
+    return emitOpError("failed to compute input slice");
+  }
+  tiledOperands.emplace_back(inputSlice->getResult(0));
+  Operation *outputSlice =
+      getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
+  if (!outputSlice) {
+    return emitOpError("failed to compute output slice");
+  }
+  tiledOperands.emplace_back(outputSlice->getResult(0));
 
   SmallVector<Type, 4> resultTypes;
   if (hasPureTensorSemantics())
@@ -2614,7 +2622,9 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
   Operation *tiledOp =
       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
 
-  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+  return TilingResult{{tiledOp},
+                      SmallVector<Value>(tiledOp->getResults()),
+                      {inputSlice, outputSlice}};
 }
 
 LogicalResult SoftmaxOp::getResultTilePosition(
@@ -2961,8 +2971,9 @@ FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
   int64_t filterRank = getFilterOperandRank();
   SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
   Location loc = getLoc();
-  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
-      loc, getFilter(), sliceOffsets, sliceSizes, filterStrides));
+  auto filterSlice = builder.create<tensor::ExtractSliceOp>(
+      loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
+  tiledOperands.emplace_back(filterSlice);
 
   SmallVector<OpFoldResult> resultOffsets, resultSizes;
   if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
@@ -2971,15 +2982,19 @@ FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
 
   int64_t outputRank = getOutputOperandRank();
   SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
-  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
-      loc, getOutput(), resultOffsets, resultSizes, outputStrides));
+  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
+      loc, getOutput(), resultOffsets, resultSizes, outputStrides);
+  tiledOperands.emplace_back(outputSlice);
 
   SmallVector<Type> resultTypes;
   resultTypes.push_back(tiledOperands[1].getType());
   Operation *tiledOp =
       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
 
-  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+  return TilingResult{
+      {tiledOp},
+      SmallVector<Value>(tiledOp->getResults()),
+      llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
 }
 
 //===----------------------------------------------------------------------===//
@@ -3128,8 +3143,9 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
       {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
   int64_t inputRank = getInputOperandRank();
   SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
-  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
-      loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
+  auto inputSlice = builder.create<tensor::ExtractSliceOp>(
+      loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
+  tiledOperands.emplace_back(inputSlice);
 
   SmallVector<OpFoldResult> resultOffsets, resultSizes;
   if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
@@ -3138,15 +3154,19 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
 
   int64_t outputRank = getOutputOperandRank();
   SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
-  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
-      loc, getOutput(), resultOffsets, resultSizes, outputStrides));
+  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
+      loc, getOutput(), resultOffsets, resultSizes, outputStrides);
+  tiledOperands.emplace_back(outputSlice);
 
   SmallVector<Type> resultTypes;
   resultTypes.push_back(tiledOperands[1].getType());
   Operation *tiledOp =
       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
 
-  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+  return TilingResult{
+      {tiledOp},
+      SmallVector<Value>(tiledOp->getResults()),
+      llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
 }
 
 //===----------------------------------------------------------------------===//
@@ -3290,8 +3310,9 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
                      sizes[getValueFDim()]});
   int64_t valueRank = getValueOperandRank();
   SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
-  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
-      loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
+  auto valueSlice = builder.create<tensor::ExtractSliceOp>(
+      loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
+  tiledOperands.emplace_back(valueSlice);
 
   SmallVector<OpFoldResult> resultOffsets, resultSizes;
   if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
@@ -3300,15 +3321,19 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
 
   int64_t outputRank = getOutputOperandRank();
   SmallVector<OpFoldResult> strides(outputRank, oneAttr);
-  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
-      loc, getOutput(), resultOffsets, resultSizes, strides));
+  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
+      loc, getOutput(), resultOffsets, resultSizes, strides);
+  tiledOperands.emplace_back(outputSlice);
 
   SmallVector<Type> resultTypes;
   resultTypes.push_back(tiledOperands[1].getType());
   Operation *tiledOp =
       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
 
-  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+  return TilingResult{
+      {tiledOp},
+      SmallVector<Value>(tiledOp->getResults()),
+      llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index fbff91a94219cc..ca05467534a5b3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -120,8 +120,12 @@ struct LinalgOpTilingInterface
     Location loc = op->getLoc();
     LinalgOp linalgOp = cast<LinalgOp>(op);
     SmallVector<Value> valuesToTile = linalgOp->getOperands();
-    SmallVector<Value, 4> tiledOperands = makeTiledShapes(
+    SmallVector<Value> tiledOperands = makeTiledShapes(
         b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
+    SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
+        llvm::make_filter_range(
+            tiledOperands, [](Value v) -> bool { return v.getDefiningOp(); }),
+        [](Value v) -> Operation * { return v.getDefiningOp(); });
 
     SmallVector<Type> resultTensorTypes =
         getTensorOutputTypes(linalgOp, tiledOperands);
@@ -129,7 +133,8 @@ struct LinalgOpTilingInterface
     Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands);
     offsetIndices(b, cast<LinalgOp>(tiledOp), offsets);
 
-    return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+    return TilingResult{
+        {tiledOp}, SmallVector<Value>(tiledOp->getResults()), generatedSlices};
   }
 
   /// Utility to fetch the offsets and sizes when applied as per the indexing
@@ -260,7 +265,8 @@ struct LinalgOpTilingInterface
 
     return TilingResult{
         tilingResult->tiledOps,
-        SmallVector<Value>{tilingResult->tiledValues[resultNumber]}};
+        SmallVector<Value>{tilingResult->tiledValues[resultNumber]},
+        tilingResult->generatedSlices};
   }
 
   /// Method to generate the tiled implementation of an operation from the tile
@@ -406,8 +412,12 @@ struct LinalgOpPartialReductionInterface
     }
 
     // Step 2a: Extract a slice of the input operands.
-    SmallVector<Value, 4> tiledInputs = makeTiledShapes(
+    SmallVector<Value> tiledInputs = makeTiledShapes(
         b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {}, true);
+    SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
+        llvm::make_filter_range(
+            tiledInputs, [](Value v) -> bool { return v.getDefiningOp(); }),
+        [](Value v) -> Operation * { return v.getDefiningOp(); });
 
     // Step 2b: Extract a slice of the init operands.
     SmallVector<Value, 1> tiledInits;
@@ -424,6 +434,7 @@ struct LinalgOpPartialReductionInterface
       auto extractSlice = b.create<tensor::ExtractSliceOp>(
           loc, valueToTile, initOffset, initSizes, initStride);
       tiledInits.push_back(extractSlice);
+      generatedSlices.push_back(extractSlice);
     }
 
     // Update the indexing maps.
@@ -453,7 +464,8 @@ struct LinalgOpPartialReductionInterface
     return TilingResult{
         {genericOp.getOperation()},
         llvm::map_to_vector(genericOp->getResults(),
-                            [](OpResult r) -> Value { return r; })};
+                            [](OpResult r) -> Value { return r; }),
+        generatedSlices};
   }
 
   FailureOr<MergeResult> mergeReductions(Operation *op, OpBuilder &b,
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index fa0598dd96885c..6a3f2fc5fbc496 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -565,9 +565,9 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
   assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
 }
 
-static Value materializeTiledShape(OpBuilder &builder, Location loc,
-                                   Value valueToTile,
-                                   const SliceParameters &sliceParams) {
+static Operation *materializeTiledShape(OpBuilder &builder, Location loc,
+                                        Value valueToTile,
+                                        const SliceParameters &sliceParams) {
   auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
   auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
                       .Case([&](MemRefType) {
@@ -583,14 +583,15 @@ static Value materializeTiledShape(OpBuilder &builder, Location loc,
                       .Default([](ShapedType) -> Operation * {
                         llvm_unreachable("Unexpected shaped type");
                       });
-  return sliceOp->getResult(0);
+  return sliceOp;
 }
 
-Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
-                     ArrayRef<OpFoldResult> tileSizes, AffineMap map,
-                     ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
-                     ArrayRef<OpFoldResult> subShapeSizes,
-                     bool omitPartialTileCheck) {
+Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
+                          ArrayRef<OpFoldResult> tileSizes, AffineMap map,
+                          ArrayRef<OpFoldResult> lbs,
+                          ArrayRef<OpFoldResult> ubs,
+                          ArrayRef<OpFoldResult> subShapeSizes,
+                          bool omitPartialTileCheck) {
   SliceParameters sliceParams =
       computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs,
                              ubs, subShapeSizes, omitPartialTileCheck);
@@ -841,6 +842,7 @@ SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc,
     tiledShapes.push_back(
         sliceParams.has_value()
             ? materializeTiledShape(builder, loc, valueToTile, *sliceParams)
+                  ->getResult(0)
             : valueToTile);
   }
   return tiledShapes;
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index e404c01010a325..fc27f96500abc7 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -854,7 +854,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
     if (llvm::all_of(tileSizes, isZeroIndex)) {
       tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
       tilingResult =
-          TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults()};
+          TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
+                       /*generatedSlices=*/{}};
       return success();
     }
 
@@ -910,12 +911,14 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
   // op.
   if (loops.empty()) {
     return scf::SCFTilingResult{tilingResult->tiledOps, loops,
-                                tilingResult->tiledValues};
+                                tilingResult->tiledValues,
+                                tilingResult->generatedSlices};
   }
 
   SmallVector<Value> replacements = llvm::map_to_vector(
       loops.front()->getResults(), [](OpResult r) -> Value { return r; });
-  return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements};
+  return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements,
+                              tilingResult->generatedSlices};
 }
 
 FailureOr<scf::SCFReductionTilingResult>
@@ -1180,9 +1183,9 @@ mlir::scf::tileAndFuseProducerOfSlice(
         ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
         .set(origDestinationTensors[resultNumber]);
   }
-  return scf::SCFFuseProducerOfSliceResult{fusableProducer,
-                                           tileAndFuseResult->tiledValues[0],
-                                           tileAndFuseResult->tiledOps};
+  return scf::SCFFuseProducerOfSliceResult{
+      fusableProducer, tileAndFuseResult->tiledValues[0],
+      tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices};
 }
 
 /// Reconstruct the fused producer from within the tiled-and-fused code.
@@ -1358,48 +1361,55 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
   //    operations. If the producers of the source of the `tensor.extract_slice`
   //    can be tiled such that the tiled value is generated in-place, that
   //    effectively tiles + fuses the operations.
-  auto addCandidateSlices = [](Operation *fusedOp,
-                               std::deque<tensor::ExtractSliceOp> &candidates) {
-    for (Value operand : fusedOp->getOperands())
-      if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
-        candidates.push_back(sliceOp);
+  struct WorklistItem {
+    tensor::ExtractSliceOp candidateSlice;
+    SCFTileAndFuseOptions::ControlFnResult controlFnResult;
+  };
+  std::deque<WorklistItem> worklist;
+  auto addCandidateSlices = [&worklist, &options,
+                             &loops](ArrayRef<Operation *> candidates) {
+    for (auto candidate : candidates) {
+      auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(candidate);
+      if (!sliceOp)
+        continue;
+
+      auto [fusableProducer, destinationInitArg] =
+          getUntiledProducerFromSliceSource(&sliceOp.getSourceMutable(), loops);
+      if (!fusableProducer)
+        continue;
+      std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
+          options.fusionControlFn(sliceOp, fusableProducer,
+                                  destinationInitArg.has_value());
+      if (!controlFnResult)
+        continue;
+      worklist.emplace_back(WorklistItem{sliceOp, controlFnResult.value()});
+    }
   };
 
-  std::deque<tensor::ExtractSliceOp> candidates;
-  addCandidateSlices(tiledAndFusedOps.back(), candidates);
+  addCandidateSlices(tilingResult->generatedSlices);
   OpBuilder::InsertionGuard g(rewriter);
-  while (!candidates.empty()) {
+  while (!worklist.empty()) {
     // Traverse the slices in BFS fashion.
-    tensor::ExtractSliceOp candidateSliceOp = candidates.front();
-    candidates.pop_front();
-
-    // Find the original producer of the slice.
-    auto [fusableProducer, destinationInitArg] =
-        getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
-                                          loops);
-    if (!fusableProducer)
-      continue;
-
-    auto [fuseSlice, yieldReplacement] = options.fusionControlFn(
-        candidateSliceOp, fusableProducer, destinationInitArg.has_value());
-    if (!fuseSlice)
-      continue;
+    WorklistItem worklistItem = worklist.front();
+    worklist.pop_front();
 
     // The operands of the fused producer might themselved be slices of
     // values produced by operations that implement the `TilingInterface`.
     // Add these operations to the worklist.
     std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
-        tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, loops);
+        tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice,
+                                   loops);
     if (!fusedResult)
       continue;
 
-    if (yieldReplacement) {
+    if (worklistItem.controlFnResult.yieldProducerReplacement) {
       // Reconstruct and yield all opResult of fusableProducerOp by default. The
       // caller can specific which one to yield by designating optional argument
       // named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
-      Operation *fusableProducerOp = fusableProducer.getOwner();
+      Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
       if (failed(yieldReplacementForFusedProducer(
-              rewriter, candidateSliceOp, fusedResult.value(), loops))) {
+              rewriter, worklistItem.candidateSlice, fusedResult.value(),
+              loops))) {
         return rewriter.notifyMatchFailure(
             fusableProducerOp, "failed to replacement value for this "
                                "operation from within the tiled loop");
@@ -1412,12 +1422,7 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
       }
     }
 
-    if (Operation *tiledAndFusedOp =
-            fusedResult->tiledAndFusedProducer.getDefiningOp()) {
-      fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
-      tiledAndFusedOps.insert(tiledAndFusedOp);
-      addCandidateSlices(tiledAndFusedOp, candidates);
-    }
+    addCandidateSlices(fusedResult->generatedSlices);
   }
 
   DenseMap<Value, Value> replacements;
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 9e17184ebed794..104d6ae1f9f6b5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -187,8 +187,9 @@ struct PackOpTiling
     SmallVector<OpFoldResult> strides(inputRank, oneAttr);
 
     SmallVector<Value> tiledOperands;
-    tiledOperands.push_back(b.create<ExtractSliceOp>(
-        loc, packOp.getSource(), inputIndices, inputSizes, strides));
+    auto sourceSlice = b.create<ExtractSliceOp>(
+        loc, packOp.getSource(), inputIndices, inputSizes, strides);
+    tiledOperands.push_back(sourceSlice);
 
     SmallVector<OpFoldResult> outputOffsets, outputSizes;
     if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets,
@@ -196,9 +197,9 @@ struct PackOpTiling
       return {};
 
     strides.append(packOp.getDestRank() - inputRank, oneAttr);
-    auto extractSlice = b.create<ExtractSliceOp>(
+    auto outSlice = b.create<ExtractSliceOp>(
         loc, packOp.getDest(), outputOffsets, outputSizes, strides);
-    tiledOperands.push_back(extractSlice);
+    tiledOperands.push_back(outSlice);
 
     if (auto val = packOp.getPaddingValue())
       tiledOperands.push_back(val);
@@ -206,10 +207,12 @@ struct PackOpTiling
       tiledOperands.push_back(tile);
 
     Operation *tiledPackOp = b.create<PackOp>(
-        loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs());
+        loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
 
-    return TilingResult{{tiledPackOp},
-                        SmallVector<Value>(tiledPackOp->getResults())};
+    return TilingResult{
+        {tiledPackOp},
+        SmallVector<Value>(tiledPackOp->getResults()),
+        llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})};
   }
 
   LogicalResult
@@ -348,8 +351,9 @@ struct PackOpTiling
     SmallVector<OpFoldResult> strides(inputRank, oneAttr);
 
     SmallVector<Value> tiledOperands;
-    tiledOperands.push_back(b.create<ExtractSliceOp>(loc, packOp.getSource(),
-                                                     offsets, sizes, strides));
+    auto sourceSlice = b.create<ExtractSliceOp>(loc, packOp.getSource(),
+                                                offsets, sizes, strides);
+    tiledOperands.push_back(sourceSlice);
 
     SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
     if (failed(getIterationDomainTileFromOperandTile(
@@ -363,19 +367,21 @@ struct PackOpTiling
       return failure();
 
     strides.append(packOp.getDestRank() - inputRank, oneAttr);
-    auto extractSlice = b.create<ExtractSliceOp>(
+    auto outSlice = b.create<ExtractSliceOp>(
         loc, packOp.getDest(), outputOffsets, outputSizes, strides);
-    tiledOperands.push_back(extractSlice);
+    tiledOperands.push_back(outSlice);
 
     assert(!packOp.getPaddingValue() && "Expect no padding semantic");
     for (auto tile : packOp.getInnerTiles())
       tiledOperands.push_back(tile);
 
     Operation *tiledPackOp = b.create<PackOp>(
-        loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs());
+        loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
 
-    return TilingResult{{tiledPackOp},
-                        SmallVector<Value>(tiledPackOp->getResults())};
+    return TilingResult{
+        {tiledPackOp},
+        SmallVector<Value>(tiledPackOp->getResults()),
+        llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})};
   }
 };
 
@@ -554,9 +560,12 @@ struct UnPackOpTiling
 
     SmallVector<OpFoldResult> destStrides(destRank, oneAttr);
     Value sliceDest;
+    SmallVector<Operation *> generatedSlices;
     if (isPerfectTilingCase) {
-      sliceDest = b.create<ExtractSliceOp>(loc, unpackOp.getDest(), offsets,
-                                           sizes, destStrides);
+      auto destSliceOp = b.create<ExtractSliceOp>(loc, unpackOp.getDest(),
+                                                  offsets, sizes, destStrides);
+      sliceDest = destSliceOp;
+      generatedSlices.push_back(destSliceOp);
     } else {
       sliceDest = b.create<EmptyOp>(loc, destExpandedSizes,
                                     unpackOp.getDestType().getElementType());
@@ -571,12 +580,15 @@ struct UnPackOpTiling
 
     if (isPerfectTilingCase)
       return TilingResult{{tiledUnpackOp},
-                          SmallVector<Value>(tiledUnpackOp->getResults())};
+                          SmallVector<Value>(tiledUnpackOp->getResults()),
+                          generatedSlices};
 
     auto extractSlice =
         b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0),
                                  resultOffsetsFromDest, sizes, destStrides);
-    return TilingResult{{tiledUnpackOp}, {extractSlice.getResult()}};
+    generatedSlices.push_back(extractSlice);
+    return TilingResult{
+        {tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices};
   }
 
   LogicalResult
@@ -697,7 +709,9 @@ struct UnPackOpTiling
                            tiledOperands, op->getAttrs());
 
     return TilingResult{{tiledUnPackOp},
-                        SmallVector<Value>(tiledUnPackOp->getResults())};
+                        SmallVector<Value>(tiledUnPackOp->getResults()),
+                        llvm::to_vector(ArrayRef<Operation *>{
+                            extractSourceSlice, extractDestSlice})};
   }
 };
 
@@ -867,7 +881,7 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
   // the result shape of the new SliceOp has a zero dimension.
   auto createPadOfExtractSlice = [&]() {
     // Create pad(extract_slice(x)).
-    Value newSliceOp = b.create<tensor::ExtractSliceOp>(
+    auto newSliceOp = b.create<tensor::ExtractSliceOp>(
         loc, padOp.getSource(), newOffsets, newLengths, newStrides);
     auto newPadOp = b.create<PadOp>(
         loc, Type(), newSliceOp, newLows, newHighs,
@@ -879,14 +893,16 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
     padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm);
 
     // Cast result and return.
-    return newPadOp;
+    return std::make_tuple(newPadOp, newSliceOp);
   };
 
   // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that
   // the original data source x is not used.
   if (hasZeroLen) {
     Operation *generateOp = createGenerateOp();
-    return TilingResult{{generateOp}, {castResult(generateOp->getResult(0))}};
+    return TilingResult{{generateOp},
+                        {castResult(generateOp->getResult(0))},
+                        /*generatedSlices=*/{}};
   }
 
   // If there are dynamic dimensions: Generate an scf.if check to avoid
@@ -894,6 +910,7 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
   if (generateZeroSliceGuard && dynHasZeroLenCond) {
     Operation *thenOp;
     Operation *elseOp;
+    Operation *sliceOp;
     auto result = b.create<scf::IfOp>(
         loc, dynHasZeroLenCond,
         /*thenBuilder=*/
@@ -903,14 +920,16 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
         },
         /*elseBuilder=*/
         [&](OpBuilder &b, Location loc) {
-          elseOp = createPadOfExtractSlice();
+          std::tie(elseOp, sliceOp) = createPadOfExtractSlice();
           b.create<scf::YieldOp>(loc, castResult(elseOp->getResult(0)));
         });
-    return TilingResult{{elseOp}, SmallVector<Value>(result->getResults())};
+    return TilingResult{
+        {elseOp}, SmallVector<Value>(result->getResults()), {sliceOp}};
   }
 
-  Operation *newPadOp = createPadOfExtractSlice();
-  return TilingResult{{newPadOp}, {castResult(newPadOp->getResult(0))}};
+  auto [newPadOp, sliceOp] = createPadOfExtractSlice();
+  return TilingResult{
+      {newPadOp}, {castResult(newPadOp->getResult(0))}, {sliceOp}};
 }
 
 void mlir::tensor::registerTilingInterfaceExternalModels(
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 7aa7b58433f36c..b6da47977cb4cf 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -91,11 +91,13 @@ applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
 
     scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
         [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
-            bool isDestinationOperand) {
-          Operation *owner = originalProducer.getOwner();
-          bool yieldProducerReplacement = yieldReplacementsFor.contains(owner);
-          return std::make_tuple(true, yieldProducerReplacement);
-        };
+            bool isDestinationOperand)
+        -> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> {
+      Operation *owner = originalProducer.getOwner();
+      bool yieldProducerReplacement = yieldReplacementsFor.contains(owner);
+      return scf::SCFTileAndFuseOptions::ControlFnResult{
+          yieldProducerReplacement};
+    };
     tileAndFuseOptions.setFusionControlFn(controlFn);
 
     rewriter.setInsertionPoint(target);



More information about the Mlir-commits mailing list