[Mlir-commits] [mlir] aa2210a - [linalg] Expose `rewriteAsPaddedOp` function.
Alexander Belyaev
llvmlistbot at llvm.org
Fri Aug 6 03:08:29 PDT 2021
Author: Alexander Belyaev
Date: 2021-08-06T12:08:12+02:00
New Revision: aa2210a830699fbf6e218789ac3da7abffee0b0c
URL: https://github.com/llvm/llvm-project/commit/aa2210a830699fbf6e218789ac3da7abffee0b0c
DIFF: https://github.com/llvm/llvm-project/commit/aa2210a830699fbf6e218789ac3da7abffee0b0c.diff
LOG: [linalg] Expose `rewriteAsPaddedOp` function.
Differential Revision: https://reviews.llvm.org/D107629
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index a2ce99bc97b0e..87908ac88231f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -886,6 +886,13 @@ struct PadTensorOpTransformationPattern : public OpRewritePattern<PadTensorOp> {
PatternRewriter &rewriter) const override;
};
+/// Try to create a static bounding box around each operand of `opToPad`.
+/// If successful, `paddedOp` will be updated to the cloned static form.
+LogicalResult
+rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
+ const PaddingValueComputationFunction &paddingFunc,
+ LinalgOp &paddedOp);
+
using OptimizeCopyFn =
std::function<LogicalResult(PatternRewriter &, PadTensorOp, Value)>;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 8d3ee8fe5566d..7fc1c5f47043f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -126,7 +126,7 @@ mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
/// Return failure if the operand cannot be padded to a static shape.
static LogicalResult padOperandToSmallestStaticBoundingBox(
PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand,
- const LinalgTilingOptions &options, Value &result) {
+ const PaddingValueComputationFunction &paddingFunc, Value &result) {
// Already static shape, no need to pad.
if (llvm::none_of(opToPad.getShape(opOperand), ShapedType::isDynamic))
return success();
@@ -148,7 +148,7 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
opToPad, "No constant bounding box can be found for padding");
staticSizes.push_back(indexAttr.getInt());
}
- Value pad = options.paddingValueComputationFunction(rewriter, *opOperand);
+ Value pad = paddingFunc(rewriter, *opOperand);
auto staticTensorType = RankedTensorType::get(
staticSizes, getElementTypeOrSelf(opOperand->get()));
result = linalg::PadTensorOp::createPadHighOp(
@@ -156,13 +156,10 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
return success();
}
-// Try to create a static bounding box around each operand of `res.op`.
-// If successful, `res.op` is rewritten in static form with padded operands.
-// `res.op` is updated to the cloned static form of the op on success.
-static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
- TiledLinalgOp &res,
- const LinalgTilingOptions &options) {
- LinalgOp opToPad = res.op;
+LogicalResult
+linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
+ const PaddingValueComputationFunction &paddingFunc,
+ LinalgOp &paddedOp) {
Location loc = opToPad->getLoc();
// If the op is fully static, it does not need padding.
@@ -183,7 +180,7 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
// If padding was requested but the shape cannot be bounded statically then
// the pattern fails to apply.
if (failed(padOperandToSmallestStaticBoundingBox(
- rewriter, opToPad, opOperand, options, paddedOperand)))
+ rewriter, opToPad, opOperand, paddingFunc, paddedOperand)))
return failure();
newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
}
@@ -191,8 +188,7 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
// Clone `opToPad` to operate on the statically padded shapes.
auto resultTensorTypes =
ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
- linalg::LinalgOp paddedOp =
- opToPad.clone(rewriter, loc, resultTensorTypes, newOperands);
+ paddedOp = opToPad.clone(rewriter, loc, resultTensorTypes, newOperands);
// Recover the slice out of the new static results. This keeps the original
// linalg op around because it uses the dims of the original results.
@@ -218,8 +214,6 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
rewriter.replaceOpWithIf(opToPad, paddedSubviewResults, [&](OpOperand &opOp) {
return !newUsersOfOpToPad.contains(opOp.getOwner());
});
-
- res = TiledLinalgOp{paddedOp, res.loops, res.tensorResults};
return success();
}
@@ -265,15 +259,19 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
!linalgOp.hasTensorSemantics())
return success();
- // Try to pad on the fly by rewriting res->op as a padded op.
- if (failed(rewriteAsPaddedOp(rewriter, *res, options))) {
- // Set so RAII guard does not propagate TiledLinalgOp to `result`.
- return failure();
+ // Try to pad on the fly by rewriting res->op as a padded op. If successful,
+ // `res.op` is rewritten in static form with padded operands.
+ LinalgOp paddedOp;
+ if (succeeded(rewriteAsPaddedOp(rewriter, res->op,
+ options.paddingValueComputationFunction,
+ paddedOp))) {
+ res->op = paddedOp;
+ // Do not perform replacement of `linalgOp`, let the derived patterns
+ // do this as they see fit, from the resulting TiledLinalgOp.
+ return success();
}
-
- // Do not perform replacement of `linalgOp`, let the derived patterns
- // do this as they see fit, from the resulting TiledLinalgOp.
- return success();
+ // Set so RAII guard does not propagate TiledLinalgOp to `result`.
+ return failure();
}
static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
More information about the Mlir-commits
mailing list