[Mlir-commits] [mlir] [mlir][TilingInterface] Avoid looking at operands for getting slices to continue tile + fuse. (PR #107882)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 11 09:02:28 PDT 2024


https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/107882

>From 0079cff596c5bb04503a05903d194dda0416b66c Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Sat, 7 Sep 2024 23:31:58 -0700
Subject: [PATCH] [mlir][TilingInterface] Avoid looking at operands for getting
 slices to continue tile + fuse.

Current implementation of `scf::tileConsumerAndFuseProducerUsingSCF`
looks at operands of tiled/tiled+fused operations to see if they are
produced by `extract_slice` operations to populate the worklist used
to continue fusion. This implicit assumption does not always
work. Instead make the implementations of `getTiledImplementation`
return the slices to use to continue fusion.

This is a breaking change

- To continue to get the same behavior of
  `scf::tileConsumerAndFuseProducerUsingSCF`, change all out-of-tree
  implementation of `TilingInterface::getTiledImplementation` to
  return the slices to continue fusion on. All in-tree implementations
  have been adapted to this.
- This change touches parts that required a simplification to the
  `ControlFn` in `scf::SCFTileAndFuseOptions`. It now returns a
  `std::optional<scf::SCFTileAndFuseOptions::ControlFnResult>` object
  that should be `std::nullopt` if fusion is not to be performed.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
 .../include/mlir/Dialect/Linalg/Utils/Utils.h | 11 ++-
 .../SCF/Transforms/TileUsingInterface.h       | 33 ++++---
 .../include/mlir/Interfaces/TilingInterface.h |  7 +-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 81 ++++++++++------
 .../Linalg/Transforms/TilingInterfaceImpl.cpp | 26 +++++-
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       | 20 ++--
 .../SCF/Transforms/TileUsingInterface.cpp     | 93 +++++++++++--------
 .../Tensor/IR/TensorTilingInterfaceImpl.cpp   | 71 ++++++++------
 .../tile-and-fuse-using-interface.mlir        | 45 +++++++++
 .../TestTilingInterfaceTransformOps.cpp       | 12 ++-
 10 files changed, 270 insertions(+), 129 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 65a1a8b42e1495..f1df49ce3eaa36 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -178,11 +178,12 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
 /// at offsets `lbs` and with sizes `subShapeSizes`. `omitPartialTileCheck`
 /// controls whether to omit the partial/boundary tile condition check in
 /// cases where we statically know that it is unnecessary.
-Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
-                     ArrayRef<OpFoldResult> tileSizes, AffineMap map,
-                     ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
-                     ArrayRef<OpFoldResult> subShapeSizes,
-                     bool omitPartialTileCheck);
+Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
+                          ArrayRef<OpFoldResult> tileSizes, AffineMap map,
+                          ArrayRef<OpFoldResult> lbs,
+                          ArrayRef<OpFoldResult> ubs,
+                          ArrayRef<OpFoldResult> subShapeSizes,
+                          bool omitPartialTileCheck);
 
 /// Creates extract_slice/subview ops for all `valuesToTile` of the given
 /// `linalgOp` with `builder`, assuming `linalgOp` is being fused into a loop
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 1f21af6d6a29ac..77c812cde71533 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -106,6 +106,9 @@ struct SCFTilingResult {
   /// Values to use as replacements for the untiled op. Is the same size as the
   /// number of results of the untiled op.
   SmallVector<Value> replacements;
+  /// Slices generated after tiling that can be used for fusing with the tiled
+  /// producer.
+  SmallVector<Operation *> generatedSlices;
 };
 
 /// Method to tile an op that implements the `TilingInterface` using
@@ -129,18 +132,22 @@ struct SCFTileAndFuseOptions {
   /// 2) the producer value that is to be fused
   /// 3) a boolean value set to `true` if the fusion is from
   ///    a destination operand.
-  /// It retuns two booleans
-  /// - returns `true` if the fusion should be done through the candidate slice
-  /// - returns `true` if a replacement for the fused producer needs to be
-  ///   yielded from within the tiled loop. Note that it is valid to return
-  ///   `true` only if the slice fused is disjoint across all iterations of the
-  ///   tiled loop. It is up to the caller to ensure that this is true for the
-  ///   fused producers.
-  using ControlFnTy = std::function<std::tuple<bool, bool>(
+  /// The control function returns an `std::optiona<ControlFnResult>`.
+  /// If the return value is `std::nullopt`, that implies no fusion
+  /// is to be performed along that slice.
+  struct ControlFnResult {
+    /// Set to true if the loop nest has to return a replacement value
+    /// for the fused producer.
+    bool yieldProducerReplacement = false;
+  };
+  using ControlFnTy = std::function<std::optional<ControlFnResult>(
       tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
       bool isDestinationOperand)>;
-  ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult, bool) {
-    return std::make_tuple(true, false);
+  /// The default control function implements greedy fusion without yielding
+  /// a replacement for any of the fused results.
+  ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult,
+                                   bool) -> std::optional<ControlFnResult> {
+    return ControlFnResult{};
   };
   SCFTileAndFuseOptions &setFusionControlFn(ControlFnTy controlFn) {
     fusionControlFn = controlFn;
@@ -156,6 +163,7 @@ struct SCFFuseProducerOfSliceResult {
   OpResult origProducer;       // Original untiled producer.
   Value tiledAndFusedProducer; // Tile and fused producer value.
   SmallVector<Operation *> tiledOps;
+  SmallVector<Operation *> generatedSlices;
 };
 std::optional<SCFFuseProducerOfSliceResult>
 tileAndFuseProducerOfSlice(RewriterBase &rewriter,
@@ -215,7 +223,10 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter,
 ///
 /// The @param `yieldResultNumber` decides which result would be yield. If not
 /// given, yield all `opResult` of fused producer.
-LogicalResult yieldReplacementForFusedProducer(
+///
+/// The method returns the list of new slices added during the process (which
+/// can be used to fuse along).
+FailureOr<SmallVector<Operation *>> yieldReplacementForFusedProducer(
     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
     MutableArrayRef<LoopLikeOpInterface> loops,
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h
index 2f51496d1b110a..b33aa1489c3116 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.h
+++ b/mlir/include/mlir/Interfaces/TilingInterface.h
@@ -25,12 +25,15 @@ namespace mlir {
 
 /// Container for result values of tiling.
 /// - `tiledOps` contains operations created by the tiling implementation that
-/// are returned to the caller for further transformations.
+///   are returned to the caller for further transformations.
 /// - `tiledValues` contains the tiled value corresponding to the result of the
-/// untiled operation.
+///   untiled operation.
+/// - `generatedSlices` contains the list of slices that are generated during
+///   tiling. These slices can be used for fusing producers.
 struct TilingResult {
   SmallVector<Operation *> tiledOps;
   SmallVector<Value> tiledValues;
+  SmallVector<Operation *> generatedSlices;
 };
 
 /// Container for the result of merge operation of tiling.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 630985d76a0ebf..c0d09638b477bf 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -67,20 +67,20 @@ static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v,
 
 /// Returns a memref.subview or a tensor.extract_slice based on the type of the
 /// `source`.
-static Value getSlice(OpBuilder &b, Location loc, Value source,
-                      ArrayRef<OpFoldResult> offsets,
-                      ArrayRef<OpFoldResult> sizes,
-                      ArrayRef<OpFoldResult> strides) {
-  return TypeSwitch<Type, Value>(source.getType())
-      .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
+static Operation *getSlice(OpBuilder &b, Location loc, Value source,
+                           ArrayRef<OpFoldResult> offsets,
+                           ArrayRef<OpFoldResult> sizes,
+                           ArrayRef<OpFoldResult> strides) {
+  return TypeSwitch<Type, Operation *>(source.getType())
+      .Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
         return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
                                                 strides);
       })
-      .Case<MemRefType>([&](MemRefType type) -> Value {
+      .Case<MemRefType>([&](MemRefType type) -> Operation * {
         return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
                                            strides);
       })
-      .Default([&](Type t) { return nullptr; });
+      .Default([&](Type t) -> Operation * { return nullptr; });
 }
 
 //===----------------------------------------------------------------------===//
@@ -2634,10 +2634,18 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
   auto oneAttr = builder.getI64IntegerAttr(1);
   SmallVector<OpFoldResult> strides(rank, oneAttr);
   SmallVector<Value> tiledOperands;
-  tiledOperands.emplace_back(
-      getSlice(builder, getLoc(), getInput(), offsets, sizes, strides));
-  tiledOperands.emplace_back(
-      getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides));
+  Operation *inputSlice =
+      getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
+  if (!inputSlice) {
+    return emitOpError("failed to compute input slice");
+  }
+  tiledOperands.emplace_back(inputSlice->getResult(0));
+  Operation *outputSlice =
+      getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
+  if (!outputSlice) {
+    return emitOpError("failed to compute output slice");
+  }
+  tiledOperands.emplace_back(outputSlice->getResult(0));
 
   SmallVector<Type, 4> resultTypes;
   if (hasPureTensorSemantics())
@@ -2645,7 +2653,9 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
   Operation *tiledOp =
       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
 
-  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+  return TilingResult{{tiledOp},
+                      SmallVector<Value>(tiledOp->getResults()),
+                      {inputSlice, outputSlice}};
 }
 
 LogicalResult SoftmaxOp::getResultTilePosition(
@@ -2992,8 +3002,9 @@ FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
   int64_t filterRank = getFilterOperandRank();
   SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
   Location loc = getLoc();
-  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
-      loc, getFilter(), sliceOffsets, sliceSizes, filterStrides));
+  auto filterSlice = builder.create<tensor::ExtractSliceOp>(
+      loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
+  tiledOperands.emplace_back(filterSlice);
 
   SmallVector<OpFoldResult> resultOffsets, resultSizes;
   if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
@@ -3002,15 +3013,19 @@ FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
 
   int64_t outputRank = getOutputOperandRank();
   SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
-  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
-      loc, getOutput(), resultOffsets, resultSizes, outputStrides));
+  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
+      loc, getOutput(), resultOffsets, resultSizes, outputStrides);
+  tiledOperands.emplace_back(outputSlice);
 
   SmallVector<Type> resultTypes;
   resultTypes.push_back(tiledOperands[1].getType());
   Operation *tiledOp =
       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
 
-  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+  return TilingResult{
+      {tiledOp},
+      SmallVector<Value>(tiledOp->getResults()),
+      llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
 }
 
 //===----------------------------------------------------------------------===//
@@ -3159,8 +3174,9 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
       {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
   int64_t inputRank = getInputOperandRank();
   SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
-  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
-      loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
+  auto inputSlice = builder.create<tensor::ExtractSliceOp>(
+      loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
+  tiledOperands.emplace_back(inputSlice);
 
   SmallVector<OpFoldResult> resultOffsets, resultSizes;
   if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
@@ -3169,15 +3185,19 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
 
   int64_t outputRank = getOutputOperandRank();
   SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
-  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
-      loc, getOutput(), resultOffsets, resultSizes, outputStrides));
+  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
+      loc, getOutput(), resultOffsets, resultSizes, outputStrides);
+  tiledOperands.emplace_back(outputSlice);
 
   SmallVector<Type> resultTypes;
   resultTypes.push_back(tiledOperands[1].getType());
   Operation *tiledOp =
       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
 
-  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+  return TilingResult{
+      {tiledOp},
+      SmallVector<Value>(tiledOp->getResults()),
+      llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
 }
 
 //===----------------------------------------------------------------------===//
@@ -3321,8 +3341,9 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
                      sizes[getValueFDim()]});
   int64_t valueRank = getValueOperandRank();
   SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
-  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
-      loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
+  auto valueSlice = builder.create<tensor::ExtractSliceOp>(
+      loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
+  tiledOperands.emplace_back(valueSlice);
 
   SmallVector<OpFoldResult> resultOffsets, resultSizes;
   if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
@@ -3331,15 +3352,19 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
 
   int64_t outputRank = getOutputOperandRank();
   SmallVector<OpFoldResult> strides(outputRank, oneAttr);
-  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
-      loc, getOutput(), resultOffsets, resultSizes, strides));
+  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
+      loc, getOutput(), resultOffsets, resultSizes, strides);
+  tiledOperands.emplace_back(outputSlice);
 
   SmallVector<Type> resultTypes;
   resultTypes.push_back(tiledOperands[1].getType());
   Operation *tiledOp =
       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
 
-  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+  return TilingResult{
+      {tiledOp},
+      SmallVector<Value>(tiledOp->getResults()),
+      llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index fbff91a94219cc..cf5ca9aa2b0e04 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -120,8 +120,16 @@ struct LinalgOpTilingInterface
     Location loc = op->getLoc();
     LinalgOp linalgOp = cast<LinalgOp>(op);
     SmallVector<Value> valuesToTile = linalgOp->getOperands();
-    SmallVector<Value, 4> tiledOperands = makeTiledShapes(
+    SmallVector<Value> tiledOperands = makeTiledShapes(
         b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
+    SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
+        llvm::make_filter_range(
+            tiledOperands,
+            [](Value v) -> bool {
+              return isa<tensor::ExtractSliceOp, memref::SubViewOp>(
+                  v.getDefiningOp());
+            }),
+        [](Value v) -> Operation * { return v.getDefiningOp(); });
 
     SmallVector<Type> resultTensorTypes =
         getTensorOutputTypes(linalgOp, tiledOperands);
@@ -129,7 +137,8 @@ struct LinalgOpTilingInterface
     Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands);
     offsetIndices(b, cast<LinalgOp>(tiledOp), offsets);
 
-    return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+    return TilingResult{
+        {tiledOp}, SmallVector<Value>(tiledOp->getResults()), generatedSlices};
   }
 
   /// Utility to fetch the offsets and sizes when applied as per the indexing
@@ -260,7 +269,8 @@ struct LinalgOpTilingInterface
 
     return TilingResult{
         tilingResult->tiledOps,
-        SmallVector<Value>{tilingResult->tiledValues[resultNumber]}};
+        SmallVector<Value>{tilingResult->tiledValues[resultNumber]},
+        tilingResult->generatedSlices};
   }
 
   /// Method to generate the tiled implementation of an operation from the tile
@@ -406,8 +416,12 @@ struct LinalgOpPartialReductionInterface
     }
 
     // Step 2a: Extract a slice of the input operands.
-    SmallVector<Value, 4> tiledInputs = makeTiledShapes(
+    SmallVector<Value> tiledInputs = makeTiledShapes(
         b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {}, true);
+    SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
+        llvm::make_filter_range(
+            tiledInputs, [](Value v) -> bool { return v.getDefiningOp(); }),
+        [](Value v) -> Operation * { return v.getDefiningOp(); });
 
     // Step 2b: Extract a slice of the init operands.
     SmallVector<Value, 1> tiledInits;
@@ -424,6 +438,7 @@ struct LinalgOpPartialReductionInterface
       auto extractSlice = b.create<tensor::ExtractSliceOp>(
           loc, valueToTile, initOffset, initSizes, initStride);
       tiledInits.push_back(extractSlice);
+      generatedSlices.push_back(extractSlice);
     }
 
     // Update the indexing maps.
@@ -453,7 +468,8 @@ struct LinalgOpPartialReductionInterface
     return TilingResult{
         {genericOp.getOperation()},
         llvm::map_to_vector(genericOp->getResults(),
-                            [](OpResult r) -> Value { return r; })};
+                            [](OpResult r) -> Value { return r; }),
+        generatedSlices};
   }
 
   FailureOr<MergeResult> mergeReductions(Operation *op, OpBuilder &b,
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index fa0598dd96885c..6a3f2fc5fbc496 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -565,9 +565,9 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
   assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
 }
 
-static Value materializeTiledShape(OpBuilder &builder, Location loc,
-                                   Value valueToTile,
-                                   const SliceParameters &sliceParams) {
+static Operation *materializeTiledShape(OpBuilder &builder, Location loc,
+                                        Value valueToTile,
+                                        const SliceParameters &sliceParams) {
   auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
   auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
                       .Case([&](MemRefType) {
@@ -583,14 +583,15 @@ static Value materializeTiledShape(OpBuilder &builder, Location loc,
                       .Default([](ShapedType) -> Operation * {
                         llvm_unreachable("Unexpected shaped type");
                       });
-  return sliceOp->getResult(0);
+  return sliceOp;
 }
 
-Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
-                     ArrayRef<OpFoldResult> tileSizes, AffineMap map,
-                     ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
-                     ArrayRef<OpFoldResult> subShapeSizes,
-                     bool omitPartialTileCheck) {
+Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
+                          ArrayRef<OpFoldResult> tileSizes, AffineMap map,
+                          ArrayRef<OpFoldResult> lbs,
+                          ArrayRef<OpFoldResult> ubs,
+                          ArrayRef<OpFoldResult> subShapeSizes,
+                          bool omitPartialTileCheck) {
   SliceParameters sliceParams =
       computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs,
                              ubs, subShapeSizes, omitPartialTileCheck);
@@ -841,6 +842,7 @@ SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc,
     tiledShapes.push_back(
         sliceParams.has_value()
             ? materializeTiledShape(builder, loc, valueToTile, *sliceParams)
+                  ->getResult(0)
             : valueToTile);
   }
   return tiledShapes;
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index e404c01010a325..3729300588422e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -854,7 +854,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
     if (llvm::all_of(tileSizes, isZeroIndex)) {
       tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
       tilingResult =
-          TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults()};
+          TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
+                       /*generatedSlices=*/{}};
       return success();
     }
 
@@ -910,12 +911,14 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
   // op.
   if (loops.empty()) {
     return scf::SCFTilingResult{tilingResult->tiledOps, loops,
-                                tilingResult->tiledValues};
+                                tilingResult->tiledValues,
+                                tilingResult->generatedSlices};
   }
 
   SmallVector<Value> replacements = llvm::map_to_vector(
       loops.front()->getResults(), [](OpResult r) -> Value { return r; });
-  return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements};
+  return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements,
+                              tilingResult->generatedSlices};
 }
 
 FailureOr<scf::SCFReductionTilingResult>
@@ -1180,13 +1183,13 @@ mlir::scf::tileAndFuseProducerOfSlice(
         ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
         .set(origDestinationTensors[resultNumber]);
   }
-  return scf::SCFFuseProducerOfSliceResult{fusableProducer,
-                                           tileAndFuseResult->tiledValues[0],
-                                           tileAndFuseResult->tiledOps};
+  return scf::SCFFuseProducerOfSliceResult{
+      fusableProducer, tileAndFuseResult->tiledValues[0],
+      tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices};
 }
 
 /// Reconstruct the fused producer from within the tiled-and-fused code.
-LogicalResult mlir::scf::yieldReplacementForFusedProducer(
+FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
     MutableArrayRef<LoopLikeOpInterface> loops,
@@ -1214,6 +1217,7 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer(
     }
   }
 
+  SmallVector<Operation *> generatedSlices;
   YieldTiledValuesFn newYieldValuesFn =
       [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
           ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
@@ -1284,6 +1288,7 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer(
             loc, newRegionArg, offsetList[index], sizesList[index],
             SmallVector<OpFoldResult>(offsetList[index].size(),
                                       rewriter.getIndexAttr(1)));
+        generatedSlices.push_back(destSlice);
         unsigned resultNumber = initNumberList[index];
         rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
           tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
@@ -1303,8 +1308,11 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer(
     return success();
   };
 
-  return addInitOperandsToLoopNest(rewriter, loops, initValueList,
-                                   newYieldValuesFn);
+  if (failed(addInitOperandsToLoopNest(rewriter, loops, initValueList,
+                                       newYieldValuesFn))) {
+    return failure();
+  }
+  return generatedSlices;
 }
 
 /// Implementation of tile consumer and fuse producer greedily.
@@ -1358,52 +1366,62 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
   //    operations. If the producers of the source of the `tensor.extract_slice`
   //    can be tiled such that the tiled value is generated in-place, that
   //    effectively tiles + fuses the operations.
-  auto addCandidateSlices = [](Operation *fusedOp,
-                               std::deque<tensor::ExtractSliceOp> &candidates) {
-    for (Value operand : fusedOp->getOperands())
-      if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
-        candidates.push_back(sliceOp);
+  struct WorklistItem {
+    tensor::ExtractSliceOp candidateSlice;
+    SCFTileAndFuseOptions::ControlFnResult controlFnResult;
+  };
+  std::deque<WorklistItem> worklist;
+  auto addCandidateSlices = [&worklist, &options,
+                             &loops](ArrayRef<Operation *> candidates) {
+    for (auto candidate : candidates) {
+      auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(candidate);
+      if (!sliceOp || sliceOp.use_empty())
+        continue;
+
+      auto [fusableProducer, destinationInitArg] =
+          getUntiledProducerFromSliceSource(&sliceOp.getSourceMutable(), loops);
+      if (!fusableProducer)
+        continue;
+      std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
+          options.fusionControlFn(sliceOp, fusableProducer,
+                                  destinationInitArg.has_value());
+      if (!controlFnResult)
+        continue;
+      worklist.emplace_back(WorklistItem{sliceOp, controlFnResult.value()});
+    }
   };
 
-  std::deque<tensor::ExtractSliceOp> candidates;
-  addCandidateSlices(tiledAndFusedOps.back(), candidates);
+  addCandidateSlices(tilingResult->generatedSlices);
   OpBuilder::InsertionGuard g(rewriter);
-  while (!candidates.empty()) {
+  while (!worklist.empty()) {
     // Traverse the slices in BFS fashion.
-    tensor::ExtractSliceOp candidateSliceOp = candidates.front();
-    candidates.pop_front();
-
-    // Find the original producer of the slice.
-    auto [fusableProducer, destinationInitArg] =
-        getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
-                                          loops);
-    if (!fusableProducer)
-      continue;
-
-    auto [fuseSlice, yieldReplacement] = options.fusionControlFn(
-        candidateSliceOp, fusableProducer, destinationInitArg.has_value());
-    if (!fuseSlice)
-      continue;
+    WorklistItem worklistItem = worklist.front();
+    worklist.pop_front();
 
     // The operands of the fused producer might themselved be slices of
     // values produced by operations that implement the `TilingInterface`.
     // Add these operations to the worklist.
     std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
-        tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, loops);
+        tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice,
+                                   loops);
     if (!fusedResult)
       continue;
 
-    if (yieldReplacement) {
+    if (worklistItem.controlFnResult.yieldProducerReplacement) {
       // Reconstruct and yield all opResult of fusableProducerOp by default. The
       // caller can specific which one to yield by designating optional argument
       // named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
-      Operation *fusableProducerOp = fusableProducer.getOwner();
-      if (failed(yieldReplacementForFusedProducer(
-              rewriter, candidateSliceOp, fusedResult.value(), loops))) {
+      Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
+      FailureOr<SmallVector<Operation *>> newSlices =
+          yieldReplacementForFusedProducer(rewriter,
+                                           worklistItem.candidateSlice,
+                                           fusedResult.value(), loops);
+      if (failed(newSlices)) {
         return rewriter.notifyMatchFailure(
             fusableProducerOp, "failed to replacement value for this "
                                "operation from within the tiled loop");
       }
+      addCandidateSlices(newSlices.value());
       for (auto [index, result] :
            llvm::enumerate(fusableProducerOp->getResults())) {
         origValToResultNumber[result] = loops.front()->getNumResults() -
@@ -1411,12 +1429,11 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
                                         index;
       }
     }
-
+    addCandidateSlices(fusedResult->generatedSlices);
     if (Operation *tiledAndFusedOp =
             fusedResult->tiledAndFusedProducer.getDefiningOp()) {
       fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
       tiledAndFusedOps.insert(tiledAndFusedOp);
-      addCandidateSlices(tiledAndFusedOp, candidates);
     }
   }
 
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 9e17184ebed794..104d6ae1f9f6b5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -187,8 +187,9 @@ struct PackOpTiling
     SmallVector<OpFoldResult> strides(inputRank, oneAttr);
 
     SmallVector<Value> tiledOperands;
-    tiledOperands.push_back(b.create<ExtractSliceOp>(
-        loc, packOp.getSource(), inputIndices, inputSizes, strides));
+    auto sourceSlice = b.create<ExtractSliceOp>(
+        loc, packOp.getSource(), inputIndices, inputSizes, strides);
+    tiledOperands.push_back(sourceSlice);
 
     SmallVector<OpFoldResult> outputOffsets, outputSizes;
     if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets,
@@ -196,9 +197,9 @@ struct PackOpTiling
       return {};
 
     strides.append(packOp.getDestRank() - inputRank, oneAttr);
-    auto extractSlice = b.create<ExtractSliceOp>(
+    auto outSlice = b.create<ExtractSliceOp>(
         loc, packOp.getDest(), outputOffsets, outputSizes, strides);
-    tiledOperands.push_back(extractSlice);
+    tiledOperands.push_back(outSlice);
 
     if (auto val = packOp.getPaddingValue())
       tiledOperands.push_back(val);
@@ -206,10 +207,12 @@ struct PackOpTiling
       tiledOperands.push_back(tile);
 
     Operation *tiledPackOp = b.create<PackOp>(
-        loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs());
+        loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
 
-    return TilingResult{{tiledPackOp},
-                        SmallVector<Value>(tiledPackOp->getResults())};
+    return TilingResult{
+        {tiledPackOp},
+        SmallVector<Value>(tiledPackOp->getResults()),
+        llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})};
   }
 
   LogicalResult
@@ -348,8 +351,9 @@ struct PackOpTiling
     SmallVector<OpFoldResult> strides(inputRank, oneAttr);
 
     SmallVector<Value> tiledOperands;
-    tiledOperands.push_back(b.create<ExtractSliceOp>(loc, packOp.getSource(),
-                                                     offsets, sizes, strides));
+    auto sourceSlice = b.create<ExtractSliceOp>(loc, packOp.getSource(),
+                                                offsets, sizes, strides);
+    tiledOperands.push_back(sourceSlice);
 
     SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
     if (failed(getIterationDomainTileFromOperandTile(
@@ -363,19 +367,21 @@ struct PackOpTiling
       return failure();
 
     strides.append(packOp.getDestRank() - inputRank, oneAttr);
-    auto extractSlice = b.create<ExtractSliceOp>(
+    auto outSlice = b.create<ExtractSliceOp>(
         loc, packOp.getDest(), outputOffsets, outputSizes, strides);
-    tiledOperands.push_back(extractSlice);
+    tiledOperands.push_back(outSlice);
 
     assert(!packOp.getPaddingValue() && "Expect no padding semantic");
     for (auto tile : packOp.getInnerTiles())
       tiledOperands.push_back(tile);
 
     Operation *tiledPackOp = b.create<PackOp>(
-        loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs());
+        loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
 
-    return TilingResult{{tiledPackOp},
-                        SmallVector<Value>(tiledPackOp->getResults())};
+    return TilingResult{
+        {tiledPackOp},
+        SmallVector<Value>(tiledPackOp->getResults()),
+        llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})};
   }
 };
 
@@ -554,9 +560,12 @@ struct UnPackOpTiling
 
     SmallVector<OpFoldResult> destStrides(destRank, oneAttr);
     Value sliceDest;
+    SmallVector<Operation *> generatedSlices;
     if (isPerfectTilingCase) {
-      sliceDest = b.create<ExtractSliceOp>(loc, unpackOp.getDest(), offsets,
-                                           sizes, destStrides);
+      auto destSliceOp = b.create<ExtractSliceOp>(loc, unpackOp.getDest(),
+                                                  offsets, sizes, destStrides);
+      sliceDest = destSliceOp;
+      generatedSlices.push_back(destSliceOp);
     } else {
       sliceDest = b.create<EmptyOp>(loc, destExpandedSizes,
                                     unpackOp.getDestType().getElementType());
@@ -571,12 +580,15 @@ struct UnPackOpTiling
 
     if (isPerfectTilingCase)
       return TilingResult{{tiledUnpackOp},
-                          SmallVector<Value>(tiledUnpackOp->getResults())};
+                          SmallVector<Value>(tiledUnpackOp->getResults()),
+                          generatedSlices};
 
     auto extractSlice =
         b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0),
                                  resultOffsetsFromDest, sizes, destStrides);
-    return TilingResult{{tiledUnpackOp}, {extractSlice.getResult()}};
+    generatedSlices.push_back(extractSlice);
+    return TilingResult{
+        {tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices};
   }
 
   LogicalResult
@@ -697,7 +709,9 @@ struct UnPackOpTiling
                            tiledOperands, op->getAttrs());
 
     return TilingResult{{tiledUnPackOp},
-                        SmallVector<Value>(tiledUnPackOp->getResults())};
+                        SmallVector<Value>(tiledUnPackOp->getResults()),
+                        llvm::to_vector(ArrayRef<Operation *>{
+                            extractSourceSlice, extractDestSlice})};
   }
 };
 
@@ -867,7 +881,7 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
   // the result shape of the new SliceOp has a zero dimension.
   auto createPadOfExtractSlice = [&]() {
     // Create pad(extract_slice(x)).
-    Value newSliceOp = b.create<tensor::ExtractSliceOp>(
+    auto newSliceOp = b.create<tensor::ExtractSliceOp>(
         loc, padOp.getSource(), newOffsets, newLengths, newStrides);
     auto newPadOp = b.create<PadOp>(
         loc, Type(), newSliceOp, newLows, newHighs,
@@ -879,14 +893,16 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
     padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm);
 
     // Cast result and return.
-    return newPadOp;
+    return std::make_tuple(newPadOp, newSliceOp);
   };
 
   // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that
   // the original data source x is not used.
   if (hasZeroLen) {
     Operation *generateOp = createGenerateOp();
-    return TilingResult{{generateOp}, {castResult(generateOp->getResult(0))}};
+    return TilingResult{{generateOp},
+                        {castResult(generateOp->getResult(0))},
+                        /*generatedSlices=*/{}};
   }
 
   // If there are dynamic dimensions: Generate an scf.if check to avoid
@@ -894,6 +910,7 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
   if (generateZeroSliceGuard && dynHasZeroLenCond) {
     Operation *thenOp;
     Operation *elseOp;
+    Operation *sliceOp;
     auto result = b.create<scf::IfOp>(
         loc, dynHasZeroLenCond,
         /*thenBuilder=*/
@@ -903,14 +920,16 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
         },
         /*elseBuilder=*/
         [&](OpBuilder &b, Location loc) {
-          elseOp = createPadOfExtractSlice();
+          std::tie(elseOp, sliceOp) = createPadOfExtractSlice();
           b.create<scf::YieldOp>(loc, castResult(elseOp->getResult(0)));
         });
-    return TilingResult{{elseOp}, SmallVector<Value>(result->getResults())};
+    return TilingResult{
+        {elseOp}, SmallVector<Value>(result->getResults()), {sliceOp}};
   }
 
-  Operation *newPadOp = createPadOfExtractSlice();
-  return TilingResult{{newPadOp}, {castResult(newPadOp->getResult(0))}};
+  auto [newPadOp, sliceOp] = createPadOfExtractSlice();
+  return TilingResult{
+      {newPadOp}, {castResult(newPadOp->getResult(0))}, {sliceOp}};
 }
 
 void mlir::tensor::registerTilingInterfaceExternalModels(
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
index d1aed593f45451..3ea1929e4ed785 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
@@ -542,3 +542,48 @@ module attributes {transform.with_named_sequence} {
 //   CHECK-DAG:     %[[INSERTSLICE:.+]] = tensor.insert_slice %[[GENERIC2]] into %[[ITERARG0]][%[[IV]], 0]
 //       CHECK:     scf.yield %[[INSERTSLICE]]
 //       CHECK:   return %[[RESULT]]
+
+// -----
+
+func.func @pad_producer_fusion(%arg0 : tensor<10xf32>) -> tensor<16xf32> {
+  %0 = tensor.empty() : tensor<10xf32>
+  %1 = linalg.generic {
+      indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+      iterator_types = ["parallel"]}
+      ins(%arg0 : tensor<10xf32>) outs(%0 : tensor<10xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32):
+      %2 = arith.addf %b0, %b0: f32
+      linalg.yield %2 : f32
+  } -> tensor<10xf32>
+  %cst = arith.constant 0.0 : f32
+  %2 = tensor.pad %1 low[4] high[2] {
+    ^bb0(%arg1 : index):
+      tensor.yield %cst : f32
+  } : tensor<10xf32> to tensor<16xf32>
+  return %2 : tensor<16xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %pad = transform.structured.match ops{["tensor.pad"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.structured.fuse %pad [8]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-LABEL: func @pad_producer_fusion
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<10xf32>
+//       CHECK:   %[[FOR_RESULT:.+]] = scf.for
+//       CHECK:     %[[IF_RESULT:.+]] = scf.if
+//       CHECK:     else
+//       CHECK:       %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
+//       CHECK:       %[[GENERIC:.+]] = linalg.generic
+//  CHECK-SAME:           ins(%[[SLICE]] :
+//       CHECK:       %[[PAD:.+]] = tensor.pad %[[GENERIC]]
+//       CHECK:       %[[CAST:.+]] = tensor.cast %[[PAD]]
+//       CHECK:       scf.yield %[[CAST]]
+//       CHECK:     %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[IF_RESULT]]
+//       CHECK:     scf.yield %[[INSERT_SLICE]]
+//       CHECK:   return %[[FOR_RESULT]]
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 7aa7b58433f36c..b6da47977cb4cf 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -91,11 +91,13 @@ applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
 
     scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
         [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
-            bool isDestinationOperand) {
-          Operation *owner = originalProducer.getOwner();
-          bool yieldProducerReplacement = yieldReplacementsFor.contains(owner);
-          return std::make_tuple(true, yieldProducerReplacement);
-        };
+            bool isDestinationOperand)
+        -> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> {
+      Operation *owner = originalProducer.getOwner();
+      bool yieldProducerReplacement = yieldReplacementsFor.contains(owner);
+      return scf::SCFTileAndFuseOptions::ControlFnResult{
+          yieldProducerReplacement};
+    };
     tileAndFuseOptions.setFusionControlFn(controlFn);
 
     rewriter.setInsertionPoint(target);



More information about the Mlir-commits mailing list