[Mlir-commits] [mlir] [TOSA] Don't run validation pass on non TOSA operations (PR #120205)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 17 01:58:52 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

<details>
<summary>Changes</summary>

This commit ensures the validation pass is not run on operations from other dialects. In doing so, operations from other dialects that, for example, use types not supported by TOSA don't result in an error.

---
Full diff: https://github.com/llvm/llvm-project/pull/120205.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+3) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+10-2) 


``````````diff
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 893cedefc1ebde..62bbeead4d4a7b 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -543,6 +543,9 @@ bool TosaValidation::isValidElementType(Type type) {
 void TosaValidation::runOnOperation() {
   configLevelAndProfile();
   getOperation().walk([&](Operation *op) {
+    if (!op->getDialect() || op->getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
+      return;
+
     for (Value operand : op->getOperands()) {
       auto elementTy = getElementTypeOrSelf(operand);
       if (!isValidElementType(elementTy)) {
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 79bb7fce5755ef..cca50b25d14d6b 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -625,7 +625,6 @@ func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1
 func.func @test_unsupported_int64_data_type(%arg0: tensor<1x13x13x5xf32>) -> tensor<1x13x13xi64> {
   // expected-error at +1 {{'tosa.argmax' op is not profile-aligned: element type 'i64' is not legal}}
   %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x13x13x5xf32>) -> tensor<1x13x13xi64>
-  // expected-error at +1 {{'func.return' op is not profile-aligned: element type 'i64' is not legal}}
   return %0 : tensor<1x13x13xi64>
 }
 
@@ -879,4 +878,13 @@ func.func @test_mismatch_in_out_shape_logical_not(%arg0: tensor<1x21x3xi1>) -> t
   // expected-error at +1 {{'tosa.logical_not' op requires the same shape for all operands and results}}
   %0 = tosa.logical_not %arg0 : (tensor<1x21x3xi1>) -> tensor<13x21x3xi1>
   return %0 : tensor<13x21x3xi1>
-}
\ No newline at end of file
+}
+
+// -----
+
+// Check validate pass doesn't run on non TOSA ops
+func.func @test_non_tosa_ops() {
+  %0 = arith.constant 6 : index
+  %2 = tensor.empty(%0) : tensor<?x27xi64>
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/120205


More information about the Mlir-commits mailing list