[Mlir-commits] [mlir] [mlir][tosa] Always generated pad_const and remove input_zp attr for PadOp (PR #129336)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 28 15:17:06 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
Author: Jerry-Ge (Jerry-Ge)
<details>
<summary>Changes</summary>
Always generated pad_const and remove input_zp attr for PadOp
---
Patch is 36.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129336.diff
14 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (-8)
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h (+4)
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+2-5)
- (modified) mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp (+2-19)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (-47)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+23-22)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (+14-20)
- (modified) mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir (+15-10)
- (modified) mlir/test/Dialect/Tosa/availability.mlir (+2-1)
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+14-8)
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+10-7)
- (modified) mlir/test/Dialect/Tosa/ops.mlir (-8)
- (modified) mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir (+25-24)
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+6-4)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index ce17ad9362227..15def695f6a54 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -197,14 +197,6 @@ def Tosa_PadOpQuantInfoBuilder : OpBuilder<
input, paddings);
}]>;
-def Tosa_ExplicitValuePadOpQuantInfoBuilder : OpBuilder<
- (ins "Type":$outputType, "Value":$input, "Value":$paddings,
- "Value":$pad_value),
- [{
- buildExplicitValuePadOpWithQuantInfo($_builder, $_state, outputType,
- input, paddings, pad_value);
- }]>;
-
// Wrapper over base I32EnumAttr to set common fields.
class Tosa_I32Enum<string name, string description, list<I32EnumAttrCase> cases>
: I32EnumAttr<name, description, cases> {
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 344a54f0bb1c9..f0797f97fd842 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -168,6 +168,10 @@ namespace tosa {
std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
Type srcElemType, int64_t zp = 0);
+// Create a pad-const const tensor with value of `val` of required data-type
+std::optional<Value> createPadConstTensor(OpBuilder &builder, Location loc,
+ Value src, int32_t val = 0);
+
} // namespace tosa
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index abdd8347cb2b5..aedea883396f8 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1882,8 +1882,7 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
let arguments = (ins
Tosa_RankedTensor:$input1,
Tosa_Shape:$padding,
- Optional<Tosa_ScalarTensor>:$pad_const,
- OptionalAttr<I32Attr>:$input_zp
+ Tosa_ScalarTensor:$pad_const
);
let results = (outs
@@ -1895,10 +1894,8 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
];
- let builders = [Tosa_PadOpQuantInfoBuilder,
- Tosa_ExplicitValuePadOpQuantInfoBuilder];
+ let builders = [Tosa_PadOpQuantInfoBuilder];
- let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 7f029d56e2582..6a65904272991 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -350,29 +350,12 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
}
ShapedType inputTy = cast<ShapedType>(input.getType());
- Type elementTy = inputTy.getElementType();
int64_t rank = inputTy.getRank();
// Setup the default constantAttr.
- Value padConstant;
-
- if (padOp.getPadConst()) {
- padConstant = rewriter.createOrFold<tensor::ExtractOp>(
- loc, padOp.getPadConst(), ValueRange({}));
- } else {
- TypedAttr constantAttr;
- if (isa<FloatType>(elementTy)) {
- constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
- } else if (isa<IntegerType>(elementTy) && !padOp.getInputZpAttr()) {
- constantAttr = rewriter.getIntegerAttr(elementTy, 0);
- } else if (isa<IntegerType>(elementTy) && padOp.getInputZpAttr()) {
- int64_t value = padOp.getInputZpAttr().getInt();
- constantAttr = rewriter.getIntegerAttr(elementTy, value);
- }
- if (constantAttr)
- padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr);
- }
+ Value padConstant = rewriter.createOrFold<tensor::ExtractOp>(
+ loc, padOp.getPadConst(), ValueRange({}));
if (!padConstant) {
return rewriter.notifyMatchFailure(
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 363b5958bc0fd..2c0376134b599 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -175,53 +175,6 @@ void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
}
-struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tosa::PadOp op,
- PatternRewriter &rewriter) const override {
- if (op.getPadConst())
- return failure();
-
- auto input = op.getInput1();
- auto padding = op.getPadding();
-
- ShapedType inputTy = llvm::cast<ShapedType>(input.getType());
- Type elementTy = inputTy.getElementType();
-
- Attribute constantAttr;
- if (llvm::isa<FloatType>(elementTy)) {
- constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
- } else if (llvm::isa<IntegerType>(elementTy) && !op.getInputZpAttr()) {
- constantAttr = rewriter.getIntegerAttr(elementTy, 0);
- } else if (llvm::isa<IntegerType>(elementTy) && op.getInputZpAttr()) {
- int64_t value = op.getInputZpAttr().getInt();
- constantAttr = rewriter.getIntegerAttr(elementTy, value);
- }
-
- if (!constantAttr) {
- return rewriter.notifyMatchFailure(
- op,
- "tosa.pad to linalg lowering encountered an unknown element type");
- }
-
- auto denseAttr = DenseElementsAttr::get(
- RankedTensorType::get({1}, elementTy), constantAttr);
- auto constantVal = rewriter.create<tosa::ConstOp>(
- op.getLoc(), denseAttr.getType(), denseAttr);
-
- rewriter.replaceOpWithNewOp<tosa::PadOp>(
- op, op.getType(), ValueRange{input, padding, constantVal},
- op->getAttrs());
- return success();
- }
-};
-
-void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<MaterializePadValue>(context);
-}
-
struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
using OpRewritePattern::OpRewritePattern;
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 54f9fa917f2e0..a76a687c3f1eb 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -214,6 +214,23 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
}
}
+// Create a pad-const const tensor with value of `val` of required data-type
+std::optional<Value> mlir::tosa::createPadConstTensor(OpBuilder &builder,
+ Location loc, Value src,
+ int32_t val) {
+ auto const srcType = getElementTypeOrSelf(src);
+ auto const srcElemType = getElementTypeOrSelf(src);
+ auto const padConstType = mlir::RankedTensorType::get({1}, srcType);
+ auto const padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
+ auto const pad_const_attr{
+ llvm::isa<FloatType>(srcElemType)
+ ? DenseElementsAttr::get(padConstEType,
+ builder.getFloatAttr(srcElemType, val))
+ : DenseElementsAttr::get(padConstEType,
+ builder.getIntegerAttr(srcElemType, val))};
+ return builder.create<tosa::ConstOp>(loc, padConstType, pad_const_attr);
+}
+
//===----------------------------------------------------------------------===//
// Tosa utilities.
//===----------------------------------------------------------------------===//
@@ -679,30 +696,14 @@ static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
Type outputType, Value input,
Value paddings) {
- result.addOperands({input, paddings});
- auto quantAttr = buildPadOpQuantizationAttr(builder, input);
+ const Location loc{result.location};
+ int32_t zp{0};
+ auto const quantAttr = buildPadOpQuantizationAttr(builder, input);
if (quantAttr) {
- result.addAttribute("input_zp",
- builder.getI32IntegerAttr(
- static_cast<int32_t>(quantAttr.getInputZp())));
- }
- result.types.push_back(outputType);
-}
-
-/// This builder is called on TOSA pad operator when an explicit pad_const
-/// value is passed in. It also optionally constructs quantization_attr.
-static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder,
- OperationState &result,
- Type outputType, Value input,
- Value paddings,
- Value padConst) {
- result.addOperands({input, paddings, padConst});
- auto quantAttr = buildPadOpQuantizationAttr(builder, input);
- if (quantAttr) {
- result.addAttribute("input_zp",
- builder.getI32IntegerAttr(
- static_cast<int32_t>(quantAttr.getInputZp())));
+ zp = static_cast<int32_t>(quantAttr.getInputZp());
}
+ auto const pad_const_op{createPadConstTensor(builder, loc, input, zp)};
+ result.addOperands({input, paddings, pad_const_op.value()});
result.types.push_back(outputType);
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 83bdbce5d1857..b629c3e7df510 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -148,16 +148,16 @@ class TransposeConvStridedConverter
return rewriter.notifyMatchFailure(
op, "zero point must be zero for non-int8 integer types");
- if (weightZpVal != 0) {
- weight = CreateOpAndInferShape<tosa::PadOp>(
- rewriter, loc, UnrankedTensorType::get(weightETy), weight,
- weightPaddingVal, nullptr, rewriter.getI32IntegerAttr(weightZpVal));
-
- } else {
- weight = CreateOpAndInferShape<tosa::PadOp>(
- rewriter, loc, UnrankedTensorType::get(weightETy), weight,
- weightPaddingVal);
- }
+ // construct pad_const values from zp values
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ Value const inputPadConst =
+ createPadConstTensor(builder, op->getLoc(), input, inputZpVal).value();
+ Value const weightPadConst =
+ createPadConstTensor(builder, op->getLoc(), input, weightZpVal).value();
+
+ weight = CreateOpAndInferShape<tosa::PadOp>(
+ rewriter, loc, UnrankedTensorType::get(weightETy), weight,
+ weightPaddingVal, weightPadConst);
weightTy = cast<ShapedType>(weight.getType());
weightHeight = weightTy.getDimSize(1);
@@ -169,7 +169,7 @@ class TransposeConvStridedConverter
stride[0], weightWidth / stride[1],
stride[1], inputChannels};
- ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+
weight = CreateOpAndInferShape<tosa::ReshapeOp>(
builder, UnrankedTensorType::get(weightETy), weight,
getTosaConstShape(rewriter, loc, weightReshapeDims0));
@@ -206,15 +206,9 @@ class TransposeConvStridedConverter
Value inputPaddingVal =
getTosaConstShape(rewriter, op->getLoc(), inputPadding);
- if (inputZpVal != 0) {
- input = CreateOpAndInferShape<tosa::PadOp>(
- rewriter, loc, UnrankedTensorType::get(inputETy), input,
- inputPaddingVal, nullptr, rewriter.getI32IntegerAttr(inputZpVal));
- } else {
- input = CreateOpAndInferShape<tosa::PadOp>(
- rewriter, loc, UnrankedTensorType::get(inputETy), input,
- inputPaddingVal);
- }
+ input = CreateOpAndInferShape<tosa::PadOp>(
+ rewriter, loc, UnrankedTensorType::get(inputETy), input,
+ inputPaddingVal, inputPadConst);
// We use a zero bias as we need to broadcast the bias.
auto zeroBias = rewriter.create<tosa::ConstOp>(
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index 6b7f622d3303f..c7a689f5a9ae9 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -498,35 +498,38 @@ func.func @slice_dyn(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
func.func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
%0 = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %pad_const = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
// CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
// CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
- // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
+ // CHECK-DAG: [[CST:%.+]] = arith.constant 3.140000e+00 : f32
// CHECK: tensor.pad %[[ARG0]] low{{\[}}[[INDEX1]], [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
// CHECK: tensor.yield [[CST]]
// CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
- %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, !tosa.shape<4>) -> (tensor<4x9xf32>)
+ %1 = "tosa.pad"(%arg0, %0, %pad_const) : (tensor<1x2xf32>, !tosa.shape<4>, tensor<1xf32>) -> (tensor<4x9xf32>)
return %1 : tensor<4x9xf32>
}
// -----
func.func @pad_int(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
%0 = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
- // CHECK: [[CST:%.+]] = arith.constant 0 : i32
+ %pad_const = "tosa.const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+ // CHECK: [[CST:%.+]] = arith.constant 3 : i32
// CHECK: tensor.pad
// CHECK: tensor.yield [[CST]]
- %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xi32>, !tosa.shape<4>) -> (tensor<4x9xi32>)
+ %1 = "tosa.pad"(%arg0, %0, %pad_const) : (tensor<1x2xi32>, !tosa.shape<4>, tensor<1xi32>) -> (tensor<4x9xi32>)
return %1 : tensor<4x9xi32>
}
// -----
func.func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
%0 = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
- // CHECK: [[CST:%.+]] = arith.constant 42 : i32
+ %pad_const = "tosa.const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ // CHECK: [[CST:%.+]] = arith.constant 0 : i32
// CHECK: tensor.pad
// CHECK: tensor.yield [[CST]]
- %1 = "tosa.pad"(%arg0, %0) {input_zp = 42 : i32} : (tensor<1x2xi32>, !tosa.shape<4>) -> (tensor<4x9xi32>)
+ %1 = "tosa.pad"(%arg0, %0, %pad_const) {input_zp = 42 : i32} : (tensor<1x2xi32>, !tosa.shape<4>, tensor<1xi32>) -> (tensor<4x9xi32>)
return %1 : tensor<4x9xi32>
}
@@ -551,30 +554,32 @@ func.func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
func.func @pad_dyn_input(%arg0 : tensor<?x2xf32>) -> (tensor<?x9xf32>) {
%0 = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %pad_const = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
// CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
// CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
- // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
+ // CHECK-DAG: [[CST:%.+]] = arith.constant 3.140000e+00 : f32
// CHECK: tensor.pad %[[ARG0]] low{{\[}}[[INDEX1]], [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
// CHECK: tensor.yield [[CST]]
// CHECK: } : tensor<?x2xf32> to tensor<?x9xf32>
- %1 = "tosa.pad"(%arg0, %0) : (tensor<?x2xf32>, !tosa.shape<4>) -> (tensor<?x9xf32>)
+ %1 = "tosa.pad"(%arg0, %0, %pad_const) : (tensor<?x2xf32>, !tosa.shape<4>, tensor<1xf32>) -> (tensor<?x9xf32>)
return %1 : tensor<?x9xf32>
}
// -----
func.func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor<?x9xf32>) {
%0 = tosa.const_shape {value = dense<[-1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %pad_const = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
// CHECK-DAG: [[INDEX1:%.+]] = arith.constant -1 : index
// CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
// CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
- // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
+ // CHECK-DAG: [[CST:%.+]] = arith.constant 3.140000e+00 : f32
// CHECK: tensor.pad %[[ARG0]] low{{\[}}[[INDEX1]], [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
// CHECK: tensor.yield [[CST]]
// CHECK: } : tensor<1x2xf32> to tensor<?x9xf32>
- %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, !tosa.shape<4>) -> (tensor<?x9xf32>)
+ %1 = "tosa.pad"(%arg0, %0, %pad_const) : (tensor<1x2xf32>, !tosa.shape<4>, tensor<1xf32>) -> (tensor<?x9xf32>)
return %1 : tensor<?x9xf32>
}
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 7324b0ea52e89..4203132e9f702 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -512,9 +512,10 @@ func.func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -
// CHECK-LABEL: pad
func.func @test_pad(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%padding = tosa.const_shape {value = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %pad_const = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
// CHECK: profiles: [ [pro_int, pro_fp] ]
// CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
- %0 = tosa.pad %arg0, %padding : (tensor<13x21x3xf32>, !tosa.shape<6>) -> tensor<13x21x3xf32>
+ %0 = tosa.pad %arg0, %padding, %pad_const : (tensor<13x21x3xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 175145f332f8e..f7874aaebee21 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -258,7 +258,8 @@ func.func @max_pool2d_is_noop(%arg0: tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf3
func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: return %arg0
%0 = tosa.const_shape { value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
- %1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, !tosa.shape<4>) -> tensor<?x?xf32>
+ %pad_const = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
+ %1 = tosa.pad %arg0, %0, %pad_const : (tensor<?x?xf32>, !tosa.shape<4>, tensor<1xf32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
@@ -269,7 +270,8 @@ func.func @pad_noop_padding_mismatch_nofold(%arg0: tensor<?x?xf32>) -> tensor<?x
// CHECK: %[[PAD:.+]] = tosa.pad
// CHECK: return %[[PAD]]
%shape = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
- %1 = tosa.pad %arg0, %shape : (tensor<?x?xf32>, !tosa.shape<4>) -> tensor<?x?xf32>
+ %pad_const = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
+ %1 = tosa.pad %arg0, %shape, %pad_const : (tensor<?x?xf32>, !tosa.shape<4>, tensor<1xf32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
@@ -280,7 +282,8 @@ func.func @pad_noop_type_mismatch_nofold(%arg0: tensor<10xf32>) -> tensor<?xf32>
// CHECK: %[[PAD:.+]] = tosa.pad
// CHECK: return %[[PAD]]
%shape = tosa.const_shape { value = dense<[1, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
- %0 = tosa.pad %arg0, %shape : (tensor<10xf32>, !tosa.shape<2>) -> tensor<?xf32>
+ %pad_const = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
+ %0 = tosa.pad %arg0, %shape, %pad_const : (tensor<10xf32>, !tosa.shape<2>, tensor<1xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@@ -291,8 +294,9 @@ func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>
// CHECK-DAG: %[[ZERO:.+]...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/129336
More information about the Mlir-commits
mailing list