[Mlir-commits] [mlir] [TOSA] Don't run validation pass on non TOSA operations (PR #120205)

Luke Hutton llvmlistbot at llvm.org
Tue Dec 17 01:58:15 PST 2024


https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/120205

This commit ensures the validation pass is not run on operations from other dialects. In doing so, operations from other dialects that, for example, use types not supported by TOSA don't result in an error.

>From be33e631078bf10bb84f46f9cd7c28add658a7d0 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Fri, 13 Dec 2024 22:10:21 +0000
Subject: [PATCH] [TOSA] Don't run validation pass on non TOSA operations

This commit ensures the validation pass is not run on operations
from other dialects. In doing so, operations from other dialects
that, for example, use types not supported by TOSA don't result
in an error.

Change-Id: If1efde2036f2d3e13b8c8588fea6344922453c2b
Signed-off-by: Luke Hutton <luke.hutton at arm.com>
---
 mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp |  3 +++
 mlir/test/Dialect/Tosa/invalid.mlir                 | 12 ++++++++++--
 2 files changed, 13 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 893cedefc1ebde..62bbeead4d4a7b 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -543,6 +543,9 @@ bool TosaValidation::isValidElementType(Type type) {
 void TosaValidation::runOnOperation() {
   configLevelAndProfile();
   getOperation().walk([&](Operation *op) {
+    if (!op->getDialect() || op->getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
+      return;
+
     for (Value operand : op->getOperands()) {
       auto elementTy = getElementTypeOrSelf(operand);
       if (!isValidElementType(elementTy)) {
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 79bb7fce5755ef..cca50b25d14d6b 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -625,7 +625,6 @@ func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1
 func.func @test_unsupported_int64_data_type(%arg0: tensor<1x13x13x5xf32>) -> tensor<1x13x13xi64> {
   // expected-error at +1 {{'tosa.argmax' op is not profile-aligned: element type 'i64' is not legal}}
   %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x13x13x5xf32>) -> tensor<1x13x13xi64>
-  // expected-error at +1 {{'func.return' op is not profile-aligned: element type 'i64' is not legal}}
   return %0 : tensor<1x13x13xi64>
 }
 
@@ -879,4 +878,13 @@ func.func @test_mismatch_in_out_shape_logical_not(%arg0: tensor<1x21x3xi1>) -> t
   // expected-error at +1 {{'tosa.logical_not' op requires the same shape for all operands and results}}
   %0 = tosa.logical_not %arg0 : (tensor<1x21x3xi1>) -> tensor<13x21x3xi1>
   return %0 : tensor<13x21x3xi1>
-}
\ No newline at end of file
+}
+
+// -----
+
+// Check validate pass doesn't run on non TOSA ops
+func.func @test_non_tosa_ops() {
+  %0 = arith.constant 6 : index
+  %2 = tensor.empty(%0) : tensor<?x27xi64>
+  return
+}



More information about the Mlir-commits mailing list