[Mlir-commits] [mlir] [TOSA] Add SameOperandsAndResultRank to TOSA Ops (PR #104501)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Aug 15 13:38:54 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir
Author: Tai Ly (Tai78641)
<details>
<summary>Changes</summary>
This patch adds SameOperandsAndResultRank trait to TOSA operators with ResultsBroadcastableShape trait. SameOperandsAndResultRank trait requiring that all operands and results have matching ranks unless the operand/result is unranked.
This also renders the TosaMakeBroadcastable pass unnecessary - but this pass is left in for now just in case it is still used in some flows. The lit test, broadcast.mlir, is removed.
This also adds verify of the SameOperandsAndResultRank trait in the TosaInferShapes pass to validate inferred shapes.
---
Patch is 32.72 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/104501.diff
7 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+1)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp (+28)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+3-16)
- (removed) mlir/test/Dialect/Tosa/broadcast.mlir (-285)
- (modified) mlir/test/Dialect/Tosa/constant_folding.mlir (+2-2)
- (modified) mlir/test/Dialect/Tosa/inlining.mlir (+2-1)
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+61-61)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 1412c7a2615d20..0542319c96f889 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -219,6 +219,7 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
ResultsBroadcastableShape,
+ SameOperandsAndResultRank,
Pure])> {
let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index b1d5720541846f..816c8aebc4a644 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -285,6 +285,32 @@ void propagateShapesInRegion(Region ®ion, TypeModificationState &state) {
}
}
+/// recursively validate tosa ops with SameOperandsAndResultRank trait in region
+/// and all nested regions
+void validateSameOperandsAndResultRankTrait(Region ®ion) {
+ int errs = 0;
+ for (auto &block : region) {
+ for (auto &op : block) {
+ if (!op.getDialect() ||
+ op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
+ continue;
+ if (op.hasTrait<OpTrait::SameOperandsAndResultRank>()) {
+ if (OpTrait::impl::verifySameOperandsAndResultRank(&op).failed()) {
+ errs++;
+ }
+ }
+ WhileOp whileOp = dyn_cast<WhileOp>(op);
+ IfOp ifOp = dyn_cast<IfOp>(op);
+ if (whileOp || ifOp) {
+ // recurse into whileOp's regions
+ for (auto &next : op.getRegions()) {
+ validateSameOperandsAndResultRankTrait(next);
+ }
+ }
+ }
+ }
+}
+
/// Pass that performs shape propagation across TOSA operations. This includes
/// migrating to within the regions of if/while operations.
struct TosaInferShapes
@@ -295,6 +321,8 @@ struct TosaInferShapes
TypeModificationState state;
propagateShapesInRegion(func.getBody(), state);
state.commit();
+
+ validateSameOperandsAndResultRankTrait(func.getBody());
}
};
} // namespace
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 0e35f8ea9d0cd1..23b35f11aa02f7 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -94,6 +94,7 @@ func.func @test_add_0d(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
// CHECK: } -> tensor<f32>
%0 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+
// CHECK: return [[RESULT]] : tensor<f32>
return %0 : tensor<f32>
}
@@ -341,23 +342,9 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
// -----
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (0, d1, d2)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-LABEL: @test_add_2d_different_ranks
-// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
-// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
func.func @test_add_2d_different_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
-
- // CHECK: %[[ARG0_EXPANDED:.*]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [1, 3, 4] : tensor<3x4xf32> into tensor<1x3x4xf32>
- // CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<2x3x4xf32>
- // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG0_EXPANDED]], %[[ARG1]] : tensor<1x3x4xf32>, tensor<2x3x4xf32>) outs(%[[VAL_0]] : tensor<2x3x4xf32>) {
- // CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
- // CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32
- // CHECK: linalg.yield %[[VAL_4]] : f32
- // CHECK: } -> tensor<2x3x4xf32>
- %0 = tosa.add %arg0, %arg1 : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
-
- // CHECK: return %[[RESULT]] : tensor<2x3x4xf32>
+ // expected-error at +1 {{'tosa.add' op operands don't have matching ranks}}
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
return %0 : tensor<2x3x4xf32>
}
diff --git a/mlir/test/Dialect/Tosa/broadcast.mlir b/mlir/test/Dialect/Tosa/broadcast.mlir
deleted file mode 100644
index 7613aa3b8dd03d..00000000000000
--- a/mlir/test/Dialect/Tosa/broadcast.mlir
+++ /dev/null
@@ -1,285 +0,0 @@
-// RUN: mlir-opt --tosa-make-broadcastable %s | FileCheck %s
-
-// -----
-// CHECK-LABEL: broadcast0
-func.func @test_broadcast0(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
- // CHECK-NOT: reshape
- %0 = tosa.add %arg0, %arg1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
- return %0 : tensor<1xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast1
-func.func @test_broadcast1(%arg0: tensor<1xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x1xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
- %0 = tosa.add %arg0, %arg1 : (tensor<1xf32>, tensor<2x1xf32>) -> tensor<2x1xf32>
- return %0 : tensor<2x1xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast2
-func.func @test_broadcast2(%arg0: tensor<2x1xf32>, %arg1: tensor<1xf32>) -> tensor<2x1xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
- %0 = tosa.add %arg0, %arg1 : (tensor<2x1xf32>, tensor<1xf32>) -> tensor<2x1xf32>
- return %0 : tensor<2x1xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast3
-func.func @test_broadcast3(%arg0: tensor<2x1x1x1xf32>, %arg1: tensor<1xf32>) -> tensor<2x1x1x1xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
- %0 = tosa.add %arg0, %arg1 : (tensor<2x1x1x1xf32>, tensor<1xf32>) -> tensor<2x1x1x1xf32>
- return %0 : tensor<2x1x1x1xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast4
-func.func @test_broadcast4(%arg0: tensor<1x1x1x2xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x1x2xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
- %0 = tosa.add %arg0, %arg1 : (tensor<1x1x1x2xf32>, tensor<1xf32>) -> tensor<1x1x1x2xf32>
- return %0 : tensor<1x1x1x2xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast5
-func.func @test_broadcast5(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x2x1xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
- %0 = tosa.add %arg0, %arg1 : (tensor<1x1x2x1xf32>, tensor<1xf32>) -> tensor<1x1x2x1xf32>
- return %0 : tensor<1x1x2x1xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast6
-func.func @test_broadcast6(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1xf32>) -> tensor<17x16x15x14xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
- %0 = tosa.add %arg0, %arg1 : (tensor<17x16x15x14xf32>, tensor<1xf32>) -> tensor<17x16x15x14xf32>
- return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast7
-func.func @test_broadcast7(%arg0: tensor<17x16x1x14xf32>, %arg1: tensor<1x1xf32>) -> tensor<17x16x1x14xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
- %0 = tosa.add %arg0, %arg1 : (tensor<17x16x1x14xf32>, tensor<1x1xf32>) -> tensor<17x16x1x14xf32>
- return %0 : tensor<17x16x1x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast8
-func.func @test_broadcast8(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1x1xf32>) -> tensor<17x16x15x14xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
- %0 = tosa.add %arg0, %arg1 : (tensor<17x16x15x14xf32>, tensor<1x1xf32>) -> tensor<17x16x15x14xf32>
- return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast9
-func.func @test_broadcast9(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x1xf32>) -> tensor<17x16x15x14xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 15, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
- %0 = tosa.add %arg0, %arg1 : (tensor<17x16x15x14xf32>, tensor<15x1xf32>) -> tensor<17x16x15x14xf32>
- return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast10
-func.func @test_broadcast10(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x14xf32>) -> tensor<17x16x15x14xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 15, 14>}
- // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
- %0 = tosa.add %arg0, %arg1 : (tensor<17x16x15x14xf32>, tensor<15x14xf32>) -> tensor<17x16x15x14xf32>
- return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast13
-func.func @test_broadcast13(%arg0: tensor<1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
- %0 = tosa.add %arg0, %arg1 : (tensor<1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
- return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast14
-func.func @test_broadcast14(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x1x14xf32>) -> tensor<17x16x1x14xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
- %0 = tosa.add %arg0, %arg1 : (tensor<1x1xf32>, tensor<17x16x1x14xf32>) -> tensor<17x16x1x14xf32>
- return %0 : tensor<17x16x1x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast15
-func.func @test_broadcast15(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
- %0 = tosa.add %arg0, %arg1 : (tensor<1x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
- return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast16
-func.func @test_broadcast16(%arg0: tensor<15x1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 15, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
- %0 = tosa.add %arg0, %arg1 : (tensor<15x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
- return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast17
-func.func @test_broadcast17(%arg0: tensor<15x14xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 15, 14>}
- // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
- %0 = tosa.add %arg0, %arg1 : (tensor<15x14xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
- return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast18
-func.func @test_broadcast18(%arg0: tensor<14x1xf32>, %arg1: tensor<1x15xf32>) -> tensor<14x15xf32> {
- // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %arg1
- %0 = tosa.add %arg0, %arg1 : (tensor<14x1xf32>, tensor<1x15xf32>) -> tensor<14x15xf32>
- return %0 : tensor<14x15xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast19
-func.func @test_broadcast19(%arg0: tensor<64x64x1xf32>, %arg1: tensor<1x17xf32>) -> (tensor<64x64x17xf32> ) {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 17>}
- // CHECK: %[[VAR1:.*]] = tosa.sub %arg0, %[[VAR0]]
- %0 = tosa.sub %arg0, %arg1 : (tensor<64x64x1xf32>, tensor<1x17xf32>) -> tensor<64x64x17xf32>
- return %0 : tensor<64x64x17xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast20
-func.func @test_broadcast20(%arg0: tensor<3x3x4x1xf32>, %arg1: tensor<4x5xf32>) -> (tensor<3x3x4x5xf32> ) {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 4, 5>}
- // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
- %0 = tosa.add %arg0, %arg1 : (tensor<3x3x4x1xf32>, tensor<4x5xf32>) -> tensor<3x3x4x5xf32>
- return %0 : tensor<3x3x4x5xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_mul
-func.func @test_broadcast_mul(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 15, 14>}
- // CHECK: %[[VAR1:.*]] = tosa.mul %[[VAR0]], %arg1
- %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8 } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
- return %0 : tensor<17x16x15x14xi32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_arithmetic_right_shift
-func.func @test_broadcast_arithmetic_right_shift(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 15, 14>}
- // CHECK: %[[VAR1:.*]] = tosa.arithmetic_right_shift %[[VAR0]], %arg1
- %0 = tosa.arithmetic_right_shift %arg0, %arg1 { round = true } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
- return %0 : tensor<17x16x15x14xi32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_scalar
-func.func @test_broadcast_scalar(%arg0: tensor<i32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
- %0 = tosa.add %arg0, %arg1 : (tensor<i32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
- return %0 : tensor<17x16x15x14xi32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_both_input
-func.func @test_broadcast_select_both_input(%arg0: tensor<1x16x16xi1>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<1x16x16xf32> {
- // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1>}
- // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1>}
- // CHECK: %[[VAL_2:.*]] = tosa.select %arg0, %[[VAL_0]], %[[VAL_1]]
- %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<1x16x16xi1>, tensor<f32>, tensor<f32>) -> tensor<1x16x16xf32>
- return %0 : tensor<1x16x16xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_one_input
-func.func @test_broadcast_select_one_input(%arg0: tensor<17x16x15x14xi1>, %arg1: tensor<17x16x15x14xf32>, %arg2: tensor<f32>) -> tensor<17x16x15x14xf32> {
- // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAL_1:.*]] = tosa.select %arg0, %arg1, %[[VAL_0]]
- %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<17x16x15x14xi1>, tensor<17x16x15x14xf32>, tensor<f32>) -> tensor<17x16x15x14xf32>
- return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_predicate
-func.func @test_broadcast_select_predicate(%arg0: tensor<i1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
- // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAL_1:.*]] = tosa.select %[[VAL_0]], %arg1, %arg2
- %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<i1>, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
- return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_abc
-func.func @test_broadcast_select_abc(%arg0: tensor<i1>, %arg1: tensor<32x8xf32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
- // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 32, 8>}
- // CHECK: %[[VAL_2:.*]] = tosa.select %[[VAL_0]], %[[VAL_1]], %arg2
- %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<i1>, tensor<32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
- return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_acb
-func.func @test_broadcast_select_acb(%arg0: tensor<i1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<32x8xf32>) -> tensor<1x32x32x8xf32> {
- // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 32, 8>}
- // CHECK: %[[VAL_2:.*]] = tosa.select %[[VAL_0]], %arg1, %[[VAL_1]]
- %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<i1>, tensor<1x32x32x8xf32>, tensor<32x8xf32>) -> tensor<1x32x32x8xf32>
- return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_bac
-func.func @test_broadcast_select_bac(%arg0: tensor<32x8xi1>, %arg1: tensor<f32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
- // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 32, 8>}
- // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAL_2:.*]] = tosa.select %[[VAL_0]], %[[VAL_1]], %arg2
- %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<32x8xi1>, tensor<f32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
- return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_bca
-func.func @test_broadcast_select_bca(%arg0: tensor<32x8xi1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<i1>) -> tensor<1x32x32x8xf32> {
- // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 32, 8>}
- // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAL_2:.*]] = tosa.select %[[VAL_0]], %arg1, %[[VAL_1]]
- %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<32x8xi1>, tensor<1x32x32x8xf32>, tensor<i1>) -> tensor<1x32x32x8xf32>
- return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_cab
-func.func @test_broadcast_select_cab(%arg0: tensor<1x32x32x8xi1>, %arg1: tensor<f32>, %arg2: tensor<32x8xf32>) -> tensor<1x32x32x8xf32> {
- // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 32, 8>}
- // CHECK: %[[VAL_2:.*]] = tosa.select %arg0, %[[VAL_0]], %[[VAL_1]]
- %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<1x32x32x8xi1>, tensor<f32>, tensor<32x8xf32>) -> tensor<1x32x32x8xf32>
- return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_cba
-func.func @test_broadcast_select_cba(%arg0: tensor<1x32x32x8xi1>, %arg1: tensor<32x8xf32>, %arg2: tensor<i1>) -> tensor<1x32x32x8xf32> {
- // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 32, 8>}
- // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 1>}
- // CHECK: %[[VAL_2:.*]] = tosa.select %arg0, %[[VAL_0]], %[[VAL_1]]
- %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<1x32x32x8xi1>, tensor<32x8xf32>, tensor<i1>) -> tensor<1x32x32x8xf32>
- return %0 : tensor<1x32x32x8xf32>
-}
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 869bce08a8c729..3ff3121348fcad 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -15,9 +15,9 @@ func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> {
}
// CHECK-LABEL: func @try_fold_equal_with_unranked_tensor
-func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<i32>) {
+func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) {
// CHECK: tosa.equal
// CHECK-NEXT: return
- %0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi1>
+ %0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1>
return
}
diff --git a/mlir/test/Dialect/Tosa/inlining.mlir b/mlir/test/Dialect/Tosa/inlining.mlir
index d57b5cbcf475c8..e892fdaa277505 100644
--- a/mlir/test/Dialect/Tosa/inlining.mlir
+++ b/mlir/test/Dialect/Tosa/inlining.mlir
@@ -47,7 +47,8 @@ func.func @inlined_while_fn(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tenso
}
func.func pri...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/104501
More information about the Mlir-commits
mailing list