[Mlir-commits] [mlir] fc0d4d5 - [mlir][tosa] Change axis to I32Attr

Robert Suderman llvmlistbot at llvm.org
Tue Aug 15 12:34:31 PDT 2023


Author: Tai Ly
Date: 2023-08-15T19:04:25Z
New Revision: fc0d4d53ced190bdafb78564f46e67fa97e54785

URL: https://github.com/llvm/llvm-project/commit/fc0d4d53ced190bdafb78564f46e67fa97e54785
DIFF: https://github.com/llvm/llvm-project/commit/fc0d4d53ced190bdafb78564f46e67fa97e54785.diff

LOG: [mlir][tosa] Change axis to I32Attr

This patch changes axis attribute type from I64Attr to I32Attr
to match the TOSA spec.

Signed-off-by: Tai Ly <tai.ly at arm.com>
Change-Id: I88416f0b102f8ac2191da63a3fbce25759f19613

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D157424

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
    mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
    mlir/test/Dialect/Tosa/canonicalize.mlir
    mlir/test/Dialect/Tosa/constant-op-fold.mlir
    mlir/test/Dialect/Tosa/constrained_shapes.mlir
    mlir/test/Dialect/Tosa/fold_concats.mlir
    mlir/test/Dialect/Tosa/invalid.mlir
    mlir/test/Dialect/Tosa/level_check.mlir
    mlir/test/Dialect/Tosa/ops.mlir
    mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
    mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 174979e55c7215..83e9d760aa896d 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -42,7 +42,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
 
   let arguments = (ins
     Tosa_Tensor: $input,
-    I64Attr: $axis
+    I32Attr: $axis
   );
 
   let results = (outs
@@ -1260,7 +1260,7 @@ def Tosa_ReduceAllOp : Tosa_InferTensorTypeOp<"reduce_all"> {
 
   let arguments = (ins
     Tosa_Tensor:$input,
-    I64Attr:$axis
+    I32Attr:$axis
   );
 
   let results = (outs
@@ -1288,7 +1288,7 @@ def Tosa_ReduceAnyOp : Tosa_InferTensorTypeOp<"reduce_any"> {
 
   let arguments = (ins
     Tosa_Tensor:$input,
-    I64Attr:$axis
+    I32Attr:$axis
   );
 
   let results = (outs
@@ -1316,7 +1316,7 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
 
   let arguments = (ins
     Tosa_Tensor:$input,
-    I64Attr:$axis
+    I32Attr:$axis
   );
 
   let results = (outs
@@ -1344,7 +1344,7 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
 
   let arguments = (ins
     Tosa_Tensor:$input,
-    I64Attr:$axis
+    I32Attr:$axis
   );
 
   let results = (outs
@@ -1372,7 +1372,7 @@ def Tosa_ReduceProdOp : Tosa_InferTensorTypeOp<"reduce_prod"> {
 
   let arguments = (ins
     Tosa_Tensor:$input,
-    I64Attr:$axis
+    I32Attr:$axis
   );
 
   let results = (outs
@@ -1400,7 +1400,7 @@ def Tosa_ReduceSumOp : Tosa_InferTensorTypeOp<"reduce_sum"> {
 
   let arguments = (ins
     Tosa_Tensor:$input,
-    I64Attr:$axis
+    I32Attr:$axis
   );
 
   let results = (outs
@@ -1434,7 +1434,7 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
 
   let arguments = (ins
     Variadic<Tosa_Tensor>:$input1,
-    I64Attr:$axis
+    I32Attr:$axis
   );
 
   let results = (outs
@@ -1544,7 +1544,7 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
 
   let arguments = (ins
     Tosa_Tensor:$input,
-    I64Attr:$axis
+    I32Attr:$axis
   );
 
   let results = (outs

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 2339fb7fde3dc1..8d937217d70655 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -112,9 +112,9 @@ class TransposeConvNonStridedConverter
     convPad[3] = kernelWidth - 1 + pad[3];
 
     auto reverse1 = rewriter.create<tosa::ReverseOp>(
-        loc, weightTy, weight, rewriter.getI64IntegerAttr(1));
+        loc, weightTy, weight, /* axis = */ rewriter.getI32IntegerAttr(1));
     auto reverse2 = rewriter.create<tosa::ReverseOp>(
-        loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2));
+        loc, weightTy, reverse1, /* axis = */ rewriter.getI32IntegerAttr(2));
 
     Value conv2d;
     if (op.getQuantizationInfo()) {
@@ -236,10 +236,10 @@ class TransposeConvStridedConverter
 
     weight = createOpAndInfer<tosa::ReverseOp>(
         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
-        rewriter.getI64IntegerAttr(1));
+        /* axis = */ rewriter.getI32IntegerAttr(1));
     weight = createOpAndInfer<tosa::ReverseOp>(
         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
-        rewriter.getI64IntegerAttr(2));
+        /* axis = */ rewriter.getI32IntegerAttr(2));
 
     // We need to pad the input far enough that we can pull all values.
     llvm::SmallVector<int32_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 115836280628ff..8029eae47ba4a7 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -903,7 +903,7 @@ func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
   // CHECK:   [[RES:%.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
   // CHECK:   linalg.yield [[RES]] : f32
   // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xf32> into tensor<1x4xf32>
-  %0 = tosa.reduce_sum %arg0 {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+  %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<5x4xf32>) -> tensor<1x4xf32>
 
   // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5xf32>
   // CHECK: [[CST0:%.+]] = arith.constant 0.0
@@ -913,25 +913,25 @@ func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
   // CHECK:   [[RES:%.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
   // CHECK:   linalg.yield [[RES]] : f32
   // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xf32> into tensor<5x1xf32>
-  %1 = tosa.reduce_sum %arg0 {axis = 1 : i64} : (tensor<5x4xf32>) -> tensor<5x1xf32>
+  %1 = tosa.reduce_sum %arg0 {axis = 1 : i32} : (tensor<5x4xf32>) -> tensor<5x1xf32>
 
   // CHECK: arith.constant 1.0
   // CHECK: linalg.fill
   // CHECK: linalg.generic
   // CHECK: arith.mulf
-  %2 = tosa.reduce_prod %arg0 {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+  %2 = tosa.reduce_prod %arg0 {axis = 0 : i32} : (tensor<5x4xf32>) -> tensor<1x4xf32>
 
   // CHECK: arith.constant 3.40282347E+38 : f32
   // CHECK: linalg.fill
   // CHECK: linalg.generic
   // CHECK: arith.minf
-  %3 = tosa.reduce_min %arg0 {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+  %3 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor<5x4xf32>) -> tensor<1x4xf32>
 
   // CHECK: arith.constant -3.40282347E+38 : f32
   // CHECK: linalg.fill
   // CHECK: linalg.generic
   // CHECK: arith.maxf
-  %4 = tosa.reduce_max %arg0 {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+  %4 = tosa.reduce_max %arg0 {axis = 0 : i32} : (tensor<5x4xf32>) -> tensor<1x4xf32>
   return
 }
 
@@ -953,7 +953,7 @@ func.func @reduce_float_dyn(%arg0: tensor<?x5x4xf32>) -> () {
   // CHECK:   %[[RES:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
   // CHECK:   linalg.yield %[[RES]] : f32
   // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]] : tensor<?x4xf32> into tensor<?x1x4xf32>
-  %0 = tosa.reduce_sum %arg0 {axis = 1 : i64} : (tensor<?x5x4xf32>) -> tensor<?x1x4xf32>
+  %0 = tosa.reduce_sum %arg0 {axis = 1 : i32} : (tensor<?x5x4xf32>) -> tensor<?x1x4xf32>
   return
 }
 
@@ -973,7 +973,7 @@ func.func @reduce_float_dyn_rank_1(%arg0: tensor<?xf32>) -> () {
   // CHECK:   %[[RES:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
   // CHECK:   linalg.yield %[[RES]] : f32
   // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}] : tensor<f32> into tensor<1xf32>
-  %0 = tosa.reduce_sum %arg0 {axis = 0 : i64} : (tensor<?xf32>) -> tensor<1xf32>
+  %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<?xf32>) -> tensor<1xf32>
   return
 }
 
@@ -995,7 +995,7 @@ func.func @reduce_float_dyn_nonzero_batch(%arg0: tensor<5x?x4xf32>) -> () {
   // CHECK:   %[[RES:.+]] = arith.mulf %[[ARG1]], %[[ARG2]] : f32
   // CHECK:   linalg.yield %[[RES]] : f32
   // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]] : tensor<5x?xf32> into tensor<5x?x1xf32>
-  %0 = tosa.reduce_prod %arg0 {axis = 2 : i64} : (tensor<5x?x4xf32>) -> tensor<5x?x1xf32>
+  %0 = tosa.reduce_prod %arg0 {axis = 2 : i32} : (tensor<5x?x4xf32>) -> tensor<5x?x1xf32>
   return
 }
 
@@ -1017,7 +1017,7 @@ func.func @reduce_float_dyn_multiple(%arg0: tensor<?x?xf32>) -> () {
   // CHECK:   %[[MAX:.+]] = arith.maxf %[[ARG1]], %[[ARG2]] : f32
   // CHECK:   linalg.yield %[[MAX]] : f32
   // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor<?xf32> into tensor<?x1xf32>
-  %0 = tosa.reduce_max %arg0 {axis = 1 : i64} : (tensor<?x?xf32>) -> tensor<?x1xf32>
+  %0 = tosa.reduce_max %arg0 {axis = 1 : i32} : (tensor<?x?xf32>) -> tensor<?x1xf32>
   return
 }
 
@@ -1038,7 +1038,7 @@ func.func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
   // CHECK:   [[RES:%.+]] = arith.addi %[[ARG1]], %[[ARG2]] : i32
   // CHECK:   linalg.yield [[RES]] : i32
   // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi32> into tensor<1x4xi32>
-  %0 = tosa.reduce_sum %arg0 {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<1x4xi32>
+  %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<5x4xi32>) -> tensor<1x4xi32>
 
   // CHECK: [[INIT:%.+]] = tensor.empty()
   // CHECK: [[CST0:%.+]] = arith.constant 0
@@ -1048,27 +1048,27 @@ func.func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
   // CHECK:   [[RES:%.+]] = arith.addi %[[ARG1]], %[[ARG2]] : i32
   // CHECK:   linalg.yield [[RES]] : i32
   // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xi32> into tensor<5x1xi32>
-  %1 = tosa.reduce_sum %arg0 {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5x1xi32>
+  %1 = tosa.reduce_sum %arg0 {axis = 1 : i32} : (tensor<5x4xi32>) -> tensor<5x1xi32>
 
   // CHECK: arith.constant 1
   // CHECK: linalg.fill
   // CHECK: linalg.generic
   // CHECK: arith.muli
-  %2 = tosa.reduce_prod %arg0 {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<1x4xi32>
+  %2 = tosa.reduce_prod %arg0 {axis = 0 : i32} : (tensor<5x4xi32>) -> tensor<1x4xi32>
 
   // CHECK: arith.constant 2147483647 : i32
   // CHECK: linalg.fill
   // CHECK: linalg.generic
   // CHECK: arith.cmpi slt
   // CHECK: select
-  %3 = tosa.reduce_min %arg0 {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<1x4xi32>
+  %3 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor<5x4xi32>) -> tensor<1x4xi32>
 
   // CHECK: arith.constant -2147483648 : i32
   // CHECK: linalg.fill
   // CHECK: linalg.generic
   // CHECK: arith.cmpi sgt
   // CHECK: select
-  %4 = tosa.reduce_max %arg0 {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<1x4xi32>
+  %4 = tosa.reduce_max %arg0 {axis = 0 : i32} : (tensor<5x4xi32>) -> tensor<1x4xi32>
   return
 }
 
@@ -1088,13 +1088,13 @@ func.func @reduce_bool(%arg0: tensor<5x4xi1>) -> () {
   // CHECK:   [[RES:%.+]] = arith.andi %[[ARG1]], %[[ARG2]] : i1
   // CHECK:   linalg.yield [[RES]] : i1
   // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi1> into tensor<1x4xi1>
-  %0 = tosa.reduce_all %arg0 {axis = 0 : i64} : (tensor<5x4xi1>) -> tensor<1x4xi1>
+  %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<5x4xi1>) -> tensor<1x4xi1>
 
   // CHECK: arith.constant false
   // CHECK: linalg.fill
   // CHECK: linalg.generic
   // CHECK: or
-  %1 = tosa.reduce_any %arg0 {axis = 0 : i64} : (tensor<5x4xi1>) -> tensor<1x4xi1>
+  %1 = tosa.reduce_any %arg0 {axis = 0 : i32} : (tensor<5x4xi1>) -> tensor<1x4xi1>
 
   return
 }
@@ -1294,7 +1294,7 @@ func.func @reverse(%arg0: tensor<5x4xi32>) -> () {
   // CHECK-DAG:   %[[READ_DIM:.+]] = arith.subi %[[RDIM_MINUS_C1]], %[[I0]]
   // CHECK-DAG:   %[[EXTRACT:.+]] = tensor.extract %arg0[%[[READ_DIM]], %[[I1]]] : tensor<5x4xi32>
   // CHECK:   linalg.yield %[[EXTRACT]]
-  %0 = tosa.reverse %arg0 {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<5x4xi32>
+  %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<5x4xi32>) -> tensor<5x4xi32>
 
   // CHECK: %[[C1:.+]] = arith.constant 1
   // CHECK: %[[RDIM:.+]] = tensor.dim %[[ARG0]], %[[C1]]
@@ -1307,7 +1307,7 @@ func.func @reverse(%arg0: tensor<5x4xi32>) -> () {
   // CHECK-DAG:   %[[READ_DIM:.+]] = arith.subi %[[RDIM_MINUS_C1]], %[[I1]]
   // CHECK-DAG:   %[[EXTRACT:.+]] = tensor.extract %arg0[%[[I0]], %[[READ_DIM]]] : tensor<5x4xi32>
   // CHECK:   linalg.yield %[[EXTRACT]]
-  %1 = tosa.reverse %arg0 {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5x4xi32>
+  %1 = tosa.reverse %arg0 {axis = 1 : i32} : (tensor<5x4xi32>) -> tensor<5x4xi32>
   return
 }
 
@@ -1330,7 +1330,7 @@ func.func @reverse_dyn(%arg0: tensor<?xi32>) -> () {
   // CHECK-DAG:   %[[READ_DIM:.+]] = arith.subi %[[RDIM_MINUS_C1]], %[[I0]]
   // CHECK-DAG:   %[[EXTRACT:.+]] = tensor.extract %arg0[%[[READ_DIM]]] : tensor<?xi32>
   // CHECK:   linalg.yield %[[EXTRACT]]
-  %0 = tosa.reverse %arg0 {axis = 0 : i64} : (tensor<?xi32>) -> tensor<?xi32>
+  %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<?xi32>) -> tensor<?xi32>
   return
 }
 
@@ -1429,7 +1429,7 @@ func.func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () {
   // CHECK:   [[SELECT_VAL:%.+]] = arith.select [[CMP]], %[[ARG1]], %[[ARG3]]
   // CHECK:   [[SELECT_IDX:%.+]] = arith.select [[CMP]], [[CAST]], %[[ARG2]]
   // CHECK:   linalg.yield [[SELECT_IDX]], [[SELECT_VAL]]
-  %0 = tosa.argmax %arg0 { axis = 0 : i64} : (tensor<3x2xi32>)  -> tensor<2xi32>
+  %0 = tosa.argmax %arg0 { axis = 0 : i32} : (tensor<3x2xi32>)  -> tensor<2xi32>
 
   // CHECK: [[IDX_INIT:%.+]] = tensor.empty()
   // CHECK: [[IDX_MIN:%.+]] = arith.constant 0 : i32
@@ -1445,7 +1445,7 @@ func.func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () {
   // CHECK:   [[SELECT_VAL:%.+]] = arith.select [[CMP]], %[[ARG1]], %[[ARG3]]
   // CHECK:   [[SELECT_IDX:%.+]] = arith.select [[CMP]], [[CAST]], %[[ARG2]]
   // CHECK:   linalg.yield [[SELECT_IDX]], [[SELECT_VAL]]
-  %1 = tosa.argmax %arg0 { axis = 1 : i64} : (tensor<3x2xi32>)  -> tensor<3xi32>
+  %1 = tosa.argmax %arg0 { axis = 1 : i32} : (tensor<3x2xi32>)  -> tensor<3xi32>
 
   // CHECK: arith.constant -3.40282347E+38 : f32
   // CHECK: linalg.index
@@ -1454,7 +1454,7 @@ func.func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () {
   // CHECK: select
   // CHECK: select
   // CHECK: linalg.yield
-  %2 = tosa.argmax %arg1 { axis = 0 : i64} : (tensor<6xf32>)  -> tensor<i32>
+  %2 = tosa.argmax %arg1 { axis = 0 : i32} : (tensor<6xf32>)  -> tensor<i32>
 
   return
 }
@@ -1481,7 +1481,7 @@ func.func @argmax_dyn_non_axis(%arg0 : tensor<3x?xi32>) -> () {
   // CHECK:   %[[SELECT_VAL:.+]] = arith.select %[[CMP]], %[[ARG1]], %[[ARG3]]
   // CHECK:   %[[SELECT_IDX:.+]] = arith.select %[[CMP]], %[[CAST]], %[[ARG2]]
   // CHECK:   linalg.yield %[[SELECT_IDX]], %[[SELECT_VAL]]
-  %0 = tosa.argmax %arg0 { axis = 0 : i64} : (tensor<3x?xi32>)  -> tensor<?xi32>
+  %0 = tosa.argmax %arg0 { axis = 0 : i32} : (tensor<3x?xi32>)  -> tensor<?xi32>
   return
 }
 
@@ -1504,7 +1504,7 @@ func.func @argmax_dyn_axis(%arg0 : tensor<3x?xi32>) -> () {
   // CHECK:   %[[SELECT_VAL:.+]] = arith.select %[[CMP]], %[[ARG1]], %[[ARG3]]
   // CHECK:   %[[SELECT_IDX:.+]] = arith.select %[[CMP]], %[[CAST]], %[[ARG2]]
   // CHECK:   linalg.yield %[[SELECT_IDX]], %[[SELECT_VAL]]
-  %0 = tosa.argmax %arg0 { axis = 1 : i64} : (tensor<3x?xi32>)  -> tensor<3xi32>
+  %0 = tosa.argmax %arg0 { axis = 1 : i32} : (tensor<3x?xi32>)  -> tensor<3xi32>
   return
 }
 

diff  --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index 363648afb5180d..daaa68a7260b71 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -205,12 +205,12 @@ func.func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () {
   // CHECK-DAG: [[INIT:%.+]] = tensor.empty() : tensor<11x1xf32>
   // CHECK-DAG: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1]
   // CHECK-DAG: [[INSERT1:%.+]] = tensor.insert_slice %[[ARG1]] into [[INSERT0]][5, 0] [6, 1] [1, 1]
-  %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x1xf32>, tensor<6x1xf32>)  -> (tensor<11x1xf32>)
+  %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i32} : (tensor<5x1xf32>, tensor<6x1xf32>)  -> (tensor<11x1xf32>)
 
   // CHECK-DAG: [[INIT:%.+]] = tensor.empty() : tensor<5x2xf32>
   // CHECK-DAG: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1]
   // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %[[ARG0]] into [[INSERT0]][0, 1] [5, 1] [1, 1]
-  %1 = "tosa.concat"(%arg0, %arg0) { axis = 1 : i64} : (tensor<5x1xf32>, tensor<5x1xf32>)  -> (tensor<5x2xf32>)
+  %1 = "tosa.concat"(%arg0, %arg0) { axis = 1 : i32} : (tensor<5x1xf32>, tensor<5x1xf32>)  -> (tensor<5x2xf32>)
   return
 }
 
@@ -230,7 +230,7 @@ func.func @concat_non_axis_dyn(%arg0: tensor<5x?xf32>, %arg1: tensor<6x?xf32>) -
   // CHECK-DAG: %[[IDX1_2:.+]] = arith.constant 1 : index
   // CHECK-DAG: %[[DIM2:.+]] = tensor.dim %[[ARG1]], %[[IDX1_2]] : tensor<6x?xf32>
   // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][5, 0] [6, %[[DIM2]]] [1, 1]
-  %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x?xf32>, tensor<6x?xf32>)  -> (tensor<11x?xf32>)
+  %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i32} : (tensor<5x?xf32>, tensor<6x?xf32>)  -> (tensor<11x?xf32>)
   return
 }
 
@@ -253,7 +253,7 @@ func.func @concat_axis_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<?x3xf32>) -> ()
   // CHECK-DAG: %[[DIM3:.+]] = tensor.dim %[[ARG1]], %[[IDX0_2]] : tensor<?x3xf32>
   // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][%[[DIM0]], 0] [%[[DIM3]], 3] [1, 1] : tensor<?x3xf32> into tensor<?x3xf32>
 
-  %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<?x3xf32>, tensor<?x3xf32>)  -> (tensor<?x3xf32>)
+  %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i32} : (tensor<?x3xf32>, tensor<?x3xf32>)  -> (tensor<?x3xf32>)
   return
 }
 
@@ -284,7 +284,7 @@ func.func @concat_axis_dyn_mixed(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf32>,
 
   // CHECK: return
 
-  %0 = "tosa.concat"(%arg0, %arg1, %arg2) <{axis = 0 : i64}> : (tensor<?x1xf32>, tensor<?x1xf32>, tensor<?x1xf32>) -> tensor<5x1xf32>
+  %0 = "tosa.concat"(%arg0, %arg1, %arg2) <{axis = 0 : i32}> : (tensor<?x1xf32>, tensor<?x1xf32>, tensor<?x1xf32>) -> tensor<5x1xf32>
   return
 }
 
@@ -310,6 +310,6 @@ func.func @concat_non_axis_dyn_mixed(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf
   // CHECK-DAG: %[[INSERT2:.+]] = tensor.insert_slice %[[ARG2]] into %[[INSERT1]][0, 2] [%[[DIM2_0]], 1] [1, 1] : tensor<?x1xf32> into tensor<5x3xf32>
   // CHECK: return
 
-  %0 = "tosa.concat"(%arg0, %arg1, %arg2) <{axis = 1 : i64}> : (tensor<?x1xf32>, tensor<?x1xf32>, tensor<?x1xf32>) -> tensor<5x3xf32>
+  %0 = "tosa.concat"(%arg0, %arg1, %arg2) <{axis = 1 : i32}> : (tensor<?x1xf32>, tensor<?x1xf32>, tensor<?x1xf32>) -> tensor<5x3xf32>
   return
 }

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index ced4249ef7b47a..5ed5a383c6f6fe 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -3,7 +3,7 @@
 // CHECK-LABEL: @argmax_nofold
 func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: tosa.argmax
-  %0 = tosa.argmax %arg0 {axis = 0 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
+  %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xf32>
   return %0 : tensor<?x1xf32>
 }
 
@@ -113,7 +113,7 @@ func.func @clamp_twice_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
 // CHECK-LABEL: @concat_fold
 func.func @concat_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
-  %0 = tosa.concat %arg0 {axis = 0 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
+  %0 = tosa.concat %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xf32>
   return %0 : tensor<?x1xf32>
 }
 
@@ -121,7 +121,7 @@ func.func @concat_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
 func.func @concat_fold_cast(%arg0: tensor<?x1xf32>) -> tensor<?x?xf32> {
   // CHECK: %[[VAR0:.*]] = tensor.cast %arg0
   // CHECK: return %[[VAR0]]
-  %0 = tosa.concat %arg0 {axis = 0 : i64}: (tensor<?x1xf32>) -> tensor<?x?xf32>
+  %0 = tosa.concat %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
 
@@ -277,84 +277,84 @@ func.func @select_not_pred(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2:
 // CHECK-LABEL: @reduce_all_fold
 func.func @reduce_all_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
-  %0 = tosa.reduce_all %arg0 {axis = 1 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
+  %0 = tosa.reduce_all %arg0 {axis = 1 : i32}: (tensor<?x1xf32>) -> tensor<?x1xf32>
   return %0 : tensor<?x1xf32>
 }
 
 // CHECK-LABEL: @reduce_all_nofold
 func.func @reduce_all_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: tosa.reduce_all
-  %0 = tosa.reduce_all %arg0 {axis = 0 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
+  %0 = tosa.reduce_all %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xf32>
   return %0 : tensor<?x1xf32>
 }
 
 // CHECK-LABEL: @reduce_any_fold
 func.func @reduce_any_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
-  %0 = tosa.reduce_any %arg0 {axis = 1 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
+  %0 = tosa.reduce_any %arg0 {axis = 1 : i32}: (tensor<?x1xf32>) -> tensor<?x1xf32>
   return %0 : tensor<?x1xf32>
 }
 
 // CHECK-LABEL: @reduce_any_nofold
 func.func @reduce_any_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: tosa.reduce_any
-  %0 = tosa.reduce_any %arg0 {axis = 0 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
+  %0 = tosa.reduce_any %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xf32>
   return %0 : tensor<?x1xf32>
 }
 
 // CHECK-LABEL: @reduce_max_fold
 func.func @reduce_max_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
-  %0 = tosa.reduce_max %arg0 {axis = 1 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
+  %0 = tosa.reduce_max %arg0 {axis = 1 : i32}: (tensor<?x1xf32>) -> tensor<?x1xf32>
   return %0 : tensor<?x1xf32>
 }
 
 // CHECK-LABEL: @reduce_max_nofold
 func.func @reduce_max_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: tosa.reduce_max
-  %0 = tosa.reduce_max %arg0 {axis = 0 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
+  %0 = tosa.reduce_max %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xf32>
   return %0 : tensor<?x1xf32>
 }
 
 // CHECK-LABEL: @reduce_min_fold
 func.func @reduce_min_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
-  %0 = tosa.reduce_min %arg0 {axis = 1 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
+  %0 = tosa.reduce_min %arg0 {axis = 1 : i32}: (tensor<?x1xf32>) -> tensor<?x1xf32>
   return %0 : tensor<?x1xf32>
 }
 
 // CHECK-LABEL: @reduce_min_nofold
 func.func @reduce_min_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: tosa.reduce_min
-  %0 = tosa.reduce_min %arg0 {axis = 0 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
+  %0 = tosa.reduce_min %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xf32>
   return %0 : tensor<?x1xf32>
 }
 
 // CHECK-LABEL: @reduce_prod_fold
 func.func @reduce_prod_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
-  %0 = tosa.reduce_prod %arg0 {axis = 1 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
+  %0 = tosa.reduce_prod %arg0 {axis = 1 : i32}: (tensor<?x1xf32>) -> tensor<?x1xf32>
   return %0 : tensor<?x1xf32>
 }
 
 // CHECK-LABEL: @reduce_prod_nofold
 func.func @reduce_prod_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: tosa.reduce_prod
-  %0 = tosa.reduce_prod %arg0 {axis = 0 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
+  %0 = tosa.reduce_prod %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xf32>
   return %0 : tensor<?x1xf32>
 }
 
 // CHECK-LABEL: @reduce_sum_fold
 func.func @reduce_sum_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
-  %0 = tosa.reduce_sum %arg0 {axis = 1 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
+  %0 = tosa.reduce_sum %arg0 {axis = 1 : i32}: (tensor<?x1xf32>) -> tensor<?x1xf32>
   return %0 : tensor<?x1xf32>
 }
 
 // CHECK-LABEL: @reduce_sum_nofold
 func.func @reduce_sum_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: tosa.reduce_sum
-  %0 = tosa.reduce_sum %arg0 {axis = 0 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
+  %0 = tosa.reduce_sum %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xf32>
   return %0 : tensor<?x1xf32>
 }
 
@@ -504,7 +504,7 @@ func.func @fold_resize_bilinear(%arg0 : tensor<1x15x13x1xi8>) -> tensor<1x15x13x
 // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12x1xf32>, %[[VAL_1:.*]]: tensor<1x12x12x1xf32>
 // CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>
 func.func @canonicalize_concat_slice_final_axis(%arg0 : tensor<1x12x12x1xf32>, %arg1 : tensor<1x12x12x1xf32>) -> (tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>) {
-  %0 = tosa.concat %arg0, %arg1 {axis = 3 : i64} : (tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>) -> tensor<1x12x12x2xf32>
+  %0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>) -> tensor<1x12x12x2xf32>
   %1 = tosa.slice %0 {size = array<i64: 1, 12, 12, 1>, start = array<i64: 0, 0, 0, 0>} : (tensor<1x12x12x2xf32>) -> tensor<1x12x12x1xf32>
   %2 = tosa.slice %0 {size = array<i64: 1, 12, 12, 1>, start = array<i64: 0, 0, 0, 1>} : (tensor<1x12x12x2xf32>) -> tensor<1x12x12x1xf32>
   return %1, %2 : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>
@@ -516,7 +516,7 @@ func.func @canonicalize_concat_slice_final_axis(%arg0 : tensor<1x12x12x1xf32>, %
 // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32>
 // CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1x12x12xf32>, tensor<1x12x12xf32>
 func.func @canonicalize_concat_slice_middle_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x12x12xf32>, tensor<1x12x12xf32>) {
-  %0 = tosa.concat %arg0, %arg1 {axis = 1 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x24x12xf32>
+  %0 = tosa.concat %arg0, %arg1 {axis = 1 : i32} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x24x12xf32>
   %1 = tosa.slice %0 {size = array<i64: 1, 12, 12>, start = array<i64: 0, 0, 0>} : (tensor<1x24x12xf32>) -> tensor<1x12x12xf32>
   %2 = tosa.slice %0 {size = array<i64: 1, 12, 12>, start = array<i64: 0, 12, 0>} : (tensor<1x24x12xf32>) -> tensor<1x12x12xf32>
   return %1, %2 : tensor<1x12x12xf32>, tensor<1x12x12xf32>
@@ -526,12 +526,12 @@ func.func @canonicalize_concat_slice_middle_axis(%arg0 : tensor<1x12x12xf32>, %a
 
 // CHECK-LABEL: @canonicalize_cross_concat_inputs
 // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32>
-// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_0]], %[[VAL_1]] {axis = 2 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
+// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_0]], %[[VAL_1]] {axis = 2 : i32} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
 // CHECK: %[[VAL_3:.*]] = tosa.slice %[[VAL_2]] {size = array<i64: 1, 12, 15>, start = array<i64: 0, 0, 0>} : (tensor<1x12x24xf32>) -> tensor<1x12x15xf32>
 // CHECK: %[[VAL_4:.*]] = tosa.slice %[[VAL_2]] {size = array<i64: 1, 12, 20>, start = array<i64: 0, 0, 4>} : (tensor<1x12x24xf32>) -> tensor<1x12x20xf32>
 // CHECK: return %[[VAL_3]], %[[VAL_4]] : tensor<1x12x15xf32>, tensor<1x12x20xf32>
 func.func @canonicalize_cross_concat_inputs(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x12x15xf32>, tensor<1x12x20xf32>) {
-  %0 = tosa.concat %arg0, %arg1 {axis = 2 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
+  %0 = tosa.concat %arg0, %arg1 {axis = 2 : i32} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
   %1 = tosa.slice %0 {size = array<i64: 1, 12, 15>, start = array<i64: 0, 0, 0>} : (tensor<1x12x24xf32>) -> tensor<1x12x15xf32>
   %2 = tosa.slice %0 {size = array<i64: 1, 12, 20>, start = array<i64: 0, 0, 4>} : (tensor<1x12x24xf32>) -> tensor<1x12x20xf32>
   return %1, %2 : tensor<1x12x15xf32>, tensor<1x12x20xf32>
@@ -545,7 +545,7 @@ func.func @canonicalize_cross_concat_inputs(%arg0 : tensor<1x12x12xf32>, %arg1 :
 // CHECK: %[[VAL_3:.*]] = tosa.slice %[[VAL_1]] {size = array<i64: 1, 3, 12>, start = array<i64: 1, 3, 0>} : (tensor<1x12x12xf32>) -> tensor<1x3x12xf32>
 // CHECK: return %[[VAL_2]], %[[VAL_3]] : tensor<1x6x12xf32>, tensor<1x3x12xf32>
 func.func @canonicalize_concat_slice_on_non_concat_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x6x12xf32>, tensor<1x3x12xf32>) {
-  %0 = tosa.concat %arg0, %arg1 {axis = 2 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
+  %0 = tosa.concat %arg0, %arg1 {axis = 2 : i32} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
   %1 = tosa.slice %0 {size = array<i64: 1, 6, 12>, start = array<i64: 0, 0, 0>} : (tensor<1x12x24xf32>) -> tensor<1x6x12xf32>
   %2 = tosa.slice %0 {size = array<i64: 1, 3, 12>, start = array<i64: 1, 3, 12>} : (tensor<1x12x24xf32>) -> tensor<1x3x12xf32>
   return %1, %2 : tensor<1x6x12xf32>, tensor<1x3x12xf32>

diff  --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 819aa0010a804a..39132c9744e8f7 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -558,7 +558,7 @@ func.func @cast_int_to_int_sign() -> tensor<i32> {
 func.func @reverse_splat() -> tensor<10xi32> {
   // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{value = dense<42> : tensor<10xi32>}
   %splat = "tosa.const"() {value = dense<42> : tensor<10xi32>} : () -> tensor<10xi32>
-  %reverse = tosa.reverse %splat { axis = 0 : i64 } : (tensor<10xi32>) -> tensor<10xi32>
+  %reverse = tosa.reverse %splat { axis = 0 : i32 } : (tensor<10xi32>) -> tensor<10xi32>
   // CHECK: return %[[SPLAT]]
   return %reverse : tensor<10xi32>
 }
@@ -567,9 +567,9 @@ func.func @reverse_splat() -> tensor<10xi32> {
 
 // CHECK-LABEL: @reverse_length_one
 func.func @reverse_length_one(%arg0 : tensor<10x1xi32>) -> (tensor<10x1xi32>, tensor<10x1xi32>) {
-  %nofold = tosa.reverse %arg0 { axis = 0 : i64 } : (tensor<10x1xi32>) -> tensor<10x1xi32>
-  %fold = tosa.reverse %arg0 { axis = 1 : i64 } : (tensor<10x1xi32>) -> tensor<10x1xi32>
-  // CHECK: %[[NOFOLD:.+]] = tosa.reverse %arg0 {axis = 0 : i64}
+  %nofold = tosa.reverse %arg0 { axis = 0 : i32 } : (tensor<10x1xi32>) -> tensor<10x1xi32>
+  %fold = tosa.reverse %arg0 { axis = 1 : i32 } : (tensor<10x1xi32>) -> tensor<10x1xi32>
+  // CHECK: %[[NOFOLD:.+]] = tosa.reverse %arg0 {axis = 0 : i32}
   // CHECK: return %[[NOFOLD]], %arg0
   return %nofold, %fold : tensor<10x1xi32>, tensor<10x1xi32>
 }

diff  --git a/mlir/test/Dialect/Tosa/constrained_shapes.mlir b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
index 6f8d8bdedae5cc..9acb024cf78d00 100644
--- a/mlir/test/Dialect/Tosa/constrained_shapes.mlir
+++ b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
@@ -6,6 +6,6 @@
 // Uses argmax as canonical example to validate constrained TOSA tensor shapes.
 // CHECK-LABEL: argmax
 func.func @test_argmax(%arg0: tensor<?xf32>) -> tensor<?xi32> {
-  %0 = "tosa.argmax"(%arg0) {axis = 1 : i64} : (tensor<?xf32>) -> tensor<?xi32>
+  %0 = "tosa.argmax"(%arg0) {axis = 1 : i32} : (tensor<?xf32>) -> tensor<?xi32>
   return %0 : tensor<?xi32>
 }

diff  --git a/mlir/test/Dialect/Tosa/fold_concats.mlir b/mlir/test/Dialect/Tosa/fold_concats.mlir
index 8bdfe066f075da..ec54f27346c8bd 100644
--- a/mlir/test/Dialect/Tosa/fold_concats.mlir
+++ b/mlir/test/Dialect/Tosa/fold_concats.mlir
@@ -1,28 +1,28 @@
 // RUN: mlir-opt --split-input-file --canonicalize %s | FileCheck %s
 
 func.func @single_concat(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> {
-  %0 = tosa.concat %arg0, %arg0 {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+  %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
   return %0 : tensor<1x2x7x7xf32>
 }
 
 // CHECK-LABEL:   func.func @single_concat(
 // CHECK-SAME:                             %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> {
-// CHECK:           %[[VAL_1:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]] {axis = 1 : i64} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+// CHECK:           %[[VAL_1:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
 // CHECK:           return %[[VAL_1]] : tensor<1x2x7x7xf32>
 // CHECK:         }
 
 // -----
 
 func.func @concat_
diff erent_axis(%arg0: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> {
-  %0 = tosa.concat %arg0, %arg0 {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
-  %1 = tosa.concat %0, %0 {axis = 0} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32>
+  %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+  %1 = tosa.concat %0, %0 {axis = 0 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32>
   return %1 : tensor<2x2x7x7xf32>
 }
 
 // CHECK-LABEL:   func.func @concat_
diff erent_axis(
 // CHECK-SAME:                                     %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> {
-// CHECK:           %[[VAL_1:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]] {axis = 1 : i64} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
-// CHECK:           %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_1]] {axis = 0 : i64} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32>
+// CHECK:           %[[VAL_1:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+// CHECK:           %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_1]] {axis = 0 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32>
 // CHECK:           return %[[VAL_2]] : tensor<2x2x7x7xf32>
 // CHECK:         }
 
@@ -30,15 +30,15 @@ func.func @concat_
diff erent_axis(%arg0: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf
 
 func.func @fold_concats(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
   %tmp = tensor.empty() : tensor<1x1x7x7xf32>
-  %0 = tosa.concat %arg0, %arg0 {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
-  %1 = tosa.concat %tmp, %0, %tmp {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x2x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
+  %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+  %1 = tosa.concat %tmp, %0, %tmp {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x2x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
   return %1 : tensor<1x4x7x7xf32>
 }
 
 // CHECK-LABEL:   func.func @fold_concats(
 // CHECK-SAME:                            %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
 // CHECK:           %[[VAL_1:.*]] = tensor.empty() : tensor<1x1x7x7xf32>
-// CHECK:           %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_0]], %[[VAL_0]], %[[VAL_1]] {axis = 1 : i64} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
+// CHECK:           %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_0]], %[[VAL_0]], %[[VAL_1]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
 // CHECK:           return %[[VAL_2]] : tensor<1x4x7x7xf32>
 // CHECK:         }
 
@@ -46,48 +46,48 @@ func.func @fold_concats(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
 
 func.func @nested_fold(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32> {
   %tmp = tensor.empty() : tensor<1x1x7x7xf32>
-  %0 = tosa.concat %arg0, %arg0 {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
-  %1 = tosa.concat %tmp, %0, %tmp {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x2x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
-  %2 = tosa.concat %1, %1 {axis = 1} : (tensor<1x4x7x7xf32>, tensor<1x4x7x7xf32>) -> tensor<1x8x7x7xf32>
+  %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+  %1 = tosa.concat %tmp, %0, %tmp {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x2x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
+  %2 = tosa.concat %1, %1 {axis = 1 : i32} : (tensor<1x4x7x7xf32>, tensor<1x4x7x7xf32>) -> tensor<1x8x7x7xf32>
   return %2 : tensor<1x8x7x7xf32>
 }
 
 // CHECK-LABEL:   func.func @nested_fold(
 // CHECK-SAME:                           %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32> {
 // CHECK:           %[[VAL_1:.*]] = tensor.empty() : tensor<1x1x7x7xf32>
-// CHECK:           %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_0]], %[[VAL_0]], %[[VAL_1]], %[[VAL_1]], %[[VAL_0]], %[[VAL_0]], %[[VAL_1]] {axis = 1 : i64} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32>
+// CHECK:           %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_0]], %[[VAL_0]], %[[VAL_1]], %[[VAL_1]], %[[VAL_0]], %[[VAL_0]], %[[VAL_1]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32>
 // CHECK:           return %[[VAL_2]] : tensor<1x8x7x7xf32>
 // CHECK:         }
 
 // -----
 
 func.func @wide_fold(%arg0: tensor<1x1x7x7xf32>, %arg1: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
-  %0 = tosa.concat %arg0, %arg0 {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
-  %1 = tosa.concat %arg1, %arg1 {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
-  %2 = tosa.concat %0, %1 {axis = 1} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32>
+  %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+  %1 = tosa.concat %arg1, %arg1 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+  %2 = tosa.concat %0, %1 {axis = 1 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32>
   return %2 : tensor<1x4x7x7xf32>
 }
 
 // CHECK-LABEL:   func.func @wide_fold(
 // CHECK-SAME:                         %[[VAL_0:.*]]: tensor<1x1x7x7xf32>,
 // CHECK-SAME:                         %[[VAL_1:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
-// CHECK:           %[[VAL_2:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]], %[[VAL_1]], %[[VAL_1]] {axis = 1 : i64} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
+// CHECK:           %[[VAL_2:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]], %[[VAL_1]], %[[VAL_1]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
 // CHECK:           return %[[VAL_2]] : tensor<1x4x7x7xf32>
 // CHECK:         }
 
 // -----
 
 func.func @partially_foldable(%arg0: tensor<1x1x8x8xf32>, %arg1: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> {
-  %0 = tosa.concat %arg0, %arg0 {axis = 1} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>) -> tensor<1x2x8x8xf32>
-  %1 = tosa.concat %arg1, %arg1 {axis = 2} : (tensor<1x2x4x8xf32>, tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32>
-  %2 = tosa.concat %0, %1 {axis = 1} : (tensor<1x2x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32>
+  %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>) -> tensor<1x2x8x8xf32>
+  %1 = tosa.concat %arg1, %arg1 {axis = 2 : i32} : (tensor<1x2x4x8xf32>, tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32>
+  %2 = tosa.concat %0, %1 {axis = 1 : i32} : (tensor<1x2x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32>
   return %2 : tensor<1x4x8x8xf32>
 }
 
 // CHECK-LABEL:   func.func @partially_foldable(
 // CHECK-SAME:                                  %[[VAL_0:.*]]: tensor<1x1x8x8xf32>,
 // CHECK-SAME:                                  %[[VAL_1:.*]]: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> {
-// CHECK:           %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_1]] {axis = 2 : i64} : (tensor<1x2x4x8xf32>, tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32>
-// CHECK:           %[[VAL_3:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]], %[[VAL_2]] {axis = 1 : i64} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32>
+// CHECK:           %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_1]] {axis = 2 : i32} : (tensor<1x2x4x8xf32>, tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32>
+// CHECK:           %[[VAL_3:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]], %[[VAL_2]] {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32>
 // CHECK:           return %[[VAL_3]] : tensor<1x4x8x8xf32>
 // CHECK:         }

diff  --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index d05d86eeaceac8..7c58bb10b9c5ed 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -40,7 +40,7 @@ func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>,
 func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xf32> {
   // expected-error at +2 {{failed to infer returned types}}
   // expected-error at +1 {{Cannot concat tensors with 
diff erent sizes on the non-axis dimension 1}}
-  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
+  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
 
@@ -49,7 +49,7 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
 func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
   // expected-error at +2 {{failed to infer returned types}}
   // expected-error at +1 {{'tosa.concat' op inferred type(s) 'tensor<3x2xf32>' are incompatible with return type(s) of operation 'tensor<?x?xi8>}}
-  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i64} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
+  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
   return %0 : tensor<?x?xi8>
 }
 
@@ -103,7 +103,7 @@ func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: ten
 func.func @test_reduce_sum_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
   // expected-error at +2 {{failed to infer returned types}}
   // expected-error at +1 {{'tosa.reduce_sum' op inferred type(s) 'tensor<1x3x4x5xf32>' are incompatible with return type(s) of operation 'tensor<1x3x4x5xi32>'}}
-  %0 = tosa.reduce_sum %arg0 {axis = 0 : i64} : (tensor<2x3x4x5xf32>) -> tensor<1x3x4x5xi32>
+  %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<2x3x4x5xf32>) -> tensor<1x3x4x5xi32>
   return
 }
 
@@ -112,7 +112,7 @@ func.func @test_reduce_sum_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
 func.func @test_reduce_max_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
   // expected-error at +2 {{failed to infer returned types}}
   // expected-error at +1 {{'tosa.reduce_max' op inferred type(s) 'tensor<2x3x4x1xf32>' are incompatible with return type(s) of operation 'tensor<2x3x4x1xi32>'}}
-  %0 = tosa.reduce_max %arg0 {axis = 3 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x1xi32>
+  %0 = tosa.reduce_max %arg0 {axis = 3 : i32} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x1xi32>
   return
 }
 
@@ -121,7 +121,7 @@ func.func @test_reduce_max_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
 func.func @test_reduce_min_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
   // expected-error at +2 {{failed to infer returned types}}
   // expected-error at +1 {{'tosa.reduce_min' op inferred type(s) 'tensor<2x1x4x5xf32>' are incompatible with return type(s) of operation 'tensor<2x1x4x5xi32>'}}
-  %0 = tosa.reduce_min %arg0 {axis = 1 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x1x4x5xi32>
+  %0 = tosa.reduce_min %arg0 {axis = 1 : i32} : (tensor<2x3x4x5xf32>) -> tensor<2x1x4x5xi32>
   return
 }
 
@@ -130,7 +130,7 @@ func.func @test_reduce_min_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
 func.func @test_reduce_prod_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
   // expected-error at +2 {{failed to infer returned types}}
   // expected-error at +1 {{'tosa.reduce_prod' op inferred type(s) 'tensor<2x1x4x5xf32>' are incompatible with return type(s) of operation 'tensor<2x3x4x5xf32>'}}
-  %0 = tosa.reduce_prod %arg0 {axis = 1 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>
+  %0 = tosa.reduce_prod %arg0 {axis = 1 : i32} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>
   return
 }
 
@@ -190,7 +190,7 @@ func.func @test_conv2d_zero_dim_input(%arg0: tensor<1x?x0x4xf32>, %arg1: tensor<
 
 func.func @test_avg_pool2d_static_zero_dim_input(%arg0: tensor<1x0x7x9xf32>) -> tensor<1x7x7x9xf32> {
   // expected-error at +1 {{'tosa.avg_pool2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
-    %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} 
+    %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
       : (tensor<1x0x7x9xf32>) -> tensor<1x7x7x9xf32>
     return %0 : tensor<1x7x7x9xf32>
 }
@@ -198,8 +198,8 @@ func.func @test_avg_pool2d_static_zero_dim_input(%arg0: tensor<1x0x7x9xf32>) ->
 // -----
 
 func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> {
-  // expected-error at +1 {{'tosa.avg_pool2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}  
-    %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} 
+  // expected-error at +1 {{'tosa.avg_pool2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
+    %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
       : (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
     return %0 : tensor<1x7x7x9xf32>
 }

diff  --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 40160a9452f5b0..69b22888246bb9 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -3,7 +3,7 @@
 
 func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xf32> {
   // expected-error at +1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = "tosa.argmax"(%arg0) {axis = 4 : i64} : (tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xf32>
+  %0 = "tosa.argmax"(%arg0) {axis = 4 : i32} : (tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xf32>
   return %0 : tensor<1x1x1x1x29x4xf32>
 }
 
@@ -11,7 +11,7 @@ func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x
 
 func.func @test_reduce_all(%arg0: tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x1x21x3xi1> {
   // expected-error at +1 {{'tosa.reduce_all' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = "tosa.reduce_all"(%arg0) {axis = 4 : i64} : (tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x1x21x3xi1>
+  %0 = "tosa.reduce_all"(%arg0) {axis = 4 : i32} : (tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x1x21x3xi1>
   return %0 : tensor<1x1x1x1x1x21x3xi1>
 }
 
@@ -19,7 +19,7 @@ func.func @test_reduce_all(%arg0: tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x
 
 func.func @test_reduce_any(%arg0: tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x13x21x3xi1> {
   // expected-error at +1 {{'tosa.reduce_any' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = "tosa.reduce_any"(%arg0) {axis = 0 : i64} : (tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x13x21x3xi1>
+  %0 = "tosa.reduce_any"(%arg0) {axis = 0 : i32} : (tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x13x21x3xi1>
   return %0 : tensor<1x1x1x1x13x21x3xi1>
 }
 
@@ -27,7 +27,7 @@ func.func @test_reduce_any(%arg0: tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x
 
 func.func @test_reduce_max(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
   // expected-error at +1 {{'tosa.reduce_max' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  %0 = "tosa.reduce_max"(%arg0) {axis = 0 : i32} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
   return %0 : tensor<1x1x1x1x13x21x3xf32>
 }
 
@@ -35,7 +35,7 @@ func.func @test_reduce_max(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1
 
 func.func @test_reduce_min(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
   // expected-error at +1 {{'tosa.reduce_min' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = "tosa.reduce_min"(%arg0) {axis = 0 : i64} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  %0 = "tosa.reduce_min"(%arg0) {axis = 0 : i32} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
   return %0 : tensor<1x1x1x1x13x21x3xf32>
 }
 
@@ -43,7 +43,7 @@ func.func @test_reduce_min(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1
 
 func.func @test_reduce_prod(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
   // expected-error at +1 {{'tosa.reduce_prod' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = "tosa.reduce_prod"(%arg0) {axis = 0 : i64} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  %0 = "tosa.reduce_prod"(%arg0) {axis = 0 : i32} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
   return %0 : tensor<1x1x1x1x13x21x3xf32>
 }
 
@@ -51,7 +51,7 @@ func.func @test_reduce_prod(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x
 
 func.func @test_reduce_sum(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
   // expected-error at +1 {{'tosa.reduce_sum' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i32} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
   return %0 : tensor<1x1x1x1x13x21x3xf32>
 }
 
@@ -59,7 +59,7 @@ func.func @test_reduce_sum(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1
 
 func.func @test_concat(%arg0: tensor<1x1x1x13x21x3x8xf32>, %arg1: tensor<1x1x1x13x21x3x8xf32>) -> tensor<1x1x1x26x21x3x8xf32> {
   // expected-error at +1 {{'tosa.concat' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = "tosa.concat"(%arg0, %arg1) {axis = 3 : i64} : (tensor<1x1x1x13x21x3x8xf32>, tensor<1x1x1x13x21x3x8xf32>) -> tensor<1x1x1x26x21x3x8xf32>
+  %0 = "tosa.concat"(%arg0, %arg1) {axis = 3 : i32} : (tensor<1x1x1x13x21x3x8xf32>, tensor<1x1x1x13x21x3x8xf32>) -> tensor<1x1x1x26x21x3x8xf32>
   return %0 : tensor<1x1x1x26x21x3x8xf32>
 }
 
@@ -75,7 +75,7 @@ func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x1x1x1x1x1x819xf3
 
 func.func @test_reverse(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
   // expected-error at +1 {{'tosa.reverse' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  %0 = "tosa.reverse"(%arg0) {axis = 0 : i32} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
   return %0 : tensor<1x1x1x1x13x21x3xf32>
 }
 

diff  --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 0adf4bdcd4354a..754843969ef8ef 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -5,7 +5,7 @@
 // -----
 // CHECK-LABEL: argmax
 func.func @test_argmax(%arg0: tensor<14x19xf32>) -> tensor<14xi32> {
-  %0 = tosa.argmax %arg0 {axis = 1 : i64} : (tensor<14x19xf32>) -> tensor<14xi32>
+  %0 = tosa.argmax %arg0 {axis = 1 : i32} : (tensor<14x19xf32>) -> tensor<14xi32>
   return %0 : tensor<14xi32>
 }
 
@@ -365,7 +365,7 @@ func.func @test_greater_equal(%arg0: tensor<13x1x3xf32>, %arg1: tensor<13x21x3xf
 // -----
 // CHECK-LABEL: reduce_all
 func.func @test_reduce_all(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> {
-  %0 = tosa.reduce_all %arg0 {axis = 0 : i64} : (tensor<13x21x3xi1>) -> tensor<1x21x3xi1>
+  %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<13x21x3xi1>) -> tensor<1x21x3xi1>
   %1 = tosa.reshape %0 {new_shape = array<i64: 21, 3>} : (tensor<1x21x3xi1>) -> tensor<21x3xi1>
   return %1 : tensor<21x3xi1>
 }
@@ -373,7 +373,7 @@ func.func @test_reduce_all(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> {
 // -----
 // CHECK-LABEL: reduce_any
 func.func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> {
-  %0 = tosa.reduce_any %arg0 {axis = 0 : i64} : (tensor<13x21x3xi1>) -> tensor<1x21x3xi1>
+  %0 = tosa.reduce_any %arg0 {axis = 0 : i32} : (tensor<13x21x3xi1>) -> tensor<1x21x3xi1>
   %1 = tosa.reshape %0 {new_shape = array<i64: 21, 3>} : (tensor<1x21x3xi1>) -> tensor<21x3xi1>
   return %1 : tensor<21x3xi1>
 }
@@ -381,7 +381,7 @@ func.func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> {
 // -----
 // CHECK-LABEL: reduce_max
 func.func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
-  %0 = tosa.reduce_max %arg0 {axis = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
+  %0 = tosa.reduce_max %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
   %1 = tosa.reshape %0 {new_shape = array<i64: 21, 3>} : (tensor<1x21x3xf32>) -> tensor<21x3xf32>
   return %1 : tensor<21x3xf32>
 }
@@ -389,7 +389,7 @@ func.func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
 // -----
 // CHECK-LABEL: reduce_min
 func.func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
-  %0 = tosa.reduce_min %arg0 {axis = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
+  %0 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
   %1 = tosa.reshape %0 {new_shape = array<i64: 21, 3>} : (tensor<1x21x3xf32>) -> tensor<21x3xf32>
   return %1 : tensor<21x3xf32>
 }
@@ -397,7 +397,7 @@ func.func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
 // -----
 // CHECK-LABEL: reduce_product
 func.func @test_reduce_product(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
-  %0 = tosa.reduce_prod %arg0 {axis = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
+  %0 = tosa.reduce_prod %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
   %1 = tosa.reshape %0 {new_shape = array<i64: 21, 3>} : (tensor<1x21x3xf32>) -> tensor<21x3xf32>
   return %1 : tensor<21x3xf32>
 }
@@ -405,7 +405,7 @@ func.func @test_reduce_product(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
 // -----
 // CHECK-LABEL: reduce_sum
 func.func @test_reduce_sum(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
-  %0 = tosa.reduce_sum %arg0 {axis = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
+  %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
   %1 = tosa.reshape %0 {new_shape = array<i64: 21, 3>} : (tensor<1x21x3xf32>) -> tensor<21x3xf32>
   return %1 : tensor<21x3xf32>
 }
@@ -413,7 +413,7 @@ func.func @test_reduce_sum(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
 // -----
 // CHECK-LABEL: concat
 func.func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<26x21x3xf32> {
-  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i64} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<26x21x3xf32>
+  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<26x21x3xf32>
   return %0 : tensor<26x21x3xf32>
 }
 
@@ -442,7 +442,7 @@ func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x819xf32> {
 // -----
 // CHECK-LABEL: reverse
 func.func @test_reverse(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
-  %0 = tosa.reverse %arg0 {axis = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 

diff  --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
index 5ed8e295b33c89..5f9fe68c578925 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -2,8 +2,8 @@
 
 // CHECK-LABEL: @transpose_conv2d
 func.func @transpose_conv2d(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x18x19x5xf32> {
-  // CHECK: %[[REV1:.+]] = tosa.reverse %arg1 {axis = 1 : i64}
-  // CHECK: %[[REV2:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i64}
+  // CHECK: %[[REV1:.+]] = tosa.reverse %arg1 {axis = 1 : i32}
+  // CHECK: %[[REV2:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32}
   // CHECK: tosa.conv2d %arg0, %[[REV2]], %arg2
   // CHECK-SAME: dilation = array<i64: 1, 1>, pad = array<i64: 2, 2, 5, 5>, stride = array<i64: 1, 1>
   %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x18x19x5xf32>
@@ -15,8 +15,8 @@ func.func @transpose_conv2d(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3x
 // CHECK-LABEL: @transpose_conv2d_quantized
 
 func.func @transpose_conv2d_quantized(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x18x19x5xi32>) {
-  // CHECK: %[[REV1:.+]] = tosa.reverse %arg1 {axis = 1 : i64}
-  // CHECK: %[[REV2:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i64}
+  // CHECK: %[[REV1:.+]] = tosa.reverse %arg1 {axis = 1 : i32}
+  // CHECK: %[[REV2:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32}
   // CHECK: tosa.conv2d %arg0, %[[REV2]], %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 2, 2, 5, 5>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>}
   %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>) -> tensor<2x18x19x5xi32>
   return %0 : tensor<2x18x19x5xi32>
@@ -26,8 +26,8 @@ func.func @transpose_conv2d_quantized(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor
 
 // CHECK-LABEL: @transpose_conv2d_quantized_padded
 func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x21x26x5xi32>) {
-  // CHECK-DAG: %[[REV0:.+]] = tosa.reverse %0 {axis = 2 : i64}
-  // CHECK-DAG: %[[REV1:.+]] = tosa.reverse %arg1 {axis = 1 : i64}
+  // CHECK-DAG: %[[REV0:.+]] = tosa.reverse %0 {axis = 2 : i32}
+  // CHECK-DAG: %[[REV1:.+]] = tosa.reverse %arg1 {axis = 1 : i32}
   // CHECK: tosa.conv2d %arg0, %1, %arg2
   // CHECK-SAME: dilation = array<i64: 1, 1>, pad = array<i64: 3, 4, 8, 9>,
   // CHECK-SAME: quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>}
@@ -50,8 +50,8 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<
   // CHECK-DAG: %[[RESW1:.+]]  = tosa.reshape %[[PADW]] {new_shape = array<i64: 5, 2, 2, 2, 3, 3>}
   // CHECK-DAG: %[[TRANS:.+]]  = tosa.transpose %[[RESW1]], %[[TRANSV]]
   // CHECK-DAG: %[[RESW2:.+]]  = tosa.reshape %[[TRANS]] {new_shape = array<i64: 30, 2, 2, 3>}
-  // CHECK-DAG: %[[REV1:.+]]  = tosa.reverse %[[RESW2]] {axis = 1 : i64}
-  // CHECK-DAG: %[[NEWWEIGHT:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i64}
+  // CHECK-DAG: %[[REV1:.+]]  = tosa.reverse %[[RESW2]] {axis = 1 : i32}
+  // CHECK-DAG: %[[NEWWEIGHT:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32}
 
   // Pad out the input matrix to handle the transpose conv.
   // CHECK-DAG: %[[PAD:.+]]  = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>}
@@ -83,8 +83,8 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
   // CHECK-DAG: %[[RESW1:.+]]  = tosa.reshape %[[PADW]] {new_shape = array<i64: 5, 2, 2, 2, 3, 3>}
   // CHECK-DAG: %[[TRANS:.+]]  = tosa.transpose %[[RESW1]], %[[TRANSV]]
   // CHECK-DAG: %[[RESW2:.+]]  = tosa.reshape %[[TRANS]] {new_shape = array<i64: 30, 2, 2, 3>}
-  // CHECK-DAG: %[[REV1:.+]]  = tosa.reverse %[[RESW2]] {axis = 1 : i64}
-  // CHECK-DAG: %[[NEWWEIGHT:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i64}
+  // CHECK-DAG: %[[REV1:.+]]  = tosa.reverse %[[RESW2]] {axis = 1 : i32}
+  // CHECK-DAG: %[[NEWWEIGHT:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32}
 
   // Pad out the input matrix to handle the transpose conv.
   // CHECK-DAG: %[[PAD:.+]]  = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>}
@@ -121,7 +121,7 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 :
   // CHECK: %[[RESHAPE_WEIGHT_0:.+]] = tosa.reshape %[[PAD_WEIGHT]] {new_shape = array<i64: 1, 2, 1, 1, 2, 1>}
   // CHECK: %[[TRANSPOSE_WEIGHT:.+]] = tosa.transpose %[[RESHAPE_WEIGHT_0]], %[[WEIGHT_PERMS]]
   // CHECK: %[[RESHAPE_WEIGHT_1:.+]] = tosa.reshape %[[TRANSPOSE_WEIGHT]] {new_shape = array<i64: 2, 2, 1, 1>}
-  // CHECK: %[[REVERSE:.+]] = tosa.reverse %[[RESHAPE_WEIGHT_1]] {axis = 1 : i64}
+  // CHECK: %[[REVERSE:.+]] = tosa.reverse %[[RESHAPE_WEIGHT_1]] {axis = 1 : i32}
   // CHECK: %[[PAD_INPUT:.+]] = tosa.pad %arg0, %[[INPUT_PAD]] {quantization_info = #tosa.pad_quant<input_zp = -103>}
   // CHECK: %[[CONV:.+]] = tosa.conv2d %[[PAD_INPUT]], %[[REVERSE]], %[[ZERO]]
   // CHECK-SAME{literal}: dilation = [1, 1], pad = [0, 0, 0, 0], quantization_info = #tosa.conv_quant<input_zp = -103, weight_zp = 93>, stride = [1, 1]}

diff  --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index ee668c8b21c4ce..cb96b2a8a0d193 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -51,8 +51,8 @@ func.func @test_unary_f32(%arg0 : tensor<4xf32>) -> () {
   // CHECK: tosa.reciprocal %arg0 : (tensor<4xf32>) -> tensor<4xf32>
   %7 = tosa.reciprocal %arg0 : (tensor<4xf32>) -> tensor<*xf32>
 
-  // CHECK: tosa.reverse %arg0 {axis = 0 : i64} : (tensor<4xf32>) -> tensor<4xf32>
-  %8 = tosa.reverse %arg0 { axis = 0 : i64 } : (tensor<4xf32>) -> tensor<?xf32>
+  // CHECK: tosa.reverse %arg0 {axis = 0 : i32} : (tensor<4xf32>) -> tensor<4xf32>
+  %8 = tosa.reverse %arg0 { axis = 0 : i32 } : (tensor<4xf32>) -> tensor<?xf32>
 
   // CHECK: tosa.rsqrt %arg0 : (tensor<4xf32>) -> tensor<4xf32>
   %9 = tosa.rsqrt %arg0 : (tensor<4xf32>) -> tensor<*xf32>
@@ -90,8 +90,8 @@ func.func @test_unary_i32(%arg0 : tensor<4xi32>) -> () {
   // CHECK: tosa.negate %arg0 : (tensor<4xi32>) -> tensor<4xi32>
   %4 = tosa.negate %arg0 : (tensor<4xi32>) -> tensor<*xi32>
 
-  // CHECK: tosa.reverse %arg0 {axis = 0 : i64} : (tensor<4xi32>) -> tensor<4xi32>
-  %5 = tosa.reverse %arg0 { axis = 0 : i64 } : (tensor<4xi32>) -> tensor<?xi32>
+  // CHECK: tosa.reverse %arg0 {axis = 0 : i32} : (tensor<4xi32>) -> tensor<4xi32>
+  %5 = tosa.reverse %arg0 { axis = 0 : i32 } : (tensor<4xi32>) -> tensor<?xi32>
 
   // CHECK: tosa.rescale %arg0 {{.+}} : (tensor<4xi32>) -> tensor<4xi16>
   %6 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 42, 43>, shift = array<i32: 14, 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<4xi32>) -> tensor<*xi16>
@@ -248,11 +248,11 @@ func.func @test_select_i32(%arg0 : tensor<4xi1>, %arg1 : tensor<i32>, %arg2 : te
 
 // CHECK-LABEL: @test_static_argmax
 func.func @test_static_argmax(%arg0 : tensor<2x3xi32>) -> () {
-  // CHECK: tosa.argmax %arg0 {axis = 0 : i64} : (tensor<2x3xi32>) -> tensor<3xi32>
-  %0 = tosa.argmax %arg0 {axis = 0 : i64} : (tensor<2x3xi32>) -> tensor<?xi32>
+  // CHECK: tosa.argmax %arg0 {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<3xi32>
+  %0 = tosa.argmax %arg0 {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<?xi32>
 
-  // CHECK: tosa.argmax %arg0 {axis = 1 : i64} : (tensor<2x3xi32>) -> tensor<2xi32>
-  %1 = tosa.argmax %arg0 {axis = 1 : i64} : (tensor<2x3xi32>) -> tensor<?xi32>
+  // CHECK: tosa.argmax %arg0 {axis = 1 : i32} : (tensor<2x3xi32>) -> tensor<2xi32>
+  %1 = tosa.argmax %arg0 {axis = 1 : i32} : (tensor<2x3xi32>) -> tensor<?xi32>
   return
 }
 
@@ -260,11 +260,11 @@ func.func @test_static_argmax(%arg0 : tensor<2x3xi32>) -> () {
 
 // CHECK-LABEL: @test_dynamic_argmax
 func.func @test_dynamic_argmax(%arg0 : tensor<2x?xi32>) -> () {
-  // CHECK: tosa.argmax %arg0 {axis = 0 : i64} : (tensor<2x?xi32>) -> tensor<?xi32>
-  %0 = tosa.argmax %arg0 {axis = 0 : i64} : (tensor<2x?xi32>) -> tensor<?xi32>
+  // CHECK: tosa.argmax %arg0 {axis = 0 : i32} : (tensor<2x?xi32>) -> tensor<?xi32>
+  %0 = tosa.argmax %arg0 {axis = 0 : i32} : (tensor<2x?xi32>) -> tensor<?xi32>
 
-  // CHECK: tosa.argmax %arg0 {axis = 1 : i64} : (tensor<2x?xi32>) -> tensor<2xi32>
-  %1 = tosa.argmax %arg0 {axis = 1 : i64} : (tensor<2x?xi32>) -> tensor<?xi32>
+  // CHECK: tosa.argmax %arg0 {axis = 1 : i32} : (tensor<2x?xi32>) -> tensor<2xi32>
+  %1 = tosa.argmax %arg0 {axis = 1 : i32} : (tensor<2x?xi32>) -> tensor<?xi32>
   return
 }
 
@@ -406,20 +406,20 @@ func.func @test_dynamic_reshape(%arg0 : tensor<4x?xi32>) -> () {
 
 // CHECK: @test_reduce_binary
 func.func @test_reduce_binary(%arg0 : tensor<2x3x?x?xi1>) -> () {
-  // CHECK: tosa.reduce_all %arg0 {axis = 0 : i64} : (tensor<2x3x?x?xi1>) -> tensor<1x3x?x?xi1>
-  %0 = tosa.reduce_all %arg0 {axis = 0 : i64} : (tensor<2x3x?x?xi1>) -> tensor<?x?x?x?xi1>
+  // CHECK: tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<2x3x?x?xi1>) -> tensor<1x3x?x?xi1>
+  %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<2x3x?x?xi1>) -> tensor<?x?x?x?xi1>
 
-  // CHECK: tosa.reduce_all %arg0 {axis = 1 : i64} : (tensor<2x3x?x?xi1>) -> tensor<2x1x?x?xi1>
-  %1 = tosa.reduce_all %arg0 {axis = 1 : i64} : (tensor<2x3x?x?xi1>) -> tensor<?x?x?x?xi1>
+  // CHECK: tosa.reduce_all %arg0 {axis = 1 : i32} : (tensor<2x3x?x?xi1>) -> tensor<2x1x?x?xi1>
+  %1 = tosa.reduce_all %arg0 {axis = 1 : i32} : (tensor<2x3x?x?xi1>) -> tensor<?x?x?x?xi1>
 
-  // CHECK: tosa.reduce_all %arg0 {axis = 2 : i64} : (tensor<2x3x?x?xi1>) -> tensor<2x3x1x?xi1>
-  %2 = tosa.reduce_all %arg0 {axis = 2 : i64} : (tensor<2x3x?x?xi1>) -> tensor<?x?x?x?xi1>
+  // CHECK: tosa.reduce_all %arg0 {axis = 2 : i32} : (tensor<2x3x?x?xi1>) -> tensor<2x3x1x?xi1>
+  %2 = tosa.reduce_all %arg0 {axis = 2 : i32} : (tensor<2x3x?x?xi1>) -> tensor<?x?x?x?xi1>
 
-  // CHECK: tosa.reduce_all %arg0 {axis = 3 : i64} : (tensor<2x3x?x?xi1>) -> tensor<2x3x?x1xi1>
-  %3 = tosa.reduce_all %arg0 {axis = 3 : i64} : (tensor<2x3x?x?xi1>) -> tensor<?x?x?x?xi1>
+  // CHECK: tosa.reduce_all %arg0 {axis = 3 : i32} : (tensor<2x3x?x?xi1>) -> tensor<2x3x?x1xi1>
+  %3 = tosa.reduce_all %arg0 {axis = 3 : i32} : (tensor<2x3x?x?xi1>) -> tensor<?x?x?x?xi1>
 
-  // CHECK: tosa.reduce_any %arg0 {axis = 0 : i64} : (tensor<2x3x?x?xi1>) -> tensor<1x3x?x?xi1>
-  %4 = tosa.reduce_any %arg0 {axis = 0 : i64} : (tensor<2x3x?x?xi1>) -> tensor<?x?x?x?xi1>
+  // CHECK: tosa.reduce_any %arg0 {axis = 0 : i32} : (tensor<2x3x?x?xi1>) -> tensor<1x3x?x?xi1>
+  %4 = tosa.reduce_any %arg0 {axis = 0 : i32} : (tensor<2x3x?x?xi1>) -> tensor<?x?x?x?xi1>
 
   return
 }
@@ -428,26 +428,26 @@ func.func @test_reduce_binary(%arg0 : tensor<2x3x?x?xi1>) -> () {
 
 // CHECK: @test_reduce_float
 func.func @test_reduce_float(%arg0 : tensor<2x3x?x?xf32>) -> () {
-  // CHECK: tosa.reduce_sum %arg0 {axis = 0 : i64} : (tensor<2x3x?x?xf32>) -> tensor<1x3x?x?xf32>
-  %0 = tosa.reduce_sum %arg0 {axis = 0 : i64} : (tensor<2x3x?x?xf32>) -> tensor<?x?x?x?xf32>
+  // CHECK: tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<2x3x?x?xf32>) -> tensor<1x3x?x?xf32>
+  %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<2x3x?x?xf32>) -> tensor<?x?x?x?xf32>
 
-  // CHECK: tosa.reduce_sum %arg0 {axis = 1 : i64} : (tensor<2x3x?x?xf32>) -> tensor<2x1x?x?xf32>
-  %1 = tosa.reduce_sum %arg0 {axis = 1 : i64} : (tensor<2x3x?x?xf32>) -> tensor<?x?x?x?xf32>
+  // CHECK: tosa.reduce_sum %arg0 {axis = 1 : i32} : (tensor<2x3x?x?xf32>) -> tensor<2x1x?x?xf32>
+  %1 = tosa.reduce_sum %arg0 {axis = 1 : i32} : (tensor<2x3x?x?xf32>) -> tensor<?x?x?x?xf32>
 
-  // CHECK: tosa.reduce_sum %arg0 {axis = 2 : i64} : (tensor<2x3x?x?xf32>) -> tensor<2x3x1x?xf32>
-  %2 = tosa.reduce_sum %arg0 {axis = 2 : i64} : (tensor<2x3x?x?xf32>) -> tensor<?x?x?x?xf32>
+  // CHECK: tosa.reduce_sum %arg0 {axis = 2 : i32} : (tensor<2x3x?x?xf32>) -> tensor<2x3x1x?xf32>
+  %2 = tosa.reduce_sum %arg0 {axis = 2 : i32} : (tensor<2x3x?x?xf32>) -> tensor<?x?x?x?xf32>
 
-  // CHECK: tosa.reduce_sum %arg0 {axis = 3 : i64} : (tensor<2x3x?x?xf32>) -> tensor<2x3x?x1xf32>
-  %3 = tosa.reduce_sum %arg0 {axis = 3 : i64} : (tensor<2x3x?x?xf32>) -> tensor<?x?x?x?xf32>
+  // CHECK: tosa.reduce_sum %arg0 {axis = 3 : i32} : (tensor<2x3x?x?xf32>) -> tensor<2x3x?x1xf32>
+  %3 = tosa.reduce_sum %arg0 {axis = 3 : i32} : (tensor<2x3x?x?xf32>) -> tensor<?x?x?x?xf32>
 
-  // CHECK: tosa.reduce_max %arg0 {axis = 3 : i64} : (tensor<2x3x?x?xf32>) -> tensor<2x3x?x1xf32>
-  %4 = tosa.reduce_max %arg0 {axis = 3 : i64} : (tensor<2x3x?x?xf32>) -> tensor<?x?x?x?xf32>
+  // CHECK: tosa.reduce_max %arg0 {axis = 3 : i32} : (tensor<2x3x?x?xf32>) -> tensor<2x3x?x1xf32>
+  %4 = tosa.reduce_max %arg0 {axis = 3 : i32} : (tensor<2x3x?x?xf32>) -> tensor<?x?x?x?xf32>
 
-  // CHECK: tosa.reduce_min %arg0 {axis = 3 : i64} : (tensor<2x3x?x?xf32>) -> tensor<2x3x?x1xf32>
-  %5 = tosa.reduce_min %arg0 {axis = 3 : i64} : (tensor<2x3x?x?xf32>) -> tensor<?x?x?x?xf32>
+  // CHECK: tosa.reduce_min %arg0 {axis = 3 : i32} : (tensor<2x3x?x?xf32>) -> tensor<2x3x?x1xf32>
+  %5 = tosa.reduce_min %arg0 {axis = 3 : i32} : (tensor<2x3x?x?xf32>) -> tensor<?x?x?x?xf32>
 
-  // CHECK: tosa.reduce_prod %arg0 {axis = 3 : i64} : (tensor<2x3x?x?xf32>) -> tensor<2x3x?x1xf32>
-  %6 = tosa.reduce_prod %arg0 {axis = 3 : i64} : (tensor<2x3x?x?xf32>) -> tensor<?x?x?x?xf32>
+  // CHECK: tosa.reduce_prod %arg0 {axis = 3 : i32} : (tensor<2x3x?x?xf32>) -> tensor<2x3x?x1xf32>
+  %6 = tosa.reduce_prod %arg0 {axis = 3 : i32} : (tensor<2x3x?x?xf32>) -> tensor<?x?x?x?xf32>
 
   return
 }
@@ -456,8 +456,8 @@ func.func @test_reduce_float(%arg0 : tensor<2x3x?x?xf32>) -> () {
 
 // CHECK-LABEL: @test_concat
 func.func @test_concat(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> () {
-  // CHECK: tosa.concat %arg0, %arg1 {axis = 0 : i64} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<3x2xf32>
-  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i64} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
+  // CHECK: tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<3x2xf32>
+  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
 
   return
 }
@@ -466,8 +466,8 @@ func.func @test_concat(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> () {
 
 // CHECK-LABEL: @test_concat_dynamic
 func.func @test_concat_dynamic(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x?xf32>) -> () {
-  // CHECK: tosa.concat %arg0, %arg1 {axis = 0 : i64} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<3x2xf32>
-  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i64} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<?x?xf32>
+  // CHECK: tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<3x2xf32>
+  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<?x?xf32>
 
   return
 }
@@ -476,8 +476,8 @@ func.func @test_concat_dynamic(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x?xf32>)
 
 // CHECK-LABEL: @test_concat_dynamic_axis
 func.func @test_concat_dynamic_axis(%arg0 : tensor<?x2xf32>, %arg1 : tensor<2x2xf32>) -> () {
-  // CHECK: tosa.concat %arg0, %arg1 {axis = 0 : i64} : (tensor<?x2xf32>, tensor<2x2xf32>) -> tensor<?x2xf32>
-  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i64} : (tensor<?x2xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
+  // CHECK: tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<?x2xf32>, tensor<2x2xf32>) -> tensor<?x2xf32>
+  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<?x2xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
 
   return
 }
@@ -486,8 +486,8 @@ func.func @test_concat_dynamic_axis(%arg0 : tensor<?x2xf32>, %arg1 : tensor<2x2x
 
 // CHECK-LABEL: @test_concat_axis_1
 func.func @test_concat_axis_1(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> () {
-  // CHECK: tosa.concat %arg0, %arg1 {axis = 1 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<2x3xf32>
-  %0 = tosa.concat %arg0, %arg1 {axis = 1 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
+  // CHECK: tosa.concat %arg0, %arg1 {axis = 1 : i32} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<2x3xf32>
+  %0 = tosa.concat %arg0, %arg1 {axis = 1 : i32} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
 
   return
 }
@@ -1165,7 +1165,7 @@ func.func @while_test(%arg0 : tensor<i32>, %arg1 : tensor<1xi32>) -> () {
 
     // CHECK:      tosa.concat
     // CHECK-SAME: (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
-    %3 = tosa.concat %arg3, %arg3 {axis = 0 : i64} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
+    %3 = tosa.concat %arg3, %arg3 {axis = 0 : i32} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
 
     // CHECK:      tosa.yield
     // CHECK-SAME: tensor<i32>


        


More information about the Mlir-commits mailing list