[Mlir-commits] [mlir] [mlir][tosa] Add error if checks Variable Operators (PR #137291)
TatWai Chong
llvmlistbot at llvm.org
Mon Apr 28 19:58:18 PDT 2025
https://github.com/tatwaichong updated https://github.com/llvm/llvm-project/pull/137291
>From 1b6299c6e47f98fb0da00e549b4a83f0c9dcafdc Mon Sep 17 00:00:00 2001
From: TatWai Chong <tatwai.chong at arm.com>
Date: Thu, 24 Apr 2025 23:38:38 -0700
Subject: [PATCH] [mlir][tosa] Add error if checks Variable Operators
For VARIABLE, VARIABLE_WRITE & VARIABLE_READ
---
.../mlir/Dialect/Tosa/IR/TosaUtilOps.td | 6 ++
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 102 ++++++++++++++++++
.../TosaToLinalg/tosa-to-linalg-pipeline.mlir | 10 --
mlir/test/Dialect/Tosa/invalid.mlir | 10 +-
mlir/test/Dialect/Tosa/variables.mlir | 4 +-
mlir/test/Dialect/Tosa/verifier.mlir | 81 +++++++++++++-
6 files changed, 195 insertions(+), 18 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index 0ab0a62f1cf11..5f99162907949 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -106,6 +106,8 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
attr-dict
custom<TypeOrAttr>($type, $initial_value)
}];
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -131,6 +133,8 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable_write", []> {
let assemblyFormat = [{
$name attr-dict `,` $input1 `:` type($input1)
}];
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -159,6 +163,8 @@ def Tosa_VariableReadOp : Tosa_Op<"variable_read", []> {
let assemblyFormat = [{
$name attr-dict `:` type($output1)
}];
+
+ let hasVerifier = 1;
}
#endif // TOSA_UTIL_OPS
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index d535009f34533..49f90980311f6 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -562,6 +562,82 @@ static LogicalResult verifyConvOpErrorIf(T op) {
return success();
}
+// Verify whether same type and shape of the given two types.
+static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
+ StringRef name1, Type type2,
+ StringRef name2) {
+ auto shapeType1 = dyn_cast<ShapedType>(type1);
+ auto shapeType2 = dyn_cast<ShapedType>(type2);
+ if (!shapeType1 || !shapeType2)
+ return failure();
+
+ auto elemType1 = shapeType1.getElementType();
+ auto elemType2 = shapeType2.getElementType();
+ if (elemType1 != elemType2)
+ return op->emitOpError()
+ << "require same element type for " << name1 << " (" << elemType1
+ << ") and " << name2 << " (" << elemType2 << ")";
+
+ if (failed(verifyCompatibleShape(type1, type2)))
+ return op->emitOpError()
+ << "require same shapes for " << name1 << " (" << type1 << ") and "
+ << name2 << " (" << type2 << ")";
+
+ return success();
+}
+
+// Returns the first declaration point prior to this operation or failure if
+// not found.
+static FailureOr<tosa::VariableOp> findVariableDecl(Operation *op,
+ StringRef symName) {
+ ModuleOp module = op->getParentOfType<ModuleOp>();
+ tosa::VariableOp varOp = nullptr;
+
+ // TODO: Adopt SymbolTable trait to Varible ops.
+ // Currently, the variable's definition point is searched via walk(),
+ // starting from the top-level ModuleOp and stopping at the point of use. Once
+ // TOSA control flow and variable extensions reach the complete state, may
+ // leverage MLIR's Symbol Table functionality to look up symbol and enhance
+ // the search to a TOSA specific graph traversal over the IR structure.
+ module.walk([&](Operation *tempOp) {
+ // Reach this op itself.
+ if (tempOp == op) {
+ return WalkResult::interrupt();
+ }
+
+ if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
+ if (symName == tosaOp.getName()) {
+ varOp = tosaOp;
+ return WalkResult::interrupt();
+ }
+ }
+
+ return WalkResult::advance();
+ });
+
+ if (varOp)
+ return varOp;
+
+ return failure();
+}
+
+template <typename T>
+static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
+ StringRef symName = op.getName();
+ FailureOr<tosa::VariableOp> varOp = findVariableDecl(op, symName);
+ if (failed(varOp))
+ return op->emitOpError("'")
+ << symName << "' has not been declared by 'tosa.variable'";
+
+ // Verify type and shape
+ Type varType = cast<tosa::VariableOp>(varOp.value()).getType();
+ if (errorIfTypeOrShapeMismatch(op, type, name, varType, "the input tensor")
+ .failed())
+ return failure();
+
+ return success();
+}
+
// verify that inType and outType have same element types
template <typename T>
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
@@ -3531,6 +3607,32 @@ LogicalResult tosa::SelectOp::verify() {
return success();
}
+LogicalResult tosa::VariableOp::verify() {
+ StringRef symName = getName();
+ FailureOr<tosa::VariableOp> varOp = findVariableDecl(*this, symName);
+ if (succeeded(varOp))
+ return emitOpError("illegal to have multiple declaration of '")
+ << symName << "'";
+
+ return success();
+}
+
+LogicalResult tosa::VariableReadOp::verify() {
+ if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
+ .failed())
+ return failure();
+
+ return success();
+}
+
+LogicalResult tosa::VariableWriteOp::verify() {
+ if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
+ .failed())
+ return failure();
+
+ return success();
+}
+
// parse and print of WhileOp refer to the implementation of SCF dialect.
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::Argument, 4> regionArgs;
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
index 37ed5cec00a0d..74706c426ea9c 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
@@ -1,16 +1,6 @@
// RUN: mlir-opt %s --split-input-file --tosa-to-linalg-pipeline -verify-diagnostics
-// -----
-
-// check that -tosa-validate of stateful ops kick in
-func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
- tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error at +1 {{'tosa.variable_write' op operand type does not equal variable type}}
- tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
- return
-}
-
// -----
// check that -tosa-validate level checking kick in
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index c4f95b47628d1..9ccb310c4491d 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -566,7 +566,7 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: ten
func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error at +1 {{'tosa.variable' op name has already been declared}}
+ // expected-error at +1 {{'tosa.variable' op illegal to have multiple declaration of 'stored_var'}}
tosa.variable @stored_var = dense<3> : tensor<1x4x8xi8>
return
}
@@ -575,7 +575,7 @@ func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error at +1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
+ // expected-error at +1 {{'tosa.variable_read' op require same element type for 'output1' ('i16') and the input tensor ('i8')}}
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi16>
return
}
@@ -584,7 +584,7 @@ func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error at +1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
+ // expected-error at +1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('i8'}}
%0 = tosa.variable_read @stored_var : tensor<1x4x8xi32>
return
}
@@ -593,7 +593,7 @@ func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error at +1 {{'tosa.variable_write' op illegal: operand/result data types not supported}}
+ // expected-error at +1 {{'tosa.variable_write' op require same element type for 'input1' ('i16') and the input tensor ('i8')}}
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16>
return
}
@@ -602,7 +602,7 @@ func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error at +1 {{'tosa.variable_write' op operand type does not equal variable type}}
+ // expected-error at +1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<1x4x8xi8>') and the input tensor ('tensor<2x4x8xi8>')}}
tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
return
}
diff --git a/mlir/test/Dialect/Tosa/variables.mlir b/mlir/test/Dialect/Tosa/variables.mlir
index 6fa6b26155461..25f63331f39df 100644
--- a/mlir/test/Dialect/Tosa/variables.mlir
+++ b/mlir/test/Dialect/Tosa/variables.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s | mlir-opt | FileCheck %s
-// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --split-input-file | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --split-input-file --mlir-print-op-generic | mlir-opt | FileCheck %s
// -----
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index c669c36e5452f..986f981456eef 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -436,4 +436,83 @@ func.func @test_pad_invalid_padding_value(%arg0: tensor<10xi8>, %arg1: tensor<1x
// expected-error at +1 {{invalid padding values at dimension 0: values must be non-negative or -1 for dynamic padding, got [-2, 2]}}
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<10xi8>, !tosa.shape<2>, tensor<1xi8>) -> tensor<10xi8>
return %1 : tensor<10xi8>
-}
\ No newline at end of file
+}
+
+// -----
+
+func.func @test_variable_multiple_declaration() -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ // expected-error at +1 {{'tosa.variable' op illegal to have multiple declaration of 'stored_var'}}
+ tosa.variable @stored_var = dense<-3> : tensor<2x4x8xi32>
+ return
+}
+
+// -----
+
+func.func @test_variable_shape_mismatch() -> () {
+ // expected-error at +1 {{inferred shape of elements literal ([2]) does not match type ([3])}}
+ tosa.variable @stored_var = dense<[3.14, 2.14]> : tensor<3xf32>
+ // expected-error at +1 {{custom op 'tosa.variable' expected attribute}}
+ return
+}
+
+// -----
+
+func.func @test_variable_type_mismatch() -> () {
+ // expected-error at +1 {{expected integer elements, but parsed floating-point}}
+ tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xi32>
+ // expected-error at +1 {{custom op 'tosa.variable' expected attribute}}
+ return
+}
+
+// -----
+
+func.func @test_variable_read_no_declaration() -> () {
+ // expected-error at +1 {{'tosa.variable_read' op 'stored_var' has not been declared by 'tosa.variable'}}
+ %0 = tosa.variable_read @stored_var : tensor<f32>
+ return
+}
+
+// -----
+
+func.func @test_variable_read_type_mismatch() -> () {
+ tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32>
+ // expected-error at +1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('f32')}}
+ %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+ return
+}
+
+// -----
+
+func.func @test_variable_read_shape_mismatch() -> () {
+ tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32>
+ // expected-error at +1 {{'tosa.variable_read' op require same shapes for 'output1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}}
+ %0 = tosa.variable_read @stored_var : tensor<2x4x8xf32>
+ return
+}
+
+// -----
+
+func.func @test_variable_write_no_declaration(%arg0: tensor<f32>) -> () {
+ // expected-error at +1 {{'tosa.variable_write' op 'stored_var' has not been declared by 'tosa.variable'}}
+ tosa.variable_write @stored_var, %arg0 : tensor<f32>
+ return
+}
+
+// -----
+
+func.func @test_variable_write_type_mismatch(%arg0: tensor<2x4x8xi32>) -> () {
+ tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32>
+ // expected-error at +1 {{'tosa.variable_write' op require same element type for 'input1' ('i32') and the input tensor ('f32')}}
+ tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi32>
+ return
+}
+
+// -----
+
+func.func @test_variable_write_shape_mismatch(%arg0: tensor<2x4x8xf32>) -> () {
+ tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32>
+ // expected-error at +1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}}
+ tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xf32>
+ return
+}
More information about the Mlir-commits
mailing list