[Mlir-commits] [mlir] 729f958 - [TOSA] Add SameOperandsAndResultRank to TOSA Ops (#104501)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 22 05:21:08 PST 2025


Author: Tai Ly
Date: 2025-01-22T13:21:04Z
New Revision: 729f958c4f7548c2d5be5f024b7254cd3ea25c64

URL: https://github.com/llvm/llvm-project/commit/729f958c4f7548c2d5be5f024b7254cd3ea25c64
DIFF: https://github.com/llvm/llvm-project/commit/729f958c4f7548c2d5be5f024b7254cd3ea25c64.diff

LOG: [TOSA] Add SameOperandsAndResultRank to TOSA Ops (#104501)

[note: this is blocked by:
https://github.com/tensorflow/tensorflow/pull/73891 otherwise tensorflow
may have lit test failures]

This patch adds SameOperandsAndResultRank trait to TOSA operators with
ResultsBroadcastableShape trait. SameOperandsAndResultRank trait
requiring that all operands and results have matching ranks unless the
operand/result is unranked.

This also renders the TosaMakeBroadcastable pass unnecessary - but this
pass is left in for now just in case it is still used in some flows. The
lit test, broadcast.mlir, is removed.

This also adds verify of the SameOperandsAndResultRank trait in the
TosaInferShapes pass to validate inferred shapes.

Signed-off-by: Tai Ly <tai.ly at arm.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
    mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
    mlir/test/Dialect/Tosa/constant_folding.mlir
    mlir/test/Dialect/Tosa/inlining.mlir
    mlir/test/Dialect/Tosa/invalid.mlir
    mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Removed: 
    mlir/test/Dialect/Tosa/broadcast.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 47cda3c9f481ee..4975530a9588ca 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -231,7 +231,7 @@ def TosaResolvableShapeOperands : NativeOpTrait<"TosaResolvableShapeOperands"> {
 //===----------------------------------------------------------------------===//
 
 class Tosa_Op<string mnemonic, list<Trait> traits = []> :
-    Op<Tosa_Dialect, mnemonic, !listconcat(traits, [TosaOpInterface, 
+    Op<Tosa_Dialect, mnemonic, !listconcat(traits, [TosaOpInterface,
     TosaResolvableShapeOperands])> {
 }
 
@@ -241,6 +241,7 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
                                         ["inferReturnTypeComponents"]>,
               ResultsBroadcastableShape,
               TosaElementwiseOperator,
+              SameOperandsAndResultRank,
               Pure])> {
   let assemblyFormat =
       "operands attr-dict `:` functional-type(operands, results)";

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index d08d5fea66310e..3c1ca6aac90960 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -303,6 +303,32 @@ void propagateShapesInRegion(Region &region, TypeModificationState &state) {
   }
 }
 
+/// Recursively validate tosa ops with SameOperandsAndResultRank trait in region
+/// and all nested regions
+void validateSameOperandsAndResultRankTrait(Region &region) {
+  int errs = 0;
+  for (auto &block : region) {
+    for (auto &op : block) {
+      if (!op.getDialect() ||
+          op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
+        continue;
+      if (op.hasTrait<OpTrait::SameOperandsAndResultRank>()) {
+        if (OpTrait::impl::verifySameOperandsAndResultRank(&op).failed()) {
+          errs++;
+        }
+      }
+      WhileOp whileOp = dyn_cast<WhileOp>(op);
+      IfOp ifOp = dyn_cast<IfOp>(op);
+      if (whileOp || ifOp) {
+        // recurse into whileOp's regions
+        for (auto &next : op.getRegions()) {
+          validateSameOperandsAndResultRankTrait(next);
+        }
+      }
+    }
+  }
+}
+
 /// Pass that performs shape propagation across TOSA operations. This includes
 /// migrating to within the regions of if/while operations.
 struct TosaInferShapes
@@ -313,6 +339,8 @@ struct TosaInferShapes
     TypeModificationState state;
     propagateShapesInRegion(func.getBody(), state);
     state.commit();
+
+    validateSameOperandsAndResultRankTrait(func.getBody());
   }
 };
 } // namespace

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index ea1b79cbd9507f..75b48f2b06d899 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -45,3 +45,11 @@ func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32,
   %0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
   return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
 }
+
+// -----
+
+func.func @test_add_2d_
diff erent_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+  // expected-error at +1 {{'tosa.add' op operands don't have matching ranks}}
+  %0 = "tosa.add"(%arg0, %arg1) : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
+  return %0 : tensor<2x3x4xf32>
+}

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index e2c95ba7a0c6b9..f860dca85c9e9c 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -94,6 +94,7 @@ func.func @test_add_0d(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
   // CHECK: } -> tensor<f32>
   %0 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
 
+
   // CHECK: return [[RESULT]] : tensor<f32>
   return %0 : tensor<f32>
 }
@@ -103,20 +104,20 @@ func.func @test_add_0d(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
 // CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (0, d1)>
 
-// CHECK-LABEL:   func.func @test_add_0d_broadcast(
+// CHECK-LABEL:   func.func @test_add_2d_broadcast(
 // CHECK-SAME:                                     %[[ARG0:.*]]: tensor<2x1xf32>,
-// CHECK-SAME:                                     %[[ARG1:.*]]: tensor<f32>) -> tensor<2x1xf32> {
-// CHECK:           %[[EXPANDED:.*]] = tensor.expand_shape %[[ARG1]] [] output_shape [1, 1] : tensor<f32> into tensor<1x1xf32>
+// CHECK-SAME:                                     %[[ARG1:.*]]: tensor<1x1xf32>) -> tensor<2x1xf32> {
 // CHECK:           %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<2x1xf32>
-// CHECK:           %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]], %[[EXPANDED]] : tensor<2x1xf32>, tensor<1x1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<2x1xf32>) {
+// CHECK:           %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]], %[[ARG1]] : tensor<2x1xf32>, tensor<1x1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<2x1xf32>) {
 // CHECK:           ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
 // CHECK:             %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] : f32
 // CHECK:             linalg.yield %[[ADD]] : f32
 // CHECK:           } -> tensor<2x1xf32>
 // CHECK:           return %[[RESULT]] : tensor<2x1xf32>
 // CHECK:         }
-func.func @test_add_0d_broadcast(%arg0: tensor<2x1xf32>, %arg1: tensor<f32>) -> tensor<2x1xf32> {
-  %0 = tosa.add %arg0, %arg1 : (tensor<2x1xf32>, tensor<f32>) -> tensor<2x1xf32>
+func.func @test_add_2d_broadcast(%arg0: tensor<2x1xf32>, %arg1: tensor<1x1xf32>) -> tensor<2x1xf32> {
+  // tosa element-wise operators now require operands of equal ranks
+  %0 = tosa.add %arg0, %arg1 : (tensor<2x1xf32>, tensor<1x1xf32>) -> tensor<2x1xf32>
   return %0 : tensor<2x1xf32>
 }
 
@@ -383,28 +384,6 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
 
 // -----
 
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (0, d1, d2)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-LABEL: @test_add_2d_
diff erent_ranks
-// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
-// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
-func.func @test_add_2d_
diff erent_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
-
-  // CHECK: %[[ARG0_EXPANDED:.*]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [1, 3, 4] : tensor<3x4xf32> into tensor<1x3x4xf32>
-  // CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<2x3x4xf32>
-  // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG0_EXPANDED]], %[[ARG1]] : tensor<1x3x4xf32>, tensor<2x3x4xf32>) outs(%[[VAL_0]] : tensor<2x3x4xf32>) {
-  // CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
-  // CHECK:   %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32
-  // CHECK:   linalg.yield %[[VAL_4]] : f32
-  // CHECK: } -> tensor<2x3x4xf32>
-  %0 = tosa.add %arg0, %arg1 : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
-
-  // CHECK: return %[[RESULT]] : tensor<2x3x4xf32>
-  return %0 : tensor<2x3x4xf32>
-}
-
-// -----
-
 // CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d0, 0)>
 // CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-LABEL: @test_select_2d_one_dynamic

diff  --git a/mlir/test/Dialect/Tosa/broadcast.mlir b/mlir/test/Dialect/Tosa/broadcast.mlir
deleted file mode 100644
index 7613aa3b8dd03d..00000000000000
--- a/mlir/test/Dialect/Tosa/broadcast.mlir
+++ /dev/null
@@ -1,285 +0,0 @@
-// RUN: mlir-opt --tosa-make-broadcastable %s | FileCheck %s
-
-// -----
-// CHECK-LABEL: broadcast0
-func.func @test_broadcast0(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
-  //  CHECK-NOT: reshape
-  %0 = tosa.add %arg0, %arg1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
-  return %0 : tensor<1xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast1
-func.func @test_broadcast1(%arg0: tensor<1xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x1xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
-  %0 = tosa.add %arg0, %arg1 : (tensor<1xf32>, tensor<2x1xf32>) -> tensor<2x1xf32>
-  return %0 : tensor<2x1xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast2
-func.func @test_broadcast2(%arg0: tensor<2x1xf32>, %arg1: tensor<1xf32>) -> tensor<2x1xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<2x1xf32>, tensor<1xf32>) -> tensor<2x1xf32>
-  return %0 : tensor<2x1xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast3
-func.func @test_broadcast3(%arg0: tensor<2x1x1x1xf32>, %arg1: tensor<1xf32>) -> tensor<2x1x1x1xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<2x1x1x1xf32>, tensor<1xf32>) -> tensor<2x1x1x1xf32>
-  return %0 : tensor<2x1x1x1xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast4
-func.func @test_broadcast4(%arg0: tensor<1x1x1x2xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x1x2xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<1x1x1x2xf32>, tensor<1xf32>) -> tensor<1x1x1x2xf32>
-  return %0 : tensor<1x1x1x2xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast5
-func.func @test_broadcast5(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x2x1xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<1x1x2x1xf32>, tensor<1xf32>) -> tensor<1x1x2x1xf32>
-  return %0 : tensor<1x1x2x1xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast6
-func.func @test_broadcast6(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1xf32>) -> tensor<17x16x15x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<17x16x15x14xf32>, tensor<1xf32>) -> tensor<17x16x15x14xf32>
-  return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast7
-func.func @test_broadcast7(%arg0: tensor<17x16x1x14xf32>, %arg1: tensor<1x1xf32>) -> tensor<17x16x1x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<17x16x1x14xf32>, tensor<1x1xf32>) -> tensor<17x16x1x14xf32>
-  return %0 : tensor<17x16x1x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast8
-func.func @test_broadcast8(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1x1xf32>) -> tensor<17x16x15x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<17x16x15x14xf32>, tensor<1x1xf32>) -> tensor<17x16x15x14xf32>
-  return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast9
-func.func @test_broadcast9(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x1xf32>) -> tensor<17x16x15x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 15, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<17x16x15x14xf32>, tensor<15x1xf32>) -> tensor<17x16x15x14xf32>
-  return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast10
-func.func @test_broadcast10(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x14xf32>) -> tensor<17x16x15x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 15, 14>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<17x16x15x14xf32>, tensor<15x14xf32>) -> tensor<17x16x15x14xf32>
-  return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast13
-func.func @test_broadcast13(%arg0: tensor<1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
-  %0 = tosa.add %arg0, %arg1 : (tensor<1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
-  return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast14
-func.func @test_broadcast14(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x1x14xf32>) -> tensor<17x16x1x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
-  %0 = tosa.add %arg0, %arg1 : (tensor<1x1xf32>, tensor<17x16x1x14xf32>) -> tensor<17x16x1x14xf32>
-  return %0 : tensor<17x16x1x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast15
-func.func @test_broadcast15(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
-  %0 = tosa.add %arg0, %arg1 : (tensor<1x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
-  return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast16
-func.func @test_broadcast16(%arg0: tensor<15x1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 15, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
-  %0 = tosa.add %arg0, %arg1 : (tensor<15x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
-  return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast17
-func.func @test_broadcast17(%arg0: tensor<15x14xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 15, 14>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
-  %0 = tosa.add %arg0, %arg1 : (tensor<15x14xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
-  return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast18
-func.func @test_broadcast18(%arg0: tensor<14x1xf32>, %arg1: tensor<1x15xf32>) -> tensor<14x15xf32> {
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %arg1
-  %0 = tosa.add %arg0, %arg1 : (tensor<14x1xf32>, tensor<1x15xf32>) -> tensor<14x15xf32>
-  return %0 : tensor<14x15xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast19
-func.func @test_broadcast19(%arg0: tensor<64x64x1xf32>, %arg1: tensor<1x17xf32>) -> (tensor<64x64x17xf32> ) {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 17>}
-  // CHECK: %[[VAR1:.*]] = tosa.sub %arg0, %[[VAR0]]
-  %0 = tosa.sub %arg0, %arg1 : (tensor<64x64x1xf32>, tensor<1x17xf32>) -> tensor<64x64x17xf32>
-  return %0 : tensor<64x64x17xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast20
-func.func @test_broadcast20(%arg0: tensor<3x3x4x1xf32>, %arg1: tensor<4x5xf32>) -> (tensor<3x3x4x5xf32> ) {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 4, 5>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<3x3x4x1xf32>, tensor<4x5xf32>) -> tensor<3x3x4x5xf32>
-  return %0 : tensor<3x3x4x5xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_mul
-func.func @test_broadcast_mul(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 15, 14>}
-  // CHECK: %[[VAR1:.*]] = tosa.mul %[[VAR0]], %arg1
-  %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8 } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
-  return %0 : tensor<17x16x15x14xi32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_arithmetic_right_shift
-func.func @test_broadcast_arithmetic_right_shift(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 15, 14>}
-  // CHECK: %[[VAR1:.*]] = tosa.arithmetic_right_shift %[[VAR0]], %arg1
-  %0 = tosa.arithmetic_right_shift %arg0, %arg1 { round = true } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
-  return %0 : tensor<17x16x15x14xi32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_scalar
-func.func @test_broadcast_scalar(%arg0: tensor<i32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
-  %0 = tosa.add %arg0, %arg1 : (tensor<i32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
-  return %0 : tensor<17x16x15x14xi32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_both_input
-func.func @test_broadcast_select_both_input(%arg0: tensor<1x16x16xi1>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<1x16x16xf32> {
-  // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1>}
-  // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1>}
-  // CHECK: %[[VAL_2:.*]] = tosa.select %arg0, %[[VAL_0]], %[[VAL_1]]
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<1x16x16xi1>, tensor<f32>, tensor<f32>) -> tensor<1x16x16xf32>
-  return %0 : tensor<1x16x16xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_one_input
-func.func @test_broadcast_select_one_input(%arg0: tensor<17x16x15x14xi1>, %arg1: tensor<17x16x15x14xf32>, %arg2: tensor<f32>) -> tensor<17x16x15x14xf32> {
-  // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAL_1:.*]] = tosa.select %arg0, %arg1, %[[VAL_0]]
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<17x16x15x14xi1>, tensor<17x16x15x14xf32>, tensor<f32>) -> tensor<17x16x15x14xf32>
-  return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_predicate
-func.func @test_broadcast_select_predicate(%arg0: tensor<i1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
-  // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAL_1:.*]] = tosa.select %[[VAL_0]], %arg1, %arg2
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<i1>, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
-  return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_abc
-func.func @test_broadcast_select_abc(%arg0: tensor<i1>, %arg1: tensor<32x8xf32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
-  // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 32, 8>}
-  // CHECK: %[[VAL_2:.*]] = tosa.select %[[VAL_0]], %[[VAL_1]], %arg2
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<i1>, tensor<32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
-  return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_acb
-func.func @test_broadcast_select_acb(%arg0: tensor<i1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<32x8xf32>) -> tensor<1x32x32x8xf32> {
-  // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 32, 8>}
-  // CHECK: %[[VAL_2:.*]] = tosa.select %[[VAL_0]], %arg1, %[[VAL_1]]
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<i1>, tensor<1x32x32x8xf32>, tensor<32x8xf32>) -> tensor<1x32x32x8xf32>
-  return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_bac
-func.func @test_broadcast_select_bac(%arg0: tensor<32x8xi1>, %arg1: tensor<f32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
-  // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 32, 8>}
-  // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAL_2:.*]] = tosa.select %[[VAL_0]], %[[VAL_1]], %arg2
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<32x8xi1>, tensor<f32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
-  return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_bca
-func.func @test_broadcast_select_bca(%arg0: tensor<32x8xi1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<i1>) -> tensor<1x32x32x8xf32> {
-  // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 32, 8>}
-  // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAL_2:.*]] = tosa.select %[[VAL_0]], %arg1, %[[VAL_1]]
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<32x8xi1>, tensor<1x32x32x8xf32>, tensor<i1>) -> tensor<1x32x32x8xf32>
-  return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_cab
-func.func @test_broadcast_select_cab(%arg0: tensor<1x32x32x8xi1>, %arg1: tensor<f32>, %arg2: tensor<32x8xf32>) -> tensor<1x32x32x8xf32> {
-  // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 32, 8>}
-  // CHECK: %[[VAL_2:.*]] = tosa.select %arg0, %[[VAL_0]], %[[VAL_1]]
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<1x32x32x8xi1>, tensor<f32>, tensor<32x8xf32>) -> tensor<1x32x32x8xf32>
-  return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_cba
-func.func @test_broadcast_select_cba(%arg0: tensor<1x32x32x8xi1>, %arg1: tensor<32x8xf32>, %arg2: tensor<i1>) -> tensor<1x32x32x8xf32> {
-  // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 32, 8>}
-  // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAL_2:.*]] = tosa.select %arg0, %[[VAL_0]], %[[VAL_1]]
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<1x32x32x8xi1>, tensor<32x8xf32>, tensor<i1>) -> tensor<1x32x32x8xf32>
-  return %0 : tensor<1x32x32x8xf32>
-}

diff  --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 869bce08a8c729..3ff3121348fcad 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -15,9 +15,9 @@ func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> {
 }
 
 // CHECK-LABEL: func @try_fold_equal_with_unranked_tensor
-func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<i32>) {
+func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) {
   // CHECK: tosa.equal
   // CHECK-NEXT: return
-  %0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi1>
+  %0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1>
   return
 }

diff  --git a/mlir/test/Dialect/Tosa/inlining.mlir b/mlir/test/Dialect/Tosa/inlining.mlir
index d57b5cbcf475c8..e892fdaa277505 100644
--- a/mlir/test/Dialect/Tosa/inlining.mlir
+++ b/mlir/test/Dialect/Tosa/inlining.mlir
@@ -47,7 +47,8 @@ func.func @inlined_while_fn(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tenso
 }
 func.func private @while_body_50(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<10xi32>) {
   %1 = "tosa.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
-  %2 = "tosa.add"(%arg3, %1) : (tensor<10xi32>, tensor<i32>) -> tensor<10xi32>
+  %3 = "tosa.reshape"(%1) {new_shape = array<i64: 1>} : (tensor<i32>) -> tensor<1xi32>
+  %2 = "tosa.add"(%arg3, %3) : (tensor<10xi32>, tensor<1xi32>) -> tensor<10xi32>
   return %1, %arg1, %arg2, %2: tensor<i32>, tensor<i32>, tensor<i32>, tensor<10xi32>
 }
 func.func private @while_cond_40(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<10xi32>) -> tensor<i1> {

diff  --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index cc7fd009f01fa6..deaa8e24423374 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1018,3 +1018,19 @@ func.func @test_const_shape_value() -> !tosa.shape<5> {
   %cst = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<5>
   return %cst : !tosa.shape<5>
 }
+
+// -----
+
+func.func @test_sub_with_unequal_operand_ranks(%arg0: tensor<1x21x3xf32>, %arg1: tensor<1x13x21x3xf32>) -> tensor<1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.sub' op operands don't have matching ranks}}
+  %0 = tosa.sub %arg0, %arg1 : (tensor<1x21x3xf32>, tensor<1x13x21x3xf32>) -> tensor<1x13x21x3xf32>
+  return %0 : tensor<1x13x21x3xf32>
+}
+
+// -----
+
+func.func @test_sub_with_unequal_result_ranks(%arg0: tensor<1x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.sub' op result type has 
diff erent rank than operands}}
+  %0 = tosa.sub %arg0, %arg1 : (tensor<1x21x3xf32>, tensor<13x21x3xf32>) -> tensor<1x13x21x3xf32>
+  return %0 : tensor<1x13x21x3xf32>
+}

diff  --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index f4da66ef561b26..8ab7284019f965 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -11,15 +11,15 @@ func.func @test_return(%arg0 : tensor<4xf32>) -> tensor<*xf32> {
 // -----
 
 // CHECK-LABEL: @test_multiple
-func.func @test_multiple(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>, %arg2 : tensor<f32>) -> tensor<*xf32> {
+func.func @test_multiple(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>, %arg2 : tensor<1xf32>) -> tensor<*xf32> {
   // CHECK: [[ADD:%.+]] = tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
   %0 = tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
 
   // CHECK: [[LOG:%.+]] = tosa.log %0 : (tensor<4xf32>) -> tensor<4xf32>
   %1 = tosa.log %0 : (tensor<*xf32>) -> tensor<*xf32>
 
-  // CHECK: [[SUB:%.+]] = tosa.sub %0, %arg2 : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
-  %2 = tosa.sub %0, %arg2 : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
+  // CHECK: [[SUB:%.+]] = tosa.sub %0, %arg2 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+  %2 = tosa.sub %0, %arg2 : (tensor<*xf32>, tensor<1xf32>) -> tensor<*xf32>
   return %0 : tensor<*xf32>
 }
 
@@ -104,33 +104,33 @@ func.func @test_unary_i32(%arg0 : tensor<4xi32>) -> () {
 // -----
 
 // CHECK-LABEL: @test_binary_scalar_f32
-func.func @test_binary_scalar_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<f32>) -> () {
-  // CHECK: tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
-  %0 = tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
+func.func @test_binary_scalar_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>) -> () {
+  // CHECK: tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+  %0 = tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
 
-  // CHECK: tosa.maximum %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
-  %1 = tosa.maximum %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
+  // CHECK: tosa.maximum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+  %1 = tosa.maximum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
 
-  // CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
-  %2 = tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
+  // CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+  %2 = tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
 
-  // CHECK: tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
-  %3 = tosa.mul %arg0, %arg1 { shift = 0 : i8 } : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
+  // CHECK: tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+  %3 = tosa.mul %arg0, %arg1 { shift = 0 : i8 } : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
 
-  // CHECK: tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
-  %4 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
+  // CHECK: tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+  %4 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
 
-  // CHECK: tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
-  %5 = tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
+  // CHECK: tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+  %5 = tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
 
-  // CHECK: tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<4xi1>
-  %6 = tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<*xi1>
+  // CHECK: tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1>
+  %6 = tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1>
 
-  // CHECK: tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<4xi1>
-  %7 = tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<*xi1>
+  // CHECK: tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1>
+  %7 = tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1>
 
-  // CHECK: tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<4xi1>
-  %8 = tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<*xi1>
+  // CHECK: tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1>
+  %8 = tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1>
 
   return
 }
@@ -172,48 +172,48 @@ func.func @test_binary_broadcast_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32
 // -----
 
 // CHECK-LABEL: @test_binary_i32
-func.func @test_binary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<i32>) -> () {
-  // CHECK: tosa.add %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
-  %0 = tosa.add %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+func.func @test_binary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<1xi32>) -> () {
+  // CHECK: tosa.add %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
+  %0 = tosa.add %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
 
-  // CHECK: tosa.bitwise_and %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
-  %1 = tosa.bitwise_and %arg0, %arg1: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+  // CHECK: tosa.bitwise_and %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
+  %1 = tosa.bitwise_and %arg0, %arg1: (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
 
-  // CHECK: tosa.bitwise_or %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
-  %2 = tosa.bitwise_or %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+  // CHECK: tosa.bitwise_or %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
+  %2 = tosa.bitwise_or %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
 
-  // CHECK: tosa.bitwise_xor %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
-  %3 = tosa.bitwise_xor %arg0, %arg1: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+  // CHECK: tosa.bitwise_xor %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
+  %3 = tosa.bitwise_xor %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
 
-  // CHECK: tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<4xi1>
-  %4 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi1>
+  // CHECK: tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi1>
+  %4 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1>
 
-  // CHECK: tosa.greater %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<4xi1>
-  %5 = tosa.greater %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi1>
+  // CHECK: tosa.greater %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi1>
+  %5 = tosa.greater %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1>
 
-  // CHECK: tosa.greater_equal %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<4xi1>
-  %6 = tosa.greater_equal %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi1>
+  // CHECK: tosa.greater_equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi1>
+  %6 = tosa.greater_equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1>
 
-  // CHECK: tosa.logical_left_shift %arg0, %arg1 {shift = 0 : i32} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
-  %7 = tosa.logical_left_shift %arg0, %arg1 { shift = 0 : i32 }: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+  // CHECK: tosa.logical_left_shift %arg0, %arg1 {shift = 0 : i32} : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
+  %7 = tosa.logical_left_shift %arg0, %arg1 { shift = 0 : i32 }: (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
 
-  // CHECK: tosa.logical_right_shift %arg0, %arg1 {shift = 0 : i32} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
-  %8 = tosa.logical_right_shift %arg0, %arg1 { shift = 0 : i32 }: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+  // CHECK: tosa.logical_right_shift %arg0, %arg1 {shift = 0 : i32} : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
+  %8 = tosa.logical_right_shift %arg0, %arg1 { shift = 0 : i32 }: (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
 
-  // CHECK: tosa.maximum %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
-  %9 = tosa.maximum %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+  // CHECK: tosa.maximum %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
+  %9 = tosa.maximum %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
 
-  // CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
-  %10 = tosa.minimum %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+  // CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
+  %10 = tosa.minimum %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
 
-  // CHECK: tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
-  %11 = tosa.mul %arg0, %arg1 { shift = 0 : i8 }: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+  // CHECK: tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
+  %11 = tosa.mul %arg0, %arg1 { shift = 0 : i8 }: (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
 
-  // CHECK: tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
-  %12 = tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+  // CHECK: tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
+  %12 = tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
 
-  // CHECK:  tosa.sub %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
-  %13 = tosa.sub %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+  // CHECK:  tosa.sub %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
+  %13 = tosa.sub %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
 
   return
 }
@@ -221,15 +221,15 @@ func.func @test_binary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<i32>) -> () {
 // -----
 
 // CHECK-LABEL: @test_binary_i1
-func.func @test_binary_i1(%arg0 : tensor<4xi1>, %arg1 : tensor<i1>) -> () {
-  // CHECK: tosa.logical_and %arg0, %arg1 : (tensor<4xi1>, tensor<i1>) -> tensor<4xi1>
-  %0 = tosa.logical_and %arg0, %arg1 : (tensor<4xi1>, tensor<i1>) -> tensor<*xi1>
+func.func @test_binary_i1(%arg0 : tensor<4xi1>, %arg1 : tensor<1xi1>) -> () {
+  // CHECK: tosa.logical_and %arg0, %arg1 : (tensor<4xi1>, tensor<1xi1>) -> tensor<4xi1>
+  %0 = tosa.logical_and %arg0, %arg1 : (tensor<4xi1>, tensor<1xi1>) -> tensor<*xi1>
 
-  // CHECK: tosa.logical_or %arg0, %arg1 : (tensor<4xi1>, tensor<i1>) -> tensor<4xi1>
-  %1 = tosa.logical_or %arg0, %arg1 : (tensor<4xi1>, tensor<i1>) -> tensor<*xi1>
+  // CHECK: tosa.logical_or %arg0, %arg1 : (tensor<4xi1>, tensor<1xi1>) -> tensor<4xi1>
+  %1 = tosa.logical_or %arg0, %arg1 : (tensor<4xi1>, tensor<1xi1>) -> tensor<*xi1>
 
-  // CHECK: tosa.logical_xor %arg0, %arg1 : (tensor<4xi1>, tensor<i1>) -> tensor<4xi1>
-  %2 = tosa.logical_xor %arg0, %arg1 : (tensor<4xi1>, tensor<i1>) -> tensor<*xi1>
+  // CHECK: tosa.logical_xor %arg0, %arg1 : (tensor<4xi1>, tensor<1xi1>) -> tensor<4xi1>
+  %2 = tosa.logical_xor %arg0, %arg1 : (tensor<4xi1>, tensor<1xi1>) -> tensor<*xi1>
 
   return
 }
@@ -237,9 +237,9 @@ func.func @test_binary_i1(%arg0 : tensor<4xi1>, %arg1 : tensor<i1>) -> () {
 // -----
 
 // CHECK-LABEL: @test_select_i32
-func.func @test_select_i32(%arg0 : tensor<4xi1>, %arg1 : tensor<i32>, %arg2 : tensor<4xi32>) -> () {
-  // CHECK: tosa.select %arg0, %arg1, %arg2 : (tensor<4xi1>, tensor<i32>, tensor<4xi32>) -> tensor<4xi32>
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<4xi1>, tensor<i32>, tensor<4xi32>) -> tensor<*xi32>
+func.func @test_select_i32(%arg0 : tensor<4xi1>, %arg1 : tensor<1xi32>, %arg2 : tensor<4xi32>) -> () {
+  // CHECK: tosa.select %arg0, %arg1, %arg2 : (tensor<4xi1>, tensor<1xi32>, tensor<4xi32>) -> tensor<4xi32>
+  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<4xi1>, tensor<1xi32>, tensor<4xi32>) -> tensor<*xi32>
 
   return
 }


        


More information about the Mlir-commits mailing list