[Mlir-commits] [mlir] 87656a3 - [mlir][linalg] Fold TensorCast into PadTensorOp.
Tobias Gysi
llvmlistbot at llvm.org
Mon Jul 19 08:58:10 PDT 2021
Author: Tobias Gysi
Date: 2021-07-19T15:57:38Z
New Revision: 87656a3134c7c03565efca85352a58541ce68789
URL: https://github.com/llvm/llvm-project/commit/87656a3134c7c03565efca85352a58541ce68789
DIFF: https://github.com/llvm/llvm-project/commit/87656a3134c7c03565efca85352a58541ce68789.diff
LOG: [mlir][linalg] Fold TensorCast into PadTensorOp.
Add pattern to fold a TensorCast into a PadTensorOp if the cast removes static size information.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D106278
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 93d330d8d846a..16ff0711f4744 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1063,11 +1063,29 @@ struct FoldToDimOfOutputOperand : public OpRewritePattern<tensor::DimOp> {
return success();
}
};
+
+// Fold CastOp into PadTensorOp when adding static information.
+struct FoldSourceTensorCast : public OpRewritePattern<PadTensorOp> {
+ using OpRewritePattern<PadTensorOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(PadTensorOp padTensorOp,
+ PatternRewriter &rewriter) const override {
+ auto castOp = padTensorOp.source().getDefiningOp<tensor::CastOp>();
+ if (!tensor::canFoldIntoConsumerOp(castOp))
+ return failure();
+
+ rewriter.updateRootInPlace(padTensorOp, [&]() {
+ padTensorOp.sourceMutable().assign(castOp.source());
+ });
+ return success();
+ }
+};
} // namespace
void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldStaticZeroPadding, FoldToDimOfOutputOperand>(context);
+ results.add<FoldStaticZeroPadding, FoldToDimOfOutputOperand,
+ FoldSourceTensorCast>(context);
}
/// Return the padding value of the PadTensorOp if it constant. In this context,
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index c453255d39485..b24876cfe255d 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -772,6 +772,22 @@ func @tensor_pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
// -----
+// CHECK-LABEL: func @fold_pad_tensor_source_cast(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<4x?xf32>
+// CHECK-NOT: tensor.cast
+// CHECK: %[[RESULT:.*]] = linalg.pad_tensor %[[ARG0]]
+func @fold_pad_tensor_source_cast(%arg0: tensor<4x?xf32>) -> tensor<4x4xf32> {
+ %cst = constant 0.0 : f32
+ %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
+ %1 = linalg.pad_tensor %0 low[0, 0] high[0, 1] {
+ ^bb0(%arg1: index, %arg2: index): // no predecessors
+ linalg.yield %cst : f32
+ } : tensor<?x?xf32> to tensor<4x4xf32>
+ return %1 : tensor<4x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @pad_static_zero_cast(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x?xf32>
// CHECK-NOT: linalg.pad_tensor
More information about the Mlir-commits
mailing list