[Mlir-commits] [mlir] b4e2592 - [mlir][tosa] Add verifier checks for Gather (#137204)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 25 03:47:09 PDT 2025
Author: Tai Ly
Date: 2025-04-25T11:47:06+01:00
New Revision: b4e259291326176602c55db2dbf8697ff3bca6e9
URL: https://github.com/llvm/llvm-project/commit/b4e259291326176602c55db2dbf8697ff3bca6e9
DIFF: https://github.com/llvm/llvm-project/commit/b4e259291326176602c55db2dbf8697ff3bca6e9.diff
LOG: [mlir][tosa] Add verifier checks for Gather (#137204)
This adds verifier checks for the gather op
to make sure the shapes of inputs and output
are consistent with respect to spec.
---------
Signed-off-by: Tai Ly <tai.ly at arm.com>
Co-authored-by: Luke Hutton <luke.hutton at arm.com>
Added:
Modified:
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/verifier.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index b5504ca84fa42..183893c9fdb46 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2262,8 +2262,52 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
}
LogicalResult tosa::GatherOp::verify() {
- return verifySameElementTypes(*this, /* inType = */ getValues().getType(),
- /* outType = */ getOutput().getType());
+ if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
+
+ const ShapeAdaptor valuesShape(getValues().getType());
+ const ShapeAdaptor indicesShape(getIndices().getType());
+ const ShapeAdaptor outputShape(getOutput().getType());
+
+ int64_t N = ShapedType::kDynamic;
+ int64_t W = ShapedType::kDynamic;
+ int64_t C = ShapedType::kDynamic;
+
+ if (valuesShape.hasRank()) {
+ N = valuesShape.getDimSize(0);
+ C = valuesShape.getDimSize(2);
+ }
+ if (indicesShape.hasRank()) {
+ const int64_t indicesN = indicesShape.getDimSize(0);
+ W = indicesShape.getDimSize(1);
+ if (N == ShapedType::kDynamic)
+ N = indicesN;
+ else if (indicesN != ShapedType::kDynamic && N != indicesN)
+ return emitOpError() << "requires indices dimension 0 to have size " << N
+ << ", got " << indicesN;
+ }
+ if (outputShape.hasRank()) {
+ const int64_t outputN = outputShape.getDimSize(0);
+ const int64_t outputW = outputShape.getDimSize(1);
+ const int64_t outputC = outputShape.getDimSize(2);
+ if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
+ N != outputN)
+ return emitOpError() << "requires output dimension 0 to have size " << N
+ << ", got " << outputN;
+
+ if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
+ W != outputW)
+ return emitOpError() << "requires output dimension 1 to have size " << W
+ << ", got " << outputW;
+ if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
+ C != outputC)
+ return emitOpError() << "requires output dimension 2 to have size " << C
+ << ", got " << outputC;
+ }
+ return success();
}
LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index e88fc11d2be88..b23dcd0c9cd3d 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -370,3 +370,36 @@ func.func @test_error_scalar_input_with_per_channel(%arg0: tensor<i8>) -> tensor
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<i8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<i16>
return %0 : tensor<i16>
}
+
+// -----
+
+// CHECK-LABEL: @test_gather_invalid_indices_N
+func.func @test_gather_invalid_indices_N(%arg0: tensor<13x21x3xf32>, %arg1: tensor<12x26xi32>) -> tensor<13x26x3xf32> {
+ // expected-error at +1 {{'tosa.gather' op requires indices dimension 0 to have size 13, got 12}}
+ %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<12x26xi32>) -> tensor<13x26x3xf32>
+ return %0 : tensor<13x26x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_gather_invalid_out_N
+func.func @test_gather_invalid_out_N(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<12x26x3xf32> {
+ // expected-error at +1 {{'tosa.gather' op requires output dimension 0 to have size 13, got 12}}
+ %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<12x26x3xf32>
+ return %0 : tensor<12x26x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_gather_invalid_out_W
+func.func @test_gather_invalid_out_W(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x28x3xf32> {
+ // expected-error at +1 {{'tosa.gather' op requires output dimension 1 to have size 26, got 28}}
+ %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x28x3xf32>
+ return %0 : tensor<13x28x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_gather_invalid_out_C
+func.func @test_gather_invalid_out_C(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x26x8xf32> {
+ // expected-error at +1 {{'tosa.gather' op requires output dimension 2 to have size 3, got 8}}
+ %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x26x8xf32>
+ return %0 : tensor<13x26x8xf32>
+}
More information about the Mlir-commits
mailing list