[Mlir-commits] [mlir] [NFC] Simplify the tiling implementation using cloning. (PR #72178)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 13 16:42:57 PST 2023


https://github.com/MaheshRavishankar created https://github.com/llvm/llvm-project/pull/72178

The current implementation of tiling using `scf.for` is convoluted to make sure that the destination passing style of the untiled program is preserved. The addition of support to tile using `scf.forall` (adapted from the transform operation in Linalg) in https://github.com/llvm/llvm-project/pull/67083 used cloning of the tiled operations to better streamline the implementation. This PR adapts the other tiling methods to use a similar approach, making the transformations (and handling destination passing style semantics) more systematic.

>From 391c974b419e09d001d65c1b73c514d2c49c01fa Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh at nod-labs.com>
Date: Wed, 18 Oct 2023 23:40:06 -0700
Subject: [PATCH 1/5] Clone operation before tiling them.

Cloning operations, updating the destination operands and then tiling
them makes the logic of tiling much simpler, and removes some very
hard to reason paths of the code.
---
 .../SCF/Transforms/TileUsingInterface.cpp     | 117 ++++++++++------
 .../TilingInterface/tile-using-interface.mlir | 126 ++++++++++++------
 .../TilingInterface/TestTilingInterface.cpp   |   2 +
 3 files changed, 162 insertions(+), 83 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index df162d29a48eb89..358740c8826c4b0 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -128,6 +128,9 @@ static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
                                                   Operation *op,
                                                   ValueRange newDestArgs) {
   Operation *clonedOp = rewriter.clone(*op);
+  if (newDestArgs.empty()) {
+    return clonedOp;
+  }
   if (auto destinationStyleOp =
           dyn_cast<DestinationStyleOpInterface>(clonedOp)) {
     destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
@@ -139,15 +142,17 @@ static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of
-/// the
-///   tile processed within the inner most loop.
+///   the tile processed within the inner most loop.
+/// Note that this methods adds `scf.yield` operation for all but the innermost
+/// loop. These yield the value returned by the immediately inner loop. The
+/// caller is expected to add the scf.yield operation for the innermost loop.
 static SmallVector<scf::ForOp> generateTileLoopNest(
     OpBuilder &builder, Location loc, ArrayRef<Range> loopRanges,
     ArrayRef<OpFoldResult> tileSizes, SmallVector<OpFoldResult> &offsets,
-    SmallVector<OpFoldResult> &sizes) {
-  assert(!loopRanges.empty() && "expected at least one loop range");
-  assert(loopRanges.size() == tileSizes.size() &&
-         "expected as many tile sizes as loop ranges");
+    SmallVector<OpFoldResult> &sizes, ValueRange destinationTensors = {}) {
+  if (loopRanges.empty()) {
+    return {};
+  }
   OpBuilder::InsertionGuard guard(builder);
   SmallVector<scf::ForOp> loops;
   offsets.resize(loopRanges.size());
@@ -169,17 +174,25 @@ static SmallVector<scf::ForOp> generateTileLoopNest(
     }
 
     auto loop = builder.create<scf::ForOp>(
-        loc, offset, size, tileSize, ValueRange{},
+        loc, offset, size, tileSize, destinationTensors,
         [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
             ValueRange /*iterArgs*/) {
           sizes[loopRange.index()] =
               getBoundedTileSize(bodyBuilder, bodyLoc, loopRange.value(), iv,
                                  getAsOpFoldResult(tileSize));
-          builder.create<scf::YieldOp>(loc);
         });
     offsets[loopRange.index()] = loop.getInductionVar();
     loops.push_back(loop);
-    builder.setInsertionPoint(loop.getBody()->getTerminator());
+    builder.setInsertionPointToEnd(loop.getBody());
+    destinationTensors = loop.getRegionIterArgs();
+  }
+
+  // Add the scf.yield operations for all the outer loops.
+  for (auto [outerLoop, innerLoop] :
+       llvm::zip(MutableArrayRef(loops).drop_back(),
+                 MutableArrayRef(loops).drop_front())) {
+    builder.setInsertionPointToEnd(outerLoop.getBody());
+    builder.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop.getResults());
   }
   return loops;
 }
@@ -317,10 +330,6 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
   // 1. Get the range of the loops that are represented by the operation.
   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
   size_t numLoops = iterationDomain.size();
-  if (numLoops == 0) {
-    return rewriter.notifyMatchFailure(
-        op, "unable to tile op with no iteration domain");
-  }
 
   // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
   // skips tiling a particular dimension. This convention is significantly
@@ -333,6 +342,14 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
     tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
   }
 
+  // 3. Find the destination tensors to use for the operation.
+  SmallVector<Value> destinationTensors;
+  if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
+                                             destinationTensors))) {
+    return rewriter.notifyMatchFailure(op,
+                                       "unable to create destination tensors");
+  }
+
   SmallVector<OpFoldResult> offsets, sizes;
   SmallVector<scf::ForOp> forLoops;
   {
@@ -354,11 +371,12 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
       applyPermutationToVector(tileSizeVector, interchangeVector);
     }
 
-    // 3. Materialize an empty loop nest that iterates over the tiles. These
+    // 4. Materialize an empty loop nest that iterates over the tiles. These
     // loops for now do not return any values even if the original operation has
     // results.
     forLoops = generateTileLoopNest(rewriter, op.getLoc(), iterationDomain,
-                                    tileSizeVector, offsets, sizes);
+                                    tileSizeVector, offsets, sizes,
+                                    destinationTensors);
 
     if (!interchangeVector.empty()) {
       auto inversePermutation = invertPermutationVector(interchangeVector);
@@ -375,17 +393,29 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
     }
   });
 
-  // 4. Generate the tiled implementation within the inner most loop.
-  if (!forLoops.empty())
-    rewriter.setInsertionPoint(forLoops.back().getBody()->getTerminator());
-  FailureOr<TilingResult> tiledImplementation =
-      op.getTiledImplementation(rewriter, offsets, sizes);
+  // 5. Generate the tiled implementation within the inner most loop.
+  SmallVector<Value> clonedOpDestination = destinationTensors;
+  if (!forLoops.empty()) {
+    rewriter.setInsertionPointToEnd(forLoops.back().getBody());
+    clonedOpDestination =
+        llvm::map_to_vector(forLoops.back().getRegionIterArgs(),
+                            [](BlockArgument b) -> Value { return b; });
+  }
 
-  if (op->getNumResults() == 0) {
-    return scf::SCFTilingResult{
-        tiledImplementation->tiledOps, getAsOperations(forLoops), {}};
+  // 5a. Clone the operation within the loop body.
+  auto clonedOp = cast<TilingInterface>(
+      cloneOpAndUpdateDestinationArgs(rewriter, op, clonedOpDestination));
+
+  // 5b. Tile the cloned operation.
+  FailureOr<TilingResult> tiledImplementation =
+      clonedOp.getTiledImplementation(rewriter, offsets, sizes);
+  if (failed(tiledImplementation)) {
+    return rewriter.notifyMatchFailure(op, "failed to tile operation");
   }
 
+  // 5c. Delete the cloned operation.
+  rewriter.eraseOp(clonedOp);
+
   // If loops are empty, the tiled op is used as the replacement for the untiled
   // op.
   if (forLoops.empty()) {
@@ -394,30 +424,39 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
                                 tiledImplementation->tiledValues};
   }
 
-  // 5. Yield all the results of the tiled operation. The surrounding loop
-  //    nest is modified to insert a destructive update pattern to yield
-  //    from the loop nest values to replace the untiled op with.
+  if (op->getNumResults() == 0) {
+    // The innermost loop does not have a `scf.yield` yet. There is nothing to
+    // return, so generate an empty `scf.yield` operation.
+    rewriter.setInsertionPointToEnd(forLoops.back().getBody());
+    rewriter.create<scf::YieldOp>(op->getLoc());
+    return scf::SCFTilingResult{
+        tiledImplementation->tiledOps, getAsOperations(forLoops), {}};
+  }
+
+  // 6. Yield all the results of the tiled operation.
   int64_t numResults = op->getNumResults();
   SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults),
       resultSizesList(numResults);
-  for (const auto &result : llvm::enumerate(op->getResults())) {
-    if (failed(op.getResultTilePosition(rewriter, result.index(), offsets,
-                                        sizes,
-                                        resultOffsetsList[result.index()],
-                                        resultSizesList[result.index()]))) {
+  SmallVector<Value> yieldedValues;
+  for (auto [index, tiledValue] :
+       llvm::enumerate(tiledImplementation->tiledValues)) {
+    SmallVector<OpFoldResult> resultOffsets, resultSizes;
+    if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
+                                        resultOffsets, resultSizes))) {
       return rewriter.notifyMatchFailure(
           op, "failed to get slice of result produced");
     }
+    SmallVector<OpFoldResult> resultStrides(resultOffsets.size(),
+                                            rewriter.getIndexAttr(1));
+    auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
+        op->getLoc(), tiledValue, clonedOpDestination[index], resultOffsets,
+        resultSizes, resultStrides);
+    yieldedValues.push_back(insertSlice);
   }
+  rewriter.create<scf::YieldOp>(op->getLoc(), yieldedValues);
 
-  SmallVector<Value> destinationTensors;
-  if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
-                                             destinationTensors)))
-    return rewriter.notifyMatchFailure(op, "failed to get destinations");
-
-  SmallVector<Value> replacements = yieldTiledValues(
-      rewriter, destinationTensors, tiledImplementation.value(),
-      resultOffsetsList, resultSizesList, forLoops);
+  SmallVector<Value> replacements = llvm::map_to_vector(
+      forLoops.front().getResults(), [](OpResult r) -> Value { return r; });
   LLVM_DEBUG({
     if (!forLoops.empty()) {
       llvm::dbgs() << "After tiled implementation :\n";
diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
index 2153eb6f237fcfd..3dcd840a4235a89 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
@@ -100,37 +100,37 @@ func.func @multi_result(%arg0 : tensor<128x200x300xf32>) -> (tensor<128x300x200x
     } -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>)
   return %0#0, %0#1 : tensor<128x300x200xf32>, tensor<300x128x200xf32>
 }
-//  CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0) -> (10, -d0 + 128)>
-//      CHECK-LABEL: func.func @multi_result(
-// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>)
-//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
-//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
-//  CHECK-DAG:   %[[C20:.+]] = arith.constant 20 : index
-//  CHECK-DAG:   %[[C128:.+]] = arith.constant 128 : index
-//  CHECK-DAG:   %[[C300:.+]] = arith.constant 300 : index
-//  CHECK-DAG:   %[[INIT0:.+]] = tensor.empty()
-//  CHECK-DAG:   %[[INIT1:.+]] = tensor.empty()
-//      CHECK:   %[[OUTER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C10]]
-// CHECK-SAME:       iter_args(%[[ARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
-//      CHECK:     %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]])
-//      CHECK:     %[[INNER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C300]] step %[[C20]]
-// CHECK-SAME:         iter_args(%[[ARG3:[a-zA-Z0-9]+]] = %[[ARG1]], %[[ARG4:[a-zA-Z0-9]+]] = %[[ARG2]])
-//  CHECK-DAG:       %[[ARG_TILE:.+]] = tensor.extract_slice %[[ARG0]]
-// CHECK-SAME:           [%[[IV0]], 0, %[[IV1]]] [%[[TS_Y]], 200, 20] [1, 1, 1]
-//  CHECK-DAG:       %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ARG3]]
-// CHECK-SAME:           [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], 20, 200] [1, 1, 1]
-//  CHECK-DAG:       %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ARG4]]
-// CHECK-SAME:           [%[[IV1]], %[[IV0]], 0] [20, %[[TS_Y]], 200] [1, 1, 1]
-//      CHECK:       %[[RESULT_TILE:.+]]:2 = linalg.generic
-// CHECK-SAME:           ins(%[[ARG_TILE]] :
-// CHECK-SAME:           outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
-//      CHECK:       %[[UPDATE0:.+]] = tensor.insert_slice %[[RESULT_TILE]]#0 into %[[ARG3]]
-// CHECK-SAME:           [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], 20, 200] [1, 1, 1]
-//      CHECK:       %[[UPDATE1:.+]] = tensor.insert_slice %[[RESULT_TILE]]#1 into %[[ARG4]]
-// CHECK-SAME:           [%[[IV1]], %[[IV0]], 0] [20, %[[TS_Y]], 200] [1, 1, 1]
-//      CHECK:       scf.yield %[[UPDATE0]], %[[UPDATE1]]
-//      CHECK:     scf.yield %[[INNER]]#0, %[[INNER]]#1
-//      CHECK:   return %[[OUTER]]#0, %[[OUTER]]#1
+//   CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0) -> (10, -d0 + 128)>
+// CHECK-LABEL: func.func @multi_result(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>)
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
+//   CHECK-DAG:   %[[C20:.+]] = arith.constant 20 : index
+//   CHECK-DAG:   %[[C128:.+]] = arith.constant 128 : index
+//   CHECK-DAG:   %[[C300:.+]] = arith.constant 300 : index
+//   CHECK-DAG:   %[[INIT0:.+]] = tensor.empty()
+//   CHECK-DAG:   %[[INIT1:.+]] = tensor.empty()
+//       CHECK:   %[[OUTER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C10]]
+//  CHECK-SAME:       iter_args(%[[ARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
+//       CHECK:     %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]])
+//       CHECK:     %[[INNER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C300]] step %[[C20]]
+//  CHECK-SAME:         iter_args(%[[ARG3:[a-zA-Z0-9]+]] = %[[ARG1]], %[[ARG4:[a-zA-Z0-9]+]] = %[[ARG2]])
+//   CHECK-DAG:       %[[ARG_TILE:.+]] = tensor.extract_slice %[[ARG0]]
+//  CHECK-SAME:           [%[[IV0]], 0, %[[IV1]]] [%[[TS_Y]], 200, 20] [1, 1, 1]
+//   CHECK-DAG:       %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ARG3]]
+//  CHECK-SAME:           [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], 20, 200] [1, 1, 1]
+//   CHECK-DAG:       %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ARG4]]
+//  CHECK-SAME:           [%[[IV1]], %[[IV0]], 0] [20, %[[TS_Y]], 200] [1, 1, 1]
+//       CHECK:       %[[RESULT_TILE:.+]]:2 = linalg.generic
+//  CHECK-SAME:           ins(%[[ARG_TILE]] :
+//  CHECK-SAME:           outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
+//       CHECK:       %[[UPDATE0:.+]] = tensor.insert_slice %[[RESULT_TILE]]#0 into %[[ARG3]]
+//  CHECK-SAME:           [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], 20, 200] [1, 1, 1]
+//       CHECK:       %[[UPDATE1:.+]] = tensor.insert_slice %[[RESULT_TILE]]#1 into %[[ARG4]]
+//  CHECK-SAME:           [%[[IV1]], %[[IV0]], 0] [20, %[[TS_Y]], 200] [1, 1, 1]
+//       CHECK:       scf.yield %[[UPDATE0]], %[[UPDATE1]]
+//       CHECK:     scf.yield %[[INNER]]#0, %[[INNER]]#1
+//       CHECK:   return %[[OUTER]]#0, %[[OUTER]]#1
 
 // -----
 
@@ -193,14 +193,9 @@ func.func @conv2D(%arg0 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>,
 
 // -----
 
-// CHECK: #[[$MAP_ADD:.+]] = affine_map<(d0, d1) -> (d0 + d1)>
-
-// CHECK-LABEL: @indexed_semantics
 func.func @indexed_semantics(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
   // Check that we correctly amend "linalg.index" results.
 
-  // CHECK: scf.for %[[I0:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
-  // CHECK: scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
   %0 = linalg.generic {
     indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                      affine_map<(d0, d1) -> (d0, d1)>],
@@ -209,13 +204,8 @@ func.func @indexed_semantics(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) ->
     ins(%arg0: tensor<?x?xf32>)
     outs(%arg1: tensor<?x?xf32>) {
   ^bb0(%arg2: f32, %arg3: f32):
-    // CHECK: %[[INDEX0:.+]] = linalg.index 0
-    // CHECK: %[[INDEX0_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX0]], %[[I0]])
     %1 = linalg.index 0 : index
-    // CHECK: %[[INDEX1:.+]] = linalg.index 1
-    // CHECK: %[[INDEX1_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX1]], %[[I1]])
     %2 = linalg.index 1 : index
-    // CHECK: arith.addi %[[INDEX0_AMENDED]], %[[INDEX1_AMENDED]]
     %3 = arith.addi %1, %2 : index
     %4 = arith.index_cast %3 : index to i64
     %5 = arith.uitofp %4 : i64 to f32
@@ -224,6 +214,15 @@ func.func @indexed_semantics(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) ->
   } -> (tensor<?x?xf32>)
   return %0 : tensor<?x?xf32>
 }
+//       CHECK: #[[$MAP_ADD:.+]] = affine_map<(d0, d1) -> (d0 + d1)>
+// CHECK-LABEL: @indexed_semantics
+//       CHECK:   scf.for %[[I0:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
+//       CHECK:     scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
+//       CHECK:       %[[INDEX0:.+]] = linalg.index 0
+//       CHECK:       %[[INDEX0_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX0]], %[[I0]])
+//       CHECK:       %[[INDEX1:.+]] = linalg.index 1
+//       CHECK:       %[[INDEX1_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX1]], %[[I1]])
+//       CHECK:       arith.addi %[[INDEX0_AMENDED]], %[[INDEX1_AMENDED]]
 
 // -----
 
@@ -276,14 +275,53 @@ func.func @interchange_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
 
 // -----
 
+func.func @linalg_copy_matmul(%a: memref<?x?xf32>, %b: memref<?x?xf32>) {
+  linalg.copy {__internal_transform__ = "simple_copy_memref"}
+      ins(%a : memref<?x?xf32>) outs(%b : memref<?x?xf32>)
+  return
+}
 // CHECK-LABEL: func @linalg_copy_matmul(
 //       CHECK:   scf.for
 //       CHECK:     scf.for
 //       CHECK:       memref.subview
 //       CHECK:       memref.subview
 //       CHECK:       linalg.copy
-func.func @linalg_copy_matmul(%a: memref<?x?xf32>, %b: memref<?x?xf32>) {
-  linalg.copy {__internal_transform__ = "simple_copy_memref"}
-      ins(%a : memref<?x?xf32>) outs(%b : memref<?x?xf32>)
+
+// -----
+
+func.func @check_scalar_operation(%arg0 : tensor<f32>) -> tensor<f32> {
+  %init = tensor.empty() : tensor<f32>
+  %0 = linalg.generic {
+      indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>],
+      iterator_types = []}
+      {__internal_transform__ = "scalar_op"}
+      ins(%arg0 : tensor<f32>) outs(%init : tensor<f32>){
+    ^bb0(%b0 : f32, %b1 : f32):
+      %1 = arith.mulf %b0, %b0 : f32
+      linalg.yield %1 : f32
+  } -> tensor<f32>
+  return %0 : tensor<f32>
+}
+// CHECK-LABEL: func @check_scalar_operation
+//   CHECK-NOT:   scf.for
+//       CHECK:   linalg.generic
+//  CHECK-SAME:       __internal_transform__ = "tiled"
+
+// -----
+
+func.func @check_scalar_memref_operation(%arg0 : memref<f32>, %arg1 : memref<f32>){
+  linalg.generic {
+      indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>],
+      iterator_types = []}
+      {__internal_transform__ = "scalar_op"}
+      ins(%arg0 : memref<f32>) outs(%arg1 : memref<f32>){
+    ^bb0(%b0 : f32, %b1 : f32):
+      %1 = arith.mulf %b0, %b0 : f32
+      linalg.yield %1 : f32
+  }
   return
 }
+// CHECK-LABEL: func @check_scalar_memref_operation
+//   CHECK-NOT:   scf.for
+//       CHECK:   linalg.generic
+//  CHECK-SAME:       __internal_transform__ = "tiled"
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index e5d7dc54409e447..112ad6cbde85894 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -579,6 +579,8 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
     addPatternForTiling(context, patterns, "pad_outer_tiling", {2, 3});
     // 10. Tiling M and N dims of `linalg.copy` on memrefs.
     addPatternForTiling(context, patterns, "simple_copy_memref", {10, 20});
+    // 11. Tiling scalar operations.
+    addPatternForTiling(context, patterns, "scalar_op", {});
     return;
   }
   if (testTilingForAll) {

>From 7455bf1bb6af099e2a694dfbd3840d6dc6b46c59 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh at nod-labs.com>
Date: Mon, 13 Nov 2023 15:16:36 -0800
Subject: [PATCH 2/5] Clone operation before tile and fuse

---
 .../SCF/Transforms/TileUsingInterface.cpp     | 82 +++++++++++--------
 mlir/test/Dialect/Tensor/tiling.mlir          |  4 +-
 .../tile-and-fuse-using-interface.mlir        |  4 +-
 3 files changed, 53 insertions(+), 37 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 358740c8826c4b0..40628f5119b0d58 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -583,17 +583,54 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
                                         loops);
   if (!fusableProducer)
     return std::nullopt;
+  unsigned resultNumber = fusableProducer.getResultNumber();
 
-  // 2. Generate the tiled implementation of the producer of the source
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(candidateSliceOp);
+
+  // 2. Clone the fused producer
+  // 2a. Compute the destination operands to use for the cloned operation.
+  SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
+  Operation *fusableProducerOp = fusableProducer.getOwner();
+  if (isa<DestinationStyleOpInterface>(fusableProducerOp)) {
+    if (failed(tensor::getOrCreateDestinations(
+            rewriter, fusableProducerOp->getLoc(), fusableProducerOp,
+            origDestinationTensors))) {
+      return std::nullopt;
+    }
+  }
+  clonedOpDestinationTensors = origDestinationTensors;
+  if (destinationInitArg &&
+      isa<DestinationStyleOpInterface>(fusableProducerOp)) {
+    // 2b. If the producer is also destination style, then to maintain the
+    // destination passing style, update the destination of the producer to be
+    // the source of the slice.
+    clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
+  }
+  // 2c. Clone the fused producer.
+  Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
+      rewriter, fusableProducerOp, clonedOpDestinationTensors);
+  // 2d. Update the source of the candidateSlice to be the cloned producer.
+  //     Easier to just clone the slice with different source since replacements
+  //     and DCE of cloned ops becomes easier
+  SmallVector<Value> candidateSliceOpOperands =
+      llvm::to_vector(candidateSliceOp->getOperands());
+  candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
+  tensor::ExtractSliceOp clonedCandidateSliceOp =
+      mlir::clone(rewriter, candidateSliceOp,
+                  candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
+
+  // 3. Generate the tiled implementation of the producer of the source
   FailureOr<TilingResult> tileAndFuseResult =
-      tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp,
-                                                   fusableProducer);
+      tensor::replaceExtractSliceWithTiledProducer(
+          rewriter, clonedCandidateSliceOp,
+          clonedProducerOp->getResult(resultNumber));
   if (failed(tileAndFuseResult))
     return std::nullopt;
   rewriter.replaceAllUsesWith(candidateSliceOp,
                               tileAndFuseResult->tiledValues[0]);
+  rewriter.eraseOp(clonedCandidateSliceOp);
+  rewriter.eraseOp(clonedProducerOp);
 
   // 3. If the slice is for a destination operand, for example,
   //
@@ -615,7 +652,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
   // %1 = linalg.fill
   // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
   //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
-  //     %4 = tensor.extract_slice %0 /*incorrect value */ [..]
+  //     %4 = tensor.extract_slice %arg1[..]
   //     %5 = linalg.fill .. outs(%4 : )
   //     .. = linalg.matmul .. outs(%5 : )
   //   }
@@ -624,46 +661,25 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
   //
   // The untiled `linalg.fill` is still used as the `init_value` since it
   // was originally a destination operand of the untiled `linalg.matmul`.
-  // When fusing an operand that is a destination operand.
-  //   - Update the iter_arg of the outer most loop to use the destination
-  //     of the untiled producer.
-  //   - Update the destination of the slice of the tiled producer generated
-  //     to use the same basic block argument as the slice that was used to
-  //     generate inplace the tiled implementation of the producer.
-  // With this the IR will be.
+  // When fusing an operand that is a destination operand, the iter_arg of
+  // the outer most loop should be changed to use the destination of the
+  // fused operation. With this the IR will be.
   //
   // ```
   // %0 = linalg.init
   // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
   //   %2 = scf.for .. iter_args(%arg1 = %arg0) {
-  //     %3 = tensor.extract_slice %arg1 /* corrected value */ [..]
+  //     %3 = tensor.extract_slice %arg1[..]
   //     %4 = linalg.fill .. outs(%3 : )
   //     .. = linalg.matmul .. outs(%4 : )
   //   }
   // }
   // ```
-  // TODO: This can be modeled better if the `DestinationStyleOpInterface`.
-  // Update to use that when it does become available.
-  scf::ForOp outerMostLoop = loops.front();
   if (destinationInitArg &&
-      (*destinationInitArg)->getOwner() == outerMostLoop) {
-    unsigned iterArgNumber =
-        outerMostLoop.getTiedLoopResult(*destinationInitArg).getResultNumber();
-    int64_t resultNumber = fusableProducer.getResultNumber();
-    if (auto dstOp =
-            dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
-      (*destinationInitArg)
-          ->set(dstOp.getTiedOpOperand(fusableProducer)->get());
-    }
-    for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) {
-      auto dstOp = dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
-      if (!dstOp)
-        continue;
-      scf::ForOp innerMostLoop = loops.back();
-      updateDestinationOperandsForTiledOp(
-          rewriter, dstOp.getDpsInitOperand(resultNumber)->get(),
-          innerMostLoop.getRegionIterArgs()[iterArgNumber]);
-    }
+      isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
+    loops.front()
+        ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
+        .set(origDestinationTensors[resultNumber]);
   }
   return scf::SCFFuseProducerOfSliceResult{fusableProducer,
                                            tileAndFuseResult->tiledValues[0],
diff --git a/mlir/test/Dialect/Tensor/tiling.mlir b/mlir/test/Dialect/Tensor/tiling.mlir
index 51f33a96e571b83..bb42f84afc50f94 100644
--- a/mlir/test/Dialect/Tensor/tiling.mlir
+++ b/mlir/test/Dialect/Tensor/tiling.mlir
@@ -374,8 +374,8 @@ module attributes {transform.with_named_sequence} {
 // CHECK:               %[[IN_J:.*]] = affine.apply #[[MAP2]](%[[J]])[%[[TILE_1]]]
 // CHECK:               %[[IN_J_SZ:.*]] = affine.min #[[MAP3]](%[[OUT_J_SZ]], %[[J]])[%[[TILE_1]], %[[IN_D1]]]
 // CHECK:               %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][%[[IN_I]], %[[IN_J]]] [%[[IN_I_SZ]], %[[IN_J_SZ]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-// CHECK:               %[[OUT_D2:.+]] = tensor.dim %[[OUT]], %[[C2]]
-// CHECK:               %[[OUT_D3:.+]] = tensor.dim %[[OUT]], %[[C3]]
+// CHECK:               %[[OUT_D2:.+]] = tensor.dim %[[ITER1]], %[[C2]]
+// CHECK:               %[[OUT_D3:.+]] = tensor.dim %[[ITER1]], %[[C3]]
 // CHECK:               %[[SUB_OUT:.*]] = tensor.extract_slice %[[ITER1]][%[[I]], %[[J]], 0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]], %[[OUT_D2]], %[[OUT_D3]]] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
 // CHECK:               %[[PACK:.*]] = tensor.pack
 // CHECK-SAME:            %[[SUB_IN]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [%[[TILE_0]], %[[TILE_1]]]
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
index cf5a1b828f95b75..2078b5b4dabb268 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
@@ -369,15 +369,15 @@ func.func @matmul_sequence_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>
 //  CHECK-SAME:   %[[ARG6:[a-zA-Z0-9_]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
 //   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
 //   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//   CHECK-DAG:   %[[N0:.+]] = tensor.dim %[[ARG0]], %[[C1]]
 //   CHECK-DAG:   %[[ORIG_GEMM1:.+]] = linalg.matmul ins(%[[ARG0]], %[[ARG1]] :
-//   CHECK-DAG:   %[[N1:.+]] = tensor.dim %[[ORIG_GEMM1]], %[[C1]]
 //   CHECK-DAG:   %[[ORIG_GEMM2:.+]] = linalg.matmul ins(%[[ORIG_GEMM1]], %[[ARG3]] :
 //   CHECK-DAG:   %[[M:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C0]]
 //   CHECK-DAG:   %[[N2:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C1]]
 //   CHECK-DAG:   %[[N3:.+]] = tensor.dim %[[ARG5]], %[[C1]]
 //       CHECK:   %[[R0:.+]] = scf.for %[[IV:[a-zA-Z0-9_]+]] =
 //  CHECK-SAME:       iter_args(%[[ARG8:.+]] = %[[ARG6]]) -> (tensor<?x?xf32>) {
+//   CHECK-DAG:     %[[N1:.+]] = tensor.dim %[[ORIG_GEMM1]], %[[C1]]
+//   CHECK-DAG:     %[[N0:.+]] = tensor.dim %[[ARG0]], %[[C1]]
 //   CHECK-DAG:     %[[TILE_M:.+]] = affine.min #[[MAP]](%[[IV]])[%[[M]]]
 //   CHECK-DAG:     %[[SLICE_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[TILE_M]], %[[N0]]]
 //   CHECK-DAG:     %[[SLICE_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, 0] [%[[N0]], %[[N1]]]

>From cf5b8c90a11b6bf558d1d0cafe57089415f45646 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh at nod-labs.com>
Date: Mon, 30 Oct 2023 15:58:32 -0700
Subject: [PATCH 3/5] Clone in Partial tile reduction.

---
 .../SCF/Transforms/TileUsingInterface.cpp     | 67 ++++++++++++-------
 1 file changed, 42 insertions(+), 25 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 40628f5119b0d58..56f82c400b7c8ec 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -496,42 +496,59 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
       reductionDims.push_back(idx);
   }
 
-  // 1. create the inital tensor value.
+  // 2. create the inital tensor value.
   FailureOr<Operation *> identityTensor =
       op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
                                                   reductionDims);
   if (failed(identityTensor))
     return b.notifyMatchFailure(op,
                                 "cannot create a tensor of identity value.");
-  // 2. Create the nested loops.
+  // 3. Create the nested loops.
   SmallVector<OpFoldResult> offsets, sizes;
-  SmallVector<scf::ForOp> loops = generateTileLoopNest(
-      b, loc, iterationDomain, tileSizesVector, offsets, sizes);
+  SmallVector<scf::ForOp> loops =
+      generateTileLoopNest(b, loc, iterationDomain, tileSizesVector, offsets,
+                           sizes, identityTensor.value()->getResults());
+
+  // 4. Generate the tiled implementation within the inner most loop.
+  // 4a. Clone the operation within the loop body.
+  SmallVector<Value> clonedOpDestination =
+      llvm::map_to_vector(identityTensor.value()->getResults(),
+                          [](OpResult res) -> Value { return res; });
+  if (!loops.empty()) {
+    b.setInsertionPointToEnd(loops.back().getBody());
+    clonedOpDestination =
+        llvm::map_to_vector(loops.back().getRegionIterArgs(),
+                            [](BlockArgument b) -> Value { return b; });
+  }
+  auto clonedOp = cast<PartialReductionOpInterface>(
+      cloneOpAndUpdateDestinationArgs(b, op, clonedOpDestination));
 
-  // 3. Generate the tiled implementation within the inner most loop.
-  b.setInsertionPoint(loops.back().getBody()->getTerminator());
-  Operation *parallelOp = op.tileToPartialReduction(
-      b, loc, (*identityTensor)->getResults(), offsets, sizes, reductionDims);
+  // 4b. Tile the cloned operation.
+  Operation *parallelOp = clonedOp.tileToPartialReduction(
+      b, loc, clonedOpDestination, offsets, sizes, reductionDims);
+  // 4c. Delete the cloned operation.
+  b.eraseOp(clonedOp);
 
-  SmallVector<OpFoldResult> resultSizesList;
-  for (size_t i = 0; i < offsets.size(); i++)
-    resultSizesList.push_back(
+  SmallVector<OpFoldResult> outSizes;
+  for (size_t i = 0; i < offsets.size(); i++) {
+    outSizes.push_back(
         tensor::getMixedSize(b, loc, parallelOp->getResult(0), i));
+  }
   SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
-  SmallVector<Value> replacements = yieldTiledValues(
-      b, (*identityTensor)->getResults(), parallelOp->getResults(), outOffsets,
-      resultSizesList, loops);
-
-  auto dstOp = cast<DestinationStyleOpInterface>(parallelOp);
-  auto innerMostLoop = loops.back();
-  SmallVector<Value> destinationTensors = llvm::to_vector(dstOp.getDpsInits());
-  assert(destinationTensors.size() ==
-             innerMostLoop.getRegionIterArgs().size() &&
-         "unexpected number of outputs");
-  updateDestinationOperandsForTiledOp(b, destinationTensors,
-                                      innerMostLoop.getRegionIterArgs());
-
-  // 4. Apply the merge reduction to combine all the partial values.
+  SmallVector<OpFoldResult> outStrides(outOffsets.size(), b.getIndexAttr(1));
+  SmallVector<Value> yieldedVals;
+  auto bbArgs = loops.back().getRegionIterArgs();
+  for (auto [result, bbArg] : llvm::zip(parallelOp->getResults(), bbArgs)) {
+    Value insert = b.create<tensor::InsertSliceOp>(
+        loc, result, bbArg, outOffsets, outSizes, outStrides);
+    yieldedVals.push_back(insert);
+  }
+  b.create<scf::YieldOp>(loc, yieldedVals);
+
+  SmallVector<Value> replacements = llvm::map_to_vector(
+      loops.front().getResults(), [](OpResult r) -> Value { return r; });
+
+  // 5. Apply the merge reduction to combine all the partial values.
   b.setInsertionPointAfter(*loops.begin());
   Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims);
   b.replaceOp(op, mergeOp->getResults());

>From f94d33cee2116bdd47e058b0a6f57445e04c37c3 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh at nod-labs.com>
Date: Mon, 13 Nov 2023 15:09:45 -0800
Subject: [PATCH 4/5] Fix tile+fuse+yield

---
 .../SCF/Transforms/TileUsingInterface.cpp     | 129 +++++++++++++++---
 1 file changed, 110 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 56f82c400b7c8ec..4bc3198dc09fb90 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -197,6 +197,77 @@ static SmallVector<scf::ForOp> generateTileLoopNest(
   return loops;
 }
 
+/// Method to add new init values to a loop nest. Updates `loops` in-place with
+/// new loops that use the `newInitValues`.
+/// The outer-loops are updated to yield the new result values of the inner
+/// loop. For the innermost loop, the call back `getNewYields` is invoked to get
+/// the additional values to yield form the innermost loop.
+static void addInitOperandsToLoopNest(
+    RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loops,
+    ValueRange newInitValues,
+    llvm::function_ref<SmallVector<Value>(RewriterBase &rewriter, Value iv,
+                                          ValueRange newRegionIterArgs)>
+        getNewYieldValsFn) {
+  SmallVector<scf::ForOp> newLoops;
+  if (loops.empty()) {
+    return;
+  }
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(loops.front());
+  for (auto &loop : loops) {
+    rewriter.setInsertionPoint(loop);
+
+    // Create a new loop with the new init values for this loop.
+    SmallVector<Value> newInits = llvm::to_vector(loop.getInitArgs());
+    newInits.append(newInitValues.begin(), newInitValues.end());
+    auto newLoop = rewriter.create<scf::ForOp>(
+        loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(),
+        loop.getStep(), newInits,
+        [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
+
+    // Merge the body of the new loop with the body of the old loops.
+    SmallVector<Value> sourceBlockArgs;
+    sourceBlockArgs.push_back(newLoop.getInductionVar());
+    auto newRegionIterArgs = newLoop.getRegionIterArgs();
+    sourceBlockArgs.append(
+        newRegionIterArgs.begin(),
+        std::next(newRegionIterArgs.begin(), loop.getNumResults()));
+    rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(), sourceBlockArgs);
+    rewriter.replaceOp(loop,
+                       newLoop.getResults().take_front(loop.getNumResults()));
+    loop = newLoop;
+    newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
+  }
+
+  // Update the loop body of the innermost loop to get new yield values.
+  scf::ForOp innerMostLoop = loops.back();
+  auto innerMostYieldOp =
+      cast<scf::YieldOp>(innerMostLoop.getBody()->getTerminator());
+  rewriter.setInsertionPoint(innerMostYieldOp);
+  SmallVector<Value> newYieldVals =
+      getNewYieldValsFn(rewriter, innerMostLoop.getInductionVar(),
+                        innerMostLoop.getRegionIterArgs());
+  SmallVector<Value> newYieldOperands =
+      llvm::to_vector(innerMostYieldOp->getOperands());
+  newYieldOperands.append(newYieldVals);
+  rewriter.replaceOpWithNewOp<scf::YieldOp>(innerMostYieldOp, newYieldOperands);
+
+  // Make all other loops except the innermost loops yield the values returned
+  // by the inner loop.
+  for (auto [outerLoop, innerLoop] :
+       llvm::zip(loops.drop_back(), loops.drop_front())) {
+    auto outerLoopYield =
+        cast<scf::YieldOp>(outerLoop.getBody()->getTerminator());
+    SmallVector<Value> newYields =
+        llvm::to_vector(outerLoopYield.getOperands());
+    ValueRange additionalYields =
+        innerLoop.getResults().take_back(newInitValues.size());
+    newYields.append(additionalYields.begin(), additionalYields.end());
+    rewriter.setInsertionPoint(outerLoopYield);
+    rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
+  }
+}
+
 /// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`,
 /// construct the destructive update pattern that inserts the yielded
 /// value into a destination tensor provided by `initValue` at offset
@@ -644,6 +715,8 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
           clonedProducerOp->getResult(resultNumber));
   if (failed(tileAndFuseResult))
     return std::nullopt;
+  // Note: Do not delete the candidateSliceOp, since its passed in from the
+  // caller.
   rewriter.replaceAllUsesWith(candidateSliceOp,
                               tileAndFuseResult->tiledValues[0]);
   rewriter.eraseOp(clonedCandidateSliceOp);
@@ -708,28 +781,46 @@ void mlir::scf::yieldReplacementForFusedProducer(
     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
     MutableArrayRef<scf::ForOp> loops) {
-  auto [fusableProducer, fusedProducerValue, tileAndFusedOps] =
-      fusedProducerInfo;
-  SmallVector<Value> initValues;
+  if (loops.empty()) {
+    return;
+  }
+  OpResult fusableProducer = fusedProducerInfo.origProducer;
+  Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer;
   FailureOr<Value> initValue = tensor::getOrCreateDestination(
       rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
   if (succeeded(initValue)) {
-    SmallVector<OpFoldResult> resultOffsets = sliceOp.getMixedOffsets();
-    SmallVector<OpFoldResult> resultSizes = sliceOp.getMixedSizes();
-    SmallVector<Value> yieldedVals =
-        yieldTiledValues(rewriter, initValue.value(), fusedProducerValue,
-                         resultOffsets, resultSizes, loops);
-  }
-  for (auto tileAndFusedOp : tileAndFusedOps) {
-    auto dstStyleProducer =
-        dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
-    if (!dstStyleProducer)
-      continue;
-    Value dstValue =
-        dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber())
-            ->get();
-    updateDestinationOperandsForTiledOp(
-        rewriter, dstValue, loops.back().getRegionIterArgs().back());
+
+    auto newYieldValuesFn =
+        [&](RewriterBase &innerRewriter, Value iv,
+            ValueRange newRegionIterArgs) -> SmallVector<Value> {
+      OpBuilder::InsertionGuard g(innerRewriter);
+      if (auto tiledDestStyleOp =
+              tiledAndFusedProducer
+                  .getDefiningOp<DestinationStyleOpInterface>()) {
+        rewriter.setInsertionPoint(tiledDestStyleOp);
+        BlockArgument newRegionArg = loops.back().getRegionIterArgs().back();
+        auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
+            sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
+            sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
+        unsigned resultNumber = fusableProducer.getResultNumber();
+        rewriter.updateRootInPlace(tiledDestStyleOp, [&]() {
+          tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
+        });
+
+        Block *block = rewriter.getInsertionPoint()->getBlock();
+        rewriter.setInsertionPoint(block->getTerminator());
+        Value replacement = rewriter.create<tensor::InsertSliceOp>(
+            fusedProducerInfo.origProducer.getLoc(),
+            fusedProducerInfo.tiledAndFusedProducer,
+            loops.back().getRegionIterArgs().back(), sliceOp.getMixedOffsets(),
+            sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
+        return {replacement};
+      }
+    };
+
+    addInitOperandsToLoopNest(rewriter, loops,
+                              SmallVector<Value>{initValue.value()},
+                              newYieldValuesFn);
   }
 }
 

>From c282a944e1c872666ddd890cefbd0fbc14a444bf Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh at nod-labs.com>
Date: Mon, 13 Nov 2023 15:10:08 -0800
Subject: [PATCH 5/5] NFC: code-reorganzation and deletion.

---
 .../SCF/Transforms/TileUsingInterface.h       |  12 +-
 .../SCF/Transforms/TileUsingInterface.cpp     | 117 ------------------
 2 files changed, 6 insertions(+), 123 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 81325b62791c44b..2f8f337bb8057ce 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -83,6 +83,12 @@ FailureOr<SCFTilingResult> tileUsingSCFForOp(RewriterBase &rewriter,
                                              TilingInterface op,
                                              const SCFTilingOptions &options);
 
+/// Method to tile an op that implements the `TilingInterface` using
+/// `scf.forall`.
+FailureOr<SCFTilingResult>
+tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
+                     const SCFTilingOptions &options);
+
 /// Options used to control tile + fuse.
 struct SCFTileAndFuseOptions {
   /// The tiling options used to control the tiling of the consumer.
@@ -93,12 +99,6 @@ struct SCFTileAndFuseOptions {
   }
 };
 
-/// Method to tile an op that implements the `TilingInterface` using
-/// `scf.forall`.
-FailureOr<SCFTilingResult>
-tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
-                     const SCFTilingOptions &options);
-
 /// Fuse the producer of the source of `candidateSliceOp` by computing the
 /// required slice of the producer in-place.  Note that the method
 /// replaces the uses of `candidateSliceOp` with the tiled and fused producer
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 4bc3198dc09fb90..01102e786acf2be 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -268,123 +268,6 @@ static void addInitOperandsToLoopNest(
   }
 }
 
-/// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`,
-/// construct the destructive update pattern that inserts the yielded
-/// value into a destination tensor provided by `initValue` at offset
-/// `tileOffsets` and size `tileSizes`. For example,
-///
-/// ```mlir
-/// scf.for %iv0 = ... {
-///   %0 = tiled_op
-/// }
-/// ```
-///
-/// is transformed to
-///
-/// ```mlir
-/// scf.for %iv0 = ... iter_args(%arg = %0) {
-///   %1 = tensor.extract_slice %arg
-///   %2 = tiled_op
-///   %3 = tensor.insert_slice %2 into %arg
-///   scf.yield %3
-/// }
-/// ```
-/// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`.
-static SmallVector<Value>
-yieldTiledValues(RewriterBase &rewriter, ValueRange initValues,
-                 ValueRange yieldedValues,
-                 ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
-                 ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
-                 MutableArrayRef<scf::ForOp> loops) {
-  NewYieldValuesFn yieldValueFn =
-      [&](OpBuilder &b, Location loc,
-          ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
-    SmallVector<Value> inserts;
-    for (const auto &yieldedValue : llvm::enumerate(yieldedValues)) {
-      ArrayRef<OpFoldResult> tileOffsets =
-          tileOffsetsList[yieldedValue.index()];
-      ArrayRef<OpFoldResult> tileSizes = tileSizesList[yieldedValue.index()];
-      SmallVector<OpFoldResult> tileStrides(tileOffsets.size(),
-                                            b.getIndexAttr(1));
-      Value insert = b.create<tensor::InsertSliceOp>(
-          loc, yieldedValue.value(), newBBArgs[yieldedValue.index()],
-          tileOffsets, tileSizes, tileStrides);
-      inserts.push_back(insert);
-    }
-    return inserts;
-  };
-
-  SmallVector<scf::ForOp> newLoops =
-      replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn,
-                                   /*replaceIterOperandsUsesInLoop =*/false);
-  for (const auto &loop : llvm::enumerate(loops)) {
-    loops[loop.index()] = newLoops[loop.index()];
-  }
-  return llvm::to_vector(llvm::map_range(
-      loops.front().getResults().take_back(yieldedValues.size()),
-      [](OpResult r) -> Value { return r; }));
-}
-
-/// If the tiled operation is destination passing style, update the
-/// slice of the destination used (which refers to the untiled destination)
-/// to use the corresponding region argument of the innermost loop.
-///
-/// ```mlir
-/// %0 =
-/// scf.for %iv0 = ... iter_args(%arg = %0) {
-///   %1 = tensor.extract_slice %0
-///   %2 = tiled_op
-///   %3 = tensor.insert_slice %2 into %arg
-///   scf.yield %3
-/// }
-/// ```
-///
-/// is transformed to
-///
-/// ```mlir
-/// scf.for %iv0 = ... iter_args(%arg = %0) {
-///   %1 = tensor.extract_slice %arg
-///   %2 = tiled_op
-///   %3 = tensor.insert_slice %2 into %arg
-///   scf.yield %3
-/// }
-/// ```
-static void
-updateDestinationOperandsForTiledOp(OpBuilder &builder,
-                                    ValueRange tiledOpDestinationValues,
-                                    ValueRange bbArgsList) {
-  for (const auto &destValue : llvm::enumerate(tiledOpDestinationValues)) {
-    auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>();
-    if (!sliceOp)
-      continue;
-    sliceOp.setOperand(0, bbArgsList[destValue.index()]);
-  }
-}
-
-/// Helper method to yield the values of the tiled op, as well as
-/// update the destination operands of the tiled op, if it is
-/// a destination passing style op.
-static SmallVector<Value>
-yieldTiledValues(RewriterBase &rewriter, ArrayRef<Value> initValues,
-                 TilingResult tilingResult,
-                 ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
-                 ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
-                 MutableArrayRef<scf::ForOp> loops) {
-  SmallVector<Value> replacements =
-      yieldTiledValues(rewriter, initValues, tilingResult.tiledValues,
-                       tileOffsetsList, tileSizesList, loops);
-  for (auto tiledOp : tilingResult.tiledOps) {
-    if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(tiledOp)) {
-      auto innerMostLoop = loops.back();
-      SmallVector<Value> tiledOpDestinationTensors =
-          llvm::to_vector(dstOp.getDpsInits());
-      updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors,
-                                          innerMostLoop.getRegionIterArgs());
-    }
-  }
-  return replacements;
-}
-
 /// Implementation of tiling transformation of `op` that implements the
 /// `TilingInterface` using `scf.for` to iterate over the tiles.
 FailureOr<scf::SCFTilingResult>



More information about the Mlir-commits mailing list