[Mlir-commits] [mlir] 809e3d8 - [mlir][TilingInterface] Modify `TilingInterface` methods to better return the state of the transformed IR.

Mahesh Ravishankar llvmlistbot at llvm.org
Thu Mar 16 07:29:15 PDT 2023


Author: Mahesh Ravishankar
Date: 2023-03-16T14:29:03Z
New Revision: 809e3d8c98a80fc61c8bdbb3745d1d50a3f1d365

URL: https://github.com/llvm/llvm-project/commit/809e3d8c98a80fc61c8bdbb3745d1d50a3f1d365
DIFF: https://github.com/llvm/llvm-project/commit/809e3d8c98a80fc61c8bdbb3745d1d50a3f1d365.diff

LOG: [mlir][TilingInterface] Modify `TilingInterface` methods to better return the state of the transformed IR.

Currently the `getTiledImplementation` and `generateResultTileValue`
return just `SmallVector<Operation *>` and `FailureOr<Value>`.

- For `getTiledImplementation` returning empty implies tiling wasnt
  done. There is also an implicit assumption that the tiled operation
  results correspond to the tiled values of the result of the original
  operation. This cannot handle cases where the tiled implementation
  might use multiple operations to compute the tiled value for the
  results of the untiled operation. Sometimes, the tiled operation
  might not directly give the tiled values, and might require casts,
  etc to get a replacement.
- For `generateResultTileValue`, it is assumed that the op defining
  the returned `Value` is the operation that represents the tiled
  computation. Again presence of casts, etc violate this.

Instead make these methods return
```
struct TilingResult {
  SmallVector<Operation *> tiledOps;
  SmallVector<Value> tiledValues;
};
```

The `tiledOps` represent the operations generated that are relevant
for subsequent transformations. The `tiledValues` represent the tiled
values for the results of the original operation. This better
transmits the state of the transformed IR.

As a consequence the following methods also return `FailureOr<TilingResult>`
- `tensor::replaceExtractSliceWithTiledProducer`
- `tensor::bubbleUpPadSlice`

Differential Revision: https://reviews.llvm.org/D145133

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h
    mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
    mlir/include/mlir/Interfaces/TilingInterface.h
    mlir/include/mlir/Interfaces/TilingInterface.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Split.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
    mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
    mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h b/mlir/include/mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h
index 30a5026cd68b3..7228a5a297ad8 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h
@@ -16,6 +16,9 @@
 #include "mlir/IR/Dialect.h"
 
 namespace mlir {
+
+struct TilingResult;
+
 namespace tensor {
 
 class PadOp;
@@ -39,10 +42,10 @@ class PadOp;
 /// to guard against the case that we might take a zero-sized slice from the
 /// original source. For such cases, we `tensor.generate` to generate the
 /// full tensor.
-Operation *bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
-                            ArrayRef<OpFoldResult> offsets,
-                            ArrayRef<OpFoldResult> sizes,
-                            bool generateZeroSliceGuard = true);
+FailureOr<TilingResult> bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
+                                         ArrayRef<OpFoldResult> offsets,
+                                         ArrayRef<OpFoldResult> sizes,
+                                         bool generateZeroSliceGuard = true);
 
 /// Registers external models for Tiling interface for tensor ops.
 /// Currently, it registers:

diff  --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 01985c943527c..4cdf360c51d72 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -13,6 +13,9 @@
 #include "mlir/IR/PatternMatch.h"
 
 namespace mlir {
+
+struct TilingResult;
+
 namespace tensor {
 
 /// Populates `patterns` with patterns to wrap a tensor.pad op with an scf.if op
@@ -26,7 +29,7 @@ void populateSplitPaddingPatterns(RewritePatternSet &patterns,
 /// provide a mechanism to control where the application happens. With use of
 /// transform dialect that control is done within the transform dialect. Other
 /// use cases can inherit from this pattern and add necessary controls.
-FailureOr<Value> replaceExtractSliceWithTiledProducer(
+FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
     OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
 
 /// Collects patterns to merge consecutive tensor.insert_slice/extract_slice

diff  --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h
index 99cbe21b178ca..ca570490ccf5b 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.h
+++ b/mlir/include/mlir/Interfaces/TilingInterface.h
@@ -21,6 +21,20 @@
 #include "mlir/Interfaces/ViewLikeInterface.h"
 #include "mlir/Support/LLVM.h"
 
+namespace mlir {
+
+/// Container for result values of tiling.
+/// - `tiledOps` contains operations created by the tiling implementation that
+/// are returned to the caller for further transformations.
+/// - `tiledValues` contains the tiled value corresponding to the result of the
+/// untiled operation.
+struct TilingResult {
+  SmallVector<Operation *> tiledOps;
+  SmallVector<Value> tiledValues;
+};
+
+} // namespace mlir
+
 /// Include the ODS generated interface header files.
 #include "mlir/Interfaces/TilingInterface.h.inc"
 

diff  --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 6cc0685bdae41..66382f29c2424 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -63,7 +63,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           The method returns the operation that is the tiled
           implementation.
         }],
-        /*retType=*/"SmallVector<Operation *>",
+        /*retType=*/"FailureOr<TilingResult>",
         /*methodName=*/"getTiledImplementation",
         /*args=*/(ins
             "OpBuilder &":$b,
@@ -119,7 +119,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
             iteration space).
           - `sizes` provides the size of the tile.
         }],
-        /*retType=*/"FailureOr<Value>",
+        /*retType=*/"FailureOr<TilingResult>",
         /*methodName=*/"generateResultTileValue",
         /*args=*/(ins
           "OpBuilder &":$b,

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 4aae6458ff128..4503d451a405c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -431,16 +431,15 @@ void transform::FuseIntoContainingOp::build(OpBuilder &builder,
 /// Find the first "extract" user of `producerOp` and tile it right before its
 /// use. The tiled op is fused under the `containingOp`.
 /// Return this fused op on success or nullptr if anything fails.
-static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
-                                             Diagnostic &diag,
-                                             Operation *producerOp,
-                                             Operation *containingOp) {
+static SmallVector<Operation *>
+tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
+                           Operation *producerOp, Operation *containingOp) {
   LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n");
   auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
   if (!tileableProducer) {
     diag.attachNote(producerOp->getLoc())
         << "producer is not a TileableInterface: " << *producerOp;
-    return nullptr;
+    return {};
   }
 
   // Search the producer slices accessed within the containing operation.
@@ -455,7 +454,7 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
   if (it == tileableProducer->getUsers().end()) {
     diag.attachNote(tileableProducer->getLoc())
         << "could not find fusion opportunity for: " << *tileableProducer;
-    return nullptr;
+    return {};
   }
   auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
 
@@ -468,27 +467,29 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
       sliceOpToTile.getSource().cast<OpResult>().getResultNumber();
   LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
 
-  FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
-      rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
-      sliceOpToTile.getMixedSizes());
-  if (failed(tiledProducer)) {
+  FailureOr<TilingResult> tileAndFuseResult =
+      tileableProducer.generateResultTileValue(rewriter, resultNumber,
+                                               sliceOpToTile.getMixedOffsets(),
+                                               sliceOpToTile.getMixedSizes());
+  if (failed(tileAndFuseResult)) {
     diag.attachNote(tileableProducer->getLoc())
         << "failed to tile producer op: " << *tileableProducer;
-    return nullptr;
+    return {};
+  }
+  for (auto tiledOp : tileAndFuseResult->tiledOps) {
+    LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledOp << "\n");
   }
-  LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledProducer << "\n");
 
   // Replace the extract op.
-  Operation *fusedOp = tiledProducer->getDefiningOp();
   auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
-      rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber),
+      rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
       sliceOpToTile->getResult(0)
           .getType()
           .cast<RankedTensorType>()
           .getShape());
   assert(succeeded(maybeRankReduced) && "unexpected shape");
   rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
-  return fusedOp;
+  return tileAndFuseResult->tiledOps;
 }
 
 /// First, find the first "scf::ForallOp" user of `producerOp` and ensure
@@ -497,7 +498,8 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
 /// right before its "extract" use. The tiled op is fused under the
 /// `containingOp`.
 /// Return this fused op on success or nullptr if anything fails.
-static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
+static SmallVector<Operation *>
+tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
     RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
     Operation *containingOp) {
   LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n");
@@ -506,7 +508,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
   if (!tileableProducer) {
     diag.attachNote(producerOp->getLoc())
         << "producer is not a TileableInterface: " << *producerOp;
-    return nullptr;
+    return {};
   }
 
   // Search the first use by a "scf::ForallOp" user.
@@ -520,7 +522,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
   if (!forallOp || forallOp != containingOp) {
     diag.attachNote(tileableProducer->getLoc())
         << "could not find a use by the containing op: " << *tileableProducer;
-    return nullptr;
+    return {};
   }
 
   // Search the producer slices accessed within the containing
@@ -542,7 +544,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
   if (itBBArgUsers == bbArg.getUsers().end()) {
     diag.attachNote(containingOp->getLoc())
         << "could not find fusion opportunity for bbArg: " << bbArg;
-    return nullptr;
+    return {};
   }
   auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
 
@@ -562,7 +564,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
           destinationTensors))) {
     diag.attachNote(tileableProducer->getLoc())
         << "failed to get destination tensors for: " << *tileableProducer;
-    return nullptr;
+    return {};
   }
 
   IRMapping bvm;
@@ -573,21 +575,19 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
       llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
 
   // Tile the producer.
-  FailureOr<Value> tiledProducer =
+  FailureOr<TilingResult> tileAndFuseResult =
       tileableProducerClone.generateResultTileValue(
           rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
           sliceOpToTile.getMixedSizes());
-  if (failed(tiledProducer)) {
+  if (failed(tileAndFuseResult)) {
     diag.attachNote(tileableProducer->getLoc())
         << "failed to tile producer op: " << *tileableProducer;
-    return nullptr;
+    return {};
   }
-  LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledProducer << "\n");
 
   // Replace the extract op.
-  Operation *fusedOp = tiledProducer->getDefiningOp();
   auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
-      rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber),
+      rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
       sliceOpToTile->getResult(0)
           .getType()
           .cast<RankedTensorType>()
@@ -601,7 +601,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
                              destinationTensors.front());
   });
 
-  return fusedOp;
+  return tileAndFuseResult->tiledOps;
 }
 
 static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
@@ -714,21 +714,21 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
     // cases, we can tile/clone once and reuse the value for each use.
     // Futhermore, producers should then be traversed according to a
     // topological sorting.
-    Operation *tiled =
+    SmallVector<Operation *> tiledOps =
         tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
-    if (tiled) {
+    if (!tiledOps.empty()) {
       LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
-      fusedOps.push_back(tiled);
+      fusedOps.append(tiledOps);
       continue;
     }
 
-    Operation *tiledContainingOpOperand =
+    SmallVector<Operation *> tiledContainingOpOperand =
         tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
             rewriter, diag, producerOp, containingOp);
-    if (tiledContainingOpOperand) {
+    if (!tiledContainingOpOperand.empty()) {
       LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n"
                         << *containingOp);
-      fusedOps.push_back(tiledContainingOpOperand);
+      fusedOps.append(tiledContainingOpOperand);
       continue;
     }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
index c8c9c0bd4af89..e6fce56d4140b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
@@ -41,26 +41,26 @@ createSplitPart(RewriterBase &b, Location loc, TilingInterface op,
   offsetsCopy[dimension] = offset;
 
   // Create the part as it it were a single tile.
-  SmallVector<Operation *> tiled =
+  FailureOr<TilingResult> tilingResult =
       op.getTiledImplementation(b, offsetsCopy, sizesCopy);
-  assert(tiled.size() == 1 && "expected a single result from tiling");
-  auto part = cast<TilingInterface>(tiled.front());
 
   // Insert the results back and populate the `results` list.
-  for (auto i : llvm::seq<unsigned>(0, part->getNumResults())) {
+  for (auto [index, result] : llvm::enumerate(tilingResult->tiledValues)) {
     SmallVector<OpFoldResult> resultOffsets, resultSizes;
-    if (failed(op.getResultTilePosition(b, i, offsetsCopy, sizesCopy,
+    if (failed(op.getResultTilePosition(b, index, offsetsCopy, sizesCopy,
                                         resultOffsets, resultSizes)))
       return nullptr;
     SmallVector<OpFoldResult> resultStrides(resultOffsets.size(),
                                             b.getIndexAttr(1));
     Value inserted = b.create<tensor::InsertSliceOp>(
-        loc, part->getResult(i), resultOperands[i], resultOffsets, resultSizes,
+        loc, result, resultOperands[index], resultOffsets, resultSizes,
         resultStrides);
     results.push_back(inserted);
   }
-
-  return part;
+  // TODO: this part can be generalized maybe to not expect a single op.
+  assert(tilingResult->tiledOps.size() == 1 &&
+         "expected split part to return a single tiled operation");
+  return cast<TilingInterface>(tilingResult->tiledOps[0]);
 }
 
 std::pair<TilingInterface, TilingInterface>

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 62eef97a17448..1e404cabbb518 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -388,12 +388,13 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
     }
 
     // 4. Tile the cloned op and delete the clone.
-    SmallVector<Operation *> tiledOps =
+    FailureOr<TilingResult> tilingResult =
         cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
                                                                tiledSizes);
     b.eraseOp(clonedOp);
-    assert(tiledOps.size() == 1 && "expected a single produced tiled op");
-    tiledOp = tiledOps.front();
+    assert(tilingResult->tiledOps.size() == 1 &&
+           "expected a single produced tiled op");
+    tiledOp = tilingResult->tiledOps.front();
   }
 
   // 5. Parallel insert back into the result tensor.
@@ -729,12 +730,13 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
 
     // 5. Tile the cloned op and delete the clone.
     if (tileSizes.empty()) {
-      SmallVector<Operation *> tiledOps =
+      FailureOr<TilingResult> tilingResult =
           cast<TilingInterface>(clonedOp).getTiledImplementation(
               b, tiledOffsets, tiledSizes);
-      assert(tiledOps.size() == 1 && "expected a single produced tiled op");
-      tiledOp = tiledOps.front();
-      tilingResults = tiledOp->getResults();
+      assert(tilingResult->tiledOps.size() == 1 &&
+             "expected a single produced tiled op");
+      tiledOp = tilingResult->tiledOps.front();
+      tilingResults = tilingResult->tiledValues;
     } else {
       LinalgTilingOptions options;
       FailureOr<TiledLinalgOp> maybeTiled = tileLinalgOpImpl<scf::ForOp>(

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index cfc27ca44e421..676d6330cde3e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -111,7 +111,7 @@ struct LinalgOpTilingInterface
   }
 
   // Instantiate the tiled implementation of the operation.
-  SmallVector<Operation *>
+  FailureOr<TilingResult>
   getTiledImplementation(Operation *op, OpBuilder &b,
                          ArrayRef<OpFoldResult> offsets,
                          ArrayRef<OpFoldResult> sizes) const {
@@ -129,7 +129,7 @@ struct LinalgOpTilingInterface
     Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands);
     offsetIndices(b, cast<LinalgOp>(tiledOp), offsets);
 
-    return {tiledOp};
+    return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
   }
 
   // Return the details of the output tile generated by the tiled
@@ -160,10 +160,10 @@ struct LinalgOpTilingInterface
     return success();
   }
 
-  FailureOr<Value> generateResultTileValue(Operation *op, OpBuilder &b,
-                                           unsigned resultNumber,
-                                           ArrayRef<OpFoldResult> offsets,
-                                           ArrayRef<OpFoldResult> sizes) const {
+  FailureOr<TilingResult>
+  generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
+                          ArrayRef<OpFoldResult> offsets,
+                          ArrayRef<OpFoldResult> sizes) const {
     auto linalgOp = cast<LinalgOp>(op);
 
     // Check that the indexing map used for the output is a projected
@@ -197,12 +197,15 @@ struct LinalgOpTilingInterface
       iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
     }
 
-    SmallVector<Operation *> tiledOp = tilingInterfaceOp.getTiledImplementation(
-        b, iterationTileOffsets, iterationTileSizes);
-    if (tiledOp.size() != 1)
+    FailureOr<TilingResult> tilingResult =
+        tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
+                                                 iterationTileSizes);
+    if (tilingResult->tiledOps.size() != 1)
       return op->emitOpError("failed to generate tiled implementation");
 
-    return tiledOp[0]->getResult(resultNumber);
+    return TilingResult{
+        tilingResult->tiledOps,
+        SmallVector<Value>{tilingResult->tiledValues[resultNumber]}};
   }
 
   LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 17c46182eb5d1..e001f59b21e93 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -952,12 +952,14 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
       return failure();
   }
 
-  Operation *tiledPadOp =
+  FailureOr<TilingResult> tilingResult =
       tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(),
                                sliceOp.getMixedSizes(), zeroSliceGuard);
+  if (failed(tilingResult))
+    return failure();
   // All shapes are static and the data source is actually used. Rewrite into
   // pad(extract_slice(x)).
-  rewriter.replaceOp(sliceOp, tiledPadOp->getResults());
+  rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
   return success();
 }
 

diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 915e4b4ed1c56..6706f54662839 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -251,18 +251,20 @@ updateDestinationOperandsForTiledOp(OpBuilder &builder,
 /// a destination passing style op.
 static SmallVector<Value>
 yieldTiledValues(RewriterBase &rewriter, ArrayRef<Value> initValues,
-                 Operation *tiledOp,
+                 TilingResult tilingResult,
                  ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
                  ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
                  MutableArrayRef<scf::ForOp> loops) {
   SmallVector<Value> replacements =
-      yieldTiledValues(rewriter, initValues, tiledOp->getResults(),
+      yieldTiledValues(rewriter, initValues, tilingResult.tiledValues,
                        tileOffsetsList, tileSizesList, loops);
-  if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(tiledOp)) {
-    auto innerMostLoop = loops.back();
-    SmallVector<Value> tiledOpDestinationTensors = dstOp.getDpsInitOperands();
-    updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors,
-                                        innerMostLoop.getRegionIterArgs());
+  for (auto tiledOp : tilingResult.tiledOps) {
+    if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(tiledOp)) {
+      auto innerMostLoop = loops.back();
+      SmallVector<Value> tiledOpDestinationTensors = dstOp.getDpsInitOperands();
+      updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors,
+                                          innerMostLoop.getRegionIterArgs());
+    }
   }
   return replacements;
 }
@@ -345,9 +347,9 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
   if (!tilingResult.loops.empty())
     rewriter.setInsertionPoint(
         tilingResult.loops.back().getBody()->getTerminator());
-  SmallVector<Operation *> tiledImplementation =
+  FailureOr<TilingResult> tiledImplementation =
       op.getTiledImplementation(rewriter, offsets, sizes);
-  tilingResult.tiledOps.append(tiledImplementation);
+  tilingResult.tiledOps.append(tiledImplementation->tiledOps);
   if (op->getNumResults() == 0) {
     // nothing more to do.
     return tilingResult;
@@ -356,9 +358,7 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
   // If loops are empty, the tiled op is used as the replacement for the untiled
   // op.
   if (tilingResult.loops.empty()) {
-    tilingResult.replacements = llvm::to_vector(
-        llvm::map_range(tiledImplementation[0]->getResults(),
-                        [](OpResult result) -> Value { return result; }));
+    tilingResult.replacements = tiledImplementation->tiledValues;
     return tilingResult;
   }
 
@@ -384,7 +384,7 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
     return rewriter.notifyMatchFailure(op, "failed to get destinations");
 
   tilingResult.replacements = yieldTiledValues(
-      rewriter, destinationTensors, tilingResult.tiledOps.back(),
+      rewriter, destinationTensors, tiledImplementation.value(),
       resultOffsetsList, resultSizesList, tilingResult.loops);
 
   LLVM_DEBUG({
@@ -523,12 +523,13 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
   // 2. Generate the tiled implementation of the producer of the source
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(candidateSliceOp);
-  FailureOr<Value> fusedProducerValue =
+  FailureOr<TilingResult> tileAndFuseResult =
       tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp,
                                                    fusableProducer);
-  if (failed(fusedProducerValue))
+  if (failed(tileAndFuseResult))
     return std::nullopt;
-  rewriter.replaceAllUsesWith(candidateSliceOp, fusedProducerValue.value());
+  rewriter.replaceAllUsesWith(candidateSliceOp,
+                              tileAndFuseResult->tiledValues[0]);
 
   // 3. If the slice is for a destination operand, for example,
   //
@@ -592,8 +593,10 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
       outerMostLoop.setIterArg(iterArgNumber.value(),
                                dstOp.getTiedOpOperand(fusableProducer)->get());
     }
-    if (auto dstOp = fusedProducerValue.value()
-                         .getDefiningOp<DestinationStyleOpInterface>()) {
+    for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) {
+      auto dstOp = dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
+      if (!dstOp)
+        continue;
       scf::ForOp innerMostLoop = loops.back();
       updateDestinationOperandsForTiledOp(
           rewriter, dstOp.getDpsInitOperand(resultNumber)->get(),
@@ -601,7 +604,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
     }
   }
   return scf::SCFFuseProducerOfSliceResult{fusableProducer,
-                                           fusedProducerValue.value()};
+                                           tileAndFuseResult->tiledValues[0]};
 }
 
 /// Reconstruct the fused producer from within the tiled-and-fused code.

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 1c4db01dc8f28..0faa29ade8047 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -46,15 +46,15 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
     return loopRanges;
   }
 
-  SmallVector<Operation *>
+  FailureOr<TilingResult>
   getTiledImplementation(Operation *op, OpBuilder &b,
                          ArrayRef<OpFoldResult> offsets,
                          ArrayRef<OpFoldResult> sizes) const {
-    Operation *result =
+    FailureOr<TilingResult> result =
         tensor::bubbleUpPadSlice(b, cast<PadOp>(op), offsets, sizes);
-    if (!result)
-      return {};
-    return {result};
+    if (failed(result))
+      return failure();
+    return result.value();
   }
 
   LogicalResult
@@ -117,7 +117,7 @@ struct PackOpTiling
     return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b);
   }
 
-  SmallVector<Operation *>
+  FailureOr<TilingResult>
   getTiledImplementation(Operation *op, OpBuilder &b,
                          ArrayRef<OpFoldResult> offsets,
                          ArrayRef<OpFoldResult> sizes) const {
@@ -192,7 +192,8 @@ struct PackOpTiling
     Operation *tiledPackOp = b.create<PackOp>(
         loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs());
 
-    return {tiledPackOp};
+    return TilingResult{{tiledPackOp},
+                        SmallVector<Value>(tiledPackOp->getResults())};
   }
 
   LogicalResult
@@ -353,7 +354,7 @@ struct UnPackOpTiling
   /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements
   /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we
   /// can get the actual result.
-  SmallVector<Operation *>
+  FailureOr<TilingResult>
   getTiledImplementation(Operation *op, OpBuilder &b,
                          ArrayRef<OpFoldResult> offsets,
                          ArrayRef<OpFoldResult> sizes) const {
@@ -412,12 +413,13 @@ struct UnPackOpTiling
         loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs());
 
     if (isPerfectTilingCase)
-      return {tiledUnpackOp};
+      return TilingResult{{tiledUnpackOp},
+                          SmallVector<Value>(tiledUnpackOp->getResults())};
 
-    Operation *extractSlice =
+    auto extractSlice =
         b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0),
                                  resultOffsetsFromDest, sizes, destStrides);
-    return {tiledUnpackOp, extractSlice};
+    return TilingResult{{tiledUnpackOp}, {extractSlice.getResult()}};
   }
 
   LogicalResult
@@ -431,26 +433,29 @@ struct UnPackOpTiling
     return success();
   }
 
-  FailureOr<Value> generateResultTileValue(Operation *op, OpBuilder &b,
-                                           unsigned resultNumber,
-                                           ArrayRef<OpFoldResult> offsets,
-                                           ArrayRef<OpFoldResult> sizes) const {
-    return getTiledImplementation(op, b, offsets, sizes)
-        .back()
-        ->getResult(resultNumber);
+  FailureOr<TilingResult>
+  generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
+                          ArrayRef<OpFoldResult> offsets,
+                          ArrayRef<OpFoldResult> sizes) const {
+    FailureOr<TilingResult> tilingResult =
+        getTiledImplementation(op, b, offsets, sizes);
+    if (failed(tilingResult))
+      return failure();
+    return tilingResult.value();
   }
 };
 
 } // namespace
 
-Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
-                                    ArrayRef<OpFoldResult> offsets,
-                                    ArrayRef<OpFoldResult> sizes,
-                                    bool generateZeroSliceGuard) {
+FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
+                                                 tensor::PadOp padOp,
+                                                 ArrayRef<OpFoldResult> offsets,
+                                                 ArrayRef<OpFoldResult> sizes,
+                                                 bool generateZeroSliceGuard) {
   // Only constant padding value supported.
   Value padValue = padOp.getConstantPaddingValue();
   if (!padValue)
-    return nullptr;
+    return failure();
 
   // Helper variables and functions for various arithmetic operations. These
   // are used extensively for computing new offset/length and padding values.
@@ -584,10 +589,9 @@ Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
       RankedTensorType::get(shape, padOp.getResultType().getElementType());
 
   // Insert cast to ensure that types match. (May be folded away.)
-  auto castResult = [&](Operation *op) -> Operation * {
-    Value val = op->getResult(0);
+  auto castResult = [&](Value val) -> Value {
     if (resultType == val.getType())
-      return op;
+      return val;
     return b.create<tensor::CastOp>(loc, resultType, val);
   };
 
@@ -601,7 +605,7 @@ Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
         [&](OpBuilder &builder, Location gLoc, ValueRange indices) {
           builder.create<tensor::YieldOp>(gLoc, padValue);
         });
-    return castResult(generateOp);
+    return generateOp;
   };
 
   // Emit a SliceOp and a PadOp. Should not be used in cases where
@@ -617,30 +621,38 @@ Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
     padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm);
 
     // Cast result and return.
-    return castResult(newPadOp);
+    return newPadOp;
   };
 
   // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that
   // the original data source x is not used.
-  if (hasZeroLen)
-    return createGenerateOp();
+  if (hasZeroLen) {
+    Operation *generateOp = createGenerateOp();
+    return TilingResult{{generateOp}, {castResult(generateOp->getResult(0))}};
+  }
 
   // If there are dynamic dimensions: Generate an scf.if check to avoid
   // creating SliceOps with result dimensions of size 0 at runtime.
   if (generateZeroSliceGuard && dynHasZeroLenCond) {
+    Operation *thenOp;
+    Operation *elseOp;
     auto result = b.create<scf::IfOp>(
         loc, dynHasZeroLenCond,
         /*thenBuilder=*/
         [&](OpBuilder &b, Location loc) {
-          b.create<scf::YieldOp>(loc, createGenerateOp()->getResult(0));
+          thenOp = createGenerateOp();
+          b.create<scf::YieldOp>(loc, castResult(thenOp->getResult(0)));
         },
         /*elseBuilder=*/
         [&](OpBuilder &b, Location loc) {
-          b.create<scf::YieldOp>(loc, createPadOfExtractSlice()->getResult(0));
+          elseOp = createPadOfExtractSlice();
+          b.create<scf::YieldOp>(loc, castResult(elseOp->getResult(0)));
         });
-    return result;
+    return TilingResult{{result}, SmallVector<Value>(result->getResults())};
   }
-  return createPadOfExtractSlice();
+
+  Operation *newPadOp = createPadOfExtractSlice();
+  return TilingResult{{newPadOp}, {castResult(newPadOp->getResult(0))}};
 }
 
 void mlir::tensor::registerTilingInterfaceExternalModels(

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
index 65176ed7b9e74..40d79c2053817 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
@@ -20,7 +20,7 @@
 
 using namespace mlir;
 
-FailureOr<Value> tensor::replaceExtractSliceWithTiledProducer(
+FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
     OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) {
   auto producerOp = dyn_cast<TilingInterface>(producer.getOwner());
   if (!producerOp)
@@ -32,7 +32,7 @@ FailureOr<Value> tensor::replaceExtractSliceWithTiledProducer(
       }))
     return failure();
 
-  FailureOr<Value> tiledResult = producerOp.generateResultTileValue(
+  FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue(
       builder, producer.getResultNumber(), sliceOp.getMixedOffsets(),
       sliceOp.getMixedSizes());
   if (failed(tiledResult))


        


More information about the Mlir-commits mailing list