[Mlir-commits] [mlir] [mlir][tosa] Apply 'Symbol' trait to `tosa.variable` (PR #153223)
Luke Hutton
llvmlistbot at llvm.org
Sat Oct 4 03:09:26 PDT 2025
https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/153223
>From da051fedef967f56605eaf37e6816073f507b646 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 6 Aug 2025 10:04:52 +0000
Subject: [PATCH] [mlir][tosa] Apply 'Symbol' trait to `tosa.variable`
Implement SymbolOpInterface on tosa.variable so that it's declaration is
automatically inserted into its parents SymbolTable.
Verifiers for tosa.variable_read/write can now look up the symbol and
guarantee it exists, and duplicate names are caught at creation time.
Previously this was completed by walking the graph which could be
inefficient.
Unfortunately, the Symbol trait expects to find a symbol name
via a hard-coded attribute name "sym_name". Therefore, "name" is renamed
to"sym_name" and a getName() wrapper is provided for backwards
compatibility.
This change also restricts tosa.variable declarations to ops that carry
a SymbolTable (e.g. modules), rather than allowing them to be placed
inside a func.func.
Note: EXT-VARIABLE is an experimental extension in the TOSA
specification, so is not subject to backwards compatibility
guarantees.
Change-Id: I00a3f8f3b3b4f68cb3c120fe2c928d7b74b214cb
---
.../mlir/Dialect/Tosa/IR/TosaOpBase.td | 4 +-
.../mlir/Dialect/Tosa/IR/TosaUtilOps.td | 16 ++-
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 67 ++-------
mlir/test/Dialect/Tosa/invalid.mlir | 55 ++++----
mlir/test/Dialect/Tosa/invalid_extension.mlir | 22 +--
mlir/test/Dialect/Tosa/level_check.mlir | 30 ++--
mlir/test/Dialect/Tosa/variables.mlir | 132 ++++++++++--------
mlir/test/Dialect/Tosa/verifier.mlir | 56 +++++---
8 files changed, 197 insertions(+), 185 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 115a11b346780..80337fc30bc66 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -201,9 +201,9 @@ def Tosa_PadOpQuantInfoBuilder : OpBuilder<
// and optional initial value. The builder will extract var_shape and element type
// attributes from variable type.
def Tosa_VariableOpBuilder : OpBuilder<
- (ins "StringRef":$name, "Type":$variable_type, "Attribute":$initial_value),
+ (ins "StringRef":$sym_name, "Type":$variable_type, "Attribute":$initial_value),
[{
- buildVariableOp($_builder, $_state, name, variable_type, initial_value);
+ buildVariableOp($_builder, $_state, sym_name, variable_type, initial_value);
}]>;
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index d819cc198e3f2..f1a618e75061b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -18,6 +18,7 @@
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/VectorInterfaces.td"
include "mlir/Dialect/Tosa/IR/TosaInterfaces.td"
@@ -82,7 +83,7 @@ def Tosa_YieldOp : Tosa_Op<"yield", [
//===----------------------------------------------------------------------===//
// Operator: variable
//===----------------------------------------------------------------------===//
-def Tosa_VariableOp : Tosa_Op<"variable", []> {
+def Tosa_VariableOp : Tosa_Op<"variable", [Symbol]> {
let summary = "Defines a variable";
let description = [{
@@ -91,7 +92,10 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
}];
let arguments = (ins
- SymbolNameAttr:$name,
+ // Note: "sym_name" is used as opposed to "name" in the specification,
+ // since a Symbol must be named "sym_name" for it to be recognised by
+ // the containing SymbolTable.
+ SymbolNameAttr:$sym_name,
IndexElementsAttr:$var_shape,
TypeAttr:$type,
OptionalAttr<AnyAttr>:$initial_value
@@ -105,14 +109,18 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
- $name
+ $sym_name
attr-dict
custom<VariableOpTypeOrInitialValue>($var_shape, $type, $initial_value)
}];
let builders = [Tosa_VariableOpBuilder];
- let hasVerifier = 1;
+ let extraClassDeclaration = [{
+ ::llvm::StringRef getName() {
+ return getSymName();
+ }
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 332f1a0e5506f..c51b5e9cbfc78 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -905,56 +905,29 @@ static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
return shapeAdaptor.getNumElements() == 1 ? success() : failure();
}
-// Returns the first declaration point prior to this operation or failure if
-// not found.
-static FailureOr<tosa::VariableOp> findVariableDecl(Operation *op,
- StringRef symName) {
- ModuleOp module = op->getParentOfType<ModuleOp>();
- tosa::VariableOp varOp = nullptr;
-
- // TODO: Adopt SymbolTable trait to Varible ops.
- // Currently, the variable's definition point is searched via walk(),
- // starting from the top-level ModuleOp and stopping at the point of use. Once
- // TOSA control flow and variable extensions reach the complete state, may
- // leverage MLIR's Symbol Table functionality to look up symbol and enhance
- // the search to a TOSA specific graph traversal over the IR structure.
- module.walk([&](Operation *tempOp) {
- // Reach this op itself.
- if (tempOp == op) {
- return WalkResult::interrupt();
- }
-
- if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
- if (symName == tosaOp.getName()) {
- varOp = tosaOp;
- return WalkResult::interrupt();
- }
- }
-
- return WalkResult::advance();
- });
-
- if (varOp)
- return varOp;
-
- return failure();
-}
-
template <typename T>
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
- StringRef symName = op.getName();
- FailureOr<tosa::VariableOp> varOp = findVariableDecl(op, symName);
- if (failed(varOp))
+ Operation *symTableOp =
+ op->template getParentWithTrait<OpTrait::SymbolTable>();
+ if (!symTableOp)
+ // If the operation is not the scope of a symbol table, we cannot
+ // verify it against it's declaration.
+ return success();
+
+ SymbolTable symTable(symTableOp);
+ const auto varOp = symTable.lookup<tosa::VariableOp>(op.getName());
+
+ // Verify prior declaration
+ if (!varOp)
return op->emitOpError("'")
- << symName << "' has not been declared by 'tosa.variable'";
+ << op.getName() << "' has not been declared by 'tosa.variable'";
// Verify type and shape
- auto variableType = getVariableType(varOp.value());
+ auto variableType = getVariableType(varOp);
if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
"the input tensor")
.failed())
return failure();
-
return success();
}
@@ -1418,7 +1391,7 @@ static void buildVariableOp(OpBuilder &builder, OperationState &result,
ArrayRef<int64_t> shape = shapedType.getShape();
auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
- result.addAttribute("name", nameAttr);
+ result.addAttribute("sym_name", nameAttr);
result.addAttribute("var_shape", varShapeAttr);
result.addAttribute("type", elementTypeAttr);
result.addAttribute("initial_value", initialValue);
@@ -4160,16 +4133,6 @@ LogicalResult tosa::SelectOp::verify() {
return success();
}
-LogicalResult tosa::VariableOp::verify() {
- StringRef symName = getName();
- FailureOr<tosa::VariableOp> varOp = findVariableDecl(*this, symName);
- if (succeeded(varOp))
- return emitOpError("illegal to have multiple declaration of '")
- << symName << "'";
-
- return success();
-}
-
LogicalResult tosa::VariableReadOp::verify() {
if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
.failed())
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 41c3243792259..e60f1c9b4a01a 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -573,64 +573,61 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: ten
// -----
-func.func @test_variable_unranked(%arg0: tensor<2x4x8xi8>) -> () {
+module {
tosa.variable @stored_var : tensor<*xi8>
// expected-error at +1 {{custom op 'tosa.variable' expected ranked type}}
- return
}
// -----
-func.func @test_variable_unranked_initial_value(%arg0: tensor<2x4x8xi8>) -> () {
+module {
// expected-error at +1 {{elements literal type must have static shape}}
tosa.variable @stored_var = dense<0> : tensor<*xi8>
// expected-error at +1 {{custom op 'tosa.variable' expected attribute}}
- return
-}
-
-// -----
-
-func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
- tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error at +1 {{'tosa.variable' op illegal to have multiple declaration of 'stored_var'}}
- tosa.variable @stored_var = dense<3> : tensor<1x4x8xi8>
- return
}
// -----
-func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
+module {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error at +1 {{'tosa.variable_read' op require same element type for 'output1' ('i16') and the input tensor ('i8')}}
- %0 = tosa.variable_read @stored_var : tensor<2x4x8xi16>
- return
+ func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
+ // expected-error at +1 {{'tosa.variable_read' op require same element type for 'output1' ('i16') and the input tensor ('i8')}}
+ %0 = tosa.variable_read @stored_var : tensor<2x4x8xi16>
+ return
+ }
}
// -----
-func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
+module {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error at +1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('i8'}}
- %0 = tosa.variable_read @stored_var : tensor<1x4x8xi32>
- return
+ func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
+ // expected-error at +1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('i8'}}
+ %0 = tosa.variable_read @stored_var : tensor<1x4x8xi32>
+ return
+ }
}
// -----
-func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
+module {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error at +1 {{'tosa.variable_write' op require same element type for 'input1' ('i16') and the input tensor ('i8')}}
- tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16>
- return
+ func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
+ // expected-error at +1 {{'tosa.variable_write' op require same element type for 'input1' ('i16') and the input tensor ('i8')}}
+ tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16>
+ return
+ }
}
// -----
-func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
+module {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error at +1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<1x4x8xi8>') and the input tensor ('tensor<2x4x8xi8>')}}
- tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
- return
+ func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
+ // expected-error at +1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<1x4x8xi8>') and the input tensor ('tensor<2x4x8xi8>')}}
+ tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
+ return
+ }
}
// -----
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 3138ce2621a3a..1daabe9222a9b 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -310,21 +310,27 @@ func.func @test_identity(%arg0: tensor<13x21x3xi4>) -> tensor<13x21x3xi4> {
}
// -----
-func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
+module {
// expected-error at +1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}}
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error at +1 {{'tosa.variable_read' op illegal: requires [variable]}}
- %0 = tosa.variable_read @stored_var : tensor<2x4x8xi8>
- return
+
+ func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
+ // expected-error at +1 {{'tosa.variable_read' op illegal: requires [variable]}}
+ %0 = tosa.variable_read @stored_var : tensor<2x4x8xi8>
+ return
+ }
}
// -----
-func.func @test_variable_write_type(%arg0: tensor<2x4x8xi8>) -> () {
+module {
// expected-error at +1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}}
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error at +1 {{'tosa.variable_write' op illegal: requires [variable]}}
- tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi8>
- return
+
+ func.func @test_variable_write_type(%arg0: tensor<2x4x8xi8>) -> () {
+ // expected-error at +1 {{'tosa.variable_write' op illegal: requires [variable]}}
+ tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi8>
+ return
+ }
}
// -----
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 3742adf650408..5bf2dbb8d02b1 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1097,14 +1097,17 @@ func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x260000000x3xf32>, %
// -----
-func.func @test_variable_read_write_tensor_size_invalid() -> () {
+module {
// expected-error at +1 {{'tosa.variable' op failed level check: variable type tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
tosa.variable @stored_var : tensor<536870912xf32>
- // expected-error at +1 {{'tosa.variable_read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
- %0 = tosa.variable_read @stored_var : tensor<536870912xf32>
- // expected-error at +1 {{'tosa.variable_write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
- tosa.variable_write @stored_var, %0 : tensor<536870912xf32>
- return
+
+ func.func @test_variable_read_write_tensor_size_invalid() -> () {
+ // expected-error at +1 {{'tosa.variable_read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
+ %0 = tosa.variable_read @stored_var : tensor<536870912xf32>
+ // expected-error at +1 {{'tosa.variable_write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
+ tosa.variable_write @stored_var, %0 : tensor<536870912xf32>
+ return
+ }
}
// -----
@@ -1165,14 +1168,17 @@ func.func @test_cond_if_rank_invalid(%arg0: tensor<1x1x1x1x1x1x1x1xf32>, %arg1:
// -----
-func.func @test_variable_read_write_rank_invalid() -> () {
+module {
// expected-error at +1 {{'tosa.variable' op failed level check: variable type rank(shape) <= MAX_RANK}}
tosa.variable @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
- // expected-error at +1 {{'tosa.variable_read' op failed level check: result rank(shape) <= MAX_RANK}}
- %0 = tosa.variable_read @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
- // expected-error at +1 {{'tosa.variable_write' op failed level check: operand rank(shape) <= MAX_RANK}}
- tosa.variable_write @stored_var, %0 : tensor<1x1x1x1x1x1x1x1xf32>
- return
+
+ func.func @test_variable_read_write_rank_invalid() -> () {
+ // expected-error at +1 {{'tosa.variable_read' op failed level check: result rank(shape) <= MAX_RANK}}
+ %0 = tosa.variable_read @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
+ // expected-error at +1 {{'tosa.variable_write' op failed level check: operand rank(shape) <= MAX_RANK}}
+ tosa.variable_write @stored_var, %0 : tensor<1x1x1x1x1x1x1x1xf32>
+ return
+ }
}
// -----
diff --git a/mlir/test/Dialect/Tosa/variables.mlir b/mlir/test/Dialect/Tosa/variables.mlir
index 9953eb375d3ac..0c104e8e8d7ea 100644
--- a/mlir/test/Dialect/Tosa/variables.mlir
+++ b/mlir/test/Dialect/Tosa/variables.mlir
@@ -3,76 +3,98 @@
// -----
-// CHECK-LABEL: @test_variable_scalar(
-// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<f32>) {
-func.func @test_variable_scalar(%arg0: tensor<f32>) -> () {
- // CHECK: tosa.variable @stored_var = dense<3.140000e+00> : tensor<f32>
+
+module {
+ // CHECK: tosa.variable @stored_var = dense<3.140000e+00> : tensor<f32>
tosa.variable @stored_var = dense<3.14> : tensor<f32>
- // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<f32>
- %0 = tosa.variable_read @stored_var : tensor<f32>
- // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
- %1 = "tosa.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
- // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<f32>
- tosa.variable_write @stored_var, %1 : tensor<f32>
- return
+
+ // CHECK-LABEL: @test_variable_scalar(
+ // CHECK-SAME: %[[ADD_VAL:.*]]: tensor<f32>) {
+ func.func @test_variable_scalar(%arg0: tensor<f32>) -> () {
+ // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<f32>
+ %0 = tosa.variable_read @stored_var : tensor<f32>
+ // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %1 = "tosa.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<f32>
+ tosa.variable_write @stored_var, %1 : tensor<f32>
+ return
+ }
}
+
// -----
-// CHECK-LABEL: @test_variable_tensor(
-// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) {
-func.func @test_variable_tensor(%arg0: tensor<2x4x8xi32>) -> () {
- // CHECK: tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+
+module {
+ // CHECK: tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
- // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32>
- %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
- // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
- %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
- // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
- tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32>
- return
+
+ // CHECK-LABEL: @test_variable_tensor(
+ // CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) {
+ func.func @test_variable_tensor(%arg0: tensor<2x4x8xi32>) -> () {
+ // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+ %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+ // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+ %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+ // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
+ tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32>
+ return
+ }
}
// -----
-// CHECK-LABEL: @test_variable_scalar_no_initial_value(
-// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<f32>) {
-func.func @test_variable_scalar_no_initial_value(%arg0: tensor<f32>) -> () {
- // CHECK: tosa.variable @stored_var : tensor<f32>
+
+module {
+ // CHECK: tosa.variable @stored_var : tensor<f32>
tosa.variable @stored_var : tensor<f32>
- // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<f32>
- %0 = tosa.variable_read @stored_var : tensor<f32>
- // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
- %1 = "tosa.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
- // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<f32>
- tosa.variable_write @stored_var, %1 : tensor<f32>
- return
+
+ // CHECK-LABEL: @test_variable_scalar_no_initial_value(
+ // CHECK-SAME: %[[ADD_VAL:.*]]: tensor<f32>) {
+ func.func @test_variable_scalar_no_initial_value(%arg0: tensor<f32>) -> () {
+ // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<f32>
+ %0 = tosa.variable_read @stored_var : tensor<f32>
+ // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %1 = "tosa.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<f32>
+ tosa.variable_write @stored_var, %1 : tensor<f32>
+ return
+ }
}
// -----
-// CHECK-LABEL: @test_variable_tensor_no_initial_value(
-// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) {
-func.func @test_variable_tensor_no_initial_value(%arg0: tensor<2x4x8xi32>) -> () {
- // CHECK: tosa.variable @stored_var : tensor<2x4x8xi32>
+
+module {
+ // CHECK: tosa.variable @stored_var : tensor<2x4x8xi32>
tosa.variable @stored_var : tensor<2x4x8xi32>
- // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32>
- %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
- // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
- %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
- // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
- tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32>
- return
+
+ // CHECK-LABEL: @test_variable_tensor_no_initial_value(
+ // CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) {
+ func.func @test_variable_tensor_no_initial_value(%arg0: tensor<2x4x8xi32>) -> () {
+ // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+ %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+ // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+ %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+ // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
+ tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32>
+ return
+ }
}
+
// -----
-// CHECK-LABEL: @test_variable_tensor_with_unknowns(
-// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) {
-func.func @test_variable_tensor_with_unknowns(%arg0: tensor<2x4x8xi32>) -> () {
- // CHECK: tosa.variable @stored_var : tensor<2x?x8xi32>
+
+module {
+ // CHECK: tosa.variable @stored_var : tensor<2x?x8xi32>
tosa.variable @stored_var : tensor<2x?x8xi32>
- // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32>
- %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
- // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
- %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
- // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
- tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32>
- return
+
+ // CHECK-LABEL: @test_variable_tensor_with_unknowns(
+ // CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) {
+ func.func @test_variable_tensor_with_unknowns(%arg0: tensor<2x4x8xi32>) -> () {
+ // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+ %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+ // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+ %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+ // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
+ tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32>
+ return
+ }
}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 0128da729136e..430b06ad16c39 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -944,29 +944,27 @@ func.func @test_while_loop_cond_output_not_bool(%arg0: tensor<10xi32>, %arg1: te
// -----
-func.func @test_variable_multiple_declaration() -> () {
+module {
+ // expected-note at below {{see existing symbol definition here}}
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
- // expected-error at +1 {{'tosa.variable' op illegal to have multiple declaration of 'stored_var'}}
+ // expected-error at +1 {{redefinition of symbol named 'stored_var'}}
tosa.variable @stored_var = dense<-3> : tensor<2x4x8xi32>
- return
}
// -----
-func.func @test_variable_shape_mismatch() -> () {
+module {
// expected-error at +1 {{inferred shape of elements literal ([2]) does not match type ([3])}}
tosa.variable @stored_var = dense<[3.14, 2.14]> : tensor<3xf32>
// expected-error at +1 {{custom op 'tosa.variable' expected attribute}}
- return
}
// -----
-func.func @test_variable_type_mismatch() -> () {
+module {
// expected-error at +1 {{expected integer elements, but parsed floating-point}}
tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xi32>
// expected-error at +1 {{custom op 'tosa.variable' expected attribute}}
- return
}
// -----
@@ -979,20 +977,26 @@ func.func @test_variable_read_no_declaration() -> () {
// -----
-func.func @test_variable_read_type_mismatch() -> () {
+module {
tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32>
- // expected-error at +1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('f32')}}
- %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
- return
+
+ func.func @test_variable_read_type_mismatch() -> () {
+ // expected-error at +1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('f32')}}
+ %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+ return
+ }
}
// -----
-func.func @test_variable_read_shape_mismatch() -> () {
+module {
tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32>
- // expected-error at +1 {{'tosa.variable_read' op require same shapes for 'output1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}}
- %0 = tosa.variable_read @stored_var : tensor<2x4x8xf32>
- return
+
+ func.func @test_variable_read_shape_mismatch() -> () {
+ // expected-error at +1 {{'tosa.variable_read' op require same shapes for 'output1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}}
+ %0 = tosa.variable_read @stored_var : tensor<2x4x8xf32>
+ return
+ }
}
// -----
@@ -1005,20 +1009,26 @@ func.func @test_variable_write_no_declaration(%arg0: tensor<f32>) -> () {
// -----
-func.func @test_variable_write_type_mismatch(%arg0: tensor<2x4x8xi32>) -> () {
+module {
tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32>
- // expected-error at +1 {{'tosa.variable_write' op require same element type for 'input1' ('i32') and the input tensor ('f32')}}
- tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi32>
- return
+
+ func.func @test_variable_write_type_mismatch(%arg0: tensor<2x4x8xi32>) -> () {
+ // expected-error at +1 {{'tosa.variable_write' op require same element type for 'input1' ('i32') and the input tensor ('f32')}}
+ tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi32>
+ return
+ }
}
// -----
-func.func @test_variable_write_shape_mismatch(%arg0: tensor<2x4x8xf32>) -> () {
+module {
tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32>
- // expected-error at +1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}}
- tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xf32>
- return
+
+ func.func @test_variable_write_shape_mismatch(%arg0: tensor<2x4x8xf32>) -> () {
+ // expected-error at +1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}}
+ tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xf32>
+ return
+ }
}
// -----
More information about the Mlir-commits
mailing list