[Mlir-commits] [mlir] [TOSA] Add SameOperandsAndResultRank to TOSA Ops (PR #104501)
Tai Ly
llvmlistbot at llvm.org
Thu Aug 15 13:38:21 PDT 2024
https://github.com/Tai78641 created https://github.com/llvm/llvm-project/pull/104501
This patch adds SameOperandsAndResultRank trait to TOSA operators with ResultsBroadcastableShape trait. SameOperandsAndResultRank trait requiring that all operands and results have matching ranks unless the operand/result is unranked.
This also renders the TosaMakeBroadcastable pass unnecessary - but this pass is left in for now just in case it is still used in some flows. The lit test, broadcast.mlir, is removed.
This also adds verify of the SameOperandsAndResultRank trait in the TosaInferShapes pass to validate inferred shapes.
>From 4e68f23edcf0b4071827a7945c22d42707fc76c6 Mon Sep 17 00:00:00 2001
From: Tai Ly <tai.ly at arm.com>
Date: Thu, 3 Aug 2023 21:57:52 +0000
Subject: [PATCH] [TOSA] Add SameOperandsAndResultRank to TOSA Ops
This patch adds SameOperandsAndResultRank trait to TOSA operators
with ResultsBroadcastableShape trait. SameOperandsAndResultRank trait
requiring that all operands and results have matching ranks unless
the operand/result is unranked.
This also renders the TosaMakeBroadcastable pass unnecessary - but
this pass is left in for now just in case it is still used in some
flows. The lit test, broadcast.mlir, is removed.
This also adds verify of the SameOperandsAndResultRank trait in the
TosaInferShapes pass to validate inferred shapes.
Signed-off-by: Tai Ly <tai.ly at arm.com>
Change-Id: I27bf16b31f15aa92d42ad5376b8791cf74e4f6ac
---
.../mlir/Dialect/Tosa/IR/TosaOpBase.td | 1 +
.../Tosa/Transforms/TosaInferShapes.cpp | 28 ++
.../TosaToLinalg/tosa-to-linalg.mlir | 19 +-
mlir/test/Dialect/Tosa/broadcast.mlir | 285 ------------------
mlir/test/Dialect/Tosa/constant_folding.mlir | 4 +-
mlir/test/Dialect/Tosa/inlining.mlir | 3 +-
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 122 ++++----
7 files changed, 97 insertions(+), 365 deletions(-)
delete mode 100644 mlir/test/Dialect/Tosa/broadcast.mlir
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 1412c7a2615d20..0542319c96f889 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -219,6 +219,7 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
ResultsBroadcastableShape,
+ SameOperandsAndResultRank,
Pure])> {
let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index b1d5720541846f..816c8aebc4a644 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -285,6 +285,32 @@ void propagateShapesInRegion(Region ®ion, TypeModificationState &state) {
}
}
+/// recursively validate tosa ops with SameOperandsAndResultRank trait in region
+/// and all nested regions
+void validateSameOperandsAndResultRankTrait(Region ®ion) {
+ int errs = 0;
+ for (auto &block : region) {
+ for (auto &op : block) {
+ if (!op.getDialect() ||
+ op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
+ continue;
+ if (op.hasTrait<OpTrait::SameOperandsAndResultRank>()) {
+ if (OpTrait::impl::verifySameOperandsAndResultRank(&op).failed()) {
+ errs++;
+ }
+ }
+ WhileOp whileOp = dyn_cast<WhileOp>(op);
+ IfOp ifOp = dyn_cast<IfOp>(op);
+ if (whileOp || ifOp) {
+ // recurse into whileOp's regions
+ for (auto &next : op.getRegions()) {
+ validateSameOperandsAndResultRankTrait(next);
+ }
+ }
+ }
+ }
+}
+
/// Pass that performs shape propagation across TOSA operations. This includes
/// migrating to within the regions of if/while operations.
struct TosaInferShapes
@@ -295,6 +321,8 @@ struct TosaInferShapes
TypeModificationState state;
propagateShapesInRegion(func.getBody(), state);
state.commit();
+
+ validateSameOperandsAndResultRankTrait(func.getBody());
}
};
} // namespace
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 0e35f8ea9d0cd1..23b35f11aa02f7 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -94,6 +94,7 @@ func.func @test_add_0d(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
// CHECK: } -> tensor<f32>
%0 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+
// CHECK: return [[RESULT]] : tensor<f32>
return %0 : tensor<f32>
}
@@ -341,23 +342,9 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
// -----
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (0, d1, d2)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-LABEL: @test_add_2d_different_ranks
-// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
-// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
func.func @test_add_2d_different_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
-
- // CHECK: %[[ARG0_EXPANDED:.*]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [1, 3, 4] : tensor<3x4xf32> into tensor<1x3x4xf32>
- // CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<2x3x4xf32>
- // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG0_EXPANDED]], %[[ARG1]] : tensor<1x3x4xf32>, tensor<2x3x4xf32>) outs(%[[VAL_0]] : tensor<2x3x4xf32>) {
- // CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
- // CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32
- // CHECK: linalg.yield %[[VAL_4]] : f32
- // CHECK: } -> tensor<2x3x4xf32>
- %0 = tosa.add %arg0, %arg1 : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
-
- // CHECK: return %[[RESULT]] : tensor<2x3x4xf32>
+ // expected-error at +1 {{'tosa.add' op operands don't have matching ranks}}
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
return %0 : tensor<2x3x4xf32>
}
diff --git a/mlir/test/Dialect/Tosa/broadcast.mlir b/mlir/test/Dialect/Tosa/broadcast.mlir
deleted file mode 100644
index 7613aa3b8dd03d..00000000000000
--- a/mlir/test/Dialect/Tosa/broadcast.mlir
+++ /dev/null
@@ -1,285 +0,0 @@
-// RUN: mlir-opt --tosa-make-broadcastable %s | FileCheck %s
-
-// -----
-// CHECK-LABEL: broadcast0
-func.func @test_broadcast0(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
- // CHECK-NOT: reshape
- %0 = tosa.add %arg0, %arg1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
- return %0 : tensor<1xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast1
-func.func @test_broadcast1(%arg0: tensor<1xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x1xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
- %0 = tosa.add %arg0, %arg1 : (tensor<1xf32>, tensor<2x1xf32>) -> tensor<2x1xf32>
- return %0 : tensor<2x1xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast2
-func.func @test_broadcast2(%arg0: tensor<2x1xf32>, %arg1: tensor<1xf32>) -> tensor<2x1xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
- %0 = tosa.add %arg0, %arg1 : (tensor<2x1xf32>, tensor<1xf32>) -> tensor<2x1xf32>
- return %0 : tensor<2x1xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast3
-func.func @test_broadcast3(%arg0: tensor<2x1x1x1xf32>, %arg1: tensor<1xf32>) -> tensor<2x1x1x1xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
- %0 = tosa.add %arg0, %arg1 : (tensor<2x1x1x1xf32>, tensor<1xf32>) -> tensor<2x1x1x1xf32>
- return %0 : tensor<2x1x1x1xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast4
-func.func @test_broadcast4(%arg0: tensor<1x1x1x2xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x1x2xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
- %0 = tosa.add %arg0, %arg1 : (tensor<1x1x1x2xf32>, tensor<1xf32>) -> tensor<1x1x1x2xf32>
- return %0 : tensor<1x1x1x2xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast5
-func.func @test_broadcast5(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x2x1xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
- %0 = tosa.add %arg0, %arg1 : (tensor<1x1x2x1xf32>, tensor<1xf32>) -> tensor<1x1x2x1xf32>
- return %0 : tensor<1x1x2x1xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast6
-func.func @test_broadcast6(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1xf32>) -> tensor<17x16x15x14xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
- %0 = tosa.add %arg0, %arg1 : (tensor<17x16x15x14xf32>, tensor<1xf32>) -> tensor<17x16x15x14xf32>
- return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast7
-func.func @test_broadcast7(%arg0: tensor<17x16x1x14xf32>, %arg1: tensor<1x1xf32>) -> tensor<17x16x1x14xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
- %0 = tosa.add %arg0, %arg1 : (tensor<17x16x1x14xf32>, tensor<1x1xf32>) -> tensor<17x16x1x14xf32>
- return %0 : tensor<17x16x1x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast8
-func.func @test_broadcast8(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1x1xf32>) -> tensor<17x16x15x14xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
- %0 = tosa.add %arg0, %arg1 : (tensor<17x16x15x14xf32>, tensor<1x1xf32>) -> tensor<17x16x15x14xf32>
- return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast9
-func.func @test_broadcast9(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x1xf32>) -> tensor<17x16x15x14xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 15, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
- %0 = tosa.add %arg0, %arg1 : (tensor<17x16x15x14xf32>, tensor<15x1xf32>) -> tensor<17x16x15x14xf32>
- return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast10
-func.func @test_broadcast10(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x14xf32>) -> tensor<17x16x15x14xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 15, 14>}
- // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
- %0 = tosa.add %arg0, %arg1 : (tensor<17x16x15x14xf32>, tensor<15x14xf32>) -> tensor<17x16x15x14xf32>
- return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast13
-func.func @test_broadcast13(%arg0: tensor<1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
- %0 = tosa.add %arg0, %arg1 : (tensor<1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
- return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast14
-func.func @test_broadcast14(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x1x14xf32>) -> tensor<17x16x1x14xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
- %0 = tosa.add %arg0, %arg1 : (tensor<1x1xf32>, tensor<17x16x1x14xf32>) -> tensor<17x16x1x14xf32>
- return %0 : tensor<17x16x1x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast15
-func.func @test_broadcast15(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
- %0 = tosa.add %arg0, %arg1 : (tensor<1x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
- return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast16
-func.func @test_broadcast16(%arg0: tensor<15x1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 15, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
- %0 = tosa.add %arg0, %arg1 : (tensor<15x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
- return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast17
-func.func @test_broadcast17(%arg0: tensor<15x14xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 15, 14>}
- // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
- %0 = tosa.add %arg0, %arg1 : (tensor<15x14xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
- return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast18
-func.func @test_broadcast18(%arg0: tensor<14x1xf32>, %arg1: tensor<1x15xf32>) -> tensor<14x15xf32> {
- // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %arg1
- %0 = tosa.add %arg0, %arg1 : (tensor<14x1xf32>, tensor<1x15xf32>) -> tensor<14x15xf32>
- return %0 : tensor<14x15xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast19
-func.func @test_broadcast19(%arg0: tensor<64x64x1xf32>, %arg1: tensor<1x17xf32>) -> (tensor<64x64x17xf32> ) {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 17>}
- // CHECK: %[[VAR1:.*]] = tosa.sub %arg0, %[[VAR0]]
- %0 = tosa.sub %arg0, %arg1 : (tensor<64x64x1xf32>, tensor<1x17xf32>) -> tensor<64x64x17xf32>
- return %0 : tensor<64x64x17xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast20
-func.func @test_broadcast20(%arg0: tensor<3x3x4x1xf32>, %arg1: tensor<4x5xf32>) -> (tensor<3x3x4x5xf32> ) {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 4, 5>}
- // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
- %0 = tosa.add %arg0, %arg1 : (tensor<3x3x4x1xf32>, tensor<4x5xf32>) -> tensor<3x3x4x5xf32>
- return %0 : tensor<3x3x4x5xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_mul
-func.func @test_broadcast_mul(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 15, 14>}
- // CHECK: %[[VAR1:.*]] = tosa.mul %[[VAR0]], %arg1
- %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8 } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
- return %0 : tensor<17x16x15x14xi32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_arithmetic_right_shift
-func.func @test_broadcast_arithmetic_right_shift(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 15, 14>}
- // CHECK: %[[VAR1:.*]] = tosa.arithmetic_right_shift %[[VAR0]], %arg1
- %0 = tosa.arithmetic_right_shift %arg0, %arg1 { round = true } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
- return %0 : tensor<17x16x15x14xi32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_scalar
-func.func @test_broadcast_scalar(%arg0: tensor<i32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
- %0 = tosa.add %arg0, %arg1 : (tensor<i32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
- return %0 : tensor<17x16x15x14xi32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_both_input
-func.func @test_broadcast_select_both_input(%arg0: tensor<1x16x16xi1>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<1x16x16xf32> {
- // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1>}
- // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1>}
- // CHECK: %[[VAL_2:.*]] = tosa.select %arg0, %[[VAL_0]], %[[VAL_1]]
- %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<1x16x16xi1>, tensor<f32>, tensor<f32>) -> tensor<1x16x16xf32>
- return %0 : tensor<1x16x16xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_one_input
-func.func @test_broadcast_select_one_input(%arg0: tensor<17x16x15x14xi1>, %arg1: tensor<17x16x15x14xf32>, %arg2: tensor<f32>) -> tensor<17x16x15x14xf32> {
- // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAL_1:.*]] = tosa.select %arg0, %arg1, %[[VAL_0]]
- %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<17x16x15x14xi1>, tensor<17x16x15x14xf32>, tensor<f32>) -> tensor<17x16x15x14xf32>
- return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_predicate
-func.func @test_broadcast_select_predicate(%arg0: tensor<i1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
- // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAL_1:.*]] = tosa.select %[[VAL_0]], %arg1, %arg2
- %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<i1>, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
- return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_abc
-func.func @test_broadcast_select_abc(%arg0: tensor<i1>, %arg1: tensor<32x8xf32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
- // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 32, 8>}
- // CHECK: %[[VAL_2:.*]] = tosa.select %[[VAL_0]], %[[VAL_1]], %arg2
- %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<i1>, tensor<32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
- return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_acb
-func.func @test_broadcast_select_acb(%arg0: tensor<i1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<32x8xf32>) -> tensor<1x32x32x8xf32> {
- // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 32, 8>}
- // CHECK: %[[VAL_2:.*]] = tosa.select %[[VAL_0]], %arg1, %[[VAL_1]]
- %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<i1>, tensor<1x32x32x8xf32>, tensor<32x8xf32>) -> tensor<1x32x32x8xf32>
- return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_bac
-func.func @test_broadcast_select_bac(%arg0: tensor<32x8xi1>, %arg1: tensor<f32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
- // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 32, 8>}
- // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAL_2:.*]] = tosa.select %[[VAL_0]], %[[VAL_1]], %arg2
- %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<32x8xi1>, tensor<f32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
- return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_bca
-func.func @test_broadcast_select_bca(%arg0: tensor<32x8xi1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<i1>) -> tensor<1x32x32x8xf32> {
- // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 32, 8>}
- // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAL_2:.*]] = tosa.select %[[VAL_0]], %arg1, %[[VAL_1]]
- %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<32x8xi1>, tensor<1x32x32x8xf32>, tensor<i1>) -> tensor<1x32x32x8xf32>
- return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_cab
-func.func @test_broadcast_select_cab(%arg0: tensor<1x32x32x8xi1>, %arg1: tensor<f32>, %arg2: tensor<32x8xf32>) -> tensor<1x32x32x8xf32> {
- // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 32, 8>}
- // CHECK: %[[VAL_2:.*]] = tosa.select %arg0, %[[VAL_0]], %[[VAL_1]]
- %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<1x32x32x8xi1>, tensor<f32>, tensor<32x8xf32>) -> tensor<1x32x32x8xf32>
- return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_cba
-func.func @test_broadcast_select_cba(%arg0: tensor<1x32x32x8xi1>, %arg1: tensor<32x8xf32>, %arg2: tensor<i1>) -> tensor<1x32x32x8xf32> {
- // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 32, 8>}
- // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAL_2:.*]] = tosa.select %arg0, %[[VAL_0]], %[[VAL_1]]
- %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<1x32x32x8xi1>, tensor<32x8xf32>, tensor<i1>) -> tensor<1x32x32x8xf32>
- return %0 : tensor<1x32x32x8xf32>
-}
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 869bce08a8c729..3ff3121348fcad 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -15,9 +15,9 @@ func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> {
}
// CHECK-LABEL: func @try_fold_equal_with_unranked_tensor
-func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<i32>) {
+func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) {
// CHECK: tosa.equal
// CHECK-NEXT: return
- %0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi1>
+ %0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1>
return
}
diff --git a/mlir/test/Dialect/Tosa/inlining.mlir b/mlir/test/Dialect/Tosa/inlining.mlir
index d57b5cbcf475c8..e892fdaa277505 100644
--- a/mlir/test/Dialect/Tosa/inlining.mlir
+++ b/mlir/test/Dialect/Tosa/inlining.mlir
@@ -47,7 +47,8 @@ func.func @inlined_while_fn(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tenso
}
func.func private @while_body_50(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<10xi32>) {
%1 = "tosa.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
- %2 = "tosa.add"(%arg3, %1) : (tensor<10xi32>, tensor<i32>) -> tensor<10xi32>
+ %3 = "tosa.reshape"(%1) {new_shape = array<i64: 1>} : (tensor<i32>) -> tensor<1xi32>
+ %2 = "tosa.add"(%arg3, %3) : (tensor<10xi32>, tensor<1xi32>) -> tensor<10xi32>
return %1, %arg1, %arg2, %2: tensor<i32>, tensor<i32>, tensor<i32>, tensor<10xi32>
}
func.func private @while_cond_40(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<10xi32>) -> tensor<i1> {
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 3224f88968f3d2..3dfe2d6799021e 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -11,15 +11,15 @@ func.func @test_return(%arg0 : tensor<4xf32>) -> tensor<*xf32> {
// -----
// CHECK-LABEL: @test_multiple
-func.func @test_multiple(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>, %arg2 : tensor<f32>) -> tensor<*xf32> {
+func.func @test_multiple(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>, %arg2 : tensor<1xf32>) -> tensor<*xf32> {
// CHECK: [[ADD:%.+]] = tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
%0 = tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
// CHECK: [[LOG:%.+]] = tosa.log %0 : (tensor<4xf32>) -> tensor<4xf32>
%1 = tosa.log %0 : (tensor<*xf32>) -> tensor<*xf32>
- // CHECK: [[SUB:%.+]] = tosa.sub %0, %arg2 : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
- %2 = tosa.sub %0, %arg2 : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
+ // CHECK: [[SUB:%.+]] = tosa.sub %0, %arg2 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+ %2 = tosa.sub %0, %arg2 : (tensor<*xf32>, tensor<1xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
@@ -104,33 +104,33 @@ func.func @test_unary_i32(%arg0 : tensor<4xi32>) -> () {
// -----
// CHECK-LABEL: @test_binary_scalar_f32
-func.func @test_binary_scalar_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<f32>) -> () {
- // CHECK: tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
- %0 = tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
+func.func @test_binary_scalar_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>) -> () {
+ // CHECK: tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+ %0 = tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
- // CHECK: tosa.maximum %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
- %1 = tosa.maximum %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
+ // CHECK: tosa.maximum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+ %1 = tosa.maximum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
- // CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
- %2 = tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
+ // CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+ %2 = tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
- // CHECK: tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
- %3 = tosa.mul %arg0, %arg1 { shift = 0 : i8 } : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
+ // CHECK: tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+ %3 = tosa.mul %arg0, %arg1 { shift = 0 : i8 } : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
- // CHECK: tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
- %4 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
+ // CHECK: tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+ %4 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
- // CHECK: tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
- %5 = tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
+ // CHECK: tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+ %5 = tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
- // CHECK: tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<4xi1>
- %6 = tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<*xi1>
+ // CHECK: tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1>
+ %6 = tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1>
- // CHECK: tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<4xi1>
- %7 = tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<*xi1>
+ // CHECK: tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1>
+ %7 = tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1>
- // CHECK: tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<4xi1>
- %8 = tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<*xi1>
+ // CHECK: tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1>
+ %8 = tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1>
return
}
@@ -172,48 +172,48 @@ func.func @test_binary_broadcast_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32
// -----
// CHECK-LABEL: @test_binary_i32
-func.func @test_binary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<i32>) -> () {
- // CHECK: tosa.add %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
- %0 = tosa.add %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+func.func @test_binary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<1xi32>) -> () {
+ // CHECK: tosa.add %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
+ %0 = tosa.add %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
- // CHECK: tosa.bitwise_and %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
- %1 = tosa.bitwise_and %arg0, %arg1: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+ // CHECK: tosa.bitwise_and %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
+ %1 = tosa.bitwise_and %arg0, %arg1: (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
- // CHECK: tosa.bitwise_or %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
- %2 = tosa.bitwise_or %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+ // CHECK: tosa.bitwise_or %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
+ %2 = tosa.bitwise_or %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
- // CHECK: tosa.bitwise_xor %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
- %3 = tosa.bitwise_xor %arg0, %arg1: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+ // CHECK: tosa.bitwise_xor %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
+ %3 = tosa.bitwise_xor %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
- // CHECK: tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<4xi1>
- %4 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi1>
+ // CHECK: tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi1>
+ %4 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1>
- // CHECK: tosa.greater %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<4xi1>
- %5 = tosa.greater %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi1>
+ // CHECK: tosa.greater %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi1>
+ %5 = tosa.greater %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1>
- // CHECK: tosa.greater_equal %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<4xi1>
- %6 = tosa.greater_equal %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi1>
+ // CHECK: tosa.greater_equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi1>
+ %6 = tosa.greater_equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1>
- // CHECK: tosa.logical_left_shift %arg0, %arg1 {shift = 0 : i32} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
- %7 = tosa.logical_left_shift %arg0, %arg1 { shift = 0 : i32 }: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+ // CHECK: tosa.logical_left_shift %arg0, %arg1 {shift = 0 : i32} : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
+ %7 = tosa.logical_left_shift %arg0, %arg1 { shift = 0 : i32 }: (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
- // CHECK: tosa.logical_right_shift %arg0, %arg1 {shift = 0 : i32} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
- %8 = tosa.logical_right_shift %arg0, %arg1 { shift = 0 : i32 }: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+ // CHECK: tosa.logical_right_shift %arg0, %arg1 {shift = 0 : i32} : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
+ %8 = tosa.logical_right_shift %arg0, %arg1 { shift = 0 : i32 }: (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
- // CHECK: tosa.maximum %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
- %9 = tosa.maximum %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+ // CHECK: tosa.maximum %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
+ %9 = tosa.maximum %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
- // CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
- %10 = tosa.minimum %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+ // CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
+ %10 = tosa.minimum %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
- // CHECK: tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
- %11 = tosa.mul %arg0, %arg1 { shift = 0 : i8 }: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+ // CHECK: tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
+ %11 = tosa.mul %arg0, %arg1 { shift = 0 : i8 }: (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
- // CHECK: tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
- %12 = tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+ // CHECK: tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
+ %12 = tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
- // CHECK: tosa.sub %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
- %13 = tosa.sub %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+ // CHECK: tosa.sub %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
+ %13 = tosa.sub %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
return
}
@@ -221,15 +221,15 @@ func.func @test_binary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<i32>) -> () {
// -----
// CHECK-LABEL: @test_binary_i1
-func.func @test_binary_i1(%arg0 : tensor<4xi1>, %arg1 : tensor<i1>) -> () {
- // CHECK: tosa.logical_and %arg0, %arg1 : (tensor<4xi1>, tensor<i1>) -> tensor<4xi1>
- %0 = tosa.logical_and %arg0, %arg1 : (tensor<4xi1>, tensor<i1>) -> tensor<*xi1>
+func.func @test_binary_i1(%arg0 : tensor<4xi1>, %arg1 : tensor<1xi1>) -> () {
+ // CHECK: tosa.logical_and %arg0, %arg1 : (tensor<4xi1>, tensor<1xi1>) -> tensor<4xi1>
+ %0 = tosa.logical_and %arg0, %arg1 : (tensor<4xi1>, tensor<1xi1>) -> tensor<*xi1>
- // CHECK: tosa.logical_or %arg0, %arg1 : (tensor<4xi1>, tensor<i1>) -> tensor<4xi1>
- %1 = tosa.logical_or %arg0, %arg1 : (tensor<4xi1>, tensor<i1>) -> tensor<*xi1>
+ // CHECK: tosa.logical_or %arg0, %arg1 : (tensor<4xi1>, tensor<1xi1>) -> tensor<4xi1>
+ %1 = tosa.logical_or %arg0, %arg1 : (tensor<4xi1>, tensor<1xi1>) -> tensor<*xi1>
- // CHECK: tosa.logical_xor %arg0, %arg1 : (tensor<4xi1>, tensor<i1>) -> tensor<4xi1>
- %2 = tosa.logical_xor %arg0, %arg1 : (tensor<4xi1>, tensor<i1>) -> tensor<*xi1>
+ // CHECK: tosa.logical_xor %arg0, %arg1 : (tensor<4xi1>, tensor<1xi1>) -> tensor<4xi1>
+ %2 = tosa.logical_xor %arg0, %arg1 : (tensor<4xi1>, tensor<1xi1>) -> tensor<*xi1>
return
}
@@ -237,9 +237,9 @@ func.func @test_binary_i1(%arg0 : tensor<4xi1>, %arg1 : tensor<i1>) -> () {
// -----
// CHECK-LABEL: @test_select_i32
-func.func @test_select_i32(%arg0 : tensor<4xi1>, %arg1 : tensor<i32>, %arg2 : tensor<4xi32>) -> () {
- // CHECK: tosa.select %arg0, %arg1, %arg2 : (tensor<4xi1>, tensor<i32>, tensor<4xi32>) -> tensor<4xi32>
- %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<4xi1>, tensor<i32>, tensor<4xi32>) -> tensor<*xi32>
+func.func @test_select_i32(%arg0 : tensor<4xi1>, %arg1 : tensor<1xi32>, %arg2 : tensor<4xi32>) -> () {
+ // CHECK: tosa.select %arg0, %arg1, %arg2 : (tensor<4xi1>, tensor<1xi32>, tensor<4xi32>) -> tensor<4xi32>
+ %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<4xi1>, tensor<1xi32>, tensor<4xi32>) -> tensor<*xi32>
return
}
More information about the Mlir-commits
mailing list