[Mlir-commits] [mlir] [mlir][tosa] Add error if checks Variable Operators (PR #137291)

TatWai Chong llvmlistbot at llvm.org
Fri Apr 25 00:25:24 PDT 2025


https://github.com/tatwaichong created https://github.com/llvm/llvm-project/pull/137291

For VARIABLE, VARIABLE_WRITE & VARIABLE_READ

>From e3a3d4e21458281c628f2dc6510179fe79217457 Mon Sep 17 00:00:00 2001
From: TatWai Chong <tatwai.chong at arm.com>
Date: Thu, 24 Apr 2025 23:38:38 -0700
Subject: [PATCH] [mlir][tosa] Add error if checks Variable Operators

For VARIABLE, VARIABLE_WRITE & VARIABLE_READ
---
 .../mlir/Dialect/Tosa/IR/TosaUtilOps.td       |  4 +
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 84 +++++++++++++++++
 .../TosaToLinalg/tosa-to-linalg-pipeline.mlir | 10 ---
 mlir/test/Dialect/Tosa/invalid.mlir           |  8 +-
 mlir/test/Dialect/Tosa/variables.mlir         |  4 +-
 mlir/test/Dialect/Tosa/verifier.mlir          | 90 +++++++++++++++++++
 6 files changed, 184 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index 0ab0a62f1cf11..6e5f6317816f2 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -131,6 +131,8 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable_write", []> {
   let assemblyFormat = [{
     $name attr-dict `,` $input1 `:` type($input1)
   }];
+
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -159,6 +161,8 @@ def Tosa_VariableReadOp : Tosa_Op<"variable_read", []> {
   let assemblyFormat = [{
     $name attr-dict `:` type($output1)
   }];
+
+  let hasVerifier = 1;
 }
 
 #endif // TOSA_UTIL_OPS
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 751ae785bda6f..b1312afbbf6d4 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -572,6 +572,74 @@ static LogicalResult verifyConvOpErrorIf(T op) {
   return success();
 }
 
+// Verify whether same type and shape of the given two types.
+static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
+                                                StringRef name1, Type type2,
+                                                StringRef name2) {
+  auto shapeType1 = dyn_cast<ShapedType>(type1);
+  auto shapeType2 = dyn_cast<ShapedType>(type2);
+  if (!shapeType1 || !shapeType2)
+    return failure();
+
+  auto elemType1 = shapeType1.getElementType();
+  auto elemType2 = shapeType2.getElementType();
+  if (elemType1 != elemType2)
+    return op->emitOpError()
+           << "require same element type for " << name1 << " (" << elemType1
+           << ") and " << name2 << " (" << elemType2 << ")";
+
+  if (failed(verifyCompatibleShape(type1, type2)))
+    return op->emitOpError()
+           << "require same shapes for " << name1 << " (" << type1 << ") and "
+           << name2 << " (" << type2 << ")";
+
+  return success();
+}
+
+template <typename T>
+static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
+  // Currently, the variable's definition point is searched via walk(),
+  // starting from the top-level ModuleOp and stopping at the point of use. Once
+  // TOSA control flow and variable extensions reach the complete state, may
+  // leverage MLIR's Symbol Table functionality to look up symbol and enhance
+  // the search to a TOSA specific graph traversal over the IR structure.
+  StringRef symName = op.getName();
+  tosa::VariableOp varOp = nullptr;
+  auto thisOp = op.getOperation();
+  ModuleOp module = thisOp->template getParentOfType<ModuleOp>();
+  bool found = false;
+
+  module.walk([&](Operation *tempOp) {
+    // Reach this op itself.
+    if (tempOp == thisOp)
+      return;
+
+    if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
+      if (symName == tosaOp.getName()) {
+        if (found == true) {
+          op->emitOpError("illegal to have multiple declaration of '")
+              << symName << "'";
+          return;
+        }
+        found = true;
+        varOp = tosaOp;
+      }
+    }
+  });
+
+  if (found == false)
+    return op->emitOpError("'")
+           << symName << "' has not been declared by 'tosa.variable'";
+
+  // Verify type and shape
+  Type varType = cast<tosa::VariableOp>(varOp).getType();
+  if (errorIfTypeOrShapeMismatch(op, type, name, varType, "the input tensor")
+          .failed())
+    return failure();
+
+  return success();
+}
+
 // verify that inType and outType have same element types
 template <typename T>
 static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
@@ -3455,6 +3523,22 @@ LogicalResult tosa::SelectOp::verify() {
   return success();
 }
 
+LogicalResult tosa::VariableReadOp::verify() {
+  if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
+          .failed())
+    return failure();
+
+  return success();
+}
+
+LogicalResult tosa::VariableWriteOp::verify() {
+  if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
+          .failed())
+    return failure();
+
+  return success();
+}
+
 // parse and print of WhileOp refer to the implementation of SCF dialect.
 ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::Argument, 4> regionArgs;
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
index 37ed5cec00a0d..74706c426ea9c 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
@@ -1,16 +1,6 @@
 // RUN: mlir-opt %s --split-input-file --tosa-to-linalg-pipeline -verify-diagnostics
 
 
-// -----
-
-// check that -tosa-validate of stateful ops kick in
-func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
-  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
-  // expected-error at +1 {{'tosa.variable_write' op operand type does not equal variable type}}
-  tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
-  return
-}
-
 // -----
 
 // check that -tosa-validate level checking kick in
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index b147c94fde9b0..eba65eabe97fb 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -595,7 +595,7 @@ func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
 
 func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
-  // expected-error at +1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
+  // expected-error at +1 {{'tosa.variable_read' op require same element type for 'output1' ('i16') and the input tensor ('i8')}}
   %0 = tosa.variable_read @stored_var : tensor<2x4x8xi16>
   return
 }
@@ -604,7 +604,7 @@ func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
 
 func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
-  // expected-error at +1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
+  // expected-error at +1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('i8'}}
   %0 = tosa.variable_read @stored_var : tensor<1x4x8xi32>
   return
 }
@@ -613,7 +613,7 @@ func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
 
 func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
-  // expected-error at +1 {{'tosa.variable_write' op illegal: operand/result data types not supported}}
+  // expected-error at +1 {{'tosa.variable_write' op require same element type for 'input1' ('i16') and the input tensor ('i8')}}
   tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16>
   return
 }
@@ -622,7 +622,7 @@ func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
 
 func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
-  // expected-error at +1 {{'tosa.variable_write' op operand type does not equal variable type}}
+  // expected-error at +1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<1x4x8xi8>') and the input tensor ('tensor<2x4x8xi8>')}}
   tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
   return
 }
diff --git a/mlir/test/Dialect/Tosa/variables.mlir b/mlir/test/Dialect/Tosa/variables.mlir
index 6fa6b26155461..25f63331f39df 100644
--- a/mlir/test/Dialect/Tosa/variables.mlir
+++ b/mlir/test/Dialect/Tosa/variables.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s | mlir-opt | FileCheck %s
-// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --split-input-file | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --split-input-file --mlir-print-op-generic | mlir-opt | FileCheck %s
 
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 262e6d4265ea6..3d2505c27ee58 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -358,3 +358,93 @@ func.func @test_concat_axis_sum_error(%arg0: tensor<1x2xf32>, %arg1: tensor<2x?x
   %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
   return %0 : tensor<2x?xf32>
 }
+
+// -----
+
+func.func @test_variable_shape_mismatch() -> () {
+  // expected-error at +1 {{inferred shape of elements literal ([2]) does not match type ([3])}}
+  tosa.variable @stored_var = dense<[3.14, 2.14]> : tensor<3xf32>
+  // expected-error at +1 {{custom op 'tosa.variable' expected attribute}}
+  return
+}
+
+// -----
+
+func.func @test_variable_type_mismatch() -> () {
+  // expected-error at +1 {{expected integer elements, but parsed floating-point}}
+  tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xi32>
+  // expected-error at +1 {{custom op 'tosa.variable' expected attribute}}
+  return
+}
+
+// -----
+
+func.func @test_variable_read_no_declaration() -> () {
+  // expected-error at +1 {{'tosa.variable_read' op 'stored_var' has not been declared by 'tosa.variable'}}
+  %0 = tosa.variable_read @stored_var : tensor<f32>
+  return
+}
+
+// -----
+
+func.func @test_variable_read_multiple_declaration() -> () {
+  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+  // expected-error at +1 {{'tosa.variable_read' op illegal to have multiple declaration of 'stored_var'}}
+  %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+  return
+}
+
+// -----
+
+func.func @test_variable_read_type_mismatch() -> () {
+  tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32>
+  // expected-error at +1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('f32')}}
+  %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+  return
+}
+
+// -----
+
+func.func @test_variable_read_shape_mismatch() -> () {
+  tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32>
+  // expected-error at +1 {{'tosa.variable_read' op require same shapes for 'output1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}}
+  %0 = tosa.variable_read @stored_var : tensor<2x4x8xf32>
+  return
+}
+
+// -----
+
+func.func @test_variable_write_no_declaration(%arg0: tensor<f32>) -> () {
+  // expected-error at +1 {{'tosa.variable_write' op 'stored_var' has not been declared by 'tosa.variable'}}
+  tosa.variable_write @stored_var, %arg0 : tensor<f32>
+  return
+}
+
+// -----
+
+func.func @test_variable_write_multiple_declaration(%arg0: tensor<2x4x8xi32>) -> () {
+  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+  // expected-error at +1 {{'tosa.variable_write' op illegal to have multiple declaration of 'stored_var'}}
+  tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi32>
+  return
+}
+
+// -----
+
+func.func @test_variable_write_type_mismatch(%arg0: tensor<2x4x8xi32>) -> () {
+  tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32>
+  // expected-error at +1 {{'tosa.variable_write' op require same element type for 'input1' ('i32') and the input tensor ('f32')}}
+  tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi32>
+  return
+}
+
+// -----
+
+func.func @test_variable_write_shape_mismatch(%arg0: tensor<2x4x8xf32>) -> () {
+  tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32>
+  // expected-error at +1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}}
+  tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xf32>
+  return
+}



More information about the Mlir-commits mailing list