[Mlir-commits] [mlir] 30bedb3 - [mlir][tosa] Add error if checks Variable Operators (#137291)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 29 08:50:22 PDT 2025
Author: TatWai Chong
Date: 2025-04-29T16:50:18+01:00
New Revision: 30bedb318611d84cf9b0672f6e0675d33f90d2c8
URL: https://github.com/llvm/llvm-project/commit/30bedb318611d84cf9b0672f6e0675d33f90d2c8
DIFF: https://github.com/llvm/llvm-project/commit/30bedb318611d84cf9b0672f6e0675d33f90d2c8.diff
LOG: [mlir][tosa] Add error if checks Variable Operators (#137291)
For VARIABLE, VARIABLE_WRITE & VARIABLE_READ
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
mlir/test/Dialect/Tosa/invalid.mlir
mlir/test/Dialect/Tosa/variables.mlir
mlir/test/Dialect/Tosa/verifier.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index 0ab0a62f1cf11..5f99162907949 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -106,6 +106,8 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
attr-dict
custom<TypeOrAttr>($type, $initial_value)
}];
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -131,6 +133,8 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable_write", []> {
let assemblyFormat = [{
$name attr-dict `,` $input1 `:` type($input1)
}];
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -159,6 +163,8 @@ def Tosa_VariableReadOp : Tosa_Op<"variable_read", []> {
let assemblyFormat = [{
$name attr-dict `:` type($output1)
}];
+
+ let hasVerifier = 1;
}
#endif // TOSA_UTIL_OPS
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index b2e471f2bba93..c669bc4a31d43 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -613,6 +613,58 @@ static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
return shapeAdaptor.getNumElements() == 1 ? success() : failure();
}
+// Returns the first declaration point prior to this operation or failure if
+// not found.
+static FailureOr<tosa::VariableOp> findVariableDecl(Operation *op,
+ StringRef symName) {
+ ModuleOp module = op->getParentOfType<ModuleOp>();
+ tosa::VariableOp varOp = nullptr;
+
+ // TODO: Adopt SymbolTable trait to Varible ops.
+ // Currently, the variable's definition point is searched via walk(),
+ // starting from the top-level ModuleOp and stopping at the point of use. Once
+ // TOSA control flow and variable extensions reach the complete state, may
+ // leverage MLIR's Symbol Table functionality to look up symbol and enhance
+ // the search to a TOSA specific graph traversal over the IR structure.
+ module.walk([&](Operation *tempOp) {
+ // Reach this op itself.
+ if (tempOp == op) {
+ return WalkResult::interrupt();
+ }
+
+ if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
+ if (symName == tosaOp.getName()) {
+ varOp = tosaOp;
+ return WalkResult::interrupt();
+ }
+ }
+
+ return WalkResult::advance();
+ });
+
+ if (varOp)
+ return varOp;
+
+ return failure();
+}
+
+template <typename T>
+static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
+ StringRef symName = op.getName();
+ FailureOr<tosa::VariableOp> varOp = findVariableDecl(op, symName);
+ if (failed(varOp))
+ return op->emitOpError("'")
+ << symName << "' has not been declared by 'tosa.variable'";
+
+ // Verify type and shape
+ Type varType = cast<tosa::VariableOp>(varOp.value()).getType();
+ if (errorIfTypeOrShapeMismatch(op, type, name, varType, "the input tensor")
+ .failed())
+ return failure();
+
+ return success();
+}
+
// verify that inType and outType have same element types
template <typename T>
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
@@ -3660,6 +3712,32 @@ LogicalResult tosa::SelectOp::verify() {
return success();
}
+LogicalResult tosa::VariableOp::verify() {
+ StringRef symName = getName();
+ FailureOr<tosa::VariableOp> varOp = findVariableDecl(*this, symName);
+ if (succeeded(varOp))
+ return emitOpError("illegal to have multiple declaration of '")
+ << symName << "'";
+
+ return success();
+}
+
+LogicalResult tosa::VariableReadOp::verify() {
+ if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
+ .failed())
+ return failure();
+
+ return success();
+}
+
+LogicalResult tosa::VariableWriteOp::verify() {
+ if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
+ .failed())
+ return failure();
+
+ return success();
+}
+
// parse and print of WhileOp refer to the implementation of SCF dialect.
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::Argument, 4> regionArgs;
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
index 37ed5cec00a0d..74706c426ea9c 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
@@ -1,16 +1,6 @@
// RUN: mlir-opt %s --split-input-file --tosa-to-linalg-pipeline -verify-diagnostics
-// -----
-
-// check that -tosa-validate of stateful ops kick in
-func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
- tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error at +1 {{'tosa.variable_write' op operand type does not equal variable type}}
- tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
- return
-}
-
// -----
// check that -tosa-validate level checking kick in
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index c4f95b47628d1..9ccb310c4491d 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -566,7 +566,7 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: ten
func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error at +1 {{'tosa.variable' op name has already been declared}}
+ // expected-error at +1 {{'tosa.variable' op illegal to have multiple declaration of 'stored_var'}}
tosa.variable @stored_var = dense<3> : tensor<1x4x8xi8>
return
}
@@ -575,7 +575,7 @@ func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error at +1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
+ // expected-error at +1 {{'tosa.variable_read' op require same element type for 'output1' ('i16') and the input tensor ('i8')}}
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi16>
return
}
@@ -584,7 +584,7 @@ func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error at +1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
+ // expected-error at +1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('i8'}}
%0 = tosa.variable_read @stored_var : tensor<1x4x8xi32>
return
}
@@ -593,7 +593,7 @@ func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error at +1 {{'tosa.variable_write' op illegal: operand/result data types not supported}}
+ // expected-error at +1 {{'tosa.variable_write' op require same element type for 'input1' ('i16') and the input tensor ('i8')}}
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16>
return
}
@@ -602,7 +602,7 @@ func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error at +1 {{'tosa.variable_write' op operand type does not equal variable type}}
+ // expected-error at +1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<1x4x8xi8>') and the input tensor ('tensor<2x4x8xi8>')}}
tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
return
}
diff --git a/mlir/test/Dialect/Tosa/variables.mlir b/mlir/test/Dialect/Tosa/variables.mlir
index 6fa6b26155461..25f63331f39df 100644
--- a/mlir/test/Dialect/Tosa/variables.mlir
+++ b/mlir/test/Dialect/Tosa/variables.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s | mlir-opt | FileCheck %s
-// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --split-input-file | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --split-input-file --mlir-print-op-generic | mlir-opt | FileCheck %s
// -----
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 7ae8ec470c3dd..990e0d954f54e 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -785,3 +785,82 @@ func.func @test_while_loop_cond_output_not_bool(%arg0: tensor<10xi32>, %arg1: te
}
return
}
+
+// -----
+
+func.func @test_variable_multiple_declaration() -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ // expected-error at +1 {{'tosa.variable' op illegal to have multiple declaration of 'stored_var'}}
+ tosa.variable @stored_var = dense<-3> : tensor<2x4x8xi32>
+ return
+}
+
+// -----
+
+func.func @test_variable_shape_mismatch() -> () {
+ // expected-error at +1 {{inferred shape of elements literal ([2]) does not match type ([3])}}
+ tosa.variable @stored_var = dense<[3.14, 2.14]> : tensor<3xf32>
+ // expected-error at +1 {{custom op 'tosa.variable' expected attribute}}
+ return
+}
+
+// -----
+
+func.func @test_variable_type_mismatch() -> () {
+ // expected-error at +1 {{expected integer elements, but parsed floating-point}}
+ tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xi32>
+ // expected-error at +1 {{custom op 'tosa.variable' expected attribute}}
+ return
+}
+
+// -----
+
+func.func @test_variable_read_no_declaration() -> () {
+ // expected-error at +1 {{'tosa.variable_read' op 'stored_var' has not been declared by 'tosa.variable'}}
+ %0 = tosa.variable_read @stored_var : tensor<f32>
+ return
+}
+
+// -----
+
+func.func @test_variable_read_type_mismatch() -> () {
+ tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32>
+ // expected-error at +1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('f32')}}
+ %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+ return
+}
+
+// -----
+
+func.func @test_variable_read_shape_mismatch() -> () {
+ tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32>
+ // expected-error at +1 {{'tosa.variable_read' op require same shapes for 'output1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}}
+ %0 = tosa.variable_read @stored_var : tensor<2x4x8xf32>
+ return
+}
+
+// -----
+
+func.func @test_variable_write_no_declaration(%arg0: tensor<f32>) -> () {
+ // expected-error at +1 {{'tosa.variable_write' op 'stored_var' has not been declared by 'tosa.variable'}}
+ tosa.variable_write @stored_var, %arg0 : tensor<f32>
+ return
+}
+
+// -----
+
+func.func @test_variable_write_type_mismatch(%arg0: tensor<2x4x8xi32>) -> () {
+ tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32>
+ // expected-error at +1 {{'tosa.variable_write' op require same element type for 'input1' ('i32') and the input tensor ('f32')}}
+ tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi32>
+ return
+}
+
+// -----
+
+func.func @test_variable_write_shape_mismatch(%arg0: tensor<2x4x8xf32>) -> () {
+ tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32>
+ // expected-error at +1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}}
+ tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xf32>
+ return
+}
More information about the Mlir-commits
mailing list