[Mlir-commits] [mlir] 5a4e776 - [mlir][tosa] Added more shape inference for tosa ops

Rob Suderman llvmlistbot at llvm.org
Mon Jul 12 10:11:13 PDT 2021


Author: Rob Suderman
Date: 2021-07-12T10:04:49-07:00
New Revision: 5a4e7760101581c394aa4235c55885ec960c8b3b

URL: https://github.com/llvm/llvm-project/commit/5a4e7760101581c394aa4235c55885ec960c8b3b
DIFF: https://github.com/llvm/llvm-project/commit/5a4e7760101581c394aa4235c55885ec960c8b3b.diff

LOG: [mlir][tosa] Added more shape inference for tosa ops

Added shape inference for:
- scatter
- gather
- transpose
- slice
- pad
- concat
- reduction operations

Also updated reshape for more aggressive shape inference.

Differential Revision: https://reviews.llvm.org/D105383

Added: 
    mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Removed: 
    mlir/test/Dialect/Tosa/tosa_infer_shapes.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 06867ef199e11..76cd66aac064e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -32,7 +32,10 @@ include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
 //===----------------------------------------------------------------------===//
 // Operator: argmax
 //===----------------------------------------------------------------------===//
-def Tosa_ArgMaxOp : Tosa_Op<"argmax", [NoSideEffect]> {
+def Tosa_ArgMaxOp : Tosa_Op<"argmax", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect]> {
   let summary = "Perform argmax on the input.";
 
   let description = [{
@@ -173,7 +176,10 @@ def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [NoSideEffect]> {
 //===----------------------------------------------------------------------===//
 // Operator: fully_connected
 //===----------------------------------------------------------------------===//
-def Tosa_FullyConnectedOp : Tosa_Op<"fully_connected", [NoSideEffect]> {
+def Tosa_FullyConnectedOp : Tosa_Op<"fully_connected", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect]> {
   let summary = "Fully Connected operator";
 
   let description = [{
@@ -199,7 +205,10 @@ def Tosa_FullyConnectedOp : Tosa_Op<"fully_connected", [NoSideEffect]> {
 //===----------------------------------------------------------------------===//
 // Operator: matmul
 //===----------------------------------------------------------------------===//
-def Tosa_MatMulOp : Tosa_Op<"matmul", [NoSideEffect]> {
+def Tosa_MatMulOp : Tosa_Op<"matmul", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect]> {
   let summary = "Matrix multiplication with bias";
 
   let description = [{
@@ -589,8 +598,9 @@ def Tosa_LogicalLeftShiftOp : Tosa_Op<"logical_left_shift", [
 // Operator: logical_right_shift
 //===----------------------------------------------------------------------===//
 def Tosa_LogicalRightShiftOp : Tosa_Op<"logical_right_shift", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface, ["inferReturnTypeComponents"]>, ResultsBroadcastableShape,
-    NoSideEffect]> {
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    ResultsBroadcastableShape, NoSideEffect]> {
   let summary = "Elementwise Logical Right Shift";
 
   let description = [{
@@ -783,7 +793,10 @@ def Tosa_SubOp : Tosa_Op<"sub", [
 //===----------------------------------------------------------------------===//
 // Operator: table
 //===----------------------------------------------------------------------===//
-def Tosa_TableOp : Tosa_Op<"table", [NoSideEffect]> {
+def Tosa_TableOp : Tosa_Op<"table", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect]> {
   let summary = "Table lookup op";
 
   let description = [{
@@ -1178,7 +1191,10 @@ def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [
 //===----------------------------------------------------------------------===//
 // Operator: reduce_all
 //===----------------------------------------------------------------------===//
-def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [NoSideEffect]> {
+def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect]> {
   let summary = "Reduce All operator";
 
   let description = [{
@@ -1198,7 +1214,10 @@ def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [NoSideEffect]> {
 //===----------------------------------------------------------------------===//
 // Operator: reduce_any
 //===----------------------------------------------------------------------===//
-def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [NoSideEffect]> {
+def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect]> {
   let summary = "Reduce Any operator";
 
   let description = [{
@@ -1218,7 +1237,10 @@ def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [NoSideEffect]> {
 //===----------------------------------------------------------------------===//
 // Operator: reduce_max
 //===----------------------------------------------------------------------===//
-def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [NoSideEffect]> {
+def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect]> {
   let summary = "Reduce Max operator";
 
   let description = [{
@@ -1238,7 +1260,10 @@ def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [NoSideEffect]> {
 //===----------------------------------------------------------------------===//
 // Operator: reduce_min
 //===----------------------------------------------------------------------===//
-def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [NoSideEffect]> {
+def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect]> {
   let summary = "Reduce Min operator";
 
   let description = [{
@@ -1258,7 +1283,10 @@ def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [NoSideEffect]> {
 //===----------------------------------------------------------------------===//
 // Operator: reduce_prod
 //===----------------------------------------------------------------------===//
-def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [NoSideEffect]> {
+def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect]> {
   let summary = "Reduce Prod operator";
 
   let description = [{
@@ -1278,7 +1306,10 @@ def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [NoSideEffect]> {
 //===----------------------------------------------------------------------===//
 // Operator: reduce_sum
 //===----------------------------------------------------------------------===//
-def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [NoSideEffect]> {
+def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect]> {
   let summary = "Reduce Sum operator";
 
   let description = [{
@@ -1303,7 +1334,10 @@ def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [NoSideEffect]> {
 //===----------------------------------------------------------------------===//
 // Operator: concat
 //===----------------------------------------------------------------------===//
-def Tosa_ConcatOp : Tosa_Op<"concat", [NoSideEffect]> {
+def Tosa_ConcatOp : Tosa_Op<"concat", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect]> {
   let summary = "Concatenates tensors along one dimension.";
 
   let description = [{
@@ -1324,7 +1358,10 @@ def Tosa_ConcatOp : Tosa_Op<"concat", [NoSideEffect]> {
 //===----------------------------------------------------------------------===//
 // Operator: pad
 //===----------------------------------------------------------------------===//
-def Tosa_PadOp : Tosa_Op<"pad", [NoSideEffect]> {
+def Tosa_PadOp : Tosa_Op<"pad", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect]> {
   let summary = "Pads a tensor with zeros.";
 
   let description = [{
@@ -1396,7 +1433,9 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
 //===----------------------------------------------------------------------===//
 // Operator: slice
 //===----------------------------------------------------------------------===//
-def Tosa_SliceOp: Tosa_Op<"slice", [NoSideEffect]> {
+def Tosa_SliceOp: Tosa_Op<"slice", [
+      DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>, NoSideEffect]> {
   let summary = "Slice operator";
 
   let description = [{
@@ -1419,7 +1458,10 @@ def Tosa_SliceOp: Tosa_Op<"slice", [NoSideEffect]> {
 //===----------------------------------------------------------------------===//
 // Operator: tile
 //===----------------------------------------------------------------------===//
-def Tosa_TileOp: Tosa_Op<"tile", [NoSideEffect]> {
+def Tosa_TileOp: Tosa_Op<"tile", [
+      DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>, 
+      NoSideEffect]> {
   let summary = "Tile operator";
 
   let description = [{
@@ -1438,7 +1480,10 @@ def Tosa_TileOp: Tosa_Op<"tile", [NoSideEffect]> {
 //===----------------------------------------------------------------------===//
 // Operator: transpose
 //===----------------------------------------------------------------------===//
-def Tosa_TransposeOp : Tosa_Op<"transpose", [NoSideEffect]> {
+def Tosa_TransposeOp : Tosa_Op<"transpose", [
+      DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>, 
+      NoSideEffect]> {
   let summary = "Transpose operator";
 
   let description = [{
@@ -1463,7 +1508,10 @@ def Tosa_TransposeOp : Tosa_Op<"transpose", [NoSideEffect]> {
 //===----------------------------------------------------------------------===//
 // Operator: gather
 //===----------------------------------------------------------------------===//
-def Tosa_GatherOp : Tosa_Op<"gather", [NoSideEffect]> {
+def Tosa_GatherOp : Tosa_Op<"gather", [
+      DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>, 
+      NoSideEffect]> {
   let summary = "Gather operation,";
 
   let description = [{
@@ -1484,7 +1532,10 @@ def Tosa_GatherOp : Tosa_Op<"gather", [NoSideEffect]> {
 //===----------------------------------------------------------------------===//
 // Operator: scatter
 //===----------------------------------------------------------------------===//
-def Tosa_ScatterOp : Tosa_Op<"scatter", [NoSideEffect]> {
+def Tosa_ScatterOp : Tosa_Op<"scatter", [
+      DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+      NoSideEffect]> {
   let summary = "Scatter operation,";
 
   let description = [{

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index fd744372fceab..9126f1776ca21 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "mlir/Transforms/RegionUtils.h"
@@ -301,6 +302,260 @@ static void getI64Values(ArrayAttr arrayAttr, SmallVector<int64_t> &values) {
   }
 }
 
+LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
+    MLIRContext *context, ::llvm::Optional<Location> location,
+    ValueRange operands, DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  ShapedType inputTy = operands[0].getType().cast<ShapedType>();
+  IntegerAttr axis = attributes.get("axis").cast<IntegerAttr>();
+  int32_t axisVal = axis.getValue().getSExtValue();
+
+  if (!inputTy.hasRank()) {
+    inferredReturnShapes.push_back(ShapedTypeComponents());
+    return success();
+  }
+
+  SmallVector<int64_t> outShape;
+  outShape.reserve(inputTy.getRank() - 1);
+  for (int i = 0, s = inputTy.getRank(); i < s; i++) {
+    if (i == axisVal)
+      continue;
+    outShape.push_back(inputTy.getDimSize(i));
+  }
+
+  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
+  return success();
+}
+
+LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
+    MLIRContext *context, ::llvm::Optional<Location> location,
+    ValueRange operands, DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  // Infer all dimension sizes by reducing based on inputs.
+  int32_t axis =
+      attributes.get("axis").cast<IntegerAttr>().getValue().getSExtValue();
+  llvm::SmallVector<int64_t> outputShape;
+  bool hasRankedInput = false;
+  for (auto operand : operands) {
+    ShapedType operandTy = operand.getType().cast<ShapedType>();
+    if (!operandTy.hasRank())
+      continue;
+
+    // Copy the Operand's rank.
+    if (!hasRankedInput)
+      outputShape.resize(operandTy.getRank(), -1);
+
+    // Copy shapes until the dim is non-dynamic.
+    for (int i = 0, s = operandTy.getRank(); i < s; i++) {
+      if (i == axis || operandTy.isDynamicDim(i))
+        continue;
+      if (outputShape[i] == -1)
+        outputShape[i] = operandTy.getDimSize(i);
+      if (outputShape[i] != operandTy.getDimSize(i))
+        return failure();
+    }
+
+    hasRankedInput = true;
+  }
+
+  if (!hasRankedInput) {
+    inferredReturnShapes.push_back(ShapedTypeComponents());
+    return success();
+  }
+
+  // Determine the dimension size along the concatenation axis.
+  int concatDimSize = 0;
+  for (auto operand : operands) {
+    ShapedType operandTy = operand.getType().cast<ShapedType>();
+
+    // We need to know the length of the concatenation axis of all inputs to
+    // determine the dimension size of the output shape.
+    if (!operandTy.hasRank() || operandTy.isDynamicDim(axis)) {
+      concatDimSize = -1;
+      break;
+    }
+
+    concatDimSize += operandTy.getDimSize(axis);
+  }
+
+  outputShape[axis] = concatDimSize;
+
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  return success();
+}
+
+LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
+    MLIRContext *context, ::llvm::Optional<Location> location,
+    ValueRange operands, DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  ShapedType inputTy = operands[0].getType().cast<ShapedType>();
+  ShapedType weightTy = operands[1].getType().cast<ShapedType>();
+  ShapedType biasTy = operands[2].getType().cast<ShapedType>();
+
+  // All shapes are dynamic.
+  SmallVector<int64_t> outShape;
+  outShape.resize(2, -1);
+
+  if (inputTy.hasRank()) {
+    outShape[0] = inputTy.getDimSize(0);
+  }
+
+  if (weightTy.hasRank()) {
+    outShape[1] = weightTy.getDimSize(0);
+  }
+
+  if (biasTy.hasRank()) {
+    outShape[1] = outShape[1] == -1 ? biasTy.getDimSize(0) : outShape[1];
+  }
+
+  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
+  return success();
+}
+
+LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
+    MLIRContext *context, ::llvm::Optional<Location> location,
+    ValueRange operands, DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  ShapedType lhsTy = operands[0].getType().cast<ShapedType>();
+  ShapedType rhsTy = operands[1].getType().cast<ShapedType>();
+
+  // All shapes are dynamic.
+  SmallVector<int64_t> outShape;
+  outShape.resize(3, -1);
+
+  if (lhsTy.hasRank()) {
+    outShape[0] = lhsTy.getDimSize(0);
+    outShape[1] = lhsTy.getDimSize(1);
+  }
+
+  if (rhsTy.hasRank()) {
+    outShape[0] = outShape[0] == -1 ? rhsTy.getDimSize(0) : outShape[0];
+    outShape[2] = rhsTy.getDimSize(2);
+  }
+
+  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
+  return success();
+}
+
+LogicalResult tosa::PadOp::inferReturnTypeComponents(
+    MLIRContext *context, ::llvm::Optional<Location> location,
+    ValueRange operands, DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  ShapedType inputTy = operands[0].getType().cast<ShapedType>();
+  ShapedType paddingTy = operands[1].getType().cast<ShapedType>();
+  SmallVector<int64_t> outputShape;
+
+  // If both inputs have unknown shape, we cannot determine the shape of the
+  // output.
+  if (!inputTy.hasRank() && !paddingTy.hasRank()) {
+    inferredReturnShapes.push_back(ShapedTypeComponents());
+    return success();
+  }
+
+  // If the input rank is unknown we can info the output rank using the padding
+  // shape's first dim.
+  if (!inputTy.hasRank()) {
+    if (paddingTy.isDynamicDim(0)) {
+      inferredReturnShapes.push_back(ShapedTypeComponents());
+      return success();
+    }
+
+    outputShape.resize(paddingTy.getDimSize(0), -1);
+    inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+    return success();
+  }
+
+  DenseIntElementsAttr paddings;
+  // If the paddings value is not a constant, all dimensions must be dynamic.
+  if (!matchPattern(operands[1], m_Constant(&paddings))) {
+    outputShape.resize(inputTy.getRank(), -1);
+    inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+    return success();
+  }
+
+  SmallVector<int64_t> paddingValues;
+  for (auto val : paddings) {
+    paddingValues.push_back(val.getSExtValue());
+  }
+
+  outputShape.reserve(inputTy.getRank());
+  for (int i = 0, s = inputTy.getRank(); i < s; i++) {
+    if (inputTy.isDynamicDim(i)) {
+      outputShape.push_back(-1);
+      continue;
+    }
+
+    outputShape.push_back(inputTy.getDimSize(i) + paddingValues[i * 2] +
+                          paddingValues[i * 2 + 1]);
+  }
+
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  return success();
+}
+
+LogicalResult tosa::SliceOp::inferReturnTypeComponents(
+    MLIRContext *context, ::llvm::Optional<Location> location,
+    ValueRange operands, DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  auto sizes = attributes.get("size").cast<ArrayAttr>().getValue();
+  SmallVector<int64_t> outputShape;
+  outputShape.reserve(sizes.size());
+  for (auto val : sizes) {
+    outputShape.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
+  }
+
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  return success();
+}
+
+LogicalResult tosa::TableOp::inferReturnTypeComponents(
+    MLIRContext *context, ::llvm::Optional<Location> location,
+    ValueRange operands, DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  ShapedType inputTy = operands[0].getType().cast<ShapedType>();
+
+  if (!inputTy.hasRank()) {
+    inferredReturnShapes.push_back(ShapedTypeComponents());
+    return success();
+  }
+
+  inferredReturnShapes.push_back(inputTy.getShape());
+  return success();
+}
+
+LogicalResult tosa::TileOp::inferReturnTypeComponents(
+    MLIRContext *context, ::llvm::Optional<Location> location,
+    ValueRange operands, DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  auto multiples = attributes.get("multiples").cast<ArrayAttr>().getValue();
+  ShapedType inputTy = operands[0].getType().cast<ShapedType>();
+  SmallVector<int64_t> outputShape;
+  if (!inputTy.hasRank()) {
+    outputShape.resize(multiples.size(), -1);
+    inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+    return success();
+  }
+
+  // We need the multiple values to determine the output shape.
+  SmallVector<int64_t> multipleValues;
+  multipleValues.reserve(multiples.size());
+  for (auto val : multiples) {
+    multipleValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
+  }
+
+  // Any non dynamic dimension can be multiplied to a known size.
+  outputShape.reserve(multiples.size());
+  for (int i = 0, s = inputTy.getRank(); i < s; i++) {
+    int dim = inputTy.getDimSize(i);
+    if (dim != -1)
+      dim *= multipleValues[i];
+    outputShape.push_back(dim);
+  }
+
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  return success();
+}
+
 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
     MLIRContext *context, ::llvm::Optional<Location> location,
     ValueRange operands, DictionaryAttr attributes, RegionRange regions,
@@ -339,6 +594,163 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
+    MLIRContext *context, ::llvm::Optional<Location> location,
+    ValueRange operands, DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  ShapedType inputTy = operands[0].getType().cast<ShapedType>();
+  ShapedType permsTy = operands[1].getType().cast<ShapedType>();
+
+  // If input rank and permutation length is unknown, the output rank is
+  // unknown.
+  if (!inputTy.hasRank() && (!permsTy.hasRank() || permsTy.isDynamicDim(0))) {
+    inferredReturnShapes.push_back(ShapedTypeComponents());
+    return success();
+  }
+
+  // Without the input dims we cannot determine the output dim sizes but we
+  // can determine the output rank.
+  SmallVector<int64_t> outputShape;
+  if (!inputTy.hasRank()) {
+    outputShape.resize(permsTy.getDimSize(0), -1);
+    inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+    return success();
+  }
+
+  // Rank-0 means no permutations matter.
+  if (inputTy.getRank() == 0) {
+    inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+    return success();
+  }
+
+  // Check whether the input dimensions are all the same.
+  bool allTheSame = true;
+  for (int i = 1, s = inputTy.getRank(); i < s; i++) {
+    if (inputTy.getDimSize(0) != inputTy.getDimSize(i)) {
+      allTheSame = false;
+      break;
+    }
+  }
+
+  // If all of the input dimensions are the same we don't care about the
+  // permutation.
+  if (allTheSame) {
+    outputShape.resize(inputTy.getRank(), inputTy.getDimSize(0));
+    inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+    return success();
+  }
+
+  DenseIntElementsAttr perms;
+  outputShape.resize(inputTy.getRank(), -1);
+  // If the permuations are a constant we can directly determine the output
+  // shape.
+  if (matchPattern(operands[1], m_Constant(&perms))) {
+    llvm::SmallVector<int64_t> permValues;
+    for (auto val : perms) {
+      permValues.push_back(val.getSExtValue());
+    }
+
+    outputShape.reserve(inputTy.getRank());
+    for (int i = 0, s = inputTy.getRank(); i < s; i++) {
+      outputShape[i] = inputTy.getDimSize(permValues[i]);
+    }
+  }
+
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  return success();
+}
+
+LogicalResult tosa::GatherOp::inferReturnTypeComponents(
+    MLIRContext *context, ::llvm::Optional<Location> location,
+    ValueRange operands, DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  llvm::SmallVector<int64_t> outputShape;
+  outputShape.resize(3, -1);
+
+  if (auto ty = operands[0].getType().dyn_cast<RankedTensorType>()) {
+    outputShape[0] = ty.getDimSize(0);
+    outputShape[2] = ty.getDimSize(2);
+  }
+
+  if (auto ty = operands[1].getType().dyn_cast<RankedTensorType>()) {
+    if (outputShape[0] == -1)
+      outputShape[0] = ty.getDimSize(0);
+    if (outputShape[1] == -1)
+      outputShape[1] = ty.getDimSize(1);
+  }
+
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  return success();
+}
+
+LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
+    MLIRContext *context, ::llvm::Optional<Location> location,
+    ValueRange operands, DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  llvm::SmallVector<int64_t> outputShape;
+  outputShape.resize(3, -1);
+
+  if (auto ty = operands[0].getType().dyn_cast<RankedTensorType>()) {
+    outputShape[0] = ty.getDimSize(0);
+    outputShape[1] = ty.getDimSize(1);
+    outputShape[2] = ty.getDimSize(2);
+  }
+
+  if (auto ty = operands[1].getType().dyn_cast<RankedTensorType>()) {
+    if (outputShape[0] == -1)
+      outputShape[0] = ty.getDimSize(0);
+  }
+
+  if (auto ty = operands[2].getType().dyn_cast<RankedTensorType>()) {
+    if (outputShape[0] == -1)
+      outputShape[0] = ty.getDimSize(0);
+    if (outputShape[2] == -1)
+      outputShape[2] = ty.getDimSize(2);
+  }
+
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  return success();
+}
+
+static LogicalResult ReduceInferReturnTypes(
+    Value operand, IntegerAttr axis,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  auto operandTy = operand.getType().cast<ShapedType>();
+  if (!operandTy.hasRank()) {
+    inferredReturnShapes.push_back(ShapedTypeComponents());
+    return success();
+  }
+
+  int64_t axisVal = axis.getValue().getSExtValue();
+  SmallVector<int64_t> outputShape;
+  outputShape.reserve(operandTy.getRank());
+  for (auto dim : operandTy.getShape()) {
+    outputShape.push_back(dim);
+  }
+
+  outputShape[axisVal] = 1;
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  return success();
+}
+
+#define REDUCE_SHAPE_INFER(OP)                                                 \
+  LogicalResult OP::inferReturnTypeComponents(                                 \
+      MLIRContext *context, ::llvm::Optional<Location> location,               \
+      ValueRange operands, DictionaryAttr attributes, RegionRange regions,     \
+      SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {           \
+    return ReduceInferReturnTypes(operands[0],                                 \
+                                  attributes.get("axis").cast<IntegerAttr>(),  \
+                                  inferredReturnShapes);                       \
+  }
+
+REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
+REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
+REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
+REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
+REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
+REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
+#undef REDUCE_SHAPE_INFER
+
 static LogicalResult resolveBroadcastShape(ValueRange operands,
                                            SmallVector<int64_t> &outShape) {
   int64_t outRank = 0;

diff  --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
new file mode 100644
index 0000000000000..bfbbe07d42fde
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -0,0 +1,662 @@
+// RUN: mlir-opt --split-input-file --tosa-infer-shapes %s | FileCheck %s
+
+// CHECK-LABEL: @test_return
+func @test_return(%arg0 : tensor<4xf32>) -> tensor<*xf32> {
+  // CHECK: [[LOG:%.+]] = "tosa.log"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK: tensor.cast [[LOG]] : tensor<4xf32> to tensor<*xf32>
+  %0 = "tosa.log"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_multiple
+func @test_multiple(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>, %arg2 : tensor<f32>) -> tensor<*xf32> {
+  // CHECK: [[ADD:%.+]] = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+  %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
+
+  // CHECK: [[LOG:%.+]] = "tosa.log"(%0) : (tensor<4xf32>) -> tensor<4xf32>
+  %1 = "tosa.log"(%0) : (tensor<*xf32>) -> tensor<*xf32>
+
+  // CHECK: [[SUB:%.+]] = "tosa.sub"(%0, %arg2) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
+  %2 = "tosa.sub"(%0, %arg2) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_unary_f32
+func @test_unary_f32(%arg0 : tensor<4xf32>) -> () {
+  // CHECK: "tosa.abs"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %0 = "tosa.abs"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.ceil"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %1 = "tosa.ceil"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.clamp"(%arg0) {{.+}} : (tensor<4xf32>) -> tensor<4xf32>
+  %2 = "tosa.clamp"(%arg0) { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.exp"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %3 = "tosa.exp"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.floor"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %4 = "tosa.floor"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.log"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %5 = "tosa.log"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.negate"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %6 = "tosa.negate"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.reciprocal"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %7 = "tosa.reciprocal"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.reluN"(%arg0) {{.+}} : (tensor<4xf32>) -> tensor<4xf32>
+  %8 = "tosa.reluN"(%arg0) { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<4xf32>) -> tensor<4xf32>
+  %9 = "tosa.reverse"(%arg0) { axis = 0 : i64 } : (tensor<4xf32>) -> tensor<?xf32>
+
+  // CHECK: "tosa.rsqrt"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %10 = "tosa.rsqrt"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.tanh"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %11 = "tosa.tanh"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.sigmoid"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %12 = "tosa.sigmoid"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_unary_i32
+func @test_unary_i32(%arg0 : tensor<4xi32>) -> () {
+  // CHECK: "tosa.abs"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
+  %0 = "tosa.abs"(%arg0) : (tensor<4xi32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.bitwise_not"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
+  %1 = "tosa.bitwise_not"(%arg0) : (tensor<4xi32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.clamp"(%arg0) {{.+}} : (tensor<4xi32>) -> tensor<4xi32>
+  %2 = "tosa.clamp"(%arg0) { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xi32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.clz"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
+  %3 = "tosa.clz"(%arg0) : (tensor<4xi32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.negate"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
+  %4 = "tosa.negate"(%arg0) : (tensor<4xi32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.reluN"(%arg0) {{.+}} : (tensor<4xi32>) -> tensor<4xi32>
+  %5 = "tosa.reluN"(%arg0) { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xi32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<4xi32>) -> tensor<4xi32>
+  %6 = "tosa.reverse"(%arg0) { axis = 0 : i64 } : (tensor<4xi32>) -> tensor<?xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_unary_i1
+func @test_unary_i1(%arg0 : tensor<4xi1>) -> () {
+  // CHECK: "tosa.logical_not"(%arg0) : (tensor<4xi1>) -> tensor<4xi1>
+  %0 = "tosa.logical_not"(%arg0) : (tensor<4xi1>) -> tensor<*xi1>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_binary_scalar_f32
+func @test_binary_scalar_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<f32>) -> () {
+  // CHECK: "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
+  %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
+  %1 = "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
+  %2 = "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.mul"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
+  %3 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
+  %4 = "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
+  %5 = "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<4xi1>
+  %6 = "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<*xi1>
+
+  // CHECK: "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<4xi1>
+  %7 = "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<*xi1>
+
+  // CHECK: "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<4xi1>
+  %8 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<*xi1>
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_binary_broadcast_f32
+func @test_binary_broadcast_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>) -> () {
+  // CHECK: "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+  %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+  %1 = "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+  %2 = "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.mul"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+  %3 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 } : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+  %4 = "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+  %5 = "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1>
+  %6 = "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1>
+
+  // CHECK: "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1>
+  %7 = "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1>
+
+  // CHECK: "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1>
+  %8 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1>
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_binary_i32
+func @test_binary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<i32>) -> () {
+  // CHECK: "tosa.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
+  %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.bitwise_and"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
+  %1 = "tosa.bitwise_and"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.bitwise_or"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
+  %2 = "tosa.bitwise_or"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.bitwise_xor"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
+  %3 = "tosa.bitwise_xor"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<4xi1>
+  %4 = "tosa.equal"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<*xi1>
+
+  // CHECK: "tosa.greater"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<4xi1>
+  %5 = "tosa.greater"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<*xi1>
+
+  // CHECK: "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<4xi1>
+  %6 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<*xi1>
+
+  // CHECK: "tosa.logical_left_shift"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
+  %7 = "tosa.logical_left_shift"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.logical_right_shift"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
+  %8 = "tosa.logical_right_shift"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.maximum"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
+  %9 = "tosa.maximum"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.minimum"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
+  %10 = "tosa.minimum"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.mul"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
+  %11 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.pow"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
+  %12 = "tosa.pow"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+
+  // CHECK:  "tosa.sub"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
+  %13 = "tosa.sub"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_binary_i1
+func @test_binary_i1(%arg0 : tensor<4xi1>, %arg1 : tensor<i1>) -> () {
+  // CHECK "tosa.logical_and"(%arg0, %arg1) : (tensor<4xi1>, tensor<i1>) -> tensor<4xi1>
+  %0 = "tosa.logical_and"(%arg0, %arg1): (tensor<4xi1>, tensor<i1>) -> tensor<*xi1>
+
+  // CHECK "tosa.logical_or"(%arg0, %arg1) : (tensor<4xi1>, tensor<i1>) -> tensor<4xi1>
+  %1 = "tosa.logical_or"(%arg0, %arg1): (tensor<4xi1>, tensor<i1>) -> tensor<*xi1>
+
+  // CHECK "tosa.logical_xor"(%arg0, %arg1) : (tensor<4xi1>, tensor<i1>) -> tensor<*4i1>
+  %2 = "tosa.logical_xor"(%arg0, %arg1): (tensor<4xi1>, tensor<i1>) -> tensor<*xi1>
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_select_i32
+func @test_select_i32(%arg0 : tensor<4xi1>, %arg1 : tensor<i32>, %arg2 : tensor<4xi32>) -> () {
+  // CHECK: "tosa.select"(%arg0, %arg1, %arg2) : (tensor<4xi1>, tensor<i32>, tensor<4xi32>) -> tensor<4xi32>
+  %0 = "tosa.select"(%arg0, %arg1, %arg2): (tensor<4xi1>, tensor<i32>, tensor<4xi32>) -> tensor<*xi32>
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_static_argmax
+func @test_static_argmax(%arg0 : tensor<2x3xi32>) -> () {
+  // CHECK: "tosa.argmax"(%arg0) {axis = 0 : i64} : (tensor<2x3xi32>) -> tensor<3xi32>
+  %0 = "tosa.argmax"(%arg0) {axis = 0 : i64} : (tensor<2x3xi32>) -> tensor<?xi32>
+
+  // CHECK: "tosa.argmax"(%arg0) {axis = 1 : i64} : (tensor<2x3xi32>) -> tensor<2xi32>
+  %1 = "tosa.argmax"(%arg0) {axis = 1 : i64} : (tensor<2x3xi32>) -> tensor<?xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_dynamic_argmax
+func @test_dynamic_argmax(%arg0 : tensor<2x?xi32>) -> () {
+  // CHECK: "tosa.argmax"(%arg0) {axis = 0 : i64} : (tensor<2x?xi32>) -> tensor<?xi32>
+  %0 = "tosa.argmax"(%arg0) {axis = 0 : i64} : (tensor<2x?xi32>) -> tensor<?xi32>
+
+  // CHECK: "tosa.argmax"(%arg0) {axis = 1 : i64} : (tensor<2x?xi32>) -> tensor<2xi32>
+  %1 = "tosa.argmax"(%arg0) {axis = 1 : i64} : (tensor<2x?xi32>) -> tensor<?xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_static_fully_connected
+func @test_static_fully_connected(%arg0 : tensor<3x4xf32>, %arg1 : tensor<5x4xf32>, %arg2 : tensor<5xf32>) -> () {
+  // CHECK: "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<3x4xf32>, tensor<5x4xf32>, tensor<5xf32>) -> tensor<3x5xf32>
+  %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<3x4xf32>, tensor<5x4xf32>, tensor<5xf32>) -> tensor<?x?xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_static_input_fully_connected
+func @test_static_input_fully_connected(%arg0 : tensor<3x4xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?xf32>) -> () {
+  // CHECK: "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<3x4xf32>, tensor<?x?xf32>, tensor<?xf32>) -> tensor<3x?xf32>
+  %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<3x4xf32>, tensor<?x?xf32>, tensor<?xf32>) -> tensor<?x?xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_static_weight_fully_connected
+func @test_static_weight_fully_connected(%arg0 : tensor<?x?xf32>, %arg1 : tensor<5x4xf32>, %arg2 : tensor<?xf32>) -> () {
+  // CHECK: "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<?x?xf32>, tensor<5x4xf32>, tensor<?xf32>) -> tensor<?x5xf32>
+  %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<?x?xf32>, tensor<5x4xf32>, tensor<?xf32>) -> tensor<?x?xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_static_bias_fully_connected
+func @test_static_bias_fully_connected(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<5xf32>) -> () {
+  // CHECK: "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<5xf32>) -> tensor<?x5xf32>
+  %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<5xf32>) -> tensor<?x?xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_static_out_fully_connected
+func @test_static_out_fully_connected(%arg0 : tensor<3x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<5xf32>) -> () {
+  // CHECK: "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<3x?xf32>, tensor<?x?xf32>, tensor<5xf32>) -> tensor<3x5xf32>
+  %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<3x?xf32>, tensor<?x?xf32>, tensor<5xf32>) -> tensor<?x?xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_static_matmul
+func @test_static_matmul(%arg0 : tensor<2x3x4xi32>, %arg1 : tensor<2x4x5xi32>) -> () {
+  // CHECK: "tosa.matmul"(%arg0, %arg1) : (tensor<2x3x4xi32>, tensor<2x4x5xi32>) -> tensor<2x3x5xi32>
+  %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<2x3x4xi32>, tensor<2x4x5xi32>) -> tensor<?x?x?xi32>
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_dynamic_lhs_matmul
+func @test_dynamic_lhs_matmul(%arg0 : tensor<?x?x?xi32>, %arg1 : tensor<2x4x5xi32>) -> () {
+  // CHECK: "tosa.matmul"(%arg0, %arg1) : (tensor<?x?x?xi32>, tensor<2x4x5xi32>) -> tensor<2x?x5xi32>
+  %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<?x?x?xi32>, tensor<2x4x5xi32>) -> tensor<?x?x?xi32>
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_dynamic_rhs_matmul
+func @test_dynamic_rhs_matmul(%arg0 : tensor<2x3x4xi32>, %arg1 : tensor<?x?x?xi32>) -> () {
+  // CHECK: "tosa.matmul"(%arg0, %arg1) : (tensor<2x3x4xi32>, tensor<?x?x?xi32>) -> tensor<2x3x?xi32>
+  %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<2x3x4xi32>, tensor<?x?x?xi32>) -> tensor<?x?x?xi32>
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_dynamic_mixed_matmul
+func @test_dynamic_mixed_matmul(%arg0 : tensor<?x3x?xi32>, %arg1 : tensor<?x?x5xi32>) -> () {
+  // CHECK: "tosa.matmul"(%arg0, %arg1) : (tensor<?x3x?xi32>, tensor<?x?x5xi32>) -> tensor<?x3x5xi32>
+  %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<?x3x?xi32>, tensor<?x?x5xi32>) -> tensor<?x?x?xi32>
+
+  return
+}
+
+// -----
+
+// CHECK-LABLE: @test_table_static
+func @test_table_static(%arg0 : tensor<4x5xi16>, %arg1 : tensor<513xi16>) -> () {
+  // CHECK:"tosa.table"(%arg0, %arg1) : (tensor<4x5xi16>, tensor<513xi16>) -> tensor<4x5xi16>
+  %0 = "tosa.table"(%arg0, %arg1) : (tensor<4x5xi16>, tensor<513xi16>) -> tensor<?x?xi16>
+  return
+}
+
+// -----
+
+// CHECK-LABLE: @test_table_dynamic
+func @test_table_dynamic(%arg0 : tensor<4x?xi16>, %arg1 : tensor<513xi16>) -> () {
+  // CHECK:"tosa.table"(%arg0, %arg1) : (tensor<4x?xi16>, tensor<513xi16>) -> tensor<4x?xi16>
+  %0 = "tosa.table"(%arg0, %arg1) : (tensor<4x?xi16>, tensor<513xi16>) -> tensor<?x?xi16>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_static_reshape
+func @test_static_reshape(%arg0 : tensor<4x4xi32>) -> () {
+  // CHECK: "tosa.reshape"(%arg0) {new_shape = [16]} : (tensor<4x4xi32>) -> tensor<16xi32>
+  %0 = "tosa.reshape"(%arg0) {new_shape = [16]} : (tensor<4x4xi32>)  -> tensor<?xi32>
+
+  // CHECK: "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<4x4xi32>) -> tensor<16xi32>
+  %1 = "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<4x4xi32>)  -> tensor<?xi32>
+
+  // CHECK: "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<4x4xi32>) -> tensor<2x8xi32>
+  %2 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<4x4xi32>)  -> tensor<?x?xi32>
+
+  return
+}
+// -----
+
+// CHECK-LABEL: @test_dynamic_reshape
+func @test_dynamic_reshape(%arg0 : tensor<4x?xi32>) -> () {
+  // CHECK: %0 = "tosa.reshape"(%arg0) {new_shape = [16]} : (tensor<4x?xi32>) -> tensor<16xi32>
+  %0 = "tosa.reshape"(%arg0) {new_shape = [16]} : (tensor<4x?xi32>)  -> tensor<?xi32>
+
+  // CHECK: %1 = "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<4x?xi32>) -> tensor<?xi32>
+  %1 = "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<4x?xi32>)  -> tensor<?xi32>
+
+  // CHECK: %2 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<4x?xi32>) -> tensor<2x?xi32>
+  %2 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<4x?xi32>)  -> tensor<?x?xi32>
+
+  return
+}
+
+// -----
+
+// CHECK: @test_reduce_binary
+func @test_reduce_binary(%arg0 : tensor<2x3x?x?xi1>) -> () {
+  // CHECK: "tosa.reduce_all"(%arg0) {axis = 0 : i64} : (tensor<2x3x?x?xi1>) -> tensor<1x3x?x?xi1>
+  %0 = "tosa.reduce_all"(%arg0) {axis = 0 : i64} : (tensor<2x3x?x?xi1>) -> tensor<?x?x?x?xi1>
+
+  // CHECK: "tosa.reduce_all"(%arg0) {axis = 1 : i64} : (tensor<2x3x?x?xi1>) -> tensor<2x1x?x?xi1>
+  %1 = "tosa.reduce_all"(%arg0) {axis = 1 : i64} : (tensor<2x3x?x?xi1>) -> tensor<?x?x?x?xi1>
+
+  // CHECK: "tosa.reduce_all"(%arg0) {axis = 2 : i64} : (tensor<2x3x?x?xi1>) -> tensor<2x3x1x?xi1>
+  %2 = "tosa.reduce_all"(%arg0) {axis = 2 : i64} : (tensor<2x3x?x?xi1>) -> tensor<?x?x?x?xi1>
+
+  // CHECK: "tosa.reduce_all"(%arg0) {axis = 3 : i64} : (tensor<2x3x?x?xi1>) -> tensor<2x3x?x1xi1>
+  %3 = "tosa.reduce_all"(%arg0) {axis = 3 : i64} : (tensor<2x3x?x?xi1>) -> tensor<?x?x?x?xi1>
+
+  // CHECK: "tosa.reduce_any"(%arg0) {axis = 0 : i64} : (tensor<2x3x?x?xi1>) -> tensor<1x3x?x?xi1>
+  %4 = "tosa.reduce_any"(%arg0) {axis = 0 : i64} : (tensor<2x3x?x?xi1>) -> tensor<?x?x?x?xi1>
+
+  return
+}
+
+// -----
+
+// CHECK: @test_reduce_float
+func @test_reduce_float(%arg0 : tensor<2x3x?x?xf32>) -> () {
+  // CHECK: "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<2x3x?x?xf32>) -> tensor<1x3x?x?xf32>
+  %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<2x3x?x?xf32>) -> tensor<?x?x?x?xf32>
+
+  // CHECK: "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<2x3x?x?xf32>) -> tensor<2x1x?x?xf32>
+  %1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<2x3x?x?xf32>) -> tensor<?x?x?x?xf32>
+
+  // CHECK: "tosa.reduce_sum"(%arg0) {axis = 2 : i64} : (tensor<2x3x?x?xf32>) -> tensor<2x3x1x?xf32>
+  %2 = "tosa.reduce_sum"(%arg0) {axis = 2 : i64} : (tensor<2x3x?x?xf32>) -> tensor<?x?x?x?xf32>
+
+  // CHECK: "tosa.reduce_sum"(%arg0) {axis = 3 : i64} : (tensor<2x3x?x?xf32>) -> tensor<2x3x?x1xf32>
+  %3 = "tosa.reduce_sum"(%arg0) {axis = 3 : i64} : (tensor<2x3x?x?xf32>) -> tensor<?x?x?x?xf32>
+
+  // CHECK: "tosa.reduce_max"(%arg0) {axis = 3 : i64} : (tensor<2x3x?x?xf32>) -> tensor<2x3x?x1xf32>
+  %4 = "tosa.reduce_max"(%arg0) {axis = 3 : i64} : (tensor<2x3x?x?xf32>) -> tensor<?x?x?x?xf32>
+
+  // CHECK: "tosa.reduce_min"(%arg0) {axis = 3 : i64} : (tensor<2x3x?x?xf32>) -> tensor<2x3x?x1xf32>
+  %5 = "tosa.reduce_min"(%arg0) {axis = 3 : i64} : (tensor<2x3x?x?xf32>) -> tensor<?x?x?x?xf32>
+
+  // CHECK: "tosa.reduce_prod"(%arg0) {axis = 3 : i64} : (tensor<2x3x?x?xf32>) -> tensor<2x3x?x1xf32>
+  %6 = "tosa.reduce_prod"(%arg0) {axis = 3 : i64} : (tensor<2x3x?x?xf32>) -> tensor<?x?x?x?xf32>
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_concat
+func @test_concat(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> () {
+  // CHECK: "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<3x2xf32>
+  %0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_concat_dynamic
+func @test_concat_dynamic(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x?xf32>) -> () {
+  // CHECK: "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<3x2xf32>
+  %0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<?x?xf32>
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_concat_dynamic_axis
+func @test_concat_dynamic_axis(%arg0 : tensor<?x2xf32>, %arg1 : tensor<2x2xf32>) -> () {
+  // CHECK: "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x2xf32>, tensor<2x2xf32>) -> tensor<?x2xf32>
+  %0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x2xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_concat_axis_1
+func @test_concat_axis_1(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> () {
+  // CHECK: "tosa.concat"(%arg0, %arg1) {axis = 1 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<2x3xf32>
+  %0 = "tosa.concat"(%arg0, %arg1) {axis = 1 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_concat_failure
+func @test_concat_failure(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> () {
+  // CHECK: "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
+  %0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_padding_no_const
+func @test_padding_no_const(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xi32>) -> () {
+  // CHECK: "tosa.pad"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
+  %0 = "tosa.pad"(%arg0, %arg1)  : (tensor<1x2xf32>, tensor<2x2xi32>)  -> (tensor<?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL:@test_padding_dynamic_input
+func @test_padding_dynamic_input(%arg0 : tensor<1x?xf32>) -> () {
+  %0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+  // CHECK: "tosa.pad"(%arg0, %cst) : (tensor<1x?xf32>, tensor<2x2xi32>) -> tensor<4x?xf32>
+  %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x?xf32>, tensor<2x2xi32>)  -> (tensor<?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_padding_simple
+func @test_padding_simple(%arg0 : tensor<1x2xf32>) -> () {
+  %0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+  // CHECK: "tosa.pad"(%arg0, %cst) : (tensor<1x2xf32>, tensor<2x2xi32>) -> tensor<4x9xf32>
+  %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x2xf32>, tensor<2x2xi32>)  -> (tensor<?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_slice
+func @test_slice(%arg0 : tensor<?xi32>) -> () {
+  // CHECK: "tosa.slice"(%arg0) {size = [2], start = [1]} : (tensor<?xi32>) -> tensor<2xi32>
+  %0 = "tosa.slice"(%arg0) { size = [2], start = [1] } : (tensor<?xi32>) -> tensor<?xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_tile
+func @test_tile(%arg0 : tensor<2x3x?xi32>) -> () {
+  // CHECK: "tosa.tile"(%arg0) {multiples = [2, 1, 5]} : (tensor<2x3x?xi32>) -> tensor<4x3x?xi32>
+  %0 = "tosa.tile"(%arg0) {multiples = [2, 1, 5]} : (tensor<2x3x?xi32>)  -> (tensor<?x?x?xi32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_transpose_same
+func @test_transpose_same(%arg0 : tensor<4x4x4xi32>, %arg1 : tensor<3xi32>) -> () {
+  // CHECK: "tosa.transpose"(%arg0, %arg1) : (tensor<4x4x4xi32>, tensor<3xi32>) -> tensor<4x4x4xi32>
+  %0 = "tosa.transpose"(%arg0, %arg1) : (tensor<4x4x4xi32>, tensor<3xi32>)  -> (tensor<?x?x?xi32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_transpose_perm_unknown
+func @test_transpose_perm_unknown(%arg0 : tensor<4x4x5xi32>, %arg1 : tensor<3xi32>) -> () {
+  // CHECK: "tosa.transpose"(%arg0, %arg1) : (tensor<4x4x5xi32>, tensor<3xi32>) -> tensor<?x?x?xi32>
+  %0 = "tosa.transpose"(%arg0, %arg1) : (tensor<4x4x5xi32>, tensor<3xi32>)  -> (tensor<?x?x?xi32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_transpose_static
+func @test_transpose_static(%arg0 : tensor<3x4x5xi32>) -> () {
+  %0 = constant dense<[2, 1, 0]> : tensor<3xi32>
+  // CHECK: "tosa.transpose"(%arg0, %cst) : (tensor<3x4x5xi32>, tensor<3xi32>) -> tensor<5x4x3xi32>
+  %1 = "tosa.transpose"(%arg0, %0) : (tensor<3x4x5xi32>, tensor<3xi32>)  -> (tensor<?x?x?xi32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @gather_static
+func @gather_static(%arg0 : tensor<3x4x5xi32>, %arg1 : tensor<3x6xi32>) {
+  // CHECK: "tosa.gather"(%arg0, %arg1) : (tensor<3x4x5xi32>, tensor<3x6xi32>) -> tensor<3x6x5xi32>
+  %0 = "tosa.gather"(%arg0, %arg1) : (tensor<3x4x5xi32>, tensor<3x6xi32>)  -> (tensor<?x?x?xi32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @gather_dynamic_values
+func @gather_dynamic_values(%arg0 : tensor<?x?x?xi32>, %arg1 : tensor<3x6xi32>) {
+  // CHECK: "tosa.gather"(%arg0, %arg1) : (tensor<?x?x?xi32>, tensor<3x6xi32>) -> tensor<3x6x?xi32>
+  %0 = "tosa.gather"(%arg0, %arg1) : (tensor<?x?x?xi32>, tensor<3x6xi32>)  -> (tensor<?x?x?xi32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @gather_dynamic_indices
+func @gather_dynamic_indices(%arg0 : tensor<3x4x5xi32>, %arg1 : tensor<?x?xi32>) {
+  // CHECK: "tosa.gather"(%arg0, %arg1) : (tensor<3x4x5xi32>, tensor<?x?xi32>) -> tensor<3x?x5xi32>
+  %0 = "tosa.gather"(%arg0, %arg1) : (tensor<3x4x5xi32>, tensor<?x?xi32>)  -> (tensor<?x?x?xi32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @gather_minimum_info
+func @gather_minimum_info(%arg0 : tensor<3x?x5xi32>, %arg1 : tensor<?x6xi32>) {
+  // CHECK: "tosa.gather"(%arg0, %arg1) : (tensor<3x?x5xi32>, tensor<?x6xi32>) -> tensor<3x6x5xi32>
+  %0 = "tosa.gather"(%arg0, %arg1) : (tensor<3x?x5xi32>, tensor<?x6xi32>)  -> (tensor<?x?x?xi32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_static
+func @scatter_static(%arg0 : tensor<3x4x5xi32>, %arg1 : tensor<3x6xi32>, %arg2 : tensor<3x6x5xi32>) {
+  // CHECK: "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<3x4x5xi32>
+  %0 = "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>)  -> (tensor<?x?x?xi32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_static_values
+func @scatter_static_values(%arg0 : tensor<3x4x5xi32>, %arg1 : tensor<?x?xi32>, %arg2 : tensor<?x?x?xi32>) {
+  // CHECK: "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor<3x4x5xi32>, tensor<?x?xi32>, tensor<?x?x?xi32>) -> tensor<3x4x5xi32>
+  %0 = "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor<3x4x5xi32>, tensor<?x?xi32>, tensor<?x?x?xi32>)  -> (tensor<?x?x?xi32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_static_indices
+func @scatter_static_indices(%arg0 : tensor<?x?x?xi32>, %arg1 : tensor<3x6xi32>, %arg2 : tensor<?x?x?xi32>) {
+  // CHECK: "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor<?x?x?xi32>, tensor<3x6xi32>, tensor<?x?x?xi32>) -> tensor<3x?x?xi32>
+  %0 = "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor<?x?x?xi32>, tensor<3x6xi32>, tensor<?x?x?xi32>)  -> (tensor<?x?x?xi32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_static_input
+func @scatter_static_input(%arg0 : tensor<?x?x?xi32>, %arg1 : tensor<?x?xi32>, %arg2 : tensor<3x6x5xi32>) {
+  // CHECK: "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor<?x?x?xi32>, tensor<?x?xi32>, tensor<3x6x5xi32>) -> tensor<3x?x5xi32>
+  %0 = "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor<?x?x?xi32>, tensor<?x?xi32>, tensor<3x6x5xi32>)  -> (tensor<?x?x?xi32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_minimum_static
+func @scatter_minimum_static(%arg0 : tensor<?x4x?xi32>, %arg1 : tensor<3x?xi32>, %arg2 : tensor<?x?x5xi32>) {
+  // CHECK: "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor<?x4x?xi32>, tensor<3x?xi32>, tensor<?x?x5xi32>) -> tensor<3x4x5xi32>
+  %0 = "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor<?x4x?xi32>, tensor<3x?xi32>, tensor<?x?x5xi32>)  -> (tensor<?x?x?xi32>)
+  return
+}

diff  --git a/mlir/test/Dialect/Tosa/tosa_infer_shapes.mlir b/mlir/test/Dialect/Tosa/tosa_infer_shapes.mlir
deleted file mode 100644
index e73c79cb3ceef..0000000000000
--- a/mlir/test/Dialect/Tosa/tosa_infer_shapes.mlir
+++ /dev/null
@@ -1,278 +0,0 @@
-// RUN: mlir-opt --split-input-file --tosa-infer-shapes %s | FileCheck %s
-
-// CHECK-LABEL: @test_return
-func @test_return(%arg0 : tensor<4xf32>) -> tensor<*xf32> {
-  // CHECK: [[LOG:%.+]] = "tosa.log"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  // CHECK: tensor.cast [[LOG]] : tensor<4xf32> to tensor<*xf32>
-  %0 = "tosa.log"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
-  return %0 : tensor<*xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @test_multiple
-func @test_multiple(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>, %arg2 : tensor<f32>) -> tensor<*xf32> {
-  // CHECK: [[ADD:%.+]] = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
-  %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
-
-  // CHECK: [[LOG:%.+]] = "tosa.log"(%0) : (tensor<4xf32>) -> tensor<4xf32>
-  %1 = "tosa.log"(%0) : (tensor<*xf32>) -> tensor<*xf32>
-
-  // CHECK: [[SUB:%.+]] = "tosa.sub"(%0, %arg2) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
-  %2 = "tosa.sub"(%0, %arg2) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
-  return %0 : tensor<*xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @test_unary_f32
-func @test_unary_f32(%arg0 : tensor<4xf32>) -> () {
-  // CHECK: "tosa.abs"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %0 = "tosa.abs"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
-
-  // CHECK: "tosa.ceil"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %1 = "tosa.ceil"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
-
-  // CHECK: "tosa.clamp"(%arg0) {{.+}} : (tensor<4xf32>) -> tensor<4xf32>
-  %2 = "tosa.clamp"(%arg0) { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xf32>) -> tensor<*xf32>
-
-  // CHECK: "tosa.exp"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %3 = "tosa.exp"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
-
-  // CHECK: "tosa.floor"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %4 = "tosa.floor"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
-
-  // CHECK: "tosa.log"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %5 = "tosa.log"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
-
-  // CHECK: "tosa.negate"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %6 = "tosa.negate"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
-
-  // CHECK: "tosa.reciprocal"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %7 = "tosa.reciprocal"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
-
-  // CHECK: "tosa.reluN"(%arg0) {{.+}} : (tensor<4xf32>) -> tensor<4xf32>
-  %8 = "tosa.reluN"(%arg0) { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xf32>) -> tensor<*xf32>
-
-  // CHECK: "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<4xf32>) -> tensor<4xf32>
-  %9 = "tosa.reverse"(%arg0) { axis = 0 : i64 } : (tensor<4xf32>) -> tensor<?xf32>
-
-  // CHECK: "tosa.rsqrt"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %10 = "tosa.rsqrt"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
-
-  // CHECK: "tosa.tanh"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %11 = "tosa.tanh"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
-
-  // CHECK: "tosa.sigmoid"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %12 = "tosa.sigmoid"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @test_unary_i32
-func @test_unary_i32(%arg0 : tensor<4xi32>) -> () {
-  // CHECK: "tosa.abs"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
-  %0 = "tosa.abs"(%arg0) : (tensor<4xi32>) -> tensor<*xi32>
-
-  // CHECK: "tosa.bitwise_not"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
-  %1 = "tosa.bitwise_not"(%arg0) : (tensor<4xi32>) -> tensor<*xi32>
-
-  // CHECK: "tosa.clamp"(%arg0) {{.+}} : (tensor<4xi32>) -> tensor<4xi32>
-  %2 = "tosa.clamp"(%arg0) { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xi32>) -> tensor<*xi32>
-
-  // CHECK: "tosa.clz"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
-  %3 = "tosa.clz"(%arg0) : (tensor<4xi32>) -> tensor<*xi32>
-
-  // CHECK: "tosa.negate"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
-  %4 = "tosa.negate"(%arg0) : (tensor<4xi32>) -> tensor<*xi32>
-
-  // CHECK: "tosa.reluN"(%arg0) {{.+}} : (tensor<4xi32>) -> tensor<4xi32>
-  %5 = "tosa.reluN"(%arg0) { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xi32>) -> tensor<*xi32>
-
-  // CHECK: "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<4xi32>) -> tensor<4xi32>
-  %6 = "tosa.reverse"(%arg0) { axis = 0 : i64 } : (tensor<4xi32>) -> tensor<?xi32>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @test_unary_i1
-func @test_unary_i1(%arg0 : tensor<4xi1>) -> () {
-  // CHECK: "tosa.logical_not"(%arg0) : (tensor<4xi1>) -> tensor<4xi1>
-  %0 = "tosa.logical_not"(%arg0) : (tensor<4xi1>) -> tensor<*xi1>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @test_binary_scalar_f32
-func @test_binary_scalar_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<f32>) -> () {
-  // CHECK: "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
-  %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
-
-  // CHECK: "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
-  %1 = "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
-
-  // CHECK: "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
-  %2 = "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
-
-  // CHECK: "tosa.mul"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
-  %3 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
-
-  // CHECK: "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
-  %4 = "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
-
-  // CHECK: "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
-  %5 = "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
-
-  // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<4xi1>
-  %6 = "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<*xi1>
-
-  // CHECK: "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<4xi1>
-  %7 = "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<*xi1>
-
-  // CHECK: "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<4xi1>
-  %8 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<*xi1>
-
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @test_binary_broadcast_f32
-func @test_binary_broadcast_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>) -> () {
-  // CHECK: "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
-  %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
-
-  // CHECK: "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
-  %1 = "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
-
-  // CHECK: "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
-  %2 = "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
-
-  // CHECK: "tosa.mul"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
-  %3 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 } : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
-
-  // CHECK: "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
-  %4 = "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
-
-  // CHECK: "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
-  %5 = "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
-
-  // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1>
-  %6 = "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1>
-
-  // CHECK: "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1>
-  %7 = "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1>
-
-  // CHECK: "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1>
-  %8 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1>
-
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @test_binary_i32
-func @test_binary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<i32>) -> () {
-  // CHECK: "tosa.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
-  %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
-
-  // CHECK: "tosa.bitwise_and"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
-  %1 = "tosa.bitwise_and"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
-
-  // CHECK: "tosa.bitwise_or"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
-  %2 = "tosa.bitwise_or"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
-
-  // CHECK: "tosa.bitwise_xor"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
-  %3 = "tosa.bitwise_xor"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
-
-  // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<4xi1>
-  %4 = "tosa.equal"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<*xi1>
-
-  // CHECK: "tosa.greater"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<4xi1>
-  %5 = "tosa.greater"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<*xi1>
-
-  // CHECK: "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<4xi1>
-  %6 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<*xi1>
-
-  // CHECK: "tosa.logical_left_shift"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
-  %7 = "tosa.logical_left_shift"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
-
-  // CHECK: "tosa.logical_right_shift"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
-  %8 = "tosa.logical_right_shift"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
-
-  // CHECK: "tosa.maximum"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
-  %9 = "tosa.maximum"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
-
-  // CHECK: "tosa.minimum"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
-  %10 = "tosa.minimum"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
-
-  // CHECK: "tosa.mul"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
-  %11 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
-
-  // CHECK: "tosa.pow"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
-  %12 = "tosa.pow"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
-
-  // CHECK:  "tosa.sub"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
-  %13 = "tosa.sub"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
-
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @test_binary_i1
-func @test_binary_i1(%arg0 : tensor<4xi1>, %arg1 : tensor<i1>) -> () {
-  // CHECK "tosa.logical_and"(%arg0, %arg1) : (tensor<4xi1>, tensor<i1>) -> tensor<4xi1>
-  %0 = "tosa.logical_and"(%arg0, %arg1): (tensor<4xi1>, tensor<i1>) -> tensor<*xi1>
-
-  // CHECK "tosa.logical_or"(%arg0, %arg1) : (tensor<4xi1>, tensor<i1>) -> tensor<4xi1>
-  %1 = "tosa.logical_or"(%arg0, %arg1): (tensor<4xi1>, tensor<i1>) -> tensor<*xi1>
-
-  // CHECK "tosa.logical_xor"(%arg0, %arg1) : (tensor<4xi1>, tensor<i1>) -> tensor<*4i1>
-  %2 = "tosa.logical_xor"(%arg0, %arg1): (tensor<4xi1>, tensor<i1>) -> tensor<*xi1>
-
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @test_select_i32
-func @test_select_i32(%arg0 : tensor<4xi1>, %arg1 : tensor<i32>, %arg2 : tensor<4xi32>) -> () {
-  // CHECK: "tosa.select"(%arg0, %arg1, %arg2) : (tensor<4xi1>, tensor<i32>, tensor<4xi32>) -> tensor<4xi32>
-  %0 = "tosa.select"(%arg0, %arg1, %arg2): (tensor<4xi1>, tensor<i32>, tensor<4xi32>) -> tensor<*xi32>
-
-  return
-}
-
-// -----
-
-func @test_static_reshape(%arg0 : tensor<4x4xi32>) -> () {
-  // CHECK: "tosa.reshape"(%arg0) {new_shape = [16]} : (tensor<4x4xi32>) -> tensor<16xi32>
-  %0 = "tosa.reshape"(%arg0) {new_shape = [16]} : (tensor<4x4xi32>)  -> tensor<?xi32>
-
-  // CHECK: "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<4x4xi32>) -> tensor<16xi32>
-  %1 = "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<4x4xi32>)  -> tensor<?xi32>
-
-  // CHECK: "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<4x4xi32>) -> tensor<2x8xi32>
-  %2 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<4x4xi32>)  -> tensor<?x?xi32>
-
-  return
-}
-// -----
-
-func @test_dynamic_reshape(%arg0 : tensor<4x?xi32>) -> () {
-  // CHECK: %0 = "tosa.reshape"(%arg0) {new_shape = [16]} : (tensor<4x?xi32>) -> tensor<16xi32>
-  %0 = "tosa.reshape"(%arg0) {new_shape = [16]} : (tensor<4x?xi32>)  -> tensor<?xi32>
-
-  // CHECK: %1 = "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<4x?xi32>) -> tensor<?xi32>
-  %1 = "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<4x?xi32>)  -> tensor<?xi32>
-
-  // CHECK: %2 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<4x?xi32>) -> tensor<2x?xi32>
-  %2 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<4x?xi32>)  -> tensor<?x?xi32>
-
-  return
-}
-


        


More information about the Mlir-commits mailing list