[Mlir-commits] [mlir] [tosa] Change VariableOp to align with spec (PR #142240)
Tai Ly
llvmlistbot at llvm.org
Fri May 30 16:41:36 PDT 2025
https://github.com/Tai78641 created https://github.com/llvm/llvm-project/pull/142240
This fixes Tosa VariableOp to align with spec 1.0
- add var_shape attribute to store shape of variable type
- change type attribute to store element type of variable type
- add a builder so previous construction calls still work
- fix up level check of rank to be on variable type instead of initial value which is optional
- add level check of size for variable type
- add lit tests for variable op's without initial values
- add lit test for variable op with fixed rank but unknown dimension
- add invalid lit test for variable op with unranked type
>From 786379b11b505000b2fad6973ef9e91f214855a9 Mon Sep 17 00:00:00 2001
From: Tai Ly <tai.ly at arm.com>
Date: Thu, 29 May 2025 00:37:00 +0000
Subject: [PATCH] [tosa] Change VariableOp to align with spec
This fixes Tosa VariableOp to align with spec 1.0
- add var_shape attribute to store shape of variable type
- change type attribute to store element type of variable type
- add a builder so previous construction calls still work
- fix up level check of rank to be on variable type instead of
initial value which is optional
- add level check of size for variable type
- add lit tests for variable op's without initial values
- add lit test for variable op with fixed rank but unknown dimension
- add invalid lit test for variable op with unranked type
Signed-off-by: Tai Ly <tai.ly at arm.com>
Change-Id: Icbbd751666870a94d4902163f7e840395e2aea52
---
.../mlir/Dialect/Tosa/IR/TosaOpBase.td | 10 ++
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h | 15 +-
.../mlir/Dialect/Tosa/IR/TosaUtilOps.td | 7 +-
.../TosaToMLProgram/TosaToMLProgram.cpp | 3 +-
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 133 ++++++++++++++----
.../Tosa/Transforms/TosaProfileCompliance.cpp | 11 +-
.../Tosa/Transforms/TosaValidation.cpp | 70 ++++++---
.../TosaToMLProgram/tosa-to-mlprogram.mlir | 16 ++-
mlir/test/Dialect/Tosa/invalid.mlir | 17 +++
mlir/test/Dialect/Tosa/level_check.mlir | 9 +-
mlir/test/Dialect/Tosa/variables.mlir | 45 ++++++
11 files changed, 267 insertions(+), 69 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 0aef4653b74ff..e048f8af7cc33 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -197,6 +197,16 @@ def Tosa_PadOpQuantInfoBuilder : OpBuilder<
input, paddings);
}]>;
+// This builder is called on the TOSA variable operator with a variable type
+// and optional initial value. The builder will extract var_shape and element type
+// attributes from variable type.
+def Tosa_VariableOpBuilder : OpBuilder<
+ (ins "StringRef":$name, "Type":$variable_type, "Attribute":$initial_value),
+ [{
+ buildVariableOp($_builder, $_state, name, variable_type, initial_value);
+ }]>;
+
+
// Wrapper over base I32EnumAttr to set common fields.
class Tosa_I32Enum<string name, string description, list<I32EnumAttrCase> cases>
: I32EnumAttr<name, description, cases> {
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 6fa4aedc1f0b0..a15f073bc5fcb 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -44,10 +44,14 @@ class PatternRewriter;
namespace tosa {
-ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
- Attribute &attr);
-void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
- Attribute attr);
+ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser,
+ DenseElementsAttr &varShapeAttr,
+ TypeAttr &typeAttr,
+ Attribute &initialValueAttr);
+void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op,
+ DenseElementsAttr varShapeAttr,
+ TypeAttr typeAttr,
+ Attribute initialValueAttr);
#include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc"
@@ -172,6 +176,9 @@ std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src,
int32_t val = 0);
+// returns type of variable op
+RankedTensorType getVariableType(VariableOp variableOp);
+
} // namespace tosa
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index 5f99162907949..c8f2907f8dd1b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -92,6 +92,7 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
let arguments = (ins
SymbolNameAttr:$name,
+ IndexElementsAttr:$var_shape,
TypeAttr:$type,
OptionalAttr<AnyAttr>:$initial_value
);
@@ -101,12 +102,16 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
Extension<[Tosa_EXT_VARIABLE]>,
];
+ let hasCustomAssemblyFormat = 1;
+
let assemblyFormat = [{
$name
attr-dict
- custom<TypeOrAttr>($type, $initial_value)
+ custom<VariableOpTypeOrInitialValue>($var_shape, $type, $initial_value)
}];
+ let builders = [Tosa_VariableOpBuilder];
+
let hasVerifier = 1;
}
diff --git a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
index 310566e692202..7dbccd19a0518 100644
--- a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
+++ b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
@@ -26,8 +26,9 @@ class VariableOpConverter : public OpRewritePattern<tosa::VariableOp> {
LogicalResult matchAndRewrite(tosa::VariableOp op,
PatternRewriter &rewriter) const final {
+ auto variableType = tosa::getVariableType(op);
auto newVariable = rewriter.create<mlir::ml_program::GlobalOp>(
- op.getLoc(), op.getName(), op.getType(), /*is_mutable=*/true,
+ op.getLoc(), op.getName(), variableType, /*is_mutable=*/true,
op.getInitialValueAttr(), /*sym_visibility=*/nullptr);
newVariable.setPrivate();
rewriter.replaceOp(op, newVariable);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 93a6a8be48df7..6a1639104846e 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -131,6 +131,24 @@ SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
return {&getBodyGraph()};
}
+//===----------------------------------------------------------------------===//
+// TOSA variable operator support.
+//===----------------------------------------------------------------------===//
+
+static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
+ return to_vector(llvm::map_range(shape, [](int64_t dim) {
+ return dim == -1 ? ShapedType::kDynamic : dim;
+ }));
+}
+
+// returns type of variable op
+RankedTensorType mlir::tosa::getVariableType(tosa::VariableOp variableOp) {
+ Type elementType = variableOp.getType();
+ DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
+ auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
+ return RankedTensorType::get(shape, elementType);
+}
+
//===----------------------------------------------------------------------===//
// Tosa dialect initialization.
//===----------------------------------------------------------------------===//
@@ -177,42 +195,81 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
// Parsers and printers
//===----------------------------------------------------------------------===//
-ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
- Attribute &attr) {
+namespace {
+
+ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType,
+ DenseElementsAttr &varShapeAttr,
+ TypeAttr &typeAttr) {
+ if (auto shapedType = dyn_cast<ShapedType>(parsedType)) {
+ if (!shapedType.hasRank())
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected ranked type";
+
+ auto elementType = shapedType.getElementType();
+ typeAttr = TypeAttr::get(elementType);
+ ArrayRef<int64_t> shape = shapedType.getShape();
+ Builder builder(parser.getContext());
+ varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
+ return success();
+ }
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected shaped type";
+}
+
+} // namespace
+
+// parses the optional initial value or type for a tosa variable
+// with initial value:
+// tosa.variable @name = dense<0.0> : tensor<1x8xf32>
+//
+// without initial value:
+// tosa.variable @name : tensor<1x8xf32>
+ParseResult mlir::tosa::parseVariableOpTypeOrInitialValue(
+ OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr,
+ Attribute &initialValueAttr) {
if (succeeded(parser.parseOptionalEqual())) {
- if (failed(parser.parseAttribute(attr))) {
+ if (failed(parser.parseAttribute(initialValueAttr))) {
return parser.emitError(parser.getCurrentLocation())
<< "expected attribute";
}
- if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
- typeAttr = TypeAttr::get(typedAttr.getType());
+ if (auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
+ return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
+ typeAttr);
}
- return success();
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected Typed attr";
}
- Type type;
- if (failed(parser.parseColonType(type))) {
- return parser.emitError(parser.getCurrentLocation()) << "expected type";
+ initialValueAttr = nullptr;
+ Type parsedType;
+ if (failed(parser.parseColonType(parsedType))) {
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected type after colon";
}
- typeAttr = TypeAttr::get(type);
-
- return success();
+ return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
}
-void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
- Attribute attr) {
+void mlir::tosa::printVariableOpTypeOrInitialValue(
+ OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr,
+ TypeAttr typeAttr, Attribute initialValueAttr) {
bool needsSpace = false;
- auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
- if (!typedAttr || typedAttr.getType() != type.getValue()) {
+ auto typedAttr = dyn_cast_or_null<TypedAttr>(initialValueAttr);
+ if (!typedAttr) {
+ auto shape =
+ convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
+ Type elementType = typeAttr.getValue();
+ RankedTensorType tensorType =
+ RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
+ auto tensorTypeAttr = TypeAttr::get(tensorType);
p << ": ";
- p.printAttribute(type);
+ p.printAttribute(tensorTypeAttr);
needsSpace = true; // subsequent attr value needs a space separator
}
- if (attr) {
+ if (initialValueAttr) {
if (needsSpace)
p << ' ';
p << "= ";
- p.printAttribute(attr);
+ p.printAttribute(initialValueAttr);
}
}
@@ -657,8 +714,9 @@ static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
<< symName << "' has not been declared by 'tosa.variable'";
// Verify type and shape
- Type varType = cast<tosa::VariableOp>(varOp.value()).getType();
- if (errorIfTypeOrShapeMismatch(op, type, name, varType, "the input tensor")
+ auto variableType = getVariableType(varOp.value());
+ if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
+ "the input tensor")
.failed())
return failure();
@@ -1103,6 +1161,33 @@ static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
result.types.push_back(outputType);
}
+static void buildVariableOp(OpBuilder &builder, OperationState &result,
+ StringRef name, Type variableType,
+ Attribute initialValue) {
+ const Location loc{result.location};
+ auto nameAttr = builder.getStringAttr(name);
+
+ auto shapedType = dyn_cast<ShapedType>(variableType);
+ if (!shapedType) {
+ (void)emitError(loc, "variable type must be a shaped type");
+ return;
+ }
+ if (!shapedType.hasRank()) {
+ (void)emitError(loc, "variable type must be a ranked type");
+ return;
+ }
+
+ auto elementType = shapedType.getElementType();
+ auto elementTypeAttr = TypeAttr::get(elementType);
+ ArrayRef<int64_t> shape = shapedType.getShape();
+ auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
+
+ result.addAttribute("name", nameAttr);
+ result.addAttribute("var_shape", varShapeAttr);
+ result.addAttribute("type", elementTypeAttr);
+ result.addAttribute("initial_value", initialValue);
+}
+
//===----------------------------------------------------------------------===//
// TOSA Operator Return Type Inference.
//===----------------------------------------------------------------------===//
@@ -1676,12 +1761,6 @@ LogicalResult tosa::PadOp::verify() {
return success();
}
-static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
- return to_vector(llvm::map_range(shape, [](int64_t dim) {
- return dim == -1 ? ShapedType::kDynamic : dim;
- }));
-}
-
LogicalResult tosa::SliceOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
SliceOp::Adaptor adaptor,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 1a896c1464e1c..de08e7e9a4394 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -215,15 +215,8 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
- ::mlir::Attribute attr = op.getInitialValueAttr();
- if (attr == nullptr)
- return failure();
-
- if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
- addType(getElementTypeOrSelf(typedAttr));
- return success();
- }
- return failure();
+ addType(op.getType());
+ return success();
}
template <>
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index f9db5dcb88b4c..ea862ecb49e4e 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -238,10 +238,10 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
return true;
}
- template <typename T>
- bool levelCheckRank(Operation *op, const T &v,
+ // Perform the Level Rank check on the tensor type.
+ bool levelCheckRank(Operation *op, const Type typeToCheck,
const StringRef operandOrResult, int32_t highest_rank) {
- if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
+ if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
if (!type.hasRank()) {
op->emitOpError() << "failed level check: unranked tensor";
return false;
@@ -255,10 +255,22 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
return true;
}
- // Perform the Level tensor size check on the input tensor.
- bool levelCheckSize(Operation *op, const Value &v,
+ // Perform the Level Rank check on the tensor value.
+ bool levelCheckRank(Operation *op, const Value &v,
+ const StringRef operandOrResult, int32_t highest_rank) {
+ return levelCheckRank(op, v.getType(), operandOrResult, highest_rank);
+ }
+
+ // Perform the Level tensor size check on the tensor type.
+ bool levelCheckSize(Operation *op, const Type &typeToCheck,
const StringRef operandOrResult);
+ // Perform the Level tensor size check on the tensor value.
+ bool levelCheckSize(Operation *op, const Value &v,
+ const StringRef operandOrResult) {
+ return levelCheckSize(op, v.getType(), operandOrResult);
+ }
+
// Level check sizes of all operands and results of the operation.
template <typename T>
bool levelCheckSizes(T tosaOp) {
@@ -284,15 +296,6 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
return false;
}
- if (!op->getAttrs().empty()) {
- for (NamedAttribute attr : op->getAttrs()) {
- if (auto elemAttr = dyn_cast<ElementsAttr>(attr.getValue())) {
- if (!levelCheckRank(op, elemAttr, "attribute", tosaLevel.MAX_RANK))
- return false;
- }
- }
- }
-
for (auto v : op->getResults()) {
if (!levelCheckRank(op, v, "result", tosaLevel.MAX_RANK))
return false;
@@ -596,6 +599,26 @@ bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
return true;
}
+template <>
+bool TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
+ auto op = tosaOp.getOperation();
+ auto variableType = getVariableType(tosaOp);
+ if (!levelCheckRank(op, variableType, "variable type", tosaLevel.MAX_RANK))
+ return false;
+
+ return true;
+}
+
+template <>
+bool TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
+ auto op = tosaOp.getOperation();
+ auto variableType = getVariableType(tosaOp);
+ if (!levelCheckSize(op, variableType, "variable type"))
+ return false;
+
+ return true;
+}
+
bool TosaValidation::levelCheckRanksAndSizes(Operation *op) {
#define CHECK_RANKS_AND_SIZES(tosaOp) \
if (isa<tosa::tosaOp##Op>(op)) { \
@@ -714,10 +737,10 @@ bool TosaValidation::levelCheckRanksAndSizes(Operation *op) {
return true;
}
-// Perform the Level tensor size check
-bool TosaValidation::levelCheckSize(Operation *op, const Value &v,
+// Perform the Level tensor size check on the tensor type.
+bool TosaValidation::levelCheckSize(Operation *op, const Type &typeToCheck,
const StringRef operandOrResult) {
- if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
+ if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
if (!type.hasRank()) {
op->emitOpError() << "failed level check: unranked tensor";
return false;
@@ -800,18 +823,21 @@ inline bool CompatibleTypes(const mlir::Type &type,
}
bool TosaValidation::CheckVariable(Operation *op) {
- if (isa<mlir::tosa::VariableOp>(op)) {
- mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
+ if (auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) {
+ mlir::StringAttr nameAttr = variableOp.getNameAttr();
if (variablesMap.count(nameAttr)) {
op->emitOpError() << "name has already been declared";
return false;
}
- auto typeAttr = cast<mlir::TypeAttr>(op->getAttr("type"));
- mlir::Type type = typeAttr.getValue();
+ auto elementType = variableOp.getType();
+ DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
+ SmallVector<int64_t> shape = to_vector(varShapeAttr.getValues<int64_t>());
+ RankedTensorType variableType =
+ RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
- variablesMap[nameAttr] = type;
+ variablesMap[nameAttr] = variableType;
}
return true;
diff --git a/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir b/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir
index 365b05ff084da..d2092753f1f58 100644
--- a/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir
+++ b/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --tosa-to-mlprogram %s -o -| FileCheck %s
+// RUN: mlir-opt --tosa-to-mlprogram %s -split-input-file -o -| FileCheck %s
module {
// CHECK: ml_program.global private mutable @var_x(dense<7.000000e+00> : tensor<1xf32>) : tensor<1xf32>
@@ -10,4 +10,18 @@ module {
%0 = tosa.variable_read @var_x : tensor<1xf32>
return %0 : tensor<1xf32>
}
+}
+
+// -----
+
+module {
+ // CHECK: ml_program.global private mutable @var_x : tensor<f32>
+ tosa.variable @var_x : tensor<f32>
+ func.func @test_stateful_ops(%arg0: tensor<f32>) -> (tensor<f32>) {
+ // CHECK: ml_program.global_store @var_x = %arg0 : tensor<f32>
+ tosa.variable_write @var_x, %arg0 : tensor<f32>
+ // CHECK: %[[LOAD:.+]] = ml_program.global_load @var_x : tensor<f32>
+ %0 = tosa.variable_read @var_x : tensor<f32>
+ return %0 : tensor<f32>
+ }
}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index c41f079ec526c..05505c3671674 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -564,6 +564,23 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: ten
// -----
+func.func @test_variable_unranked(%arg0: tensor<2x4x8xi8>) -> () {
+ tosa.variable @stored_var : tensor<*xi8>
+ // expected-error at +1 {{custom op 'tosa.variable' expected ranked type}}
+ return
+}
+
+// -----
+
+func.func @test_variable_unranked_initial_value(%arg0: tensor<2x4x8xi8>) -> () {
+ // expected-error at +1 {{elements literal type must have static shape}}
+ tosa.variable @stored_var = dense<0> : tensor<*xi8>
+ // expected-error at +1 {{custom op 'tosa.variable' expected attribute}}
+ return
+}
+
+// -----
+
func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error at +1 {{'tosa.variable' op illegal to have multiple declaration of 'stored_var'}}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index e7d0a0e1fa4ea..223bf3b635e18 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -443,7 +443,7 @@ func.func @test_rescale_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi8>) -> tenso
// -----
func.func @test_const(%arg0 : tensor<1x1xi32>) -> tensor<1x1x1x1x1x1x1xi32> {
- // expected-error at +1 {{'tosa.const' op failed level check: attribute rank(shape) <= MAX_RANK}}
+ // expected-error at +1 {{'tosa.const' op failed level check: result rank(shape) <= MAX_RANK}}
%0 = "tosa.const"() {values = dense<0> : tensor<1x1x1x1x1x1x1xi32>} : () -> tensor<1x1x1x1x1x1x1xi32>
return %0: tensor<1x1x1x1x1x1x1xi32>
}
@@ -1089,7 +1089,8 @@ func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x210000000x3xf32>, %
// -----
func.func @test_variable_read_write_tensor_size_invalid() -> () {
- tosa.variable @stored_var = dense<3.14> : tensor<536870912xf32>
+ // expected-error at +1 {{'tosa.variable' op failed level check: variable type tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
+ tosa.variable @stored_var : tensor<536870912xf32>
// expected-error at +1 {{'tosa.variable_read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
%0 = tosa.variable_read @stored_var : tensor<536870912xf32>
// expected-error at +1 {{'tosa.variable_write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
@@ -1156,8 +1157,8 @@ func.func @test_cond_if_rank_invalid(%arg0: tensor<1x1x1x1x1x1x1x1xf32>, %arg1:
// -----
func.func @test_variable_read_write_rank_invalid() -> () {
- // expected-error at +1 {{'tosa.variable' op failed level check: attribute rank(shape) <= MAX_RANK}}
- tosa.variable @stored_var = dense<3.14> : tensor<1x1x1x1x1x1x1x1xf32>
+ // expected-error at +1 {{'tosa.variable' op failed level check: variable type rank(shape) <= MAX_RANK}}
+ tosa.variable @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
// expected-error at +1 {{'tosa.variable_read' op failed level check: result rank(shape) <= MAX_RANK}}
%0 = tosa.variable_read @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
// expected-error at +1 {{'tosa.variable_write' op failed level check: operand rank(shape) <= MAX_RANK}}
diff --git a/mlir/test/Dialect/Tosa/variables.mlir b/mlir/test/Dialect/Tosa/variables.mlir
index 25f63331f39df..9953eb375d3ac 100644
--- a/mlir/test/Dialect/Tosa/variables.mlir
+++ b/mlir/test/Dialect/Tosa/variables.mlir
@@ -31,3 +31,48 @@ func.func @test_variable_tensor(%arg0: tensor<2x4x8xi32>) -> () {
tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32>
return
}
+
+// -----
+// CHECK-LABEL: @test_variable_scalar_no_initial_value(
+// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<f32>) {
+func.func @test_variable_scalar_no_initial_value(%arg0: tensor<f32>) -> () {
+ // CHECK: tosa.variable @stored_var : tensor<f32>
+ tosa.variable @stored_var : tensor<f32>
+ // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<f32>
+ %0 = tosa.variable_read @stored_var : tensor<f32>
+ // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %1 = "tosa.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<f32>
+ tosa.variable_write @stored_var, %1 : tensor<f32>
+ return
+}
+
+// -----
+// CHECK-LABEL: @test_variable_tensor_no_initial_value(
+// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) {
+func.func @test_variable_tensor_no_initial_value(%arg0: tensor<2x4x8xi32>) -> () {
+ // CHECK: tosa.variable @stored_var : tensor<2x4x8xi32>
+ tosa.variable @stored_var : tensor<2x4x8xi32>
+ // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+ %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+ // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+ %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+ // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
+ tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32>
+ return
+}
+
+// -----
+// CHECK-LABEL: @test_variable_tensor_with_unknowns(
+// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) {
+func.func @test_variable_tensor_with_unknowns(%arg0: tensor<2x4x8xi32>) -> () {
+ // CHECK: tosa.variable @stored_var : tensor<2x?x8xi32>
+ tosa.variable @stored_var : tensor<2x?x8xi32>
+ // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+ %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+ // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+ %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+ // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
+ tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32>
+ return
+}
More information about the Mlir-commits
mailing list