[Mlir-commits] [mlir] [tosa] Add verifier checks for Scatter (PR #142661)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jun 3 12:36:41 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Tai Ly (Tai78641)

<details>
<summary>Changes</summary>

This adds verifier checks for the scatter op
to make sure the shapes of inputs and output
are consistent with respect to spec.


---
Full diff: https://github.com/llvm/llvm-project/pull/142661.diff


9 Files Affected:

- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+67) 
- (modified) mlir/test/Dialect/Tosa/availability.mlir (+3-3) 
- (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+3-3) 
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+3-3) 
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+11-11) 
- (modified) mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir (+3-3) 
- (modified) mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir (+3-3) 
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+3-3) 
- (modified) mlir/test/Dialect/Tosa/verifier.mlir (+72) 


``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index a22e6b7aa9791..f707770970e5f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2692,6 +2692,73 @@ LogicalResult tosa::ScatterOp::verify() {
           .failed()) {
     return failure();
   }
+
+  const ShapeAdaptor valuesInShape(getValuesIn().getType());
+  const ShapeAdaptor indicesShape(getIndices().getType());
+  const ShapeAdaptor inputShape(getInput().getType());
+  const ShapeAdaptor outputShape(getValuesOut().getType());
+
+  int64_t N = ShapedType::kDynamic;
+  int64_t K = ShapedType::kDynamic;
+  int64_t W = ShapedType::kDynamic;
+  int64_t C = ShapedType::kDynamic;
+  if (valuesInShape.hasRank()) {
+    N = valuesInShape.getDimSize(0);
+    K = valuesInShape.getDimSize(1);
+    C = valuesInShape.getDimSize(2);
+  }
+  if (indicesShape.hasRank()) {
+    const int64_t indicesN = indicesShape.getDimSize(0);
+    W = indicesShape.getDimSize(1);
+    if (N == ShapedType::kDynamic)
+      N = indicesN;
+    else if (indicesN != ShapedType::kDynamic && N != indicesN)
+      return emitOpError() << "requires indices dimension 0 to have size " << N
+                           << ", got " << indicesN;
+  }
+  if (inputShape.hasRank()) {
+    const int64_t inputN = inputShape.getDimSize(0);
+    const int64_t inputW = inputShape.getDimSize(1);
+    const int64_t inputC = inputShape.getDimSize(2);
+    if (N == ShapedType::kDynamic)
+      N = inputN;
+    else if (inputN != ShapedType::kDynamic && N != inputN)
+      return emitOpError() << "requires input dimension 0 to have size " << N
+                           << ", got " << inputN;
+    if (W == ShapedType::kDynamic)
+      W = inputW;
+    else if (inputW != ShapedType::kDynamic && W != inputW)
+      return emitOpError() << "requires input dimension 1 to have size " << W
+                           << ", got " << inputW;
+
+    if (C == ShapedType::kDynamic)
+      C = inputC;
+    else if (inputC != ShapedType::kDynamic && C != inputC)
+      return emitOpError() << "requires input dimension 2 to have size " << C
+                           << ", got " << inputC;
+  }
+  if (outputShape.hasRank()) {
+    const int64_t outputN = outputShape.getDimSize(0);
+    const int64_t outputK = outputShape.getDimSize(1);
+    const int64_t outputC = outputShape.getDimSize(2);
+    if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
+        N != outputN)
+      return emitOpError() << "requires values_out dimension 0 to have size "
+                           << N << ", got " << outputN;
+    if (K == ShapedType::kDynamic)
+      K = outputK;
+    else if (outputK != ShapedType::kDynamic && K != outputK)
+      return emitOpError() << "requires values_out dimension 1 to have size "
+                           << K << ", got " << outputK;
+    if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
+        C != outputC)
+      return emitOpError() << "requires values_out dimension 2 to have size "
+                           << C << ", got " << outputC;
+  }
+  if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
+    return emitOpError() << "requires dimensions K >= W, got K=" << K
+                         << " and W=" << W;
+
   return success();
 }
 
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 75126a11ac504..0176fc2883518 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -583,11 +583,11 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
 
 // -----
 // CHECK-LABEL: scatter
-func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
+func.func @test_scatter(%arg0: tensor<13x28x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x28x3xf32> {
   // CHECK: profiles: [ [pro_int, pro_fp] ]
   // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
-  return %0 : tensor<13x21x3xf32>
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x28x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x28x3xf32>
+  return %0 : tensor<13x28x3xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 2364985442e43..5630c33639d86 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -243,10 +243,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x26xi32>) ->
 }
 
 // -----
-func.func @test_scatter(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xbf16>) -> tensor<13x21x3xbf16> {
+func.func @test_scatter(%arg0: tensor<13x26x3xbf16>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xbf16>) -> tensor<13x26x3xbf16> {
   // expected-error at +1 {{'tosa.scatter' op illegal: requires [bf16] but not enabled in target}}
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xbf16>, tensor<13x26xi32>, tensor<13x26x3xbf16>) -> tensor<13x21x3xbf16>
-  return %0 : tensor<13x21x3xbf16>
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x26x3xbf16>, tensor<13x26xi32>, tensor<13x26x3xbf16>) -> tensor<13x26x3xbf16>
+  return %0 : tensor<13x26x3xbf16>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 223bf3b635e18..0dddf26fb1f85 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1080,10 +1080,10 @@ func.func @test_gather_tensor_size_invalid(%arg0: tensor<268435456x21x3xf32>, %a
 
 // -----
 
-func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x210000000x3xf32>, %arg1: tensor<13x260000000xi32>, %arg2: tensor<13x260000000x3xf32>) -> tensor<13x210000000x3xf32> {
+func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x260000000x3xf32>, %arg1: tensor<13x260000000xi32>, %arg2: tensor<13x260000000x3xf32>) -> tensor<13x260000000x3xf32> {
   // expected-error at +1 {{'tosa.scatter' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x210000000x3xf32>, tensor<13x260000000xi32>, tensor<13x260000000x3xf32>) -> tensor<13x210000000x3xf32>
-  return %0 : tensor<13x210000000x3xf32>
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x260000000x3xf32>, tensor<13x260000000xi32>, tensor<13x260000000x3xf32>) -> tensor<13x260000000x3xf32>
+  return %0 : tensor<13x260000000x3xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 767fa833dedd4..1ac82400843ed 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -714,9 +714,9 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
 
 // -----
 // CHECK-LABEL: scatter
-func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
-  return %0 : tensor<13x21x3xf32>
+func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
+  return %0 : tensor<13x52x3xf32>
 }
 
 // -----
@@ -728,8 +728,8 @@ func.func @test_gather_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tenso
 
 // -----
 // CHECK-LABEL: scatter_unranked_indices
-func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
+func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 
@@ -1010,9 +1010,9 @@ func.func @test_gather_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26
 
 // -----
 // CHECK-LABEL: scatter_f8E5M2
-func.func @test_scatter_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> {
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E5M2>, tensor<13x26xi32>, tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2>
-  return %0 : tensor<13x21x3xf8E5M2>
+func.func @test_scatter_f8E5M2(%arg0: tensor<13x52x3xf8E5M2>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E5M2>) -> tensor<13x52x3xf8E5M2> {
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf8E5M2>, tensor<13x26xi32>, tensor<13x26x3xf8E5M2>) -> tensor<13x52x3xf8E5M2>
+  return %0 : tensor<13x52x3xf8E5M2>
 }
 
 // -----
@@ -1155,7 +1155,7 @@ func.func @test_gather_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<1
 
 // -----
 // CHECK-LABEL: scatter_f8E4M3FN
-func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> {
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN>
-  return %0 : tensor<13x21x3xf8E4M3FN>
+func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x29x3xf8E4M3FN>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E4M3FN>) -> tensor<13x29x3xf8E4M3FN> {
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x29x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x29x3xf8E4M3FN>
+  return %0 : tensor<13x29x3xf8E4M3FN>
 }
diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
index 72669c62c95ca..fad4859351251 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
@@ -310,10 +310,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
 }
 
 // -----
-func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
+func.func @test_scatter(%arg0: tensor<13x28x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x28x3xf32> {
   // expected-error at +1 {{'tosa.scatter' op illegal: requires [pro_fp] but not enabled in target}}
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
-  return %0 : tensor<13x21x3xf32>
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x28x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x28x3xf32>
+  return %0 : tensor<13x28x3xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
index e98b906377b22..9438179622aad 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
@@ -242,10 +242,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x26xi32>) ->
 }
 
 // -----
-func.func @test_scatter(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi32>) -> tensor<13x21x3xi32> {
+func.func @test_scatter(%arg0: tensor<13x27x3xi32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi32>) -> tensor<13x27x3xi32> {
   // expected-error at +1 {{'tosa.scatter' op illegal: requires [pro_int] but not enabled in target}}
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xi32>, tensor<13x26xi32>, tensor<13x26x3xi32>) -> tensor<13x21x3xi32>
-  return %0 : tensor<13x21x3xi32>
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x27x3xi32>, tensor<13x26xi32>, tensor<13x26x3xi32>) -> tensor<13x27x3xi32>
+  return %0 : tensor<13x27x3xi32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 1ad1e6c76c294..591a3f0acf65d 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -656,9 +656,9 @@ func.func @gather_minimum_info(%arg0 : tensor<3x?x5xi32>, %arg1 : tensor<?x6xi32
 // -----
 
 // CHECK-LABEL: @scatter_static
-func.func @scatter_static(%arg0 : tensor<3x4x5xi32>, %arg1 : tensor<3x6xi32>, %arg2 : tensor<3x6x5xi32>) {
-  // CHECK: tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<3x4x5xi32>
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<?x?x?xi32>
+func.func @scatter_static(%arg0 : tensor<3x8x5xi32>, %arg1 : tensor<3x6xi32>, %arg2 : tensor<3x6x5xi32>) {
+  // CHECK: tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x8x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<3x8x5xi32>
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x8x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<?x?x?xi32>
   return
 }
 
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 990e0d954f54e..b3052369b055e 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -864,3 +864,75 @@ func.func @test_variable_write_shape_mismatch(%arg0: tensor<2x4x8xf32>) -> () {
   tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xf32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_indices_N
+func.func @scatter_invalid_indices_N(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<3x2xi32>, %arg2 : tensor<2x2x5xi32>) {
+  // expected-error at +1 {{'tosa.scatter' op requires indices dimension 0 to have size 2, got 3}}
+  %1 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<3x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x5xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_input_N
+func.func @scatter_invalid_input_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<2x2xi32>, %arg2 : tensor<3x2x5xi32>) {
+  // expected-error at +1 {{'tosa.scatter' op requires input dimension 0 to have size 2, got 3}}
+  %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<2x2xi32>, tensor<3x2x5xi32>) -> tensor<2x4x5xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_out_N
+func.func @scatter_invalid_out_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
+  // expected-error at +1 {{'tosa.scatter' op requires values_out dimension 0 to have size 2, got 3}}
+  %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<3x4x5xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_out_K
+func.func @scatter_invalid_out_K(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
+  // expected-error at +1 {{'tosa.scatter' op requires values_out dimension 1 to have size 4, got 3}}
+  %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x3x5xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_input_W
+func.func @scatter_invalid_input_W(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x3x5xi32>) {
+  // expected-error at +1 {{'tosa.scatter' op requires input dimension 1 to have size 2, got 3}}
+  %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x3x5xi32>) -> tensor<2x4x5xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_input_C
+func.func @scatter_invalid_input_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x6xi32>) {
+  // expected-error at +1 {{'tosa.scatter' op requires input dimension 2 to have size 5, got 6}}
+  %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x6xi32>) -> tensor<2x4x5xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_out_C
+func.func @scatter_invalid_out_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
+  // expected-error at +1 {{'tosa.scatter' op requires values_out dimension 2 to have size 5, got 6}}
+  %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x6xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_K_W
+func.func @scatter_invalid_K_W(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<2x6xi32>, %arg2 : tensor<2x6x5xi32>) {
+  // expected-error at +1 {{'tosa.scatter' op requires dimensions K >= W, got K=4 and W=6}}
+  %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<2x6xi32>, tensor<2x6x5xi32>) -> tensor<2x4x5xi32>
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/142661


More information about the Mlir-commits mailing list