[Mlir-commits] [mlir] [mlir][tosa] Fold PadOp to tensor operations (PR #132700)
Georgios Pinitas
llvmlistbot at llvm.org
Tue Apr 8 06:12:59 PDT 2025
https://github.com/GeorgeARM updated https://github.com/llvm/llvm-project/pull/132700
>From ffc7db962b7780d2656a7deefbec72319367dd52 Mon Sep 17 00:00:00 2001
From: Georgios Pinitas <georgios.pinitas at arm.com>
Date: Sat, 22 Mar 2025 05:41:48 +0000
Subject: [PATCH] [mlir][tosa] Fold PadOp to tensor operations
Add a canonicalizer to enable folding of explicit padding operations to
implicit padding attributes of tensor operations.
This enables folding to the following operations:
- Conv2d
- DepthwiseConv2d
- AvgPool2d
- MaxPool2d
Signed-off-by: Georgios Pinitas <georgios.pinitas at arm.com>
Co-authored-by: Rob-Hughes-Arm <robert.hughes at arm.com>
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 5 +
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 302 ++++++++++++++++--
mlir/test/Dialect/Tosa/canonicalize.mlir | 152 +++++++++
3 files changed, 424 insertions(+), 35 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 741de84cc5840..543180e68190f 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -107,6 +107,7 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
LogicalResult verifyOutputZeroPoint(int64_t zp);
}];
+ let hasCanonicalizer = 1;
let hasVerifier = 1;
}
@@ -153,6 +154,8 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
}];
let builders = [Tosa_ConvOpQuantInfoBuilder];
+
+ let hasCanonicalizer = 1;
let hasVerifier = 1;
}
@@ -244,6 +247,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
}];
let builders = [Tosa_ConvOpQuantInfoBuilder];
+
+ let hasCanonicalizer = 1;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 09d2c5d35263c..6fab358c2777c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -39,6 +39,273 @@ using namespace mlir::tosa;
// Operator Canonicalizers.
//===----------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
+// Tensor Data Engine Operators.
+//===----------------------------------------------------------------------===//
+
+// Check that the zero point of the tensor and padding operations are aligned.
+bool checkMatchingPadConstAndZp(Value padConst, Value zp) {
+ // Check that padConst is a constant value and a scalar tensor
+ DenseElementsAttr padConstAttr;
+ if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
+ (padConstAttr.size() != 1)) {
+ return false;
+ }
+
+ // Check that floating point pad is zero
+ if (auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
+ float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
+ return padConstVal == 0.0f;
+ }
+
+ // Check that the zp and padConst align for the integer (quantized) case
+ if (auto padConstIntAttr =
+ mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
+ DenseIntElementsAttr zpAttr;
+ // Check that zp is a constant value and a scalar tensor
+ if (!matchPattern(zp, m_Constant(&zpAttr)) || (padConstAttr.size() != 1)) {
+ return false;
+ }
+
+ // Check equality
+ int64_t zpVal = (*zpAttr.begin()).getSExtValue();
+ int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
+ return zpVal == padConstVal;
+ }
+
+ // Bail-out on unsupported type
+ return false;
+}
+
+namespace {
+template <typename OpTy>
+struct PoolPadFoldAdaptor;
+
+template <>
+struct PoolPadFoldAdaptor<tosa::AvgPool2dOp> {
+ using OpTy = tosa::AvgPool2dOp;
+ static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
+ const llvm::ArrayRef<int64_t> kernel = op.getKernel();
+ if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
+ newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
+ return false;
+ return true;
+ }
+ static bool checkPadConstCompliance(OpTy op, Value padConst) {
+ return checkMatchingPadConstAndZp(padConst, op.getInputZp());
+ }
+ static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
+ Value padInput, ArrayRef<int64_t> newPad) {
+ rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>(
+ op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(),
+ op.getKernel(), op.getStride(), rewriter.getDenseI64ArrayAttr(newPad),
+ op.getAccType());
+ }
+};
+
+template <>
+struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
+ using OpTy = tosa::MaxPool2dOp;
+ static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
+ const llvm::ArrayRef<int64_t> kernel = op.getKernel();
+ if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
+ newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
+ return false;
+ return true;
+ }
+ static bool checkPadConstCompliance(OpTy, Value padConst) {
+ // Check that padConst is a constant value and a scalar tensor
+ DenseElementsAttr padConstAttr;
+ if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
+ padConstAttr.size() != 1) {
+ return false;
+ }
+
+ // Pad needs to be in the minimum value to be able to merge
+ if (auto padConstFpAttr =
+ mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
+ float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
+ return padConstVal == std::numeric_limits<float>::lowest();
+ } else if (auto padConstIntAttr =
+ mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
+ int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
+ return padConstVal == std::numeric_limits<int32_t>::lowest();
+ }
+
+ // Bail-out on unsupported type
+ return false;
+ }
+ static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
+ Value padInput, ArrayRef<int64_t> newPad) {
+ rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
+ op, op.getType(), padInput, op.getKernel(), op.getStride(),
+ rewriter.getDenseI64ArrayAttr(newPad), op.getNanMode());
+ }
+};
+
+template <typename OpTy>
+struct ConvPadFoldAdaptor {
+ static bool checkKernelCompliance(OpTy, const ArrayRef<int64_t>) {
+ return true;
+ }
+ static bool checkPadConstCompliance(OpTy op, Value padConst) {
+ return checkMatchingPadConstAndZp(padConst, op.getInputZp());
+ }
+ static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
+ Value padInput, ArrayRef<int64_t> newPad) {
+ rewriter.replaceOpWithNewOp<OpTy>(
+ op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
+ op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
+ op.getDilationAttr(), op.getAccType(), op.getLocalBound());
+ }
+};
+
+// Pattern attempts to fold a `tosa.pad` operator to a following tensor
+// operation like `tosa.conv2d` by merging the padding associated with the
+// pad operator directly to the implicit padding of the tensor operation.
+// This helps eliminate the explicit padding operator if unused.
+template <typename OpTy, typename AdaptorTy>
+struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy tensorOp,
+ PatternRewriter &rewriter) const override {
+ // Check producer is a tosa::PadOp
+ auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
+ if (!padOp)
+ return rewriter.notifyMatchFailure(tensorOp,
+ "Producer must be a tosa::PadOp.");
+
+ // Validate that tensor operation has sane padding
+ const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
+ if (tensorOpPad.size() != 4) // pad_top, pad_bottom, pad_left, pad_right
+ return rewriter.notifyMatchFailure(
+ tensorOp, "Tensor operation padding shall have 4 elements.");
+
+ // Validate tosa::PadOp padding
+ DenseIntElementsAttr padOpPadding;
+ if (!matchPattern(padOp.getPadding(), m_Constant(&padOpPadding))) {
+ return rewriter.notifyMatchFailure(
+ tensorOp,
+ "The `padding` input specified on the tosa::PadOp must be constant.");
+ }
+ // N_before, N_after, H_before, H_after, W_before, W_after, C_before,
+ // C_after
+ if (padOpPadding.size() != 8)
+ return rewriter.notifyMatchFailure(tensorOp,
+ "Pad padding should have 8 elements.");
+ int64_t padNBefore = (*(padOpPadding.begin() + 0)).getLimitedValue();
+ int64_t padNAfter = (*(padOpPadding.begin() + 1)).getLimitedValue();
+ int64_t padHBefore = (*(padOpPadding.begin() + 2)).getLimitedValue();
+ int64_t padHAfter = (*(padOpPadding.begin() + 3)).getLimitedValue();
+ int64_t padWBefore = (*(padOpPadding.begin() + 4)).getLimitedValue();
+ int64_t padWAfter = (*(padOpPadding.begin() + 5)).getLimitedValue();
+ int64_t padCBefore = (*(padOpPadding.begin() + 6)).getLimitedValue();
+ int64_t padCAfter = (*(padOpPadding.begin() + 7)).getLimitedValue();
+
+ if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
+ return rewriter.notifyMatchFailure(
+ tensorOp, "Folding padding in N or C dimensions is not supported.");
+
+ // Fold padding from Pad into the tensor operation
+ // 4 elements - pad_top, pad_bottom, pad_left, pad_right
+ SmallVector<int64_t> foldedPad(tensorOpPad.size());
+ foldedPad[0] = padHBefore + tensorOpPad[0];
+ foldedPad[1] = padHAfter + tensorOpPad[1];
+ foldedPad[2] = padWBefore + tensorOpPad[2];
+ foldedPad[3] = padWAfter + tensorOpPad[3];
+
+ // Check kernel related restrictions
+ if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
+ return rewriter.notifyMatchFailure(
+ tensorOp, "Padding size not aligned with kernel restrictions.");
+ }
+
+ // Check padding constant restrictions
+ if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
+ return rewriter.notifyMatchFailure(
+ tensorOp,
+ "Padding constant is not aligned with operator zero-point.");
+ }
+
+ // Check that padding doesn't grow more than 8K level (8192) for now
+ if (llvm::any_of(foldedPad, [](int64_t padVal) { return padVal > 8192; })) {
+ return rewriter.notifyMatchFailure(
+ tensorOp, "Padding size more than the 8K level limit.");
+ }
+
+ // Create operator
+ AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
+ foldedPad);
+
+ return success();
+ }
+};
+} // namespace
+
+void AvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<FoldPadToTensorOp<tosa::AvgPool2dOp,
+ PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>(
+ context);
+}
+
+void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<
+ FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
+ context);
+}
+
+void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
+ ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
+ context);
+}
+
+struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
+ PatternRewriter &rewriter) const override {
+ Value input = op.getInput();
+ Value output = op.getOutput();
+ ShapedType inputType = llvm::cast<ShapedType>(input.getType());
+ ShapedType outputType = llvm::cast<ShapedType>(output.getType());
+
+ if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
+ return failure();
+ }
+
+ // If the output and input shapes are 1x1, then this is a no op.
+ ArrayRef<int64_t> outputShape = outputType.getShape();
+ if (outputShape[1] != 1 || outputShape[2] != 1) {
+ return failure();
+ }
+
+ ArrayRef<int64_t> inputShape = inputType.getShape();
+ if (inputShape[1] != 1 || inputShape[2] != 1) {
+ return failure();
+ }
+
+ rewriter.replaceOp(op, input);
+ return success();
+ }
+};
+
+void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<MaxPool2dIsNoOp,
+ FoldPadToTensorOp<tosa::MaxPool2dOp,
+ PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
+ context);
+}
+
+//===----------------------------------------------------------------------===//
+// Data Layout / Memory Reinterpretation.
+//===----------------------------------------------------------------------===//
+
struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
@@ -175,41 +442,6 @@ void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
}
-struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
- PatternRewriter &rewriter) const override {
- Value input = op.getInput();
- Value output = op.getOutput();
- ShapedType inputType = llvm::cast<ShapedType>(input.getType());
- ShapedType outputType = llvm::cast<ShapedType>(output.getType());
-
- if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
- return failure();
- }
-
- // If the output and input shapes are 1x1, then this is a no op.
- ArrayRef<int64_t> outputShape = outputType.getShape();
- if (outputShape[1] != 1 || outputShape[2] != 1) {
- return failure();
- }
-
- ArrayRef<int64_t> inputShape = inputType.getShape();
- if (inputShape[1] != 1 || inputShape[2] != 1) {
- return failure();
- }
-
- rewriter.replaceOp(op, input);
- return success();
- }
-};
-
-void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<MaxPool2dIsNoOp>(context);
-}
-
struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
using OpRewritePattern::OpRewritePattern;
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 077a6cee0a1bb..3a0985f6e1868 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -9,6 +9,158 @@ func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<1xi32> {
// -----
+// CHECK-LABEL: @pad_wh_avg_pool2d_fold
+func.func @pad_wh_avg_pool2d_fold(%input: tensor<1x10x8x3xf32>) -> tensor<1x6x5x3xf32> {
+ // CHECK-NOT: tosa.pad
+ // CHECK: tosa.avg_pool2d
+ // CHECK-SAME: pad = array<i64: 1, 1, 1, 1>
+ %pad_shape = tosa.const_shape { values = dense<[0, 0, 1, 0, 1, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
+ %pad_const = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
+ %output_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
+ %padded = tosa.pad %input, %pad_shape, %pad_const : (tensor<1x10x8x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x11x9x3xf32>
+ %pool = tosa.avg_pool2d %padded, %input_zp, %output_zp {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 2, 2>} : (tensor<1x11x9x3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x6x5x3xf32>
+ return %pool : tensor<1x6x5x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @pad_wh_avg_pool2d_nofold_pad_const
+func.func @pad_wh_avg_pool2d_nofold_pad_const(%input: tensor<1x10x8x3xi8>) -> tensor<1x6x5x3xi8> {
+ // CHECK: tosa.pad
+ // CHECK: tosa.avg_pool2d
+ // CHECK-SAME: pad = array<i64: 0, 1, 0, 1>
+ %pad_shape = tosa.const_shape { values = dense<[0, 0, 1, 0, 1, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
+ %pad_const = "tosa.const"() <{values = dense<15> : tensor<1xi8>}> : ()-> tensor<1xi8>
+ %input_zp = "tosa.const"() <{values = dense<10> : tensor<1xi8>}> : ()-> tensor<1xi8>
+ %output_zp = "tosa.const"() <{values = dense<20> : tensor<1xi8>}> : ()-> tensor<1xi8>
+ %padded = tosa.pad %input, %pad_shape, %pad_const : (tensor<1x10x8x3xi8>, !tosa.shape<8>, tensor<1xi8>) -> tensor<1x11x9x3xi8>
+ %pool = tosa.avg_pool2d %padded, %input_zp, %output_zp {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 2, 2>} : (tensor<1x11x9x3xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x6x5x3xi8>
+ return %pool : tensor<1x6x5x3xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @pad_wh_avg_pool2d_nofold_pad_larger_than_kernel
+func.func @pad_wh_avg_pool2d_nofold_pad_larger_than_kernel(%input: tensor<1x10x8x3xf32>) -> tensor<1x7x5x3xf32> {
+ // CHECK: tosa.pad
+ // CHECK: tosa.avg_pool2d
+ %pad_shape = tosa.const_shape { values = dense<[0, 0, 3, 0, 1, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
+ %pad_const = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
+ %output_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
+ %padded = tosa.pad %input, %pad_shape, %pad_const : (tensor<1x10x8x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x13x9x3xf32>
+ %pool = tosa.avg_pool2d %padded, %input_zp, %output_zp {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 2, 2>} : (tensor<1x13x9x3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x7x5x3xf32>
+ return %pool : tensor<1x7x5x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @pad_wh_conv2d_fold
+func.func @pad_wh_conv2d_fold(%input: tensor<1x8x4x3xf32>, %weight: tensor<1x3x3x3xf32>, %bias: tensor<1xf32>) -> tensor<1x10x8x1xf32> {
+ // CHECK-NOT: tosa.pad
+ // CHECK: tosa.conv2d
+ // CHECK-SAME: pad = array<i64: 2, 2, 3, 3>
+ %pad_shape = tosa.const_shape { values = dense<[0, 0, 1, 1, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
+ %pad_const = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
+ %padded = tosa.pad %input, %pad_shape, %pad_const : (tensor<1x8x4x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x10x8x3xf32>
+ %conv = tosa.conv2d %padded, %weight, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<1x10x8x3xf32>, tensor<1x3x3x3xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x10x8x1xf32>
+ return %conv : tensor<1x10x8x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @pad_bwh_conv2d_nofold
+func.func @pad_bwh_conv2d_nofold(%input: tensor<1x8x4x3xf32>, %weight: tensor<1x3x3x3xf32>, %bias: tensor<1xf32>) -> tensor<3x10x8x1xf32> {
+ // CHECK: tosa.pad
+ // CHECK: tosa.conv2d
+ // CHECK-SAME: pad = array<i64: 1, 1, 1, 1>
+ %pad_shape = tosa.const_shape { values = dense<[1, 1, 1, 1, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
+ %pad_const = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
+ %padded = tosa.pad %input, %pad_shape, %pad_const : (tensor<1x8x4x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<3x10x8x3xf32>
+ %conv = tosa.conv2d %padded, %weight, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<3x10x8x3xf32>, tensor<1x3x3x3xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3x10x8x1xf32>
+ return %conv : tensor<3x10x8x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @pad_wh_conv2d_nofold_pad_const
+func.func @pad_wh_conv2d_nofold_pad_const(%input: tensor<1x8x4x3xf32>, %weight: tensor<1x3x3x3xf32>, %bias: tensor<1xf32>) -> tensor<1x10x8x1xf32> {
+ // CHECK: tosa.pad
+ // CHECK: tosa.conv2d
+ // CHECK-SAME: pad = array<i64: 1, 1, 1, 1>
+ %pad_shape = tosa.const_shape { values = dense<[0, 0, 1, 1, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
+ %pad_const = "tosa.const"() <{values = dense<1.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
+ %padded = tosa.pad %input, %pad_shape, %pad_const : (tensor<1x8x4x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x10x8x3xf32>
+ %conv = tosa.conv2d %padded, %weight, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<1x10x8x3xf32>, tensor<1x3x3x3xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x10x8x1xf32>
+ return %conv : tensor<1x10x8x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @pad_wh_depthwise_conv2d_fold
+func.func @pad_wh_depthwise_conv2d_fold(%input: tensor<1x8x4x3xf32>, %weight: tensor<3x3x3x1xf32>, %bias: tensor<3xf32>) -> tensor<1x10x8x3xf32> {
+ // CHECK-NOT: tosa.pad
+ // CHECK: tosa.depthwise_conv2d
+ // CHECK-SAME: pad = array<i64: 2, 2, 3, 3>
+ %pad_shape = tosa.const_shape { values = dense<[0, 0, 1, 1, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
+ %pad_const = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
+ %padded = tosa.pad %input, %pad_shape, %pad_const : (tensor<1x8x4x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x10x8x3xf32>
+ %conv = tosa.depthwise_conv2d %padded, %weight, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<1x10x8x3xf32>, tensor<3x3x3x1xf32>, tensor<3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x10x8x3xf32>
+ return %conv : tensor<1x10x8x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @pad_wh_max_pool2d_fold
+func.func @pad_wh_max_pool2d_fold(%input: tensor<1x10x8x3xf32>) -> tensor<1x6x5x3xf32> {
+ // CHECK-NOT: tosa.pad
+ // CHECK: tosa.max_pool2d
+ // CHECK-SAME: pad = array<i64: 1, 1, 1, 1>
+ %pad_shape = tosa.const_shape { values = dense<[0, 0, 1, 0, 1, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
+ %pad_const = "tosa.const"() <{values = dense<-3.4028235e+38> : tensor<1xf32>}> : ()-> tensor<1xf32>
+ %padded = tosa.pad %input, %pad_shape, %pad_const : (tensor<1x10x8x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x11x9x3xf32>
+ %pool = tosa.max_pool2d %padded {kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 2, 2>} : (tensor<1x11x9x3xf32>) -> tensor<1x6x5x3xf32>
+ return %pool : tensor<1x6x5x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @pad_wh_max_pool2d_nofold_pad_const
+func.func @pad_wh_max_pool2d_nofold_pad_const(%input: tensor<1x10x8x3xf32>) -> tensor<1x6x5x3xf32> {
+ // CHECK: tosa.pad
+ // CHECK: tosa.max_pool2d
+ // CHECK-SAME: pad = array<i64: 0, 1, 0, 1>
+ %pad_shape = tosa.const_shape { values = dense<[0, 0, 1, 0, 1, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
+ %pad_const = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
+ %padded = tosa.pad %input, %pad_shape, %pad_const : (tensor<1x10x8x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x11x9x3xf32>
+ %pool = tosa.max_pool2d %padded {kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 2, 2>} : (tensor<1x11x9x3xf32>) -> tensor<1x6x5x3xf32>
+ return %pool : tensor<1x6x5x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @pad_wh_max_pool2d_no_fold_8k_limit
+func.func @pad_wh_max_pool2d_no_fold_8k_limit(%input: tensor<1x10x8x3xf32>) -> tensor<1x6x4101x3xf32> {
+ // CHECK: tosa.pad
+ // CHECK: tosa.max_pool2d
+ %pad_shape = tosa.const_shape { values = dense<[0, 0, 1, 0, 8193, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
+ %pad_const = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
+ %padded = tosa.pad %input, %pad_shape, %pad_const : (tensor<1x10x8x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x11x8201x3xf32>
+ %pool = tosa.max_pool2d %padded {kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 2, 2>} : (tensor<1x11x8201x3xf32>) -> tensor<1x6x4101x3xf32>
+ return %pool : tensor<1x6x4101x3xf32>
+}
+
+// -----
+
// CHECK-LABEL: @add_bcast_zero_int
func.func @add_bcast_zero_int(%arg0: tensor<4x2x3xi32>) -> tensor<4x2x3xi32> {
// CHECK-NOT: tosa.add
More information about the Mlir-commits
mailing list