[Mlir-commits] [mlir] bef966e - tosa-make-broadcatable pass now supports numpy style broadcasting only.

Rob Suderman llvmlistbot at llvm.org
Wed Nov 10 11:49:13 PST 2021


Author: Kevin Cheng
Date: 2021-11-10T11:48:35-08:00
New Revision: bef966eb376e0756af3a8d434acfc3b250f59257

URL: https://github.com/llvm/llvm-project/commit/bef966eb376e0756af3a8d434acfc3b250f59257
DIFF: https://github.com/llvm/llvm-project/commit/bef966eb376e0756af3a8d434acfc3b250f59257.diff

LOG: tosa-make-broadcatable pass now supports numpy style broadcasting only.

- fix bug that in [c,1] + [a, b, c, d] broadcast
- add test [3,3,4,1] + [4,5]

Signed-off-by: Kevin Cheng <kevin.cheng at arm.com>
Change-Id: Iaed2f04df8775f655c82c740271395274163d147

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D113596

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
    mlir/test/Dialect/Tosa/broadcast.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
index c8fb9c505eeef..36b2287394338 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
@@ -25,7 +25,7 @@ using namespace mlir::tosa;
 /// There are two potential ways implementing broadcast:
 /// a. https://www.tensorflow.org/xla/broadcasting#formal_definition
 /// b. https://numpy.org/doc/stable/user/basics.broadcasting.html
-/// TBD: picking option (a) now.
+/// This pass implements b (numpy style) now.
 
 /// In this pass, we insert RESHAPE operators to increase the rank of the
 /// lower rank operand as a first step in the broadcasting process. The TOSA
@@ -33,75 +33,39 @@ using namespace mlir::tosa;
 /// are equal.
 
 // Examples:
-// If lower=[a], target=[a, b, c], [a] reshaped into [a, 1, 1].
-// TODO: If lower=[b], target=[a, b, c], [b] should but NOT YET reshaped into
-// [1, b, 1].
-// If lower=[c], target=[a, b, c], [c] reshaped into [1, 1, c].
-// If lower=[a, c], target=[a, b, c], [a, c] reshaped into [a, 1, c].
-// If lower=[a, b], target=[a, b, c], [a, b] reshaped into [a, b, 1].
-// If lower=[b, c], target=[a, b, c], [b, c] reshaped into [1, b, c].
-// If lower=[a], target=[a, a], [a] reshaped into [1, a] instead of [a, 1].
+// If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c].
+// If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c].
+// If lower=[a], higher=[a, a], [a] reshaped into [1, a].
 // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
 // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
 
-static void computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
-                                 ArrayRef<int64_t> lowerRankShape,
-                                 SmallVectorImpl<int64_t> &reshapeOutputShape) {
+static LogicalResult
+computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
+                     ArrayRef<int64_t> lowerRankShape,
+                     SmallVectorImpl<int64_t> &reshapeOutputShape) {
   // Initialize new shapes with [1] * higherRank.
   int64_t higherRank = higherRankShape.size();
   int64_t lowerRank = lowerRankShape.size();
 
   reshapeOutputShape.assign(higherRank, 1);
 
-  int64_t higherLeftIndex = 0;
-  int64_t higherRightIndex = higherRank;
-  int64_t lowerLeftIndex = 0;
-  int64_t lowerRightIndex = lowerRank;
-  int64_t higherRankDim, lowerRankDim;
-
-  if (lowerRightIndex != 0 && higherRightIndex != 0) {
-    // Matches lower rank shape from right dimension first, until not
-    // matching high rank shape or reaching dimension 0.
-    while (true) {
-      higherRankDim = higherRankShape[higherRightIndex - 1];
-      lowerRankDim = lowerRankShape[lowerRightIndex - 1];
-      if (higherRankDim != lowerRankDim)
-        break;
-
-      reshapeOutputShape[higherRightIndex - 1] = higherRankDim;
-
-      if (higherRightIndex > 0)
-        higherRightIndex--;
-
-      if (lowerRightIndex > 0)
-        lowerRightIndex--;
-
-      if (higherRightIndex == 0 || lowerRightIndex == 0)
-        break;
-    }
-    if (lowerRightIndex != 0 && higherRightIndex != 0) {
-      // Matches lower rank shape from left dimension, until not matching
-      // high rank shape or reaching right index.
-      while (true) {
-        higherRankDim = higherRankShape[higherLeftIndex];
-        lowerRankDim = lowerRankShape[lowerLeftIndex];
-        if (higherRankDim != lowerRankDim)
-          break;
-
-        reshapeOutputShape[higherLeftIndex] = higherRankDim;
-
-        if (higherLeftIndex < higherRightIndex)
-          higherLeftIndex++;
-
-        if (lowerLeftIndex < lowerRightIndex)
-          lowerLeftIndex++;
-
-        if (higherLeftIndex == higherRightIndex ||
-            lowerLeftIndex == lowerRightIndex)
-          break;
-      }
-    }
+  int64_t higherRankDim;
+  int64_t lowerRankDim;
+
+  for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0;
+       i--, j--) {
+    higherRankDim = higherRankShape[i];
+    lowerRankDim = lowerRankShape[j];
+
+    if (lowerRankDim == 1 && higherRankDim > 1)
+      reshapeOutputShape[i] = 1;
+    else if ((lowerRankDim > 1 && higherRankDim == 1) ||
+             (lowerRankDim == higherRankDim))
+      reshapeOutputShape[i] = lowerRankDim;
+    else if (higherRankDim != lowerRankDim)
+      return failure();
   }
+  return success();
 }
 
 /// Common code to create the reshape op where necessary to make the rank of the
@@ -143,8 +107,9 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
 
   SmallVector<int64_t, 4> reshapeOutputShape;
 
-  computeReshapeOutput(outputType.getShape(), lowerRankShape,
-                       reshapeOutputShape);
+  if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape)
+          .failed())
+    return failure();
 
   auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
   auto reshapeOutputType = RankedTensorType::get(

diff  --git a/mlir/test/Dialect/Tosa/broadcast.mlir b/mlir/test/Dialect/Tosa/broadcast.mlir
index 16fb75450675b..bb96e3eda7caf 100644
--- a/mlir/test/Dialect/Tosa/broadcast.mlir
+++ b/mlir/test/Dialect/Tosa/broadcast.mlir
@@ -11,7 +11,8 @@ func @test_broadcast0(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf3
 // -----
 // CHECK-LABEL: broadcast1
 func @test_broadcast1(%arg0: tensor<1xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x1xf32> {
-  //  CHECK: reshape
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1]}
+  // CHECK: %[[VAR1:.*]] = "tosa.add"(%[[VAR0]], %arg1)
   %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<2x1xf32>) -> tensor<2x1xf32>
   return %0 : tensor<2x1xf32>
 }
@@ -19,7 +20,8 @@ func @test_broadcast1(%arg0: tensor<1xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x
 // -----
 // CHECK-LABEL: broadcast2
 func @test_broadcast2(%arg0: tensor<2x1xf32>, %arg1: tensor<1xf32>) -> tensor<2x1xf32> {
-  //  CHECK: reshape
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1]}
+  // CHECK: %[[VAR1:.*]] = "tosa.add"(%arg0, %[[VAR0]])
   %0 = "tosa.add"(%arg0, %arg1) : (tensor<2x1xf32>, tensor<1xf32>) -> tensor<2x1xf32>
   return %0 : tensor<2x1xf32>
 }
@@ -27,7 +29,8 @@ func @test_broadcast2(%arg0: tensor<2x1xf32>, %arg1: tensor<1xf32>) -> tensor<2x
 // -----
 // CHECK-LABEL: broadcast3
 func @test_broadcast3(%arg0: tensor<2x1x1x1xf32>, %arg1: tensor<1xf32>) -> tensor<2x1x1x1xf32> {
-  //  CHECK: reshape
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 1]}
+  // CHECK: %[[VAR1:.*]] = "tosa.add"(%arg0, %[[VAR0]])
   %0 = "tosa.add"(%arg0, %arg1) : (tensor<2x1x1x1xf32>, tensor<1xf32>) -> tensor<2x1x1x1xf32>
   return %0 : tensor<2x1x1x1xf32>
 }
@@ -35,7 +38,8 @@ func @test_broadcast3(%arg0: tensor<2x1x1x1xf32>, %arg1: tensor<1xf32>) -> tenso
 // -----
 // CHECK-LABEL: broadcast4
 func @test_broadcast4(%arg0: tensor<1x1x1x2xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x1x2xf32> {
-  //  CHECK: reshape
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 1]}
+  // CHECK: %[[VAR1:.*]] = "tosa.add"(%arg0, %[[VAR0]])
   %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1x1x2xf32>, tensor<1xf32>) -> tensor<1x1x1x2xf32>
   return %0 : tensor<1x1x1x2xf32>
 }
@@ -43,7 +47,8 @@ func @test_broadcast4(%arg0: tensor<1x1x1x2xf32>, %arg1: tensor<1xf32>) -> tenso
 // -----
 // CHECK-LABEL: broadcast5
 func @test_broadcast5(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x2x1xf32> {
-  //  CHECK: reshape
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 1]}
+  // CHECK: %[[VAR1:.*]] = "tosa.add"(%arg0, %[[VAR0]])
   %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1x2x1xf32>, tensor<1xf32>) -> tensor<1x1x2x1xf32>
   return %0 : tensor<1x1x2x1xf32>
 }
@@ -51,7 +56,8 @@ func @test_broadcast5(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1xf32>) -> tenso
 // -----
 // CHECK-LABEL: broadcast6
 func @test_broadcast6(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1xf32>) -> tensor<17x16x15x14xf32> {
-  //  CHECK: reshape
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 1]}
+  // CHECK: %[[VAR1:.*]] = "tosa.add"(%arg0, %[[VAR0]])
   %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<1xf32>) -> tensor<17x16x15x14xf32>
   return %0 : tensor<17x16x15x14xf32>
 }
@@ -59,7 +65,8 @@ func @test_broadcast6(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1xf32>) -> t
 // -----
 // CHECK-LABEL: broadcast7
 func @test_broadcast7(%arg0: tensor<17x16x1x14xf32>, %arg1: tensor<1x1xf32>) -> tensor<17x16x1x14xf32> {
-  //  CHECK: reshape
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 1]}
+  // CHECK: %[[VAR1:.*]] = "tosa.add"(%arg0, %[[VAR0]])
   %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x1x14xf32>, tensor<1x1xf32>) -> tensor<17x16x1x14xf32>
   return %0 : tensor<17x16x1x14xf32>
 }
@@ -67,7 +74,8 @@ func @test_broadcast7(%arg0: tensor<17x16x1x14xf32>, %arg1: tensor<1x1xf32>) ->
 // -----
 // CHECK-LABEL: broadcast8
 func @test_broadcast8(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1x1xf32>) -> tensor<17x16x15x14xf32> {
-  //  CHECK: reshape
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 1]}
+  // CHECK: %[[VAR1:.*]] = "tosa.add"(%arg0, %[[VAR0]])
   %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<1x1xf32>) -> tensor<17x16x15x14xf32>
   return %0 : tensor<17x16x15x14xf32>
 }
@@ -75,7 +83,8 @@ func @test_broadcast8(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1x1xf32>) ->
 // -----
 // CHECK-LABEL: broadcast9
 func @test_broadcast9(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x1xf32>) -> tensor<17x16x15x14xf32> {
-  //  CHECK: reshape
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 15, 1]}
+  // CHECK: %[[VAR1:.*]] = "tosa.add"(%arg0, %[[VAR0]])
   %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<15x1xf32>) -> tensor<17x16x15x14xf32>
   return %0 : tensor<17x16x15x14xf32>
 }
@@ -83,7 +92,8 @@ func @test_broadcast9(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x1xf32>) -
 // -----
 // CHECK-LABEL: broadcast10
 func @test_broadcast10(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x14xf32>) -> tensor<17x16x15x14xf32> {
-  //  CHECK: reshape
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 15, 14]}
+  // CHECK: %[[VAR1:.*]] = "tosa.add"(%arg0, %[[VAR0]])
   %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<15x14xf32>) -> tensor<17x16x15x14xf32>
   return %0 : tensor<17x16x15x14xf32>
 }
@@ -91,7 +101,8 @@ func @test_broadcast10(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x14xf32>)
 // -----
 // CHECK-LABEL: broadcast13
 func @test_broadcast13(%arg0: tensor<1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
-  //  CHECK: reshape
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1, 1, 1]}
+  // CHECK: %[[VAR1:.*]] = "tosa.add"(%[[VAR0]], %arg1)
   %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
   return %0 : tensor<17x16x15x14xf32>
 }
@@ -99,7 +110,8 @@ func @test_broadcast13(%arg0: tensor<1xf32>, %arg1: tensor<17x16x15x14xf32>) ->
 // -----
 // CHECK-LABEL: broadcast14
 func @test_broadcast14(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x1x14xf32>) -> tensor<17x16x1x14xf32> {
-  //  CHECK: reshape
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1, 1, 1]}
+  // CHECK: %[[VAR1:.*]] = "tosa.add"(%[[VAR0]], %arg1)
   %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1xf32>, tensor<17x16x1x14xf32>) -> tensor<17x16x1x14xf32>
   return %0 : tensor<17x16x1x14xf32>
 }
@@ -107,7 +119,8 @@ func @test_broadcast14(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x1x14xf32>) ->
 // -----
 // CHECK-LABEL: broadcast15
 func @test_broadcast15(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
-  //  CHECK: reshape
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1, 1, 1]}
+  // CHECK: %[[VAR1:.*]] = "tosa.add"(%[[VAR0]], %arg1)
   %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
   return %0 : tensor<17x16x15x14xf32>
 }
@@ -115,7 +128,8 @@ func @test_broadcast15(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x15x14xf32>) -
 // -----
 // CHECK-LABEL: broadcast16
 func @test_broadcast16(%arg0: tensor<15x1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
-  //  CHECK: reshape
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1, 15, 1]}
+  // CHECK: %[[VAR1:.*]] = "tosa.add"(%[[VAR0]], %arg1)
   %0 = "tosa.add"(%arg0, %arg1) : (tensor<15x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
   return %0 : tensor<17x16x15x14xf32>
 }
@@ -123,7 +137,8 @@ func @test_broadcast16(%arg0: tensor<15x1xf32>, %arg1: tensor<17x16x15x14xf32>)
 // -----
 // CHECK-LABEL: broadcast17
 func @test_broadcast17(%arg0: tensor<15x14xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
-  //  CHECK: reshape
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1, 15, 14]}
+  // CHECK: %[[VAR1:.*]] = "tosa.add"(%[[VAR0]], %arg1)
   %0 = "tosa.add"(%arg0, %arg1) : (tensor<15x14xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
   return %0 : tensor<17x16x15x14xf32>
 }
@@ -131,24 +146,34 @@ func @test_broadcast17(%arg0: tensor<15x14xf32>, %arg1: tensor<17x16x15x14xf32>)
 // -----
 // CHECK-LABEL: broadcast18
 func @test_broadcast18(%arg0: tensor<14x1xf32>, %arg1: tensor<1x15xf32>) -> tensor<14x15xf32> {
-  //  CHECK: add
+  // CHECK: %[[VAR1:.*]] = "tosa.add"(%arg0, %arg1)
   %0 = "tosa.add"(%arg0, %arg1) : (tensor<14x1xf32>, tensor<1x15xf32>) -> tensor<14x15xf32>
   return %0 : tensor<14x15xf32>
 }
 
 // -----
 // CHECK-LABEL: broadcast19
-func @broadcast19(%arg0: tensor<64x64x1xf32>, %arg1: tensor<1x17xf32>) -> (tensor<64x64x17xf32> ) {
-  //  CHECK: reshape
-  //  CHECK: sub
+func @test_broadcast19(%arg0: tensor<64x64x1xf32>, %arg1: tensor<1x17xf32>) -> (tensor<64x64x17xf32> ) {
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 17]}
+  // CHECK: %[[VAR1:.*]] = "tosa.sub"(%arg0, %[[VAR0]])
   %0 = "tosa.sub"(%arg0, %arg1) : (tensor<64x64x1xf32>, tensor<1x17xf32>) -> tensor<64x64x17xf32>
   return %0 : tensor<64x64x17xf32>
 }
 
+// -----
+// CHECK-LABEL: broadcast20
+func @test_broadcast20(%arg0: tensor<3x3x4x1xf32>, %arg1: tensor<4x5xf32>) -> (tensor<3x3x4x5xf32> ) {
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 4, 5]}
+  // CHECK: %[[VAR1:.*]] = "tosa.add"(%arg0, %[[VAR0]])
+  %0 = "tosa.add"(%arg0, %arg1) : (tensor<3x3x4x1xf32>, tensor<4x5xf32>) -> tensor<3x3x4x5xf32>
+  return %0 : tensor<3x3x4x5xf32>
+}
+
 // -----
 // CHECK-LABEL: broadcast_mul
 func @test_broadcast_mul(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
-  //  CHECK: reshape
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1, 15, 14]}
+  // CHECK: %[[VAR1:.*]] = "tosa.mul"(%[[VAR0]], %arg1)
   %0 = "tosa.mul"(%arg0, %arg1) {shift = 1 : i32 } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
   return %0 : tensor<17x16x15x14xi32>
 }
@@ -156,7 +181,8 @@ func @test_broadcast_mul(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32
 // -----
 // CHECK-LABEL: broadcast_arithmetic_right_shift
 func @test_broadcast_arithmetic_right_shift(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
-  //  CHECK: reshape
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1, 15, 14]}
+  // CHECK: %[[VAR1:.*]] = "tosa.arithmetic_right_shift"(%[[VAR0]], %arg1)
   %0 = "tosa.arithmetic_right_shift"(%arg0, %arg1) { round = true } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
   return %0 : tensor<17x16x15x14xi32>
 }
@@ -164,7 +190,8 @@ func @test_broadcast_arithmetic_right_shift(%arg0: tensor<15x14xi32>, %arg1: ten
 // -----
 // CHECK-LABEL: broadcast_scalar
 func @test_broadcast_scalar(%arg0: tensor<i32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
-  //  CHECK-NEXT: reshape
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1, 1, 1]}
+  // CHECK: %[[VAR1:.*]] = "tosa.add"(%[[VAR0]], %arg1)
   %0 = "tosa.add"(%arg0, %arg1) : (tensor<i32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
   return %0 : tensor<17x16x15x14xi32>
 }


        


More information about the Mlir-commits mailing list