[Mlir-commits] [mlir] [tosa] Add verifier checks for Scatter (PR #142661)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 3 12:36:41 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
Author: Tai Ly (Tai78641)
<details>
<summary>Changes</summary>
This adds verifier checks for the scatter op
to make sure the shapes of inputs and output
are consistent with respect to spec.
---
Full diff: https://github.com/llvm/llvm-project/pull/142661.diff
9 Files Affected:
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+67)
- (modified) mlir/test/Dialect/Tosa/availability.mlir (+3-3)
- (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+3-3)
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+3-3)
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+11-11)
- (modified) mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir (+3-3)
- (modified) mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir (+3-3)
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+3-3)
- (modified) mlir/test/Dialect/Tosa/verifier.mlir (+72)
``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index a22e6b7aa9791..f707770970e5f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2692,6 +2692,73 @@ LogicalResult tosa::ScatterOp::verify() {
.failed()) {
return failure();
}
+
+ const ShapeAdaptor valuesInShape(getValuesIn().getType());
+ const ShapeAdaptor indicesShape(getIndices().getType());
+ const ShapeAdaptor inputShape(getInput().getType());
+ const ShapeAdaptor outputShape(getValuesOut().getType());
+
+ int64_t N = ShapedType::kDynamic;
+ int64_t K = ShapedType::kDynamic;
+ int64_t W = ShapedType::kDynamic;
+ int64_t C = ShapedType::kDynamic;
+ if (valuesInShape.hasRank()) {
+ N = valuesInShape.getDimSize(0);
+ K = valuesInShape.getDimSize(1);
+ C = valuesInShape.getDimSize(2);
+ }
+ if (indicesShape.hasRank()) {
+ const int64_t indicesN = indicesShape.getDimSize(0);
+ W = indicesShape.getDimSize(1);
+ if (N == ShapedType::kDynamic)
+ N = indicesN;
+ else if (indicesN != ShapedType::kDynamic && N != indicesN)
+ return emitOpError() << "requires indices dimension 0 to have size " << N
+ << ", got " << indicesN;
+ }
+ if (inputShape.hasRank()) {
+ const int64_t inputN = inputShape.getDimSize(0);
+ const int64_t inputW = inputShape.getDimSize(1);
+ const int64_t inputC = inputShape.getDimSize(2);
+ if (N == ShapedType::kDynamic)
+ N = inputN;
+ else if (inputN != ShapedType::kDynamic && N != inputN)
+ return emitOpError() << "requires input dimension 0 to have size " << N
+ << ", got " << inputN;
+ if (W == ShapedType::kDynamic)
+ W = inputW;
+ else if (inputW != ShapedType::kDynamic && W != inputW)
+ return emitOpError() << "requires input dimension 1 to have size " << W
+ << ", got " << inputW;
+
+ if (C == ShapedType::kDynamic)
+ C = inputC;
+ else if (inputC != ShapedType::kDynamic && C != inputC)
+ return emitOpError() << "requires input dimension 2 to have size " << C
+ << ", got " << inputC;
+ }
+ if (outputShape.hasRank()) {
+ const int64_t outputN = outputShape.getDimSize(0);
+ const int64_t outputK = outputShape.getDimSize(1);
+ const int64_t outputC = outputShape.getDimSize(2);
+ if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
+ N != outputN)
+ return emitOpError() << "requires values_out dimension 0 to have size "
+ << N << ", got " << outputN;
+ if (K == ShapedType::kDynamic)
+ K = outputK;
+ else if (outputK != ShapedType::kDynamic && K != outputK)
+ return emitOpError() << "requires values_out dimension 1 to have size "
+ << K << ", got " << outputK;
+ if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
+ C != outputC)
+ return emitOpError() << "requires values_out dimension 2 to have size "
+ << C << ", got " << outputC;
+ }
+ if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
+ return emitOpError() << "requires dimensions K >= W, got K=" << K
+ << " and W=" << W;
+
return success();
}
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 75126a11ac504..0176fc2883518 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -583,11 +583,11 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
// -----
// CHECK-LABEL: scatter
-func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
+func.func @test_scatter(%arg0: tensor<13x28x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x28x3xf32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
// CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
- return %0 : tensor<13x21x3xf32>
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x28x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x28x3xf32>
+ return %0 : tensor<13x28x3xf32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 2364985442e43..5630c33639d86 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -243,10 +243,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x26xi32>) ->
}
// -----
-func.func @test_scatter(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xbf16>) -> tensor<13x21x3xbf16> {
+func.func @test_scatter(%arg0: tensor<13x26x3xbf16>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xbf16>) -> tensor<13x26x3xbf16> {
// expected-error at +1 {{'tosa.scatter' op illegal: requires [bf16] but not enabled in target}}
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xbf16>, tensor<13x26xi32>, tensor<13x26x3xbf16>) -> tensor<13x21x3xbf16>
- return %0 : tensor<13x21x3xbf16>
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x26x3xbf16>, tensor<13x26xi32>, tensor<13x26x3xbf16>) -> tensor<13x26x3xbf16>
+ return %0 : tensor<13x26x3xbf16>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 223bf3b635e18..0dddf26fb1f85 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1080,10 +1080,10 @@ func.func @test_gather_tensor_size_invalid(%arg0: tensor<268435456x21x3xf32>, %a
// -----
-func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x210000000x3xf32>, %arg1: tensor<13x260000000xi32>, %arg2: tensor<13x260000000x3xf32>) -> tensor<13x210000000x3xf32> {
+func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x260000000x3xf32>, %arg1: tensor<13x260000000xi32>, %arg2: tensor<13x260000000x3xf32>) -> tensor<13x260000000x3xf32> {
// expected-error at +1 {{'tosa.scatter' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x210000000x3xf32>, tensor<13x260000000xi32>, tensor<13x260000000x3xf32>) -> tensor<13x210000000x3xf32>
- return %0 : tensor<13x210000000x3xf32>
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x260000000x3xf32>, tensor<13x260000000xi32>, tensor<13x260000000x3xf32>) -> tensor<13x260000000x3xf32>
+ return %0 : tensor<13x260000000x3xf32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 767fa833dedd4..1ac82400843ed 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -714,9 +714,9 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
// -----
// CHECK-LABEL: scatter
-func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
- return %0 : tensor<13x21x3xf32>
+func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
+ return %0 : tensor<13x52x3xf32>
}
// -----
@@ -728,8 +728,8 @@ func.func @test_gather_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tenso
// -----
// CHECK-LABEL: scatter_unranked_indices
-func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
+func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
@@ -1010,9 +1010,9 @@ func.func @test_gather_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26
// -----
// CHECK-LABEL: scatter_f8E5M2
-func.func @test_scatter_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> {
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E5M2>, tensor<13x26xi32>, tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2>
- return %0 : tensor<13x21x3xf8E5M2>
+func.func @test_scatter_f8E5M2(%arg0: tensor<13x52x3xf8E5M2>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E5M2>) -> tensor<13x52x3xf8E5M2> {
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf8E5M2>, tensor<13x26xi32>, tensor<13x26x3xf8E5M2>) -> tensor<13x52x3xf8E5M2>
+ return %0 : tensor<13x52x3xf8E5M2>
}
// -----
@@ -1155,7 +1155,7 @@ func.func @test_gather_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<1
// -----
// CHECK-LABEL: scatter_f8E4M3FN
-func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> {
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN>
- return %0 : tensor<13x21x3xf8E4M3FN>
+func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x29x3xf8E4M3FN>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E4M3FN>) -> tensor<13x29x3xf8E4M3FN> {
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x29x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x29x3xf8E4M3FN>
+ return %0 : tensor<13x29x3xf8E4M3FN>
}
diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
index 72669c62c95ca..fad4859351251 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
@@ -310,10 +310,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
}
// -----
-func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
+func.func @test_scatter(%arg0: tensor<13x28x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x28x3xf32> {
// expected-error at +1 {{'tosa.scatter' op illegal: requires [pro_fp] but not enabled in target}}
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
- return %0 : tensor<13x21x3xf32>
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x28x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x28x3xf32>
+ return %0 : tensor<13x28x3xf32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
index e98b906377b22..9438179622aad 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
@@ -242,10 +242,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x26xi32>) ->
}
// -----
-func.func @test_scatter(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi32>) -> tensor<13x21x3xi32> {
+func.func @test_scatter(%arg0: tensor<13x27x3xi32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi32>) -> tensor<13x27x3xi32> {
// expected-error at +1 {{'tosa.scatter' op illegal: requires [pro_int] but not enabled in target}}
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xi32>, tensor<13x26xi32>, tensor<13x26x3xi32>) -> tensor<13x21x3xi32>
- return %0 : tensor<13x21x3xi32>
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x27x3xi32>, tensor<13x26xi32>, tensor<13x26x3xi32>) -> tensor<13x27x3xi32>
+ return %0 : tensor<13x27x3xi32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 1ad1e6c76c294..591a3f0acf65d 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -656,9 +656,9 @@ func.func @gather_minimum_info(%arg0 : tensor<3x?x5xi32>, %arg1 : tensor<?x6xi32
// -----
// CHECK-LABEL: @scatter_static
-func.func @scatter_static(%arg0 : tensor<3x4x5xi32>, %arg1 : tensor<3x6xi32>, %arg2 : tensor<3x6x5xi32>) {
- // CHECK: tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<3x4x5xi32>
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<?x?x?xi32>
+func.func @scatter_static(%arg0 : tensor<3x8x5xi32>, %arg1 : tensor<3x6xi32>, %arg2 : tensor<3x6x5xi32>) {
+ // CHECK: tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x8x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<3x8x5xi32>
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x8x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<?x?x?xi32>
return
}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 990e0d954f54e..b3052369b055e 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -864,3 +864,75 @@ func.func @test_variable_write_shape_mismatch(%arg0: tensor<2x4x8xf32>) -> () {
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xf32>
return
}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_indices_N
+func.func @scatter_invalid_indices_N(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<3x2xi32>, %arg2 : tensor<2x2x5xi32>) {
+ // expected-error at +1 {{'tosa.scatter' op requires indices dimension 0 to have size 2, got 3}}
+ %1 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<3x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x5xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_input_N
+func.func @scatter_invalid_input_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<2x2xi32>, %arg2 : tensor<3x2x5xi32>) {
+ // expected-error at +1 {{'tosa.scatter' op requires input dimension 0 to have size 2, got 3}}
+ %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<2x2xi32>, tensor<3x2x5xi32>) -> tensor<2x4x5xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_out_N
+func.func @scatter_invalid_out_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
+ // expected-error at +1 {{'tosa.scatter' op requires values_out dimension 0 to have size 2, got 3}}
+ %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<3x4x5xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_out_K
+func.func @scatter_invalid_out_K(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
+ // expected-error at +1 {{'tosa.scatter' op requires values_out dimension 1 to have size 4, got 3}}
+ %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x3x5xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_input_W
+func.func @scatter_invalid_input_W(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x3x5xi32>) {
+ // expected-error at +1 {{'tosa.scatter' op requires input dimension 1 to have size 2, got 3}}
+ %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x3x5xi32>) -> tensor<2x4x5xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_input_C
+func.func @scatter_invalid_input_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x6xi32>) {
+ // expected-error at +1 {{'tosa.scatter' op requires input dimension 2 to have size 5, got 6}}
+ %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x6xi32>) -> tensor<2x4x5xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_out_C
+func.func @scatter_invalid_out_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
+ // expected-error at +1 {{'tosa.scatter' op requires values_out dimension 2 to have size 5, got 6}}
+ %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x6xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_K_W
+func.func @scatter_invalid_K_W(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<2x6xi32>, %arg2 : tensor<2x6x5xi32>) {
+ // expected-error at +1 {{'tosa.scatter' op requires dimensions K >= W, got K=4 and W=6}}
+ %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<2x6xi32>, tensor<2x6x5xi32>) -> tensor<2x4x5xi32>
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/142661
More information about the Mlir-commits
mailing list