[Mlir-commits] [mlir] b01d223 - [mlir][Linalg] Use reify for padded op shape derivation.
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Sep 13 04:55:04 PDT 2021
Author: Nicolas Vasilache
Date: 2021-09-13T11:54:59Z
New Revision: b01d223faf8ac4d62baea8a0e1d7b6cab7938118
URL: https://github.com/llvm/llvm-project/commit/b01d223faf8ac4d62baea8a0e1d7b6cab7938118
DIFF: https://github.com/llvm/llvm-project/commit/b01d223faf8ac4d62baea8a0e1d7b6cab7938118.diff
LOG: [mlir][Linalg] Use reify for padded op shape derivation.
Previously, we would insert a DimOp and rely on later canonicalizations.
Unfortunately, reifyShape kind of rewrites are not canonicalizations anymore.
This introduces undesirable pass dependencies.
Instead, immediately reify the result shape and avoid the DimOp altogether.
This is akin to a local folding, which avoids introducing more reliance on `-resolve-shaped-type-result-dims` (similar to compositions of `affine.apply` by construction to avoid chains of size > 1).
It does not completely get rid of the reliance on the pass as the process is merely local: calling the pass may still be necessary for global effects. Indeed, one of the tests still requires the pass.
Differential Revision: https://reviews.llvm.org/D109571
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 52a652a697b00..8541b2aec6a47 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -185,6 +185,13 @@ linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
}
+ SmallVector<SmallVector<Value>> reifiedResultShapes;
+ if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation())
+ .reifyResultShapes(rewriter, reifiedResultShapes)))
+ return failure();
+ assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
+ "expected same number of results");
+
// Clone `opToPad` to operate on the statically padded shapes.
auto resultTensorTypes =
ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
@@ -192,28 +199,21 @@ linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
// Recover the slice out of the new static results. This keeps the original
// linalg op around because it uses the dims of the original results.
- // This later folds away.
SmallVector<Value> paddedSubviewResults;
paddedSubviewResults.reserve(opToPad->getNumResults());
- SetVector<Operation *> newUsersOfOpToPad;
- for (auto it : llvm::zip(opToPad->getResults(), paddedOp->getResults())) {
- auto rank = std::get<0>(it).getType().cast<RankedTensorType>().getRank();
+ for (auto en : llvm::enumerate(paddedOp->getResults())) {
+ Value paddedResult = en.value();
+ int64_t resultNumber = en.index();
+ int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank();
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
- auto sizes = llvm::to_vector<4>(llvm::map_range(
- llvm::seq<unsigned>(0, rank), [&](unsigned d) -> OpFoldResult {
- auto dimOp = rewriter.create<tensor::DimOp>(loc, std::get<0>(it), d);
- newUsersOfOpToPad.insert(dimOp);
- return dimOp.getResult();
- }));
+ SmallVector<OpFoldResult> sizes;
+ for (Value v : reifiedResultShapes[resultNumber])
+ sizes.push_back(v);
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
paddedSubviewResults.push_back(rewriter.create<tensor::ExtractSliceOp>(
- loc, std::get<1>(it), offsets, sizes, strides));
+ loc, paddedResult, offsets, sizes, strides));
}
- // Replace the transient `opToPad` locally, except for uses that we just
- // created for the purpose of extracting the dims.
- rewriter.replaceOpWithIf(opToPad, paddedSubviewResults, [&](OpOperand &opOp) {
- return !newUsersOfOpToPad.contains(opOp.getOwner());
- });
+ rewriter.replaceOp(opToPad, paddedSubviewResults);
return success();
}
@@ -244,14 +244,16 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
return failure();
// Setup RAII guard to return properly.
+ LinalgOp paddedOp;
LinalgOp tiledOp = res->op;
auto guard = llvm::make_scope_exit([&]() {
// Return relevant information to derived pattern.
result = *res;
- // Replace filter on both tiledOp and tiledAndPaddedOp, if necessary.
- filter.replaceLinalgTransformationFilter(rewriter, tiledOp);
- if (tiledOp != res->op)
- filter.replaceLinalgTransformationFilter(rewriter, res->op);
+ // Update filter.
+ if (paddedOp)
+ filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
+ else
+ filter.replaceLinalgTransformationFilter(rewriter, tiledOp);
});
// Consider padding on the fly only if the op has tensor semantics.
@@ -261,7 +263,6 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
// Try to pad on the fly by rewriting res->op as a padded op. If successful,
// `res.op` is rewritten in static form with padded operands.
- LinalgOp paddedOp;
if (succeeded(rewriteAsPaddedOp(rewriter, res->op,
options.paddingValueComputationFunction,
paddedOp))) {
diff --git a/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir b/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir
index e2a22fa104b22..ba8b4e63cd10e 100644
--- a/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir
+++ b/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3" -resolve-shaped-type-result-dims -cse -split-input-file | \
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3" -cse -split-input-file | \
// RUN: FileCheck %s -check-prefix=TILE2
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,3" -resolve-shaped-type-result-dims -cse -split-input-file | \
// RUN: FileCheck %s -check-prefix=TILE1
More information about the Mlir-commits
mailing list