[Mlir-commits] [mlir] 94f2a6d - [mlir][TilingInterface] NFC: Consolidate yield handling.
Mahesh Ravishankar
llvmlistbot at llvm.org
Sun Jan 15 21:19:00 PST 2023
Author: Mahesh Ravishankar
Date: 2023-01-16T05:03:41Z
New Revision: 94f2a6ddde5cc8940bd26cec68899d84934c21df
URL: https://github.com/llvm/llvm-project/commit/94f2a6ddde5cc8940bd26cec68899d84934c21df
DIFF: https://github.com/llvm/llvm-project/commit/94f2a6ddde5cc8940bd26cec68899d84934c21df.diff
LOG: [mlir][TilingInterface] NFC: Consolidate yield handling.
Add a new utility method to yield the tiled value as well as
preserving destination passing style.
Differential Revision: https://reviews.llvm.org/D139392
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 02323c584a84f..52cd7609d55e1 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -173,7 +173,7 @@ generateTileLoopNest(OpBuilder &builder, Location loc,
/// }
/// ```
/// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`.
-static FailureOr<SmallVector<Value>>
+static SmallVector<Value>
yieldTiledValues(RewriterBase &rewriter, ValueRange initValues,
ValueRange yieldedValues,
ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
@@ -245,6 +245,27 @@ updateDestinationOperandsForTiledOp(OpBuilder &builder,
}
}
+/// Helper method to yield the values of the tiled op, as well as
+/// update the destination operands of the tiled op, if it is
+/// a destination passing style op.
+static SmallVector<Value>
+yieldTiledValues(RewriterBase &rewriter, ArrayRef<Value> initValues,
+ Operation *tiledOp,
+ ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
+ ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
+ MutableArrayRef<scf::ForOp> loops) {
+ SmallVector<Value> replacements =
+ yieldTiledValues(rewriter, initValues, tiledOp->getResults(),
+ tileOffsetsList, tileSizesList, loops);
+ if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(tiledOp)) {
+ auto innerMostLoop = loops.back();
+ SmallVector<Value> tiledOpDestinationTensors = dstOp.getDpsInitOperands();
+ updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors,
+ innerMostLoop.getRegionIterArgs());
+ }
+ return replacements;
+}
+
/// Implementation of tiling transformation of `op` that implements the
/// `TilingInterface` using `scf.for` to iterate over the tiles.
FailureOr<scf::SCFTilingResult>
@@ -258,12 +279,6 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
op, "missing tile size computation function");
}
- // Get destination tensors.
- SmallVector<Value> destinationTensors;
- if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
- destinationTensors)))
- return rewriter.notifyMatchFailure(op, "failed to get destinations");
-
// 1. Get the range of the loops that are represented by the operation.
SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
size_t numLoops = iterationDomain.size();
@@ -362,24 +377,14 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
}
}
- FailureOr<SmallVector<Value>> replacementOr = yieldTiledValues(
- rewriter, destinationTensors, tilingResult.tiledOps.back()->getResults(),
- resultOffsetsList, resultSizesList, tilingResult.loops);
- if (failed(replacementOr))
- return rewriter.notifyMatchFailure(op, "failed to yield replacement");
-
- if (auto dstOp =
- dyn_cast<DestinationStyleOpInterface>(tilingResult.tiledOps.back())) {
- auto innerMostLoop = tilingResult.loops.back();
- SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands();
- assert(destinationTensors.size() ==
- innerMostLoop.getRegionIterArgs().size() &&
- "unexpected number of outputs");
- updateDestinationOperandsForTiledOp(rewriter, destinationTensors,
- innerMostLoop.getRegionIterArgs());
- }
+ SmallVector<Value> destinationTensors;
+ if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
+ destinationTensors)))
+ return rewriter.notifyMatchFailure(op, "failed to get destinations");
- tilingResult.replacements = *replacementOr;
+ tilingResult.replacements = yieldTiledValues(
+ rewriter, destinationTensors, tilingResult.tiledOps.back(),
+ resultOffsetsList, resultSizesList, tilingResult.loops);
LLVM_DEBUG({
if (!tilingResult.loops.empty()) {
@@ -449,11 +454,9 @@ mlir::scf::tileReductionUsingScf(PatternRewriter &b,
resultSizesList.push_back(
b.createOrFold<tensor::DimOp>(loc, parallelOp->getResult(0), i));
SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
- FailureOr<SmallVector<Value>> replacementOr = yieldTiledValues(
+ SmallVector<Value> replacements = yieldTiledValues(
b, (*identityTensor)->getResults(), parallelOp->getResults(), outOffsets,
resultSizesList, loops);
- if (failed(replacementOr))
- return b.notifyMatchFailure(op, "failed to yield replacement");
auto dstOp = cast<DestinationStyleOpInterface>(parallelOp);
auto innerMostLoop = loops.back();
@@ -466,7 +469,7 @@ mlir::scf::tileReductionUsingScf(PatternRewriter &b,
// 4. Apply the merge reduction to combine all the partial values.
b.setInsertionPointAfter(*loops.begin());
- Operation *mergeOp = op.mergeReductions(b, loc, *replacementOr, reductionDim);
+ Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDim);
b.replaceOp(op, mergeOp->getResults());
SCFReductionTilingResult results;
More information about the Mlir-commits
mailing list