[Mlir-commits] [mlir] b9d7ffd - Folds linalg.pad_tensor with zero padding
Ahmed Taei
llvmlistbot at llvm.org
Wed Jun 9 15:43:40 PDT 2021
Author: Ahmed Taei
Date: 2021-06-09T15:39:40-07:00
New Revision: b9d7ffd9cf5f9caefb9796468bf4cbeec709b320
URL: https://github.com/llvm/llvm-project/commit/b9d7ffd9cf5f9caefb9796468bf4cbeec709b320
DIFF: https://github.com/llvm/llvm-project/commit/b9d7ffd9cf5f9caefb9796468bf4cbeec709b320.diff
LOG: Folds linalg.pad_tensor with zero padding
Differential Revision: https://reviews.llvm.org/D103984
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 8b0766b2f703..51e1ab401c9e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -291,6 +291,8 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
"ArrayRef<OpFoldResult>":$low, "ArrayRef<OpFoldResult>":$high,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
];
+
+ let hasCanonicalizer = 1;
}
def Linalg_RangeOp :
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 47c6bc70339f..8830f57f2a1c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1119,6 +1119,28 @@ LogicalResult PadTensorOp::reifyReturnTypeShapesPerResultDim(
return success();
}
+namespace {
+// Folds linalg.pad_tensor when padding is static zeros.
+struct FoldStaticZeroPadding : public OpRewritePattern<PadTensorOp> {
+ using OpRewritePattern<PadTensorOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(PadTensorOp padTensorOp,
+ PatternRewriter &rewriter) const override {
+ if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
+ return failure();
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(
+ padTensorOp, padTensorOp.result().getType(), padTensorOp.source());
+ return success();
+ }
+};
+
+} // namespace
+
+void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<FoldStaticZeroPadding>(context);
+}
+
//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index b3915477ce47..c51cbdbb3568 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1132,3 +1132,19 @@ func @indexed_generic(%arg0: memref<?x?xindex>, %arg1: memref<?x?xindex>) {
// CHECK-NEXT: %[[SUM1:.+]] = addi %[[SUM0]], %[[ARG2]] : index
// CHECK-NEXT: %[[SUM2:.+]] = addi %[[SUM1]], %[[ARG3]] : index
// CHECK-NEXT: linalg.yield %[[SUM2]] : index
+
+// -----
+
+func @tensor_pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+ %c0 = constant 0 : index
+ %cst = constant 0.0 : f32
+ %0 = tensor.cast %arg0 : tensor<4x4xf32> to tensor<?x?xf32>
+ %1 = linalg.pad_tensor %0 low[%c0, %c0] high[%c0, %c0] {
+ ^bb0(%arg1: index, %arg2: index): // no predecessors
+ linalg.yield %cst : f32
+ } : tensor<?x?xf32> to tensor<4x4xf32>
+ return %1 : tensor<4x4xf32>
+}
+// CHECK-LABEL: @tensor_pad_cast
+// CHECK-SAME: %[[ARG0:.+]]: tensor<4x4xf32>
+// CHECK: return %[[ARG0]]
More information about the Mlir-commits
mailing list