[Mlir-commits] [mlir] 06a0385 - [mlir][linalg] Fold tensor.pad(linalg.fill) with the same value

Lei Zhang llvmlistbot at llvm.org
Thu Feb 10 05:39:47 PST 2022


Author: Lei Zhang
Date: 2022-02-10T08:39:35-05:00
New Revision: 06a03851429d387af5a983602a79b8f7757f6b86

URL: https://github.com/llvm/llvm-project/commit/06a03851429d387af5a983602a79b8f7757f6b86
DIFF: https://github.com/llvm/llvm-project/commit/06a03851429d387af5a983602a79b8f7757f6b86.diff

LOG: [mlir][linalg] Fold tensor.pad(linalg.fill) with the same value

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D119160

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 133ff048a44a..4868fdb99341 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -441,12 +441,52 @@ struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
   }
 };
 
+/// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
+/// filling value are the same.
+struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::PadOp padOp,
+                                PatternRewriter &rewriter) const override {
+    auto fillOp = padOp.source().getDefiningOp<linalg::FillOp>();
+    if (!fillOp)
+      return failure();
+
+    // We can only fold if the padding value is the same as the original
+    // filling value.
+    Value padValue = padOp.getConstantPaddingValue();
+    if (!padValue || fillOp.value() != padValue)
+      return failure();
+
+    ReifiedRankedShapedTypeDims reifiedShape;
+    ReifyRankedShapedTypeOpInterface interface =
+        cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation());
+    if (failed(interface.reifyResultShapes(rewriter, reifiedShape)))
+      return rewriter.notifyMatchFailure(
+          padOp, "failed to reify tensor.pad op result shape");
+
+    auto oldResultType = padOp.getResultType();
+    SmallVector<int64_t, 4> staticShape(oldResultType.getRank(),
+                                        ShapedType::kDynamicSize);
+    auto newInitOp = rewriter.create<InitTensorOp>(
+        padOp.getLoc(), reifiedShape.front(), staticShape,
+        oldResultType.getElementType());
+    auto newFillOp =
+        rewriter.create<FillOp>(fillOp.getLoc(), padValue, newInitOp);
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(padOp, oldResultType,
+                                                newFillOp.result());
+
+    return success();
+  }
+};
+
 } // namespace
 
 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
-  results.add<FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
-              FoldFillWithTensorReshape<tensor::ExpandShapeOp>>(context);
+  results
+      .add<FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
+           FoldFillWithTensorReshape<tensor::ExpandShapeOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 48f70c1404ad..cbc8e4a50de5 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -585,3 +585,68 @@ func @fold_self_copy(%0 : memref<4x16xf32>) {
     }
   return 
 }
+
+// -----
+
+// CHECK-LABEL: func @fold_static_pad_fill
+//       CHECK:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+//       CHECK:   %[[INIT:.+]] = linalg.init_tensor [412, 276] : tensor<412x276xf32>
+//       CHECK:   %[[FILL:.+]] = linalg.fill(%[[F0]], %[[INIT]])
+//       CHECK:   return %[[FILL]]
+func @fold_static_pad_fill() -> tensor<412x276xf32> {
+  %f0 = arith.constant 0.0 : f32
+  %init = linalg.init_tensor [400, 273] : tensor<400x273xf32>
+  %fill = linalg.fill(%f0, %init) : f32, tensor<400x273xf32> -> tensor<400x273xf32>
+  %pad = tensor.pad %fill low[4, 1] high[8, 2] {
+  ^bb0(%arg1: index, %arg2: index):
+    tensor.yield %f0 : f32
+  } : tensor<400x273xf32> to tensor<412x276xf32>
+  return %pad : tensor<412x276xf32>
+}
+
+// -----
+
+// CHECK: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 + 9)>
+// CHECK: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 + 10)>
+// CHECK: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 + 23)>
+// CHECK: #[[MAP3:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 32)>
+
+//      CHECK: func @fold_dynamic_pad_fill
+// CHECK-SAME: %[[SRC:.+]]: tensor<8x?x16x32xf32>, %[[LOW0:.+]]: index, %[[LOW3:.+]]: index, %[[HIGH2:.+]]: index, %[[HIGH3:.+]]: index
+
+//  CHECK-DAG:   %[[I1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+//      CHECK:   %[[OF:.+]] = linalg.fill(%[[F0]], %[[SRC]]) : f32, tensor<8x?x16x32xf32>
+//      CHECK:   %[[S0:.+]] = affine.apply #[[MAP0]]()[%[[LOW0]]]
+//      CHECK:   %[[DIM1:.+]] = tensor.dim %[[OF]], %[[I1]] : tensor<8x?x16x32xf32>
+//      CHECK:   %[[S1:.+]] = affine.apply #[[MAP1]]()[%[[DIM1]]]
+//      CHECK:   %[[S2:.+]] = affine.apply #[[MAP2]]()[%[[HIGH2]]]
+//      CHECK:   %[[S3:.+]] = affine.apply #[[MAP3]]()[%[[LOW3]], %[[HIGH3]]]
+//      CHECK:   %[[INIT:.+]] = linalg.init_tensor [%[[S0]], %[[S1]], %[[S2]], %[[S3]]] : tensor<?x?x?x?xf32>
+//      CHECK:   %[[FILL:.+]] = linalg.fill(%[[F0]], %[[INIT]])
+//      CHECK:   return %[[FILL]]
+func @fold_dynamic_pad_fill(%init: tensor<8x?x16x32xf32>, %low0: index, %low3: index, %high2: index, %high3: index) -> tensor<?x?x?x?xf32> {
+  %f0 = arith.constant 0.0 : f32
+  %fill = linalg.fill(%f0, %init) : f32, tensor<8x?x16x32xf32> -> tensor<8x?x16x32xf32>
+  %pad = tensor.pad %fill low[%low0, 8, 7, %low3] high[1, 2, %high2, %high3] {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
+    tensor.yield %f0 : f32
+  } : tensor<8x?x16x32xf32> to tensor<?x?x?x?xf32>
+  return %pad : tensor<?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @no_fold_pad_fill_value_mismatch
+func @no_fold_pad_fill_value_mismatch() -> tensor<412x276xf32> {
+  %f0 = arith.constant 0.0 : f32
+  %f1 = arith.constant 1.0 : f32
+  %init = linalg.init_tensor [400, 273] : tensor<400x273xf32>
+  %fill = linalg.fill(%f0, %init) : f32, tensor<400x273xf32> -> tensor<400x273xf32>
+  // CHECK: tensor.pad
+  %pad = tensor.pad %fill low[4, 1] high[8, 2] {
+  ^bb0(%arg1: index, %arg2: index):
+    tensor.yield %f1 : f32
+  } : tensor<400x273xf32> to tensor<412x276xf32>
+  return %pad : tensor<412x276xf32>
+}


        


More information about the Mlir-commits mailing list