[Mlir-commits] [mlir] [mlir][tosa] Add error if checks Variable Operators (PR #137291)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 25 00:25:56 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: TatWai Chong (tatwaichong)
<details>
<summary>Changes</summary>
For VARIABLE, VARIABLE_WRITE & VARIABLE_READ
---
Full diff: https://github.com/llvm/llvm-project/pull/137291.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td (+4)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+84)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir (-10)
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+4-4)
- (modified) mlir/test/Dialect/Tosa/variables.mlir (+2-2)
- (modified) mlir/test/Dialect/Tosa/verifier.mlir (+90)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index 0ab0a62f1cf11..6e5f6317816f2 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -131,6 +131,8 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable_write", []> {
let assemblyFormat = [{
$name attr-dict `,` $input1 `:` type($input1)
}];
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -159,6 +161,8 @@ def Tosa_VariableReadOp : Tosa_Op<"variable_read", []> {
let assemblyFormat = [{
$name attr-dict `:` type($output1)
}];
+
+ let hasVerifier = 1;
}
#endif // TOSA_UTIL_OPS
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 751ae785bda6f..b1312afbbf6d4 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -572,6 +572,74 @@ static LogicalResult verifyConvOpErrorIf(T op) {
return success();
}
+// Verify whether same type and shape of the given two types.
+static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
+ StringRef name1, Type type2,
+ StringRef name2) {
+ auto shapeType1 = dyn_cast<ShapedType>(type1);
+ auto shapeType2 = dyn_cast<ShapedType>(type2);
+ if (!shapeType1 || !shapeType2)
+ return failure();
+
+ auto elemType1 = shapeType1.getElementType();
+ auto elemType2 = shapeType2.getElementType();
+ if (elemType1 != elemType2)
+ return op->emitOpError()
+ << "require same element type for " << name1 << " (" << elemType1
+ << ") and " << name2 << " (" << elemType2 << ")";
+
+ if (failed(verifyCompatibleShape(type1, type2)))
+ return op->emitOpError()
+ << "require same shapes for " << name1 << " (" << type1 << ") and "
+ << name2 << " (" << type2 << ")";
+
+ return success();
+}
+
+template <typename T>
+static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
+ // Currently, the variable's definition point is searched via walk(),
+ // starting from the top-level ModuleOp and stopping at the point of use. Once
+ // TOSA control flow and variable extensions reach the complete state, may
+ // leverage MLIR's Symbol Table functionality to look up symbol and enhance
+ // the search to a TOSA specific graph traversal over the IR structure.
+ StringRef symName = op.getName();
+ tosa::VariableOp varOp = nullptr;
+ auto thisOp = op.getOperation();
+ ModuleOp module = thisOp->template getParentOfType<ModuleOp>();
+ bool found = false;
+
+ module.walk([&](Operation *tempOp) {
+ // Reach this op itself.
+ if (tempOp == thisOp)
+ return;
+
+ if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
+ if (symName == tosaOp.getName()) {
+ if (found == true) {
+ op->emitOpError("illegal to have multiple declaration of '")
+ << symName << "'";
+ return;
+ }
+ found = true;
+ varOp = tosaOp;
+ }
+ }
+ });
+
+ if (found == false)
+ return op->emitOpError("'")
+ << symName << "' has not been declared by 'tosa.variable'";
+
+ // Verify type and shape
+ Type varType = cast<tosa::VariableOp>(varOp).getType();
+ if (errorIfTypeOrShapeMismatch(op, type, name, varType, "the input tensor")
+ .failed())
+ return failure();
+
+ return success();
+}
+
// verify that inType and outType have same element types
template <typename T>
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
@@ -3455,6 +3523,22 @@ LogicalResult tosa::SelectOp::verify() {
return success();
}
+LogicalResult tosa::VariableReadOp::verify() {
+ if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
+ .failed())
+ return failure();
+
+ return success();
+}
+
+LogicalResult tosa::VariableWriteOp::verify() {
+ if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
+ .failed())
+ return failure();
+
+ return success();
+}
+
// parse and print of WhileOp refer to the implementation of SCF dialect.
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::Argument, 4> regionArgs;
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
index 37ed5cec00a0d..74706c426ea9c 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
@@ -1,16 +1,6 @@
// RUN: mlir-opt %s --split-input-file --tosa-to-linalg-pipeline -verify-diagnostics
-// -----
-
-// check that -tosa-validate of stateful ops kick in
-func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
- tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error at +1 {{'tosa.variable_write' op operand type does not equal variable type}}
- tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
- return
-}
-
// -----
// check that -tosa-validate level checking kick in
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index b147c94fde9b0..eba65eabe97fb 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -595,7 +595,7 @@ func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error at +1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
+ // expected-error at +1 {{'tosa.variable_read' op require same element type for 'output1' ('i16') and the input tensor ('i8')}}
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi16>
return
}
@@ -604,7 +604,7 @@ func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error at +1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
+ // expected-error at +1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('i8'}}
%0 = tosa.variable_read @stored_var : tensor<1x4x8xi32>
return
}
@@ -613,7 +613,7 @@ func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error at +1 {{'tosa.variable_write' op illegal: operand/result data types not supported}}
+ // expected-error at +1 {{'tosa.variable_write' op require same element type for 'input1' ('i16') and the input tensor ('i8')}}
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16>
return
}
@@ -622,7 +622,7 @@ func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error at +1 {{'tosa.variable_write' op operand type does not equal variable type}}
+ // expected-error at +1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<1x4x8xi8>') and the input tensor ('tensor<2x4x8xi8>')}}
tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
return
}
diff --git a/mlir/test/Dialect/Tosa/variables.mlir b/mlir/test/Dialect/Tosa/variables.mlir
index 6fa6b26155461..25f63331f39df 100644
--- a/mlir/test/Dialect/Tosa/variables.mlir
+++ b/mlir/test/Dialect/Tosa/variables.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s | mlir-opt | FileCheck %s
-// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --split-input-file | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --split-input-file --mlir-print-op-generic | mlir-opt | FileCheck %s
// -----
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 262e6d4265ea6..3d2505c27ee58 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -358,3 +358,93 @@ func.func @test_concat_axis_sum_error(%arg0: tensor<1x2xf32>, %arg1: tensor<2x?x
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
return %0 : tensor<2x?xf32>
}
+
+// -----
+
+func.func @test_variable_shape_mismatch() -> () {
+ // expected-error at +1 {{inferred shape of elements literal ([2]) does not match type ([3])}}
+ tosa.variable @stored_var = dense<[3.14, 2.14]> : tensor<3xf32>
+ // expected-error at +1 {{custom op 'tosa.variable' expected attribute}}
+ return
+}
+
+// -----
+
+func.func @test_variable_type_mismatch() -> () {
+ // expected-error at +1 {{expected integer elements, but parsed floating-point}}
+ tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xi32>
+ // expected-error at +1 {{custom op 'tosa.variable' expected attribute}}
+ return
+}
+
+// -----
+
+func.func @test_variable_read_no_declaration() -> () {
+ // expected-error at +1 {{'tosa.variable_read' op 'stored_var' has not been declared by 'tosa.variable'}}
+ %0 = tosa.variable_read @stored_var : tensor<f32>
+ return
+}
+
+// -----
+
+func.func @test_variable_read_multiple_declaration() -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ // expected-error at +1 {{'tosa.variable_read' op illegal to have multiple declaration of 'stored_var'}}
+ %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+ return
+}
+
+// -----
+
+func.func @test_variable_read_type_mismatch() -> () {
+ tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32>
+ // expected-error at +1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('f32')}}
+ %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+ return
+}
+
+// -----
+
+func.func @test_variable_read_shape_mismatch() -> () {
+ tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32>
+ // expected-error at +1 {{'tosa.variable_read' op require same shapes for 'output1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}}
+ %0 = tosa.variable_read @stored_var : tensor<2x4x8xf32>
+ return
+}
+
+// -----
+
+func.func @test_variable_write_no_declaration(%arg0: tensor<f32>) -> () {
+ // expected-error at +1 {{'tosa.variable_write' op 'stored_var' has not been declared by 'tosa.variable'}}
+ tosa.variable_write @stored_var, %arg0 : tensor<f32>
+ return
+}
+
+// -----
+
+func.func @test_variable_write_multiple_declaration(%arg0: tensor<2x4x8xi32>) -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ // expected-error at +1 {{'tosa.variable_write' op illegal to have multiple declaration of 'stored_var'}}
+ tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi32>
+ return
+}
+
+// -----
+
+func.func @test_variable_write_type_mismatch(%arg0: tensor<2x4x8xi32>) -> () {
+ tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32>
+ // expected-error at +1 {{'tosa.variable_write' op require same element type for 'input1' ('i32') and the input tensor ('f32')}}
+ tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi32>
+ return
+}
+
+// -----
+
+func.func @test_variable_write_shape_mismatch(%arg0: tensor<2x4x8xf32>) -> () {
+ tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32>
+ // expected-error at +1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}}
+ tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xf32>
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/137291
More information about the Mlir-commits
mailing list