[Mlir-commits] [mlir] 06a0385 - [mlir][linalg] Fold tensor.pad(linalg.fill) with the same value
Lei Zhang
llvmlistbot at llvm.org
Thu Feb 10 05:39:47 PST 2022
Author: Lei Zhang
Date: 2022-02-10T08:39:35-05:00
New Revision: 06a03851429d387af5a983602a79b8f7757f6b86
URL: https://github.com/llvm/llvm-project/commit/06a03851429d387af5a983602a79b8f7757f6b86
DIFF: https://github.com/llvm/llvm-project/commit/06a03851429d387af5a983602a79b8f7757f6b86.diff
LOG: [mlir][linalg] Fold tensor.pad(linalg.fill) with the same value
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D119160
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 133ff048a44a..4868fdb99341 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -441,12 +441,52 @@ struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
}
};
+/// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
+/// filling value are the same.
+struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::PadOp padOp,
+ PatternRewriter &rewriter) const override {
+ auto fillOp = padOp.source().getDefiningOp<linalg::FillOp>();
+ if (!fillOp)
+ return failure();
+
+ // We can only fold if the padding value is the same as the original
+ // filling value.
+ Value padValue = padOp.getConstantPaddingValue();
+ if (!padValue || fillOp.value() != padValue)
+ return failure();
+
+ ReifiedRankedShapedTypeDims reifiedShape;
+ ReifyRankedShapedTypeOpInterface interface =
+ cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation());
+ if (failed(interface.reifyResultShapes(rewriter, reifiedShape)))
+ return rewriter.notifyMatchFailure(
+ padOp, "failed to reify tensor.pad op result shape");
+
+ auto oldResultType = padOp.getResultType();
+ SmallVector<int64_t, 4> staticShape(oldResultType.getRank(),
+ ShapedType::kDynamicSize);
+ auto newInitOp = rewriter.create<InitTensorOp>(
+ padOp.getLoc(), reifiedShape.front(), staticShape,
+ oldResultType.getElementType());
+ auto newFillOp =
+ rewriter.create<FillOp>(fillOp.getLoc(), padValue, newInitOp);
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(padOp, oldResultType,
+ newFillOp.result());
+
+ return success();
+ }
+};
+
} // namespace
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
- FoldFillWithTensorReshape<tensor::ExpandShapeOp>>(context);
+ results
+ .add<FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
+ FoldFillWithTensorReshape<tensor::ExpandShapeOp>>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 48f70c1404ad..cbc8e4a50de5 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -585,3 +585,68 @@ func @fold_self_copy(%0 : memref<4x16xf32>) {
}
return
}
+
+// -----
+
+// CHECK-LABEL: func @fold_static_pad_fill
+// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [412, 276] : tensor<412x276xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill(%[[F0]], %[[INIT]])
+// CHECK: return %[[FILL]]
+func @fold_static_pad_fill() -> tensor<412x276xf32> {
+ %f0 = arith.constant 0.0 : f32
+ %init = linalg.init_tensor [400, 273] : tensor<400x273xf32>
+ %fill = linalg.fill(%f0, %init) : f32, tensor<400x273xf32> -> tensor<400x273xf32>
+ %pad = tensor.pad %fill low[4, 1] high[8, 2] {
+ ^bb0(%arg1: index, %arg2: index):
+ tensor.yield %f0 : f32
+ } : tensor<400x273xf32> to tensor<412x276xf32>
+ return %pad : tensor<412x276xf32>
+}
+
+// -----
+
+// CHECK: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 + 9)>
+// CHECK: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 + 10)>
+// CHECK: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 + 23)>
+// CHECK: #[[MAP3:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 32)>
+
+// CHECK: func @fold_dynamic_pad_fill
+// CHECK-SAME: %[[SRC:.+]]: tensor<8x?x16x32xf32>, %[[LOW0:.+]]: index, %[[LOW3:.+]]: index, %[[HIGH2:.+]]: index, %[[HIGH3:.+]]: index
+
+// CHECK-DAG: %[[I1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[OF:.+]] = linalg.fill(%[[F0]], %[[SRC]]) : f32, tensor<8x?x16x32xf32>
+// CHECK: %[[S0:.+]] = affine.apply #[[MAP0]]()[%[[LOW0]]]
+// CHECK: %[[DIM1:.+]] = tensor.dim %[[OF]], %[[I1]] : tensor<8x?x16x32xf32>
+// CHECK: %[[S1:.+]] = affine.apply #[[MAP1]]()[%[[DIM1]]]
+// CHECK: %[[S2:.+]] = affine.apply #[[MAP2]]()[%[[HIGH2]]]
+// CHECK: %[[S3:.+]] = affine.apply #[[MAP3]]()[%[[LOW3]], %[[HIGH3]]]
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[S0]], %[[S1]], %[[S2]], %[[S3]]] : tensor<?x?x?x?xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill(%[[F0]], %[[INIT]])
+// CHECK: return %[[FILL]]
+func @fold_dynamic_pad_fill(%init: tensor<8x?x16x32xf32>, %low0: index, %low3: index, %high2: index, %high3: index) -> tensor<?x?x?x?xf32> {
+ %f0 = arith.constant 0.0 : f32
+ %fill = linalg.fill(%f0, %init) : f32, tensor<8x?x16x32xf32> -> tensor<8x?x16x32xf32>
+ %pad = tensor.pad %fill low[%low0, 8, 7, %low3] high[1, 2, %high2, %high3] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
+ tensor.yield %f0 : f32
+ } : tensor<8x?x16x32xf32> to tensor<?x?x?x?xf32>
+ return %pad : tensor<?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @no_fold_pad_fill_value_mismatch
+func @no_fold_pad_fill_value_mismatch() -> tensor<412x276xf32> {
+ %f0 = arith.constant 0.0 : f32
+ %f1 = arith.constant 1.0 : f32
+ %init = linalg.init_tensor [400, 273] : tensor<400x273xf32>
+ %fill = linalg.fill(%f0, %init) : f32, tensor<400x273xf32> -> tensor<400x273xf32>
+ // CHECK: tensor.pad
+ %pad = tensor.pad %fill low[4, 1] high[8, 2] {
+ ^bb0(%arg1: index, %arg2: index):
+ tensor.yield %f1 : f32
+ } : tensor<400x273xf32> to tensor<412x276xf32>
+ return %pad : tensor<412x276xf32>
+}
More information about the Mlir-commits
mailing list