[Mlir-commits] [mlir] 2a19625 - mlir/tosa: move tosa.pad from Linalg to Tensor conversion
Ramkumar Ramachandra
llvmlistbot at llvm.org
Mon Dec 5 22:40:06 PST 2022
Author: Ramkumar Ramachandra
Date: 2022-12-06T07:39:29+01:00
New Revision: 2a1962542423f62928a4c2e6cf42e97b190de49d
URL: https://github.com/llvm/llvm-project/commit/2a1962542423f62928a4c2e6cf42e97b190de49d
DIFF: https://github.com/llvm/llvm-project/commit/2a1962542423f62928a4c2e6cf42e97b190de49d.diff
LOG: mlir/tosa: move tosa.pad from Linalg to Tensor conversion
Since tosa.pad is lowered strictly to artih and tensor ops, move
ConvertPad from TosaToLinalg to TosaToTensor, benefitting non-Linalg
Tosa targets. TensorToLinalg exists, and is trivial, so nothing is lost.
Signed-off-by: Ramkumar Ramachandra <r at artagnon.com>
Differential Revision: https://reviews.llvm.org/D139091
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index ade94e1ce8aed..3c74da1e939d9 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1932,81 +1932,6 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
}
};
-class PadConverter : public OpRewritePattern<tosa::PadOp> {
-public:
- using OpRewritePattern<tosa::PadOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tosa::PadOp padOp,
- PatternRewriter &rewriter) const final {
- auto loc = padOp.getLoc();
- auto input = padOp.getInput1();
- auto padding = padOp.getPadding();
-
- ShapedType inputTy = input.getType().cast<ShapedType>();
- Type elementTy = inputTy.getElementType();
- int64_t rank = inputTy.getRank();
-
- // Setup the default constantAttr.
-
- Value padConstant;
-
- if (padOp.getPadConst()) {
- padConstant = rewriter.createOrFold<tensor::ExtractOp>(
- loc, padOp.getPadConst(), ValueRange({}));
- } else {
- Attribute constantAttr;
- if (elementTy.isa<FloatType>()) {
- constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
- } else if (elementTy.isa<IntegerType>() && !padOp.getQuantizationInfo()) {
- constantAttr = rewriter.getIntegerAttr(elementTy, 0);
- } else if (elementTy.isa<IntegerType>() && padOp.getQuantizationInfo()) {
- int64_t value = padOp.getQuantizationInfo()->getInputZp();
- constantAttr = rewriter.getIntegerAttr(elementTy, value);
- }
- if (constantAttr)
- padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr);
- }
-
- if (!padConstant) {
- return rewriter.notifyMatchFailure(
- padOp, "tosa.pad was unable to determine the pad constant value.");
- }
-
- Value lowIndex =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
- Value highIndex =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
-
- SmallVector<OpFoldResult, 3> lowValues;
- SmallVector<OpFoldResult, 3> highValues;
-
- lowValues.reserve(rank);
- highValues.reserve(rank);
-
- for (int i = 0; i < rank; i++) {
- Value inputIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, i);
- Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
- loc, padding, ValueRange({inputIndex, lowIndex}));
- Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
- loc, padding, ValueRange({inputIndex, highIndex}));
-
- lowVal = rewriter.createOrFold<arith::IndexCastOp>(
- loc, rewriter.getIndexType(), lowVal);
- highVal = rewriter.createOrFold<arith::IndexCastOp>(
- loc, rewriter.getIndexType(), highVal);
-
- lowValues.push_back(lowVal);
- highValues.push_back(highVal);
- }
-
- auto newPadOp = rewriter.create<tensor::PadOp>(
- loc, padOp.getType(), input, lowValues, highValues, padConstant);
-
- rewriter.replaceOp(padOp, newPadOp.getResult());
- return success();
- }
-};
-
// Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
// op, producing two output buffers.
//
@@ -2375,7 +2300,6 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
ArgMaxConverter,
ConcatConverter,
GatherConverter,
- PadConverter,
ReshapeConverterCollapse,
ReshapeConverterExpand,
ReshapeConverterCollapseExpand,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 5290923c25b8a..4b1e351e9746e 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -56,6 +56,7 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
target.addLegalOp<tosa::ConstOp>();
target.addLegalOp<tosa::WhileOp>();
target.addLegalOp<tosa::SliceOp>();
+ target.addLegalOp<tosa::PadOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index cb2eea2960e3d..047cb31fa477b 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -22,7 +22,7 @@ using namespace tosa;
namespace {
-class SliceOpConverter : public OpRewritePattern<tosa::SliceOp> {
+class SliceConverter : public OpRewritePattern<tosa::SliceOp> {
public:
using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
@@ -59,9 +59,84 @@ class SliceOpConverter : public OpRewritePattern<tosa::SliceOp> {
}
};
+class PadConverter : public OpRewritePattern<tosa::PadOp> {
+public:
+ using OpRewritePattern<tosa::PadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::PadOp padOp,
+ PatternRewriter &rewriter) const final {
+ auto loc = padOp.getLoc();
+ auto input = padOp.getInput1();
+ auto padding = padOp.getPadding();
+
+ ShapedType inputTy = input.getType().cast<ShapedType>();
+ Type elementTy = inputTy.getElementType();
+ int64_t rank = inputTy.getRank();
+
+ // Setup the default constantAttr.
+
+ Value padConstant;
+
+ if (padOp.getPadConst()) {
+ padConstant = rewriter.createOrFold<tensor::ExtractOp>(
+ loc, padOp.getPadConst(), ValueRange({}));
+ } else {
+ Attribute constantAttr;
+ if (elementTy.isa<FloatType>()) {
+ constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
+ } else if (elementTy.isa<IntegerType>() && !padOp.getQuantizationInfo()) {
+ constantAttr = rewriter.getIntegerAttr(elementTy, 0);
+ } else if (elementTy.isa<IntegerType>() && padOp.getQuantizationInfo()) {
+ int64_t value = padOp.getQuantizationInfo()->getInputZp();
+ constantAttr = rewriter.getIntegerAttr(elementTy, value);
+ }
+ if (constantAttr)
+ padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr);
+ }
+
+ if (!padConstant) {
+ return rewriter.notifyMatchFailure(
+ padOp, "tosa.pad was unable to determine the pad constant value.");
+ }
+
+ Value lowIndex =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
+ Value highIndex =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
+
+ SmallVector<OpFoldResult, 3> lowValues;
+ SmallVector<OpFoldResult, 3> highValues;
+
+ lowValues.reserve(rank);
+ highValues.reserve(rank);
+
+ for (int i = 0; i < rank; i++) {
+ Value inputIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, i);
+ Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
+ loc, padding, ValueRange({inputIndex, lowIndex}));
+ Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
+ loc, padding, ValueRange({inputIndex, highIndex}));
+
+ lowVal = rewriter.createOrFold<arith::IndexCastOp>(
+ loc, rewriter.getIndexType(), lowVal);
+ highVal = rewriter.createOrFold<arith::IndexCastOp>(
+ loc, rewriter.getIndexType(), highVal);
+
+ lowValues.push_back(lowVal);
+ highValues.push_back(highVal);
+ }
+
+ auto newPadOp = rewriter.create<tensor::PadOp>(
+ loc, padOp.getType(), input, lowValues, highValues, padConstant);
+
+ rewriter.replaceOp(padOp, newPadOp.getResult());
+ return success();
+ }
+};
+
} // namespace
void mlir::tosa::populateTosaToTensorConversionPatterns(
RewritePatternSet *patterns) {
- patterns->add<SliceOpConverter>(patterns->getContext());
+ patterns->add<SliceConverter, PadConverter>(patterns->getContext());
}
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
index bf8d709cc859c..af6a08e7bcf14 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
@@ -36,6 +36,7 @@ struct TosaToTensor : public impl::TosaToTensorBase<TosaToTensor> {
RewritePatternSet patterns(&getContext());
ConversionTarget target(getContext());
target.addIllegalOp<tosa::SliceOp>();
+ target.addIllegalOp<tosa::PadOp>();
target.addLegalDialect<arith::ArithDialect>();
target.addLegalDialect<tensor::TensorDialect>();
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 349ad7ed5c864..8406c050075c1 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1301,93 +1301,6 @@ func.func @tile_dyn_multiples(%arg0 : tensor<2x3xi8>) -> () {
// -----
-// CHECK-LABEL: @pad_float
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
- %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
- // TODO: Output contains multiple "arith.constant 1 : index".
- // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
- // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
- // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
- // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
- // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
- // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
- // CHECK: tensor.yield [[CST]]
- // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
- %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<4x9xf32>)
- return %1 : tensor<4x9xf32>
-}
-
-func.func @pad_int(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
- %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
- // CHECK: [[CST:%.+]] = arith.constant 0 : i32
- // CHECK: tensor.pad
- // CHECK: tensor.yield [[CST]]
- %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>)
- return %1 : tensor<4x9xi32>
-}
-
-func.func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
- %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
- // CHECK: [[CST:%.+]] = arith.constant 42 : i32
- // CHECK: tensor.pad
- // CHECK: tensor.yield [[CST]]
- %1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>)
- return %1 : tensor<4x9xi32>
-}
-
-// -----
-
-func.func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
- %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
- // TODO: Output contains multiple "arith.constant 1 : index".
- // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
- // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
- // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
- // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
- // CHECK-DAG: [[CST:%.+]] = arith.constant 4.200000e+01 : f32
- // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
- // CHECK: tensor.yield [[CST]]
- // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
- %1 = arith.constant dense<42.0> : tensor<f32>
- %2 = "tosa.pad"(%arg0, %0, %1) : (tensor<1x2xf32>, tensor<2x2xi32>, tensor<f32>) -> (tensor<4x9xf32>)
- return %2 : tensor<4x9xf32>
-}
-
-// -----
-
-func.func @pad_dyn_input(%arg0 : tensor<?x2xf32>) -> (tensor<?x9xf32>) {
- %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
- // TODO: Output contains multiple "arith.constant 1 : index".
- // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
- // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
- // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
- // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
- // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
- // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
- // CHECK: tensor.yield [[CST]]
- // CHECK: } : tensor<?x2xf32> to tensor<?x9xf32>
- %1 = "tosa.pad"(%arg0, %0) : (tensor<?x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>)
- return %1 : tensor<?x9xf32>
-}
-
-func.func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor<?x9xf32>) {
- %0 = arith.constant dense<[[-1, 2], [3, 4]]> : tensor<2x2xi32>
- // TODO: Output contains multiple "arith.constant 1 : index".
- // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
- // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
- // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
- // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
- // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
- // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
- // CHECK: tensor.yield [[CST]]
- // CHECK: } : tensor<1x2xf32> to tensor<?x9xf32>
- %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>)
- return %1 : tensor<?x9xf32>
-}
-
-// -----
-
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index 08105c72eb5a0..b50af43de021a 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -19,3 +19,90 @@ func.func @slice_dyn(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
%0 = "tosa.slice"(%arg0) {start = [2], size = [-1]} : (tensor<?xf32>) -> (tensor<?xf32>)
return %0 : tensor<?xf32>
}
+
+// -----
+
+// CHECK-LABEL: @pad_float
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
+ %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+ // TODO: Output contains multiple "arith.constant 1 : index".
+ // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
+ // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
+ // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
+ // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
+ // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
+ // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
+ // CHECK: tensor.yield [[CST]]
+ // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
+ %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<4x9xf32>)
+ return %1 : tensor<4x9xf32>
+}
+
+func.func @pad_int(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
+ %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+ // CHECK: [[CST:%.+]] = arith.constant 0 : i32
+ // CHECK: tensor.pad
+ // CHECK: tensor.yield [[CST]]
+ %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>)
+ return %1 : tensor<4x9xi32>
+}
+
+func.func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
+ %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+ // CHECK: [[CST:%.+]] = arith.constant 42 : i32
+ // CHECK: tensor.pad
+ // CHECK: tensor.yield [[CST]]
+ %1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>)
+ return %1 : tensor<4x9xi32>
+}
+
+// -----
+
+func.func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
+ %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+ // TODO: Output contains multiple "arith.constant 1 : index".
+ // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
+ // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
+ // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
+ // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
+ // CHECK-DAG: [[CST:%.+]] = arith.constant 4.200000e+01 : f32
+ // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
+ // CHECK: tensor.yield [[CST]]
+ // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
+ %1 = arith.constant dense<42.0> : tensor<f32>
+ %2 = "tosa.pad"(%arg0, %0, %1) : (tensor<1x2xf32>, tensor<2x2xi32>, tensor<f32>) -> (tensor<4x9xf32>)
+ return %2 : tensor<4x9xf32>
+}
+
+// -----
+
+func.func @pad_dyn_input(%arg0 : tensor<?x2xf32>) -> (tensor<?x9xf32>) {
+ %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+ // TODO: Output contains multiple "arith.constant 1 : index".
+ // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
+ // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
+ // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
+ // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
+ // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
+ // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
+ // CHECK: tensor.yield [[CST]]
+ // CHECK: } : tensor<?x2xf32> to tensor<?x9xf32>
+ %1 = "tosa.pad"(%arg0, %0) : (tensor<?x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>)
+ return %1 : tensor<?x9xf32>
+}
+
+func.func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor<?x9xf32>) {
+ %0 = arith.constant dense<[[-1, 2], [3, 4]]> : tensor<2x2xi32>
+ // TODO: Output contains multiple "arith.constant 1 : index".
+ // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
+ // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
+ // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
+ // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
+ // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
+ // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
+ // CHECK: tensor.yield [[CST]]
+ // CHECK: } : tensor<1x2xf32> to tensor<?x9xf32>
+ %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>)
+ return %1 : tensor<?x9xf32>
+}
More information about the Mlir-commits
mailing list