[Mlir-commits] [mlir] 024a1fa - [tosa][mlir] Add dynamic shape support for remaining ops

Rob Suderman llvmlistbot at llvm.org
Thu Jan 27 11:36:19 PST 2022


Author: natashaknk
Date: 2022-01-27T11:25:38-08:00
New Revision: 024a1fab5c35f630c0b7de721eba497692d081fe

URL: https://github.com/llvm/llvm-project/commit/024a1fab5c35f630c0b7de721eba497692d081fe
DIFF: https://github.com/llvm/llvm-project/commit/024a1fab5c35f630c0b7de721eba497692d081fe.diff

LOG: [tosa][mlir] Add dynamic shape support for remaining ops

Added support for concat, tile, pad, argmax and table ops

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D118397

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 6833a0c2d72cb..ba405f5ab1e98 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1681,11 +1681,8 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
   LogicalResult
   matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    auto inputType = op.getOperand(0).getType().template cast<ShapedType>();
     auto resultType = op.getType().dyn_cast<RankedTensorType>();
-    if (!resultType || !resultType.hasStaticShape()) {
-      return rewriter.notifyMatchFailure(op,
-                                         "expected static shaped tensor type");
-    }
 
     Location loc = op.getLoc();
     int axis = op.axis();
@@ -1697,9 +1694,14 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
     strides.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 1));
     offsets.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
 
+    SmallVector<Value> dynDims;
     for (int i = 0; i < rank; ++i) {
       sizes.push_back(rewriter.createOrFold<tensor::DimOp>(
           loc, adaptor.getOperands()[0], i));
+      if (inputType.isDynamicDim(i)) {
+        dynDims.push_back(
+            rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
+      }
     }
 
     Value resultDimSize = sizes[axis];
@@ -1711,7 +1713,7 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
     sizes[axis] = resultDimSize;
 
     Value init = rewriter.create<linalg::InitTensorOp>(
-        loc, resultType.getShape(), resultType.getElementType());
+        loc, dynDims, resultType.getShape(), resultType.getElementType());
 
     Value zeroVal = rewriter.createOrFold<arith::ConstantOp>(
         loc, rewriter.getZeroAttr(resultType.getElementType()));
@@ -1815,9 +1817,6 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
     auto elementTy = inputTy.getElementType();
     int64_t rank = inputTy.getRank();
 
-    if (!inputTy.hasStaticShape() || !resultTy.hasStaticShape())
-      return failure();
-
     SmallVector<int64_t> multiples;
     getValuesFromIntArrayAttribute(op.multiples(), multiples);
 
@@ -1828,8 +1827,15 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
       genericShape.push_back(inputShape[i]);
     }
 
+    SmallVector<Value> dynDims;
+    for (int i = 0; i < inputTy.getRank(); i++) {
+      if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
+        dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
+      }
+    }
+
     auto initTensor = rewriter.create<linalg::InitTensorOp>(
-        op.getLoc(), ArrayRef<Value>({}), genericShape, elementTy);
+        op.getLoc(), dynDims, genericShape, elementTy);
 
     // We needs to map the input shape to the non-broadcasted dimensions.
     SmallVector<AffineExpr, 4> dimExprs;
@@ -1870,16 +1876,9 @@ class PadConverter : public OpRewritePattern<tosa::PadOp> {
     auto padding = padOp.padding();
 
     ShapedType inputTy = input.getType().cast<ShapedType>();
-    ShapedType paddingTy = padding.getType().cast<ShapedType>();
     Type elementTy = inputTy.getElementType();
     int64_t rank = inputTy.getRank();
 
-    if (!inputTy.hasStaticShape() || !paddingTy.hasStaticShape()) {
-      return rewriter.notifyMatchFailure(
-          padOp,
-          "Pad converter requires static shaped input / padding values.");
-    }
-
     // Setup the default constantAttr.
 
     Value padConstant;
@@ -1970,21 +1969,23 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
     int axis = argmaxOp.axis();
     auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
 
-    if (!inputTy.hasStaticShape())
-      return rewriter.notifyMatchFailure(
-          argmaxOp,
-          "tosa.arg_max to linalg.* requires statically shaped input");
-
     if (!outElementTy.isa<IntegerType>())
       return rewriter.notifyMatchFailure(
           argmaxOp,
           "tosa.arg_max to linalg.* requires integer-like result type");
 
+    SmallVector<Value> dynDims;
+    for (int i = 0; i < inputTy.getRank(); i++) {
+      if (inputTy.isDynamicDim(i) && i != axis) {
+        dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
+      }
+    }
+
     // First fill the output buffer for the index.
     auto initTensorIdx =
         rewriter
-            .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
-                                          resultTy.getShape(), outElementTy)
+            .create<linalg::InitTensorOp>(loc, dynDims, resultTy.getShape(),
+                                          outElementTy)
             .result();
     auto fillValueIdx = rewriter.create<arith::ConstantOp>(
         loc, rewriter.getIntegerAttr(outElementTy, 0));
@@ -1993,11 +1994,10 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
             .result();
 
     // Second fill the output buffer for the running max.
-    auto initTensorMax =
-        rewriter
-            .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
-                                          resultTy.getShape(), inElementTy)
-            .result();
+    auto initTensorMax = rewriter
+                             .create<linalg::InitTensorOp>(
+                                 loc, dynDims, resultTy.getShape(), inElementTy)
+                             .result();
     auto fillValueMaxAttr =
         createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
 
@@ -2138,18 +2138,22 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
     auto tableTy = table.getType().cast<ShapedType>();
     auto resultTy = op.getType().cast<ShapedType>();
 
-    if (!inputTy.hasStaticShape())
-      return rewriter.notifyMatchFailure(
-          op, "require input type to have static shape");
-
     auto inputElementTy = inputTy.getElementType();
     auto tableElementTy = tableTy.getElementType();
     auto resultElementTy = resultTy.getElementType();
 
+    SmallVector<Value> dynDims;
+    for (int i = 0; i < resultTy.getRank(); ++i) {
+      if (inputTy.isDynamicDim(i)) {
+        dynDims.push_back(
+            rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
+      }
+    }
+
     auto initTensor =
         rewriter
-            .create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
-                                          resultTy.getShape(), resultElementTy)
+            .create<linalg::InitTensorOp>(loc, dynDims, resultTy.getShape(),
+                                          resultElementTy)
             .result();
 
     SmallVector<AffineMap, 2> affineMaps = {

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 55b8bce54b1a2..cb7f42ba4242a 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -910,6 +910,50 @@ func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () {
 
 // -----
 
+// CHECK-LABEL: @concat_non_axis_dyn
+func @concat_non_axis_dyn(%arg0: tensor<5x?xf32>, %arg1: tensor<6x?xf32>) -> () {
+  // CHECK: %[[AXIS:.+]] = arith.constant 0
+  // CHECK: %[[STRIDE:.+]]   = arith.constant 1
+  // CHECK: %[[OFFSET:.+]] = arith.constant 0 : index
+  // CHECK: %[[IDX0:.+]] = arith.constant 0 : index
+  // CHECK: %[[IDX1:.+]] = arith.constant 1 : index
+  // CHECK: %[[SIZE:.+]] = tensor.dim %arg0, %[[IDX1]]
+  // CHECK: %[[IDX1_2:.+]] = arith.constant 1 : index
+  // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[IDX1_2]]
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [11, %[[DYN]]]
+  // CHECK: %[[CST:.+]] = arith.constant 0.0
+  // CHECK: %[[FILL:.+]] = linalg.fill(%[[CST]], %[[INIT]])
+  // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %arg0 into %[[FILL]][0, 0] [5, %[[SIZE]]] [1, 1]
+  // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %arg1 into %[[INSERT0]][5, 0] [6, %[[SIZE]]] [1, 1]
+  %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x?xf32>, tensor<6x?xf32>)  -> (tensor<11x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @concat_axis_dyn
+func @concat_axis_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<?x3xf32>) -> () {
+  // CHECK: %[[AXIS:.+]] = arith.constant 0
+  // CHECK: %[[STRIDE:.+]]   = arith.constant 1
+  // CHECK: %[[OFFSET:.+]] = arith.constant 0 : index
+  // CHECK: %[[IDX0:.+]] = arith.constant 0 : index
+  // CHECK: %[[SIZE:.+]] = tensor.dim %arg0, %[[IDX0]]
+  // CHECK: %[[IDX0_2:.+]] = arith.constant 0 : index
+  // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[IDX0_2]]
+  // CHECK: %[[IDX1:.+]] = arith.constant 1 : index
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DYN]], 3]
+  // CHECK: %[[CST:.+]] = arith.constant 0.0
+  // CHECK: %[[FILL:.+]] = linalg.fill(%[[CST]], %[[INIT]])
+  // CHECK: %[[DYN1:.+]] = tensor.dim %arg0, %[[AXIS]]
+  // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %arg0 into %[[FILL]][0, 0] [%[[DYN1]], 3] [1, 1]
+  // CHECK: %[[SUM:.+]]  = arith.addi %[[OFFSET]], %[[DYN1]]
+  // CHECK: %[[DYN2:.+]] = tensor.dim %arg1, %[[AXIS]]
+  // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %arg1 into %[[INSERT0]][%[[SUM]], 0] [%[[DYN2]], 3] [1, 1]
+  %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<?x3xf32>, tensor<?x3xf32>)  -> (tensor<?x3xf32>)
+  return
+}
+
+// -----
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
 
 // CHECK-LABEL: @rescale_i8
@@ -1150,6 +1194,44 @@ func @tile(%arg0 : tensor<2x3xi8>) -> () {
 
 // -----
 
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+// CHECK-LABEL: @tile_dyn_input
+func @tile_dyn_input(%arg0 : tensor<?x3xi8>) -> () {
+  // CHECK: %[[CST0:.+]] = arith.constant 0
+  // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[CST0]] : tensor<?x3xi8>
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [2, %[[DYN]], 1, 3]
+  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<?x3xi8>) outs(%[[INIT]] : tensor<2x?x1x3xi8>)
+  // CHECK:   linalg.yield %arg1 : i8
+  // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1, 2, 3]]
+  // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1]]
+  %0 = "tosa.tile"(%arg0) {multiples = [2, 1]} : (tensor<?x3xi8>)  -> (tensor<?x3xi8>)
+
+  return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+// CHECK-LABEL: @tile_dyn_multiples
+func @tile_dyn_multiples(%arg0 : tensor<2x3xi8>) -> () {
+  // CHECK: %[[CST1:.+]] = arith.constant 1
+  // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[CST1]] : tensor<2x3xi8>
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 2, %[[DYN]], 3]
+  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs(%[[INIT]] : tensor<2x2x?x3xi8>)
+  // CHECK:   linalg.yield %arg1 : i8
+  // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1, 2, 3]]
+  // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1]]
+  %0 = "tosa.tile"(%arg0) {multiples = [2, -1]} : (tensor<2x3xi8>)  -> (tensor<2x?xi8>)
+
+  return
+}
+
+// -----
+
 func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
   %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
   // TODO: Output contains multiple "arith.constant 1 : index".
@@ -1205,6 +1287,40 @@ func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
 
 // -----
 
+func @pad_dyn_input(%arg0 : tensor<?x2xf32>) -> (tensor<?x9xf32>) {
+  %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+  // TODO: Output contains multiple "arith.constant 1 : index".
+  // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
+  // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
+  // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
+  // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
+  // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
+  // CHECK: tensor.pad %arg0 low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]]  {
+  // CHECK: ^bb0(%arg1: index, %arg2: index):
+  // CHECK:   tensor.yield [[CST]]
+  // CHECK: } : tensor<?x2xf32> to tensor<?x9xf32>
+  %1 = "tosa.pad"(%arg0, %0)  : (tensor<?x2xf32>, tensor<2x2xi32>)  -> (tensor<?x9xf32>)
+  return %1 : tensor<?x9xf32>
+}
+
+func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor<?x9xf32>) {
+  %0 = arith.constant dense<[[-1, 2], [3, 4]]> : tensor<2x2xi32>
+  // TODO: Output contains multiple "arith.constant 1 : index".
+  // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
+  // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
+  // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
+  // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
+  // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
+  // CHECK: tensor.pad %arg0 low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]]  {
+  // CHECK: ^bb0(%arg1: index, %arg2: index):
+  // CHECK:   tensor.yield [[CST]]
+  // CHECK: } : tensor<1x2xf32> to tensor<?x9xf32>
+  %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x2xf32>, tensor<2x2xi32>)  -> (tensor<?x9xf32>)
+  return %1 : tensor<?x9xf32>
+}
+
+// -----
+
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
 // CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
@@ -1256,6 +1372,54 @@ func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () {
 
 // -----
 
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
+
+func @argmax_dyn_non_axis(%arg0 : tensor<3x?xi32>) -> () {
+  // CHECK: %[[CST1:.+]] = arith.constant 1
+  // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[CST1]]
+  // CHECK: %[[IDX_INIT:.+]] = linalg.init_tensor [%[[DYN]]]
+  // CHECK: %[[IDX_MIN:.+]] = arith.constant 0 : i32
+  // CHECK: %[[IDX_FILL:.+]] = linalg.fill(%[[IDX_MIN]], %[[IDX_INIT]])
+  // CHECK: %[[VAL_INIT:.+]] = linalg.init_tensor [%[[DYN]]]
+  // CHECK: %[[VAL_MIN:.+]] = arith.constant -2147483648
+  // CHECK: %[[VAL_FILL:.+]] = linalg.fill(%[[VAL_MIN]], %[[VAL_INIT]])
+  // CHECK: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins(%arg0 : tensor<3x?xi32>) outs(%[[IDX_FILL]], %[[VAL_FILL]] : tensor<?xi32>, tensor<?xi32>)
+  // CHECK:   %[[IDX:.+]] = linalg.index 0
+  // CHECK:   %[[CAST:.+]] = arith.index_cast %[[IDX]]
+  // CHECK:   %[[CMP:.+]] = arith.cmpi sgt, %arg1, %arg3
+  // CHECK:   %[[SELECT_VAL:.+]] = select %[[CMP]], %arg1, %arg3
+  // CHECK:   %[[SELECT_IDX:.+]] = select %[[CMP]], %[[CAST]], %arg2
+  // CHECK:   linalg.yield %[[SELECT_IDX]], %[[SELECT_VAL]]
+  %0 = "tosa.argmax"(%arg0) { axis = 0 : i64} : (tensor<3x?xi32>)  -> (tensor<?xi32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
+
+func @argmax_dyn_axis(%arg0 : tensor<3x?xi32>) -> () {
+  // CHECK: %[[IDX_INIT:.+]] = linalg.init_tensor [3]
+  // CHECK: %[[IDX_MIN:.+]] = arith.constant 0 : i32
+  // CHECK: %[[IDX_FILL:.+]] = linalg.fill(%[[IDX_MIN]], %[[IDX_INIT]])
+  // CHECK: %[[VAL_INIT:.+]] = linalg.init_tensor [3]
+  // CHECK: %[[VAL_MIN:.+]] = arith.constant -2147483648
+  // CHECK: %[[VAL_FILL:.+]] = linalg.fill(%[[VAL_MIN]], %[[VAL_INIT]])
+  // CHECK: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<3x?xi32>) outs(%[[IDX_FILL]], %[[VAL_FILL]] : tensor<3xi32>, tensor<3xi32>)
+  // CHECK:   %[[IDX:.+]] = linalg.index 1
+  // CHECK:   %[[CAST:.+]] = arith.index_cast %[[IDX]]
+  // CHECK:   %[[CMP:.+]] = arith.cmpi sgt, %arg1, %arg3
+  // CHECK:   %[[SELECT_VAL:.+]] = select %[[CMP]], %arg1, %arg3
+  // CHECK:   %[[SELECT_IDX:.+]] = select %[[CMP]], %[[CAST]], %arg2
+  // CHECK:   linalg.yield %[[SELECT_IDX]], %[[SELECT_VAL]]
+  %0 = "tosa.argmax"(%arg0) { axis = 1 : i64} : (tensor<3x?xi32>)  -> (tensor<3xi32>)
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @gather_float
 func @gather_float(%arg0: tensor<2x3x2xf32>, %arg1: tensor<2x3xi32>) -> () {
   // CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 3, 2]
@@ -1349,6 +1513,40 @@ func @table16(%arg0: tensor<6xi16>, %arg1: tensor<513xi16>) -> () {
 
 // -----
 
+// CHECK-LABEL: @table8_dyn
+func @table8_dyn(%arg0: tensor<?xi8>, %arg1: tensor<512xi8>) -> () {
+  // CHECK: %[[CST0:.+]] = arith.constant 0
+  // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[CST0]]
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DYN]]]
+  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<?xi8>) outs(%[[INIT]] : tensor<?xi8>)
+  // CHECK: ^bb0(%[[ARG_IN:.+]]: i8, %[[ARG_INIT:.+]]: i8)
+  // CHECK:   %[[CAST:.+]] = arith.index_cast %[[ARG_IN]]
+  // CHECK:   %[[OFFSET:.+]] = arith.constant 128
+  // CHECK:   %[[ADD:.+]] = arith.addi %[[CAST]], %[[OFFSET]]
+  // CHECK:   %[[EXTRACT:.+]] = tensor.extract %arg1[%[[ADD]]]
+  // CHECK:   linalg.yield %[[EXTRACT]]
+  %0 = "tosa.table"(%arg0, %arg1)  : (tensor<?xi8>, tensor<512xi8>)  -> (tensor<?xi8>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @table8_dyn_table
+func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [6]
+  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<6xi8>) outs(%[[INIT]] : tensor<6xi8>)
+  // CHECK: ^bb0(%[[ARG_IN:.+]]: i8, %[[ARG_INIT:.+]]: i8)
+  // CHECK:   %[[CAST:.+]] = arith.index_cast %[[ARG_IN]]
+  // CHECK:   %[[OFFSET:.+]] = arith.constant 128
+  // CHECK:   %[[ADD:.+]] = arith.addi %[[CAST]], %[[OFFSET]]
+  // CHECK:   %[[EXTRACT:.+]] = tensor.extract %arg1[%[[ADD]]]
+  // CHECK:   linalg.yield %[[EXTRACT]]
+  %0 = "tosa.table"(%arg0, %arg1)  : (tensor<6xi8>, tensor<?xi8>)  -> (tensor<6xi8>)
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @resize_nearest
 func @resize_nearest(%input: tensor<1x2x2x1xf32>) -> () {
   // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1]


        


More information about the Mlir-commits mailing list