[Mlir-commits] [mlir] b9d7ffd - Folds linalg.pad_tensor with zero padding

Ahmed Taei llvmlistbot at llvm.org
Wed Jun 9 15:43:40 PDT 2021


Author: Ahmed Taei
Date: 2021-06-09T15:39:40-07:00
New Revision: b9d7ffd9cf5f9caefb9796468bf4cbeec709b320

URL: https://github.com/llvm/llvm-project/commit/b9d7ffd9cf5f9caefb9796468bf4cbeec709b320
DIFF: https://github.com/llvm/llvm-project/commit/b9d7ffd9cf5f9caefb9796468bf4cbeec709b320.diff

LOG: Folds linalg.pad_tensor with zero padding

Differential Revision: https://reviews.llvm.org/D103984

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 8b0766b2f703..51e1ab401c9e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -291,6 +291,8 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
       "ArrayRef<OpFoldResult>":$low, "ArrayRef<OpFoldResult>":$high,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
   ];
+
+  let hasCanonicalizer = 1;
 }
 
 def Linalg_RangeOp :

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 47c6bc70339f..8830f57f2a1c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1119,6 +1119,28 @@ LogicalResult PadTensorOp::reifyReturnTypeShapesPerResultDim(
   return success();
 }
 
+namespace {
+// Folds linalg.pad_tensor when padding is static zeros.
+struct FoldStaticZeroPadding : public OpRewritePattern<PadTensorOp> {
+  using OpRewritePattern<PadTensorOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(PadTensorOp padTensorOp,
+                                PatternRewriter &rewriter) const override {
+    if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
+      return failure();
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(
+        padTensorOp, padTensorOp.result().getType(), padTensorOp.source());
+    return success();
+  }
+};
+
+} // namespace
+
+void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                              MLIRContext *context) {
+  results.add<FoldStaticZeroPadding>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // ReshapeOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index b3915477ce47..c51cbdbb3568 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1132,3 +1132,19 @@ func @indexed_generic(%arg0: memref<?x?xindex>, %arg1: memref<?x?xindex>) {
 // CHECK-NEXT:     %[[SUM1:.+]] = addi %[[SUM0]], %[[ARG2]] : index
 // CHECK-NEXT:     %[[SUM2:.+]] = addi %[[SUM1]], %[[ARG3]] : index
 // CHECK-NEXT:     linalg.yield %[[SUM2]] : index
+
+// -----
+
+func @tensor_pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+  %c0 = constant 0 : index
+  %cst = constant 0.0 : f32
+  %0 = tensor.cast %arg0 : tensor<4x4xf32> to tensor<?x?xf32>
+  %1 = linalg.pad_tensor %0 low[%c0, %c0] high[%c0, %c0]  {
+    ^bb0(%arg1: index, %arg2: index):  // no predecessors
+      linalg.yield %cst : f32
+  } : tensor<?x?xf32> to tensor<4x4xf32>
+  return %1 : tensor<4x4xf32>
+}
+// CHECK-LABEL: @tensor_pad_cast
+// CHECK-SAME: %[[ARG0:.+]]: tensor<4x4xf32>
+// CHECK: return %[[ARG0]]


        


More information about the Mlir-commits mailing list