[Mlir-commits] [mlir] [TOSA] Add SameOperandsAndResultRank to TOSA Ops (PR #104501)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Aug 15 13:38:54 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Tai Ly (Tai78641)

<details>
<summary>Changes</summary>

This patch adds SameOperandsAndResultRank trait to TOSA operators with ResultsBroadcastableShape trait. SameOperandsAndResultRank trait requiring that all operands and results have matching ranks unless the operand/result is unranked.

This also renders the TosaMakeBroadcastable pass unnecessary - but this pass is left in for now just in case it is still used in some flows. The lit test, broadcast.mlir, is removed.

This also adds verify of the SameOperandsAndResultRank trait in the TosaInferShapes pass to validate inferred shapes.


---

Patch is 32.72 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/104501.diff


7 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+1) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp (+28) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+3-16) 
- (removed) mlir/test/Dialect/Tosa/broadcast.mlir (-285) 
- (modified) mlir/test/Dialect/Tosa/constant_folding.mlir (+2-2) 
- (modified) mlir/test/Dialect/Tosa/inlining.mlir (+2-1) 
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+61-61) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 1412c7a2615d20..0542319c96f889 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -219,6 +219,7 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
               DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
                                         ["inferReturnTypeComponents"]>,
               ResultsBroadcastableShape,
+              SameOperandsAndResultRank,
               Pure])> {
   let assemblyFormat =
       "operands attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index b1d5720541846f..816c8aebc4a644 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -285,6 +285,32 @@ void propagateShapesInRegion(Region &region, TypeModificationState &state) {
   }
 }
 
+/// recursively validate tosa ops with SameOperandsAndResultRank trait in region
+/// and all nested regions
+void validateSameOperandsAndResultRankTrait(Region &region) {
+  int errs = 0;
+  for (auto &block : region) {
+    for (auto &op : block) {
+      if (!op.getDialect() ||
+          op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
+        continue;
+      if (op.hasTrait<OpTrait::SameOperandsAndResultRank>()) {
+        if (OpTrait::impl::verifySameOperandsAndResultRank(&op).failed()) {
+          errs++;
+        }
+      }
+      WhileOp whileOp = dyn_cast<WhileOp>(op);
+      IfOp ifOp = dyn_cast<IfOp>(op);
+      if (whileOp || ifOp) {
+        // recurse into whileOp's regions
+        for (auto &next : op.getRegions()) {
+          validateSameOperandsAndResultRankTrait(next);
+        }
+      }
+    }
+  }
+}
+
 /// Pass that performs shape propagation across TOSA operations. This includes
 /// migrating to within the regions of if/while operations.
 struct TosaInferShapes
@@ -295,6 +321,8 @@ struct TosaInferShapes
     TypeModificationState state;
     propagateShapesInRegion(func.getBody(), state);
     state.commit();
+
+    validateSameOperandsAndResultRankTrait(func.getBody());
   }
 };
 } // namespace
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 0e35f8ea9d0cd1..23b35f11aa02f7 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -94,6 +94,7 @@ func.func @test_add_0d(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
   // CHECK: } -> tensor<f32>
   %0 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
 
+
   // CHECK: return [[RESULT]] : tensor<f32>
   return %0 : tensor<f32>
 }
@@ -341,23 +342,9 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
 
 // -----
 
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (0, d1, d2)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-LABEL: @test_add_2d_different_ranks
-// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
-// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
 func.func @test_add_2d_different_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
-
-  // CHECK: %[[ARG0_EXPANDED:.*]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [1, 3, 4] : tensor<3x4xf32> into tensor<1x3x4xf32>
-  // CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<2x3x4xf32>
-  // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG0_EXPANDED]], %[[ARG1]] : tensor<1x3x4xf32>, tensor<2x3x4xf32>) outs(%[[VAL_0]] : tensor<2x3x4xf32>) {
-  // CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
-  // CHECK:   %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32
-  // CHECK:   linalg.yield %[[VAL_4]] : f32
-  // CHECK: } -> tensor<2x3x4xf32>
-  %0 = tosa.add %arg0, %arg1 : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
-
-  // CHECK: return %[[RESULT]] : tensor<2x3x4xf32>
+  // expected-error at +1 {{'tosa.add' op operands don't have matching ranks}}
+  %0 = "tosa.add"(%arg0, %arg1) : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
   return %0 : tensor<2x3x4xf32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/broadcast.mlir b/mlir/test/Dialect/Tosa/broadcast.mlir
deleted file mode 100644
index 7613aa3b8dd03d..00000000000000
--- a/mlir/test/Dialect/Tosa/broadcast.mlir
+++ /dev/null
@@ -1,285 +0,0 @@
-// RUN: mlir-opt --tosa-make-broadcastable %s | FileCheck %s
-
-// -----
-// CHECK-LABEL: broadcast0
-func.func @test_broadcast0(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
-  //  CHECK-NOT: reshape
-  %0 = tosa.add %arg0, %arg1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
-  return %0 : tensor<1xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast1
-func.func @test_broadcast1(%arg0: tensor<1xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x1xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
-  %0 = tosa.add %arg0, %arg1 : (tensor<1xf32>, tensor<2x1xf32>) -> tensor<2x1xf32>
-  return %0 : tensor<2x1xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast2
-func.func @test_broadcast2(%arg0: tensor<2x1xf32>, %arg1: tensor<1xf32>) -> tensor<2x1xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<2x1xf32>, tensor<1xf32>) -> tensor<2x1xf32>
-  return %0 : tensor<2x1xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast3
-func.func @test_broadcast3(%arg0: tensor<2x1x1x1xf32>, %arg1: tensor<1xf32>) -> tensor<2x1x1x1xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<2x1x1x1xf32>, tensor<1xf32>) -> tensor<2x1x1x1xf32>
-  return %0 : tensor<2x1x1x1xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast4
-func.func @test_broadcast4(%arg0: tensor<1x1x1x2xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x1x2xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<1x1x1x2xf32>, tensor<1xf32>) -> tensor<1x1x1x2xf32>
-  return %0 : tensor<1x1x1x2xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast5
-func.func @test_broadcast5(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x2x1xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<1x1x2x1xf32>, tensor<1xf32>) -> tensor<1x1x2x1xf32>
-  return %0 : tensor<1x1x2x1xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast6
-func.func @test_broadcast6(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1xf32>) -> tensor<17x16x15x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<17x16x15x14xf32>, tensor<1xf32>) -> tensor<17x16x15x14xf32>
-  return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast7
-func.func @test_broadcast7(%arg0: tensor<17x16x1x14xf32>, %arg1: tensor<1x1xf32>) -> tensor<17x16x1x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<17x16x1x14xf32>, tensor<1x1xf32>) -> tensor<17x16x1x14xf32>
-  return %0 : tensor<17x16x1x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast8
-func.func @test_broadcast8(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1x1xf32>) -> tensor<17x16x15x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<17x16x15x14xf32>, tensor<1x1xf32>) -> tensor<17x16x15x14xf32>
-  return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast9
-func.func @test_broadcast9(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x1xf32>) -> tensor<17x16x15x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 15, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<17x16x15x14xf32>, tensor<15x1xf32>) -> tensor<17x16x15x14xf32>
-  return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast10
-func.func @test_broadcast10(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x14xf32>) -> tensor<17x16x15x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 15, 14>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<17x16x15x14xf32>, tensor<15x14xf32>) -> tensor<17x16x15x14xf32>
-  return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast13
-func.func @test_broadcast13(%arg0: tensor<1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
-  %0 = tosa.add %arg0, %arg1 : (tensor<1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
-  return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast14
-func.func @test_broadcast14(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x1x14xf32>) -> tensor<17x16x1x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
-  %0 = tosa.add %arg0, %arg1 : (tensor<1x1xf32>, tensor<17x16x1x14xf32>) -> tensor<17x16x1x14xf32>
-  return %0 : tensor<17x16x1x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast15
-func.func @test_broadcast15(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
-  %0 = tosa.add %arg0, %arg1 : (tensor<1x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
-  return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast16
-func.func @test_broadcast16(%arg0: tensor<15x1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 15, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
-  %0 = tosa.add %arg0, %arg1 : (tensor<15x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
-  return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast17
-func.func @test_broadcast17(%arg0: tensor<15x14xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 15, 14>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
-  %0 = tosa.add %arg0, %arg1 : (tensor<15x14xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
-  return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast18
-func.func @test_broadcast18(%arg0: tensor<14x1xf32>, %arg1: tensor<1x15xf32>) -> tensor<14x15xf32> {
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %arg1
-  %0 = tosa.add %arg0, %arg1 : (tensor<14x1xf32>, tensor<1x15xf32>) -> tensor<14x15xf32>
-  return %0 : tensor<14x15xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast19
-func.func @test_broadcast19(%arg0: tensor<64x64x1xf32>, %arg1: tensor<1x17xf32>) -> (tensor<64x64x17xf32> ) {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 17>}
-  // CHECK: %[[VAR1:.*]] = tosa.sub %arg0, %[[VAR0]]
-  %0 = tosa.sub %arg0, %arg1 : (tensor<64x64x1xf32>, tensor<1x17xf32>) -> tensor<64x64x17xf32>
-  return %0 : tensor<64x64x17xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast20
-func.func @test_broadcast20(%arg0: tensor<3x3x4x1xf32>, %arg1: tensor<4x5xf32>) -> (tensor<3x3x4x5xf32> ) {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 4, 5>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<3x3x4x1xf32>, tensor<4x5xf32>) -> tensor<3x3x4x5xf32>
-  return %0 : tensor<3x3x4x5xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_mul
-func.func @test_broadcast_mul(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 15, 14>}
-  // CHECK: %[[VAR1:.*]] = tosa.mul %[[VAR0]], %arg1
-  %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8 } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
-  return %0 : tensor<17x16x15x14xi32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_arithmetic_right_shift
-func.func @test_broadcast_arithmetic_right_shift(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 15, 14>}
-  // CHECK: %[[VAR1:.*]] = tosa.arithmetic_right_shift %[[VAR0]], %arg1
-  %0 = tosa.arithmetic_right_shift %arg0, %arg1 { round = true } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
-  return %0 : tensor<17x16x15x14xi32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_scalar
-func.func @test_broadcast_scalar(%arg0: tensor<i32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
-  %0 = tosa.add %arg0, %arg1 : (tensor<i32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
-  return %0 : tensor<17x16x15x14xi32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_both_input
-func.func @test_broadcast_select_both_input(%arg0: tensor<1x16x16xi1>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<1x16x16xf32> {
-  // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1>}
-  // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1>}
-  // CHECK: %[[VAL_2:.*]] = tosa.select %arg0, %[[VAL_0]], %[[VAL_1]]
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<1x16x16xi1>, tensor<f32>, tensor<f32>) -> tensor<1x16x16xf32>
-  return %0 : tensor<1x16x16xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_one_input
-func.func @test_broadcast_select_one_input(%arg0: tensor<17x16x15x14xi1>, %arg1: tensor<17x16x15x14xf32>, %arg2: tensor<f32>) -> tensor<17x16x15x14xf32> {
-  // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAL_1:.*]] = tosa.select %arg0, %arg1, %[[VAL_0]]
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<17x16x15x14xi1>, tensor<17x16x15x14xf32>, tensor<f32>) -> tensor<17x16x15x14xf32>
-  return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_predicate
-func.func @test_broadcast_select_predicate(%arg0: tensor<i1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
-  // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAL_1:.*]] = tosa.select %[[VAL_0]], %arg1, %arg2
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<i1>, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
-  return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_abc
-func.func @test_broadcast_select_abc(%arg0: tensor<i1>, %arg1: tensor<32x8xf32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
-  // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 32, 8>}
-  // CHECK: %[[VAL_2:.*]] = tosa.select %[[VAL_0]], %[[VAL_1]], %arg2
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<i1>, tensor<32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
-  return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_acb
-func.func @test_broadcast_select_acb(%arg0: tensor<i1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<32x8xf32>) -> tensor<1x32x32x8xf32> {
-  // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 32, 8>}
-  // CHECK: %[[VAL_2:.*]] = tosa.select %[[VAL_0]], %arg1, %[[VAL_1]]
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<i1>, tensor<1x32x32x8xf32>, tensor<32x8xf32>) -> tensor<1x32x32x8xf32>
-  return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_bac
-func.func @test_broadcast_select_bac(%arg0: tensor<32x8xi1>, %arg1: tensor<f32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
-  // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 32, 8>}
-  // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAL_2:.*]] = tosa.select %[[VAL_0]], %[[VAL_1]], %arg2
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<32x8xi1>, tensor<f32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
-  return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_bca
-func.func @test_broadcast_select_bca(%arg0: tensor<32x8xi1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<i1>) -> tensor<1x32x32x8xf32> {
-  // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 32, 8>}
-  // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAL_2:.*]] = tosa.select %[[VAL_0]], %arg1, %[[VAL_1]]
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<32x8xi1>, tensor<1x32x32x8xf32>, tensor<i1>) -> tensor<1x32x32x8xf32>
-  return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_cab
-func.func @test_broadcast_select_cab(%arg0: tensor<1x32x32x8xi1>, %arg1: tensor<f32>, %arg2: tensor<32x8xf32>) -> tensor<1x32x32x8xf32> {
-  // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 32, 8>}
-  // CHECK: %[[VAL_2:.*]] = tosa.select %arg0, %[[VAL_0]], %[[VAL_1]]
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<1x32x32x8xi1>, tensor<f32>, tensor<32x8xf32>) -> tensor<1x32x32x8xf32>
-  return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_cba
-func.func @test_broadcast_select_cba(%arg0: tensor<1x32x32x8xi1>, %arg1: tensor<32x8xf32>, %arg2: tensor<i1>) -> tensor<1x32x32x8xf32> {
-  // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 32, 8>}
-  // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAL_2:.*]] = tosa.select %arg0, %[[VAL_0]], %[[VAL_1]]
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<1x32x32x8xi1>, tensor<32x8xf32>, tensor<i1>) -> tensor<1x32x32x8xf32>
-  return %0 : tensor<1x32x32x8xf32>
-}
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 869bce08a8c729..3ff3121348fcad 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -15,9 +15,9 @@ func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> {
 }
 
 // CHECK-LABEL: func @try_fold_equal_with_unranked_tensor
-func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<i32>) {
+func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) {
   // CHECK: tosa.equal
   // CHECK-NEXT: return
-  %0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi1>
+  %0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1>
   return
 }
diff --git a/mlir/test/Dialect/Tosa/inlining.mlir b/mlir/test/Dialect/Tosa/inlining.mlir
index d57b5cbcf475c8..e892fdaa277505 100644
--- a/mlir/test/Dialect/Tosa/inlining.mlir
+++ b/mlir/test/Dialect/Tosa/inlining.mlir
@@ -47,7 +47,8 @@ func.func @inlined_while_fn(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tenso
 }
 func.func pri...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/104501


More information about the Mlir-commits mailing list