[Mlir-commits] [mlir] 024a1fa - [tosa][mlir] Add dynamic shape support for remaining ops
Rob Suderman
llvmlistbot at llvm.org
Thu Jan 27 11:36:19 PST 2022
Author: natashaknk
Date: 2022-01-27T11:25:38-08:00
New Revision: 024a1fab5c35f630c0b7de721eba497692d081fe
URL: https://github.com/llvm/llvm-project/commit/024a1fab5c35f630c0b7de721eba497692d081fe
DIFF: https://github.com/llvm/llvm-project/commit/024a1fab5c35f630c0b7de721eba497692d081fe.diff
LOG: [tosa][mlir] Add dynamic shape support for remaining ops
Added support for concat, tile, pad, argmax and table ops
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D118397
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 6833a0c2d72cb..ba405f5ab1e98 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1681,11 +1681,8 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
LogicalResult
matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ auto inputType = op.getOperand(0).getType().template cast<ShapedType>();
auto resultType = op.getType().dyn_cast<RankedTensorType>();
- if (!resultType || !resultType.hasStaticShape()) {
- return rewriter.notifyMatchFailure(op,
- "expected static shaped tensor type");
- }
Location loc = op.getLoc();
int axis = op.axis();
@@ -1697,9 +1694,14 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
strides.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 1));
offsets.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
+ SmallVector<Value> dynDims;
for (int i = 0; i < rank; ++i) {
sizes.push_back(rewriter.createOrFold<tensor::DimOp>(
loc, adaptor.getOperands()[0], i));
+ if (inputType.isDynamicDim(i)) {
+ dynDims.push_back(
+ rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
+ }
}
Value resultDimSize = sizes[axis];
@@ -1711,7 +1713,7 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
sizes[axis] = resultDimSize;
Value init = rewriter.create<linalg::InitTensorOp>(
- loc, resultType.getShape(), resultType.getElementType());
+ loc, dynDims, resultType.getShape(), resultType.getElementType());
Value zeroVal = rewriter.createOrFold<arith::ConstantOp>(
loc, rewriter.getZeroAttr(resultType.getElementType()));
@@ -1815,9 +1817,6 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
auto elementTy = inputTy.getElementType();
int64_t rank = inputTy.getRank();
- if (!inputTy.hasStaticShape() || !resultTy.hasStaticShape())
- return failure();
-
SmallVector<int64_t> multiples;
getValuesFromIntArrayAttribute(op.multiples(), multiples);
@@ -1828,8 +1827,15 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
genericShape.push_back(inputShape[i]);
}
+ SmallVector<Value> dynDims;
+ for (int i = 0; i < inputTy.getRank(); i++) {
+ if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
+ dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
+ }
+ }
+
auto initTensor = rewriter.create<linalg::InitTensorOp>(
- op.getLoc(), ArrayRef<Value>({}), genericShape, elementTy);
+ op.getLoc(), dynDims, genericShape, elementTy);
// We needs to map the input shape to the non-broadcasted dimensions.
SmallVector<AffineExpr, 4> dimExprs;
@@ -1870,16 +1876,9 @@ class PadConverter : public OpRewritePattern<tosa::PadOp> {
auto padding = padOp.padding();
ShapedType inputTy = input.getType().cast<ShapedType>();
- ShapedType paddingTy = padding.getType().cast<ShapedType>();
Type elementTy = inputTy.getElementType();
int64_t rank = inputTy.getRank();
- if (!inputTy.hasStaticShape() || !paddingTy.hasStaticShape()) {
- return rewriter.notifyMatchFailure(
- padOp,
- "Pad converter requires static shaped input / padding values.");
- }
-
// Setup the default constantAttr.
Value padConstant;
@@ -1970,21 +1969,23 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
int axis = argmaxOp.axis();
auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
- if (!inputTy.hasStaticShape())
- return rewriter.notifyMatchFailure(
- argmaxOp,
- "tosa.arg_max to linalg.* requires statically shaped input");
-
if (!outElementTy.isa<IntegerType>())
return rewriter.notifyMatchFailure(
argmaxOp,
"tosa.arg_max to linalg.* requires integer-like result type");
+ SmallVector<Value> dynDims;
+ for (int i = 0; i < inputTy.getRank(); i++) {
+ if (inputTy.isDynamicDim(i) && i != axis) {
+ dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
+ }
+ }
+
// First fill the output buffer for the index.
auto initTensorIdx =
rewriter
- .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
- resultTy.getShape(), outElementTy)
+ .create<linalg::InitTensorOp>(loc, dynDims, resultTy.getShape(),
+ outElementTy)
.result();
auto fillValueIdx = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(outElementTy, 0));
@@ -1993,11 +1994,10 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
.result();
// Second fill the output buffer for the running max.
- auto initTensorMax =
- rewriter
- .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
- resultTy.getShape(), inElementTy)
- .result();
+ auto initTensorMax = rewriter
+ .create<linalg::InitTensorOp>(
+ loc, dynDims, resultTy.getShape(), inElementTy)
+ .result();
auto fillValueMaxAttr =
createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
@@ -2138,18 +2138,22 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
auto tableTy = table.getType().cast<ShapedType>();
auto resultTy = op.getType().cast<ShapedType>();
- if (!inputTy.hasStaticShape())
- return rewriter.notifyMatchFailure(
- op, "require input type to have static shape");
-
auto inputElementTy = inputTy.getElementType();
auto tableElementTy = tableTy.getElementType();
auto resultElementTy = resultTy.getElementType();
+ SmallVector<Value> dynDims;
+ for (int i = 0; i < resultTy.getRank(); ++i) {
+ if (inputTy.isDynamicDim(i)) {
+ dynDims.push_back(
+ rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
+ }
+ }
+
auto initTensor =
rewriter
- .create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
- resultTy.getShape(), resultElementTy)
+ .create<linalg::InitTensorOp>(loc, dynDims, resultTy.getShape(),
+ resultElementTy)
.result();
SmallVector<AffineMap, 2> affineMaps = {
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 55b8bce54b1a2..cb7f42ba4242a 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -910,6 +910,50 @@ func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () {
// -----
+// CHECK-LABEL: @concat_non_axis_dyn
+func @concat_non_axis_dyn(%arg0: tensor<5x?xf32>, %arg1: tensor<6x?xf32>) -> () {
+ // CHECK: %[[AXIS:.+]] = arith.constant 0
+ // CHECK: %[[STRIDE:.+]] = arith.constant 1
+ // CHECK: %[[OFFSET:.+]] = arith.constant 0 : index
+ // CHECK: %[[IDX0:.+]] = arith.constant 0 : index
+ // CHECK: %[[IDX1:.+]] = arith.constant 1 : index
+ // CHECK: %[[SIZE:.+]] = tensor.dim %arg0, %[[IDX1]]
+ // CHECK: %[[IDX1_2:.+]] = arith.constant 1 : index
+ // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[IDX1_2]]
+ // CHECK: %[[INIT:.+]] = linalg.init_tensor [11, %[[DYN]]]
+ // CHECK: %[[CST:.+]] = arith.constant 0.0
+ // CHECK: %[[FILL:.+]] = linalg.fill(%[[CST]], %[[INIT]])
+ // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %arg0 into %[[FILL]][0, 0] [5, %[[SIZE]]] [1, 1]
+ // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %arg1 into %[[INSERT0]][5, 0] [6, %[[SIZE]]] [1, 1]
+ %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x?xf32>, tensor<6x?xf32>) -> (tensor<11x?xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @concat_axis_dyn
+func @concat_axis_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<?x3xf32>) -> () {
+ // CHECK: %[[AXIS:.+]] = arith.constant 0
+ // CHECK: %[[STRIDE:.+]] = arith.constant 1
+ // CHECK: %[[OFFSET:.+]] = arith.constant 0 : index
+ // CHECK: %[[IDX0:.+]] = arith.constant 0 : index
+ // CHECK: %[[SIZE:.+]] = tensor.dim %arg0, %[[IDX0]]
+ // CHECK: %[[IDX0_2:.+]] = arith.constant 0 : index
+ // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[IDX0_2]]
+ // CHECK: %[[IDX1:.+]] = arith.constant 1 : index
+ // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DYN]], 3]
+ // CHECK: %[[CST:.+]] = arith.constant 0.0
+ // CHECK: %[[FILL:.+]] = linalg.fill(%[[CST]], %[[INIT]])
+ // CHECK: %[[DYN1:.+]] = tensor.dim %arg0, %[[AXIS]]
+ // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %arg0 into %[[FILL]][0, 0] [%[[DYN1]], 3] [1, 1]
+ // CHECK: %[[SUM:.+]] = arith.addi %[[OFFSET]], %[[DYN1]]
+ // CHECK: %[[DYN2:.+]] = tensor.dim %arg1, %[[AXIS]]
+ // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %arg1 into %[[INSERT0]][%[[SUM]], 0] [%[[DYN2]], 3] [1, 1]
+ %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<?x3xf32>, tensor<?x3xf32>) -> (tensor<?x3xf32>)
+ return
+}
+
+// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: @rescale_i8
@@ -1150,6 +1194,44 @@ func @tile(%arg0 : tensor<2x3xi8>) -> () {
// -----
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+// CHECK-LABEL: @tile_dyn_input
+func @tile_dyn_input(%arg0 : tensor<?x3xi8>) -> () {
+ // CHECK: %[[CST0:.+]] = arith.constant 0
+ // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[CST0]] : tensor<?x3xi8>
+ // CHECK: %[[INIT:.+]] = linalg.init_tensor [2, %[[DYN]], 1, 3]
+ // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<?x3xi8>) outs(%[[INIT]] : tensor<2x?x1x3xi8>)
+ // CHECK: linalg.yield %arg1 : i8
+ // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1, 2, 3]]
+ // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1]]
+ %0 = "tosa.tile"(%arg0) {multiples = [2, 1]} : (tensor<?x3xi8>) -> (tensor<?x3xi8>)
+
+ return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+// CHECK-LABEL: @tile_dyn_multiples
+func @tile_dyn_multiples(%arg0 : tensor<2x3xi8>) -> () {
+ // CHECK: %[[CST1:.+]] = arith.constant 1
+ // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[CST1]] : tensor<2x3xi8>
+ // CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 2, %[[DYN]], 3]
+ // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs(%[[INIT]] : tensor<2x2x?x3xi8>)
+ // CHECK: linalg.yield %arg1 : i8
+ // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1, 2, 3]]
+ // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1]]
+ %0 = "tosa.tile"(%arg0) {multiples = [2, -1]} : (tensor<2x3xi8>) -> (tensor<2x?xi8>)
+
+ return
+}
+
+// -----
+
func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
%0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
// TODO: Output contains multiple "arith.constant 1 : index".
@@ -1205,6 +1287,40 @@ func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
// -----
+func @pad_dyn_input(%arg0 : tensor<?x2xf32>) -> (tensor<?x9xf32>) {
+ %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+ // TODO: Output contains multiple "arith.constant 1 : index".
+ // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
+ // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
+ // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
+ // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
+ // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
+ // CHECK: tensor.pad %arg0 low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
+ // CHECK: ^bb0(%arg1: index, %arg2: index):
+ // CHECK: tensor.yield [[CST]]
+ // CHECK: } : tensor<?x2xf32> to tensor<?x9xf32>
+ %1 = "tosa.pad"(%arg0, %0) : (tensor<?x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>)
+ return %1 : tensor<?x9xf32>
+}
+
+func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor<?x9xf32>) {
+ %0 = arith.constant dense<[[-1, 2], [3, 4]]> : tensor<2x2xi32>
+ // TODO: Output contains multiple "arith.constant 1 : index".
+ // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
+ // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
+ // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
+ // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
+ // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
+ // CHECK: tensor.pad %arg0 low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
+ // CHECK: ^bb0(%arg1: index, %arg2: index):
+ // CHECK: tensor.yield [[CST]]
+ // CHECK: } : tensor<1x2xf32> to tensor<?x9xf32>
+ %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>)
+ return %1 : tensor<?x9xf32>
+}
+
+// -----
+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
@@ -1256,6 +1372,54 @@ func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () {
// -----
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
+
+func @argmax_dyn_non_axis(%arg0 : tensor<3x?xi32>) -> () {
+ // CHECK: %[[CST1:.+]] = arith.constant 1
+ // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[CST1]]
+ // CHECK: %[[IDX_INIT:.+]] = linalg.init_tensor [%[[DYN]]]
+ // CHECK: %[[IDX_MIN:.+]] = arith.constant 0 : i32
+ // CHECK: %[[IDX_FILL:.+]] = linalg.fill(%[[IDX_MIN]], %[[IDX_INIT]])
+ // CHECK: %[[VAL_INIT:.+]] = linalg.init_tensor [%[[DYN]]]
+ // CHECK: %[[VAL_MIN:.+]] = arith.constant -2147483648
+ // CHECK: %[[VAL_FILL:.+]] = linalg.fill(%[[VAL_MIN]], %[[VAL_INIT]])
+ // CHECK: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins(%arg0 : tensor<3x?xi32>) outs(%[[IDX_FILL]], %[[VAL_FILL]] : tensor<?xi32>, tensor<?xi32>)
+ // CHECK: %[[IDX:.+]] = linalg.index 0
+ // CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX]]
+ // CHECK: %[[CMP:.+]] = arith.cmpi sgt, %arg1, %arg3
+ // CHECK: %[[SELECT_VAL:.+]] = select %[[CMP]], %arg1, %arg3
+ // CHECK: %[[SELECT_IDX:.+]] = select %[[CMP]], %[[CAST]], %arg2
+ // CHECK: linalg.yield %[[SELECT_IDX]], %[[SELECT_VAL]]
+ %0 = "tosa.argmax"(%arg0) { axis = 0 : i64} : (tensor<3x?xi32>) -> (tensor<?xi32>)
+ return
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
+
+func @argmax_dyn_axis(%arg0 : tensor<3x?xi32>) -> () {
+ // CHECK: %[[IDX_INIT:.+]] = linalg.init_tensor [3]
+ // CHECK: %[[IDX_MIN:.+]] = arith.constant 0 : i32
+ // CHECK: %[[IDX_FILL:.+]] = linalg.fill(%[[IDX_MIN]], %[[IDX_INIT]])
+ // CHECK: %[[VAL_INIT:.+]] = linalg.init_tensor [3]
+ // CHECK: %[[VAL_MIN:.+]] = arith.constant -2147483648
+ // CHECK: %[[VAL_FILL:.+]] = linalg.fill(%[[VAL_MIN]], %[[VAL_INIT]])
+ // CHECK: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<3x?xi32>) outs(%[[IDX_FILL]], %[[VAL_FILL]] : tensor<3xi32>, tensor<3xi32>)
+ // CHECK: %[[IDX:.+]] = linalg.index 1
+ // CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX]]
+ // CHECK: %[[CMP:.+]] = arith.cmpi sgt, %arg1, %arg3
+ // CHECK: %[[SELECT_VAL:.+]] = select %[[CMP]], %arg1, %arg3
+ // CHECK: %[[SELECT_IDX:.+]] = select %[[CMP]], %[[CAST]], %arg2
+ // CHECK: linalg.yield %[[SELECT_IDX]], %[[SELECT_VAL]]
+ %0 = "tosa.argmax"(%arg0) { axis = 1 : i64} : (tensor<3x?xi32>) -> (tensor<3xi32>)
+ return
+}
+
+// -----
+
// CHECK-LABEL: @gather_float
func @gather_float(%arg0: tensor<2x3x2xf32>, %arg1: tensor<2x3xi32>) -> () {
// CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 3, 2]
@@ -1349,6 +1513,40 @@ func @table16(%arg0: tensor<6xi16>, %arg1: tensor<513xi16>) -> () {
// -----
+// CHECK-LABEL: @table8_dyn
+func @table8_dyn(%arg0: tensor<?xi8>, %arg1: tensor<512xi8>) -> () {
+ // CHECK: %[[CST0:.+]] = arith.constant 0
+ // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[CST0]]
+ // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DYN]]]
+ // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<?xi8>) outs(%[[INIT]] : tensor<?xi8>)
+ // CHECK: ^bb0(%[[ARG_IN:.+]]: i8, %[[ARG_INIT:.+]]: i8)
+ // CHECK: %[[CAST:.+]] = arith.index_cast %[[ARG_IN]]
+ // CHECK: %[[OFFSET:.+]] = arith.constant 128
+ // CHECK: %[[ADD:.+]] = arith.addi %[[CAST]], %[[OFFSET]]
+ // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg1[%[[ADD]]]
+ // CHECK: linalg.yield %[[EXTRACT]]
+ %0 = "tosa.table"(%arg0, %arg1) : (tensor<?xi8>, tensor<512xi8>) -> (tensor<?xi8>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @table8_dyn_table
+func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
+ // CHECK: %[[INIT:.+]] = linalg.init_tensor [6]
+ // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<6xi8>) outs(%[[INIT]] : tensor<6xi8>)
+ // CHECK: ^bb0(%[[ARG_IN:.+]]: i8, %[[ARG_INIT:.+]]: i8)
+ // CHECK: %[[CAST:.+]] = arith.index_cast %[[ARG_IN]]
+ // CHECK: %[[OFFSET:.+]] = arith.constant 128
+ // CHECK: %[[ADD:.+]] = arith.addi %[[CAST]], %[[OFFSET]]
+ // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg1[%[[ADD]]]
+ // CHECK: linalg.yield %[[EXTRACT]]
+ %0 = "tosa.table"(%arg0, %arg1) : (tensor<6xi8>, tensor<?xi8>) -> (tensor<6xi8>)
+ return
+}
+
+// -----
+
// CHECK-LABEL: @resize_nearest
func @resize_nearest(%input: tensor<1x2x2x1xf32>) -> () {
// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1]
More information about the Mlir-commits
mailing list