[Mlir-commits] [mlir] 94f2a6d - [mlir][TilingInterface] NFC: Consolidate yield handling.

Mahesh Ravishankar llvmlistbot at llvm.org
Sun Jan 15 21:19:00 PST 2023


Author: Mahesh Ravishankar
Date: 2023-01-16T05:03:41Z
New Revision: 94f2a6ddde5cc8940bd26cec68899d84934c21df

URL: https://github.com/llvm/llvm-project/commit/94f2a6ddde5cc8940bd26cec68899d84934c21df
DIFF: https://github.com/llvm/llvm-project/commit/94f2a6ddde5cc8940bd26cec68899d84934c21df.diff

LOG: [mlir][TilingInterface] NFC: Consolidate yield handling.

Add a new utility method to yield the tiled value as well as
preserving destination passing style.

Differential Revision: https://reviews.llvm.org/D139392

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 02323c584a84f..52cd7609d55e1 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -173,7 +173,7 @@ generateTileLoopNest(OpBuilder &builder, Location loc,
 /// }
 /// ```
 /// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`.
-static FailureOr<SmallVector<Value>>
+static SmallVector<Value>
 yieldTiledValues(RewriterBase &rewriter, ValueRange initValues,
                  ValueRange yieldedValues,
                  ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
@@ -245,6 +245,27 @@ updateDestinationOperandsForTiledOp(OpBuilder &builder,
   }
 }
 
+/// Helper method to yield the values of the tiled op, as well as
+/// update the destination operands of the tiled op, if it is
+/// a destination passing style op.
+static SmallVector<Value>
+yieldTiledValues(RewriterBase &rewriter, ArrayRef<Value> initValues,
+                 Operation *tiledOp,
+                 ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
+                 ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
+                 MutableArrayRef<scf::ForOp> loops) {
+  SmallVector<Value> replacements =
+      yieldTiledValues(rewriter, initValues, tiledOp->getResults(),
+                       tileOffsetsList, tileSizesList, loops);
+  if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(tiledOp)) {
+    auto innerMostLoop = loops.back();
+    SmallVector<Value> tiledOpDestinationTensors = dstOp.getDpsInitOperands();
+    updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors,
+                                        innerMostLoop.getRegionIterArgs());
+  }
+  return replacements;
+}
+
 /// Implementation of tiling transformation of `op` that implements the
 /// `TilingInterface` using `scf.for` to iterate over the tiles.
 FailureOr<scf::SCFTilingResult>
@@ -258,12 +279,6 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
         op, "missing tile size computation function");
   }
 
-  // Get destination tensors.
-  SmallVector<Value> destinationTensors;
-  if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
-                                             destinationTensors)))
-    return rewriter.notifyMatchFailure(op, "failed to get destinations");
-
   // 1. Get the range of the loops that are represented by the operation.
   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
   size_t numLoops = iterationDomain.size();
@@ -362,24 +377,14 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
     }
   }
 
-  FailureOr<SmallVector<Value>> replacementOr = yieldTiledValues(
-      rewriter, destinationTensors, tilingResult.tiledOps.back()->getResults(),
-      resultOffsetsList, resultSizesList, tilingResult.loops);
-  if (failed(replacementOr))
-    return rewriter.notifyMatchFailure(op, "failed to yield replacement");
-
-  if (auto dstOp =
-          dyn_cast<DestinationStyleOpInterface>(tilingResult.tiledOps.back())) {
-    auto innerMostLoop = tilingResult.loops.back();
-    SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands();
-    assert(destinationTensors.size() ==
-               innerMostLoop.getRegionIterArgs().size() &&
-           "unexpected number of outputs");
-    updateDestinationOperandsForTiledOp(rewriter, destinationTensors,
-                                        innerMostLoop.getRegionIterArgs());
-  }
+  SmallVector<Value> destinationTensors;
+  if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
+                                             destinationTensors)))
+    return rewriter.notifyMatchFailure(op, "failed to get destinations");
 
-  tilingResult.replacements = *replacementOr;
+  tilingResult.replacements = yieldTiledValues(
+      rewriter, destinationTensors, tilingResult.tiledOps.back(),
+      resultOffsetsList, resultSizesList, tilingResult.loops);
 
   LLVM_DEBUG({
     if (!tilingResult.loops.empty()) {
@@ -449,11 +454,9 @@ mlir::scf::tileReductionUsingScf(PatternRewriter &b,
     resultSizesList.push_back(
         b.createOrFold<tensor::DimOp>(loc, parallelOp->getResult(0), i));
   SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
-  FailureOr<SmallVector<Value>> replacementOr = yieldTiledValues(
+  SmallVector<Value> replacements = yieldTiledValues(
       b, (*identityTensor)->getResults(), parallelOp->getResults(), outOffsets,
       resultSizesList, loops);
-  if (failed(replacementOr))
-    return b.notifyMatchFailure(op, "failed to yield replacement");
 
   auto dstOp = cast<DestinationStyleOpInterface>(parallelOp);
   auto innerMostLoop = loops.back();
@@ -466,7 +469,7 @@ mlir::scf::tileReductionUsingScf(PatternRewriter &b,
 
   // 4. Apply the merge reduction to combine all the partial values.
   b.setInsertionPointAfter(*loops.begin());
-  Operation *mergeOp = op.mergeReductions(b, loc, *replacementOr, reductionDim);
+  Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDim);
   b.replaceOp(op, mergeOp->getResults());
 
   SCFReductionTilingResult results;


        


More information about the Mlir-commits mailing list