[Mlir-commits] [mlir] [tosa] Change VariableOp to align with spec (PR #142240)

Tai Ly llvmlistbot at llvm.org
Fri May 30 16:41:36 PDT 2025


https://github.com/Tai78641 created https://github.com/llvm/llvm-project/pull/142240

This fixes Tosa VariableOp to align with spec 1.0
  - add var_shape attribute to store shape of variable type
  - change type attribute to store element type of variable type
  - add a builder so previous construction calls still work
  - fix up level check of rank to be on variable type instead of initial value which is optional
  - add level check of size for variable type
  - add lit tests for variable op's without initial values
  - add lit test for variable op with fixed rank but unknown dimension
  - add invalid lit test for variable op with unranked type


>From 786379b11b505000b2fad6973ef9e91f214855a9 Mon Sep 17 00:00:00 2001
From: Tai Ly <tai.ly at arm.com>
Date: Thu, 29 May 2025 00:37:00 +0000
Subject: [PATCH] [tosa] Change VariableOp to align with spec

This fixes Tosa VariableOp to align with spec 1.0
  - add var_shape attribute to store shape of variable type
  - change type attribute to store element type of variable type
  - add a builder so previous construction calls still work
  - fix up level check of rank to be on variable type instead of
    initial value which is optional
  - add level check of size for variable type
  - add lit tests for variable op's without initial values
  - add lit test for variable op with fixed rank but unknown dimension
  - add invalid lit test for variable op with unranked type

Signed-off-by: Tai Ly <tai.ly at arm.com>
Change-Id: Icbbd751666870a94d4902163f7e840395e2aea52
---
 .../mlir/Dialect/Tosa/IR/TosaOpBase.td        |  10 ++
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h   |  15 +-
 .../mlir/Dialect/Tosa/IR/TosaUtilOps.td       |   7 +-
 .../TosaToMLProgram/TosaToMLProgram.cpp       |   3 +-
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 133 ++++++++++++++----
 .../Tosa/Transforms/TosaProfileCompliance.cpp |  11 +-
 .../Tosa/Transforms/TosaValidation.cpp        |  70 ++++++---
 .../TosaToMLProgram/tosa-to-mlprogram.mlir    |  16 ++-
 mlir/test/Dialect/Tosa/invalid.mlir           |  17 +++
 mlir/test/Dialect/Tosa/level_check.mlir       |   9 +-
 mlir/test/Dialect/Tosa/variables.mlir         |  45 ++++++
 11 files changed, 267 insertions(+), 69 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 0aef4653b74ff..e048f8af7cc33 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -197,6 +197,16 @@ def Tosa_PadOpQuantInfoBuilder : OpBuilder<
                             input, paddings);
   }]>;
 
+// This builder is called on the TOSA variable operator with a variable type
+// and optional initial value. The builder will extract var_shape and element type
+// attributes from variable type.
+def Tosa_VariableOpBuilder : OpBuilder<
+  (ins "StringRef":$name, "Type":$variable_type, "Attribute":$initial_value),
+  [{
+    buildVariableOp($_builder, $_state, name, variable_type, initial_value);
+  }]>;
+
+
 // Wrapper over base I32EnumAttr to set common fields.
 class Tosa_I32Enum<string name, string description, list<I32EnumAttrCase> cases>
      : I32EnumAttr<name, description, cases> {
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 6fa4aedc1f0b0..a15f073bc5fcb 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -44,10 +44,14 @@ class PatternRewriter;
 
 namespace tosa {
 
-ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
-                            Attribute &attr);
-void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
-                     Attribute attr);
+ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser,
+                                              DenseElementsAttr &varShapeAttr,
+                                              TypeAttr &typeAttr,
+                                              Attribute &initialValueAttr);
+void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op,
+                                       DenseElementsAttr varShapeAttr,
+                                       TypeAttr typeAttr,
+                                       Attribute initialValueAttr);
 
 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc"
 
@@ -172,6 +176,9 @@ std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
 Value createPadConstTensor(OpBuilder &builder, Location loc, Value src,
                            int32_t val = 0);
 
+// returns type of variable op
+RankedTensorType getVariableType(VariableOp variableOp);
+
 } // namespace tosa
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index 5f99162907949..c8f2907f8dd1b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -92,6 +92,7 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
 
   let arguments = (ins
     SymbolNameAttr:$name,
+    IndexElementsAttr:$var_shape,
     TypeAttr:$type,
     OptionalAttr<AnyAttr>:$initial_value
   );
@@ -101,12 +102,16 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
     Extension<[Tosa_EXT_VARIABLE]>,
   ];
 
+  let hasCustomAssemblyFormat = 1;
+
   let assemblyFormat = [{
     $name
     attr-dict
-    custom<TypeOrAttr>($type, $initial_value)
+    custom<VariableOpTypeOrInitialValue>($var_shape, $type, $initial_value)
   }];
 
+  let builders = [Tosa_VariableOpBuilder];
+
   let hasVerifier = 1;
 }
 
diff --git a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
index 310566e692202..7dbccd19a0518 100644
--- a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
+++ b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
@@ -26,8 +26,9 @@ class VariableOpConverter : public OpRewritePattern<tosa::VariableOp> {
 
   LogicalResult matchAndRewrite(tosa::VariableOp op,
                                 PatternRewriter &rewriter) const final {
+    auto variableType = tosa::getVariableType(op);
     auto newVariable = rewriter.create<mlir::ml_program::GlobalOp>(
-        op.getLoc(), op.getName(), op.getType(), /*is_mutable=*/true,
+        op.getLoc(), op.getName(), variableType, /*is_mutable=*/true,
         op.getInitialValueAttr(), /*sym_visibility=*/nullptr);
     newVariable.setPrivate();
     rewriter.replaceOp(op, newVariable);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 93a6a8be48df7..6a1639104846e 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -131,6 +131,24 @@ SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
   return {&getBodyGraph()};
 }
 
+//===----------------------------------------------------------------------===//
+// TOSA variable operator support.
+//===----------------------------------------------------------------------===//
+
+static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
+  return to_vector(llvm::map_range(shape, [](int64_t dim) {
+    return dim == -1 ? ShapedType::kDynamic : dim;
+  }));
+}
+
+// returns type of variable op
+RankedTensorType mlir::tosa::getVariableType(tosa::VariableOp variableOp) {
+  Type elementType = variableOp.getType();
+  DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
+  auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
+  return RankedTensorType::get(shape, elementType);
+}
+
 //===----------------------------------------------------------------------===//
 // Tosa dialect initialization.
 //===----------------------------------------------------------------------===//
@@ -177,42 +195,81 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
 // Parsers and printers
 //===----------------------------------------------------------------------===//
 
-ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
-                                        Attribute &attr) {
+namespace {
+
+ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType,
+                                   DenseElementsAttr &varShapeAttr,
+                                   TypeAttr &typeAttr) {
+  if (auto shapedType = dyn_cast<ShapedType>(parsedType)) {
+    if (!shapedType.hasRank())
+      return parser.emitError(parser.getCurrentLocation())
+             << "expected ranked type";
+
+    auto elementType = shapedType.getElementType();
+    typeAttr = TypeAttr::get(elementType);
+    ArrayRef<int64_t> shape = shapedType.getShape();
+    Builder builder(parser.getContext());
+    varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
+    return success();
+  }
+  return parser.emitError(parser.getCurrentLocation())
+         << "expected shaped type";
+}
+
+} // namespace
+
+// parses the optional initial value or type for a tosa variable
+//  with initial value:
+//    tosa.variable @name = dense<0.0> : tensor<1x8xf32>
+//
+//  without initial value:
+//    tosa.variable @name : tensor<1x8xf32>
+ParseResult mlir::tosa::parseVariableOpTypeOrInitialValue(
+    OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr,
+    Attribute &initialValueAttr) {
   if (succeeded(parser.parseOptionalEqual())) {
-    if (failed(parser.parseAttribute(attr))) {
+    if (failed(parser.parseAttribute(initialValueAttr))) {
       return parser.emitError(parser.getCurrentLocation())
              << "expected attribute";
     }
-    if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
-      typeAttr = TypeAttr::get(typedAttr.getType());
+    if (auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
+      return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
+                                    typeAttr);
     }
-    return success();
+    return parser.emitError(parser.getCurrentLocation())
+           << "expected Typed attr";
   }
 
-  Type type;
-  if (failed(parser.parseColonType(type))) {
-    return parser.emitError(parser.getCurrentLocation()) << "expected type";
+  initialValueAttr = nullptr;
+  Type parsedType;
+  if (failed(parser.parseColonType(parsedType))) {
+    return parser.emitError(parser.getCurrentLocation())
+           << "expected type after colon";
   }
-  typeAttr = TypeAttr::get(type);
-
-  return success();
+  return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
 }
 
-void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
-                                 Attribute attr) {
+void mlir::tosa::printVariableOpTypeOrInitialValue(
+    OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr,
+    TypeAttr typeAttr, Attribute initialValueAttr) {
   bool needsSpace = false;
-  auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
-  if (!typedAttr || typedAttr.getType() != type.getValue()) {
+  auto typedAttr = dyn_cast_or_null<TypedAttr>(initialValueAttr);
+  if (!typedAttr) {
+    auto shape =
+        convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
+    Type elementType = typeAttr.getValue();
+    RankedTensorType tensorType =
+        RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
+    auto tensorTypeAttr = TypeAttr::get(tensorType);
     p << ": ";
-    p.printAttribute(type);
+    p.printAttribute(tensorTypeAttr);
     needsSpace = true; // subsequent attr value needs a space separator
   }
-  if (attr) {
+  if (initialValueAttr) {
     if (needsSpace)
       p << ' ';
     p << "= ";
-    p.printAttribute(attr);
+    p.printAttribute(initialValueAttr);
   }
 }
 
@@ -657,8 +714,9 @@ static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
            << symName << "' has not been declared by 'tosa.variable'";
 
   // Verify type and shape
-  Type varType = cast<tosa::VariableOp>(varOp.value()).getType();
-  if (errorIfTypeOrShapeMismatch(op, type, name, varType, "the input tensor")
+  auto variableType = getVariableType(varOp.value());
+  if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
+                                 "the input tensor")
           .failed())
     return failure();
 
@@ -1103,6 +1161,33 @@ static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
   result.types.push_back(outputType);
 }
 
+static void buildVariableOp(OpBuilder &builder, OperationState &result,
+                            StringRef name, Type variableType,
+                            Attribute initialValue) {
+  const Location loc{result.location};
+  auto nameAttr = builder.getStringAttr(name);
+
+  auto shapedType = dyn_cast<ShapedType>(variableType);
+  if (!shapedType) {
+    (void)emitError(loc, "variable type must be a shaped type");
+    return;
+  }
+  if (!shapedType.hasRank()) {
+    (void)emitError(loc, "variable type must be a ranked type");
+    return;
+  }
+
+  auto elementType = shapedType.getElementType();
+  auto elementTypeAttr = TypeAttr::get(elementType);
+  ArrayRef<int64_t> shape = shapedType.getShape();
+  auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
+
+  result.addAttribute("name", nameAttr);
+  result.addAttribute("var_shape", varShapeAttr);
+  result.addAttribute("type", elementTypeAttr);
+  result.addAttribute("initial_value", initialValue);
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Operator Return Type Inference.
 //===----------------------------------------------------------------------===//
@@ -1676,12 +1761,6 @@ LogicalResult tosa::PadOp::verify() {
   return success();
 }
 
-static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
-  return to_vector(llvm::map_range(shape, [](int64_t dim) {
-    return dim == -1 ? ShapedType::kDynamic : dim;
-  }));
-}
-
 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     SliceOp::Adaptor adaptor,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 1a896c1464e1c..de08e7e9a4394 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -215,15 +215,8 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
 
 template <>
 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
-  ::mlir::Attribute attr = op.getInitialValueAttr();
-  if (attr == nullptr)
-    return failure();
-
-  if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
-    addType(getElementTypeOrSelf(typedAttr));
-    return success();
-  }
-  return failure();
+  addType(op.getType());
+  return success();
 }
 
 template <>
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index f9db5dcb88b4c..ea862ecb49e4e 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -238,10 +238,10 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     return true;
   }
 
-  template <typename T>
-  bool levelCheckRank(Operation *op, const T &v,
+  // Perform the Level Rank check on the tensor type.
+  bool levelCheckRank(Operation *op, const Type typeToCheck,
                       const StringRef operandOrResult, int32_t highest_rank) {
-    if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
+    if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
       if (!type.hasRank()) {
         op->emitOpError() << "failed level check: unranked tensor";
         return false;
@@ -255,10 +255,22 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     return true;
   }
 
-  // Perform the Level tensor size check on the input tensor.
-  bool levelCheckSize(Operation *op, const Value &v,
+  // Perform the Level Rank check on the tensor value.
+  bool levelCheckRank(Operation *op, const Value &v,
+                      const StringRef operandOrResult, int32_t highest_rank) {
+    return levelCheckRank(op, v.getType(), operandOrResult, highest_rank);
+  }
+
+  // Perform the Level tensor size check on the tensor type.
+  bool levelCheckSize(Operation *op, const Type &typeToCheck,
                       const StringRef operandOrResult);
 
+  // Perform the Level tensor size check on the tensor value.
+  bool levelCheckSize(Operation *op, const Value &v,
+                      const StringRef operandOrResult) {
+    return levelCheckSize(op, v.getType(), operandOrResult);
+  }
+
   // Level check sizes of all operands and results of the operation.
   template <typename T>
   bool levelCheckSizes(T tosaOp) {
@@ -284,15 +296,6 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
         return false;
     }
 
-    if (!op->getAttrs().empty()) {
-      for (NamedAttribute attr : op->getAttrs()) {
-        if (auto elemAttr = dyn_cast<ElementsAttr>(attr.getValue())) {
-          if (!levelCheckRank(op, elemAttr, "attribute", tosaLevel.MAX_RANK))
-            return false;
-        }
-      }
-    }
-
     for (auto v : op->getResults()) {
       if (!levelCheckRank(op, v, "result", tosaLevel.MAX_RANK))
         return false;
@@ -596,6 +599,26 @@ bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
   return true;
 }
 
+template <>
+bool TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
+  auto op = tosaOp.getOperation();
+  auto variableType = getVariableType(tosaOp);
+  if (!levelCheckRank(op, variableType, "variable type", tosaLevel.MAX_RANK))
+    return false;
+
+  return true;
+}
+
+template <>
+bool TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
+  auto op = tosaOp.getOperation();
+  auto variableType = getVariableType(tosaOp);
+  if (!levelCheckSize(op, variableType, "variable type"))
+    return false;
+
+  return true;
+}
+
 bool TosaValidation::levelCheckRanksAndSizes(Operation *op) {
 #define CHECK_RANKS_AND_SIZES(tosaOp)                                          \
   if (isa<tosa::tosaOp##Op>(op)) {                                             \
@@ -714,10 +737,10 @@ bool TosaValidation::levelCheckRanksAndSizes(Operation *op) {
   return true;
 }
 
-// Perform the Level tensor size check
-bool TosaValidation::levelCheckSize(Operation *op, const Value &v,
+// Perform the Level tensor size check on the tensor type.
+bool TosaValidation::levelCheckSize(Operation *op, const Type &typeToCheck,
                                     const StringRef operandOrResult) {
-  if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
+  if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
     if (!type.hasRank()) {
       op->emitOpError() << "failed level check: unranked tensor";
       return false;
@@ -800,18 +823,21 @@ inline bool CompatibleTypes(const mlir::Type &type,
 }
 
 bool TosaValidation::CheckVariable(Operation *op) {
-  if (isa<mlir::tosa::VariableOp>(op)) {
-    mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
+  if (auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) {
+    mlir::StringAttr nameAttr = variableOp.getNameAttr();
 
     if (variablesMap.count(nameAttr)) {
       op->emitOpError() << "name has already been declared";
       return false;
     }
 
-    auto typeAttr = cast<mlir::TypeAttr>(op->getAttr("type"));
-    mlir::Type type = typeAttr.getValue();
+    auto elementType = variableOp.getType();
+    DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
+    SmallVector<int64_t> shape = to_vector(varShapeAttr.getValues<int64_t>());
+    RankedTensorType variableType =
+        RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
 
-    variablesMap[nameAttr] = type;
+    variablesMap[nameAttr] = variableType;
   }
 
   return true;
diff --git a/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir b/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir
index 365b05ff084da..d2092753f1f58 100644
--- a/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir
+++ b/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --tosa-to-mlprogram %s -o -| FileCheck %s
+// RUN: mlir-opt --tosa-to-mlprogram %s -split-input-file -o -| FileCheck %s
 
 module {
   // CHECK: ml_program.global private mutable @var_x(dense<7.000000e+00> : tensor<1xf32>) : tensor<1xf32>
@@ -10,4 +10,18 @@ module {
     %0 = tosa.variable_read @var_x : tensor<1xf32>
     return %0 : tensor<1xf32>
   }
+}
+
+// -----
+
+module {
+  // CHECK: ml_program.global private mutable @var_x : tensor<f32>
+  tosa.variable @var_x : tensor<f32>
+  func.func @test_stateful_ops(%arg0: tensor<f32>) -> (tensor<f32>) {
+    // CHECK: ml_program.global_store @var_x = %arg0 : tensor<f32>
+    tosa.variable_write @var_x, %arg0 : tensor<f32>
+    // CHECK: %[[LOAD:.+]] = ml_program.global_load @var_x : tensor<f32>
+    %0 = tosa.variable_read @var_x : tensor<f32>
+    return %0 : tensor<f32>
+  }
 }
\ No newline at end of file
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index c41f079ec526c..05505c3671674 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -564,6 +564,23 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: ten
 
 // -----
 
+func.func @test_variable_unranked(%arg0: tensor<2x4x8xi8>) -> () {
+  tosa.variable @stored_var : tensor<*xi8>
+  // expected-error at +1 {{custom op 'tosa.variable' expected ranked type}}
+  return
+}
+
+// -----
+
+func.func @test_variable_unranked_initial_value(%arg0: tensor<2x4x8xi8>) -> () {
+  // expected-error at +1 {{elements literal type must have static shape}}
+  tosa.variable @stored_var = dense<0> : tensor<*xi8>
+  // expected-error at +1 {{custom op 'tosa.variable' expected attribute}}
+  return
+}
+
+// -----
+
 func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
   // expected-error at +1 {{'tosa.variable' op illegal to have multiple declaration of 'stored_var'}}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index e7d0a0e1fa4ea..223bf3b635e18 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -443,7 +443,7 @@ func.func @test_rescale_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi8>) -> tenso
 
 // -----
 func.func @test_const(%arg0 : tensor<1x1xi32>) -> tensor<1x1x1x1x1x1x1xi32> {
-  // expected-error at +1 {{'tosa.const' op failed level check: attribute rank(shape) <= MAX_RANK}}
+  // expected-error at +1 {{'tosa.const' op failed level check: result rank(shape) <= MAX_RANK}}
   %0 = "tosa.const"() {values = dense<0> : tensor<1x1x1x1x1x1x1xi32>} : () -> tensor<1x1x1x1x1x1x1xi32>
   return %0: tensor<1x1x1x1x1x1x1xi32>
 }
@@ -1089,7 +1089,8 @@ func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x210000000x3xf32>, %
 // -----
 
 func.func @test_variable_read_write_tensor_size_invalid() -> () {
-  tosa.variable @stored_var = dense<3.14> : tensor<536870912xf32>
+  // expected-error at +1 {{'tosa.variable' op failed level check: variable type tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
+  tosa.variable @stored_var : tensor<536870912xf32>
   // expected-error at +1 {{'tosa.variable_read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
   %0 = tosa.variable_read @stored_var : tensor<536870912xf32>
   // expected-error at +1 {{'tosa.variable_write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
@@ -1156,8 +1157,8 @@ func.func @test_cond_if_rank_invalid(%arg0: tensor<1x1x1x1x1x1x1x1xf32>, %arg1:
 // -----
 
 func.func @test_variable_read_write_rank_invalid() -> () {
-  // expected-error at +1 {{'tosa.variable' op failed level check: attribute rank(shape) <= MAX_RANK}}
-  tosa.variable @stored_var = dense<3.14> : tensor<1x1x1x1x1x1x1x1xf32>
+  // expected-error at +1 {{'tosa.variable' op failed level check: variable type rank(shape) <= MAX_RANK}}
+  tosa.variable @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
   // expected-error at +1 {{'tosa.variable_read' op failed level check: result rank(shape) <= MAX_RANK}}
   %0 = tosa.variable_read @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
   // expected-error at +1 {{'tosa.variable_write' op failed level check: operand rank(shape) <= MAX_RANK}}
diff --git a/mlir/test/Dialect/Tosa/variables.mlir b/mlir/test/Dialect/Tosa/variables.mlir
index 25f63331f39df..9953eb375d3ac 100644
--- a/mlir/test/Dialect/Tosa/variables.mlir
+++ b/mlir/test/Dialect/Tosa/variables.mlir
@@ -31,3 +31,48 @@ func.func @test_variable_tensor(%arg0: tensor<2x4x8xi32>) -> () {
   tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32>
   return
 }
+
+// -----
+// CHECK-LABEL:   @test_variable_scalar_no_initial_value(
+// CHECK-SAME:                        %[[ADD_VAL:.*]]: tensor<f32>) {
+func.func @test_variable_scalar_no_initial_value(%arg0: tensor<f32>) -> () {
+  // CHECK:           tosa.variable @stored_var : tensor<f32>
+  tosa.variable @stored_var : tensor<f32>
+  // CHECK:           %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<f32>
+  %0 = tosa.variable_read @stored_var : tensor<f32>
+  // CHECK:           %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  %1 = "tosa.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  // CHECK:           tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<f32>
+  tosa.variable_write @stored_var, %1 : tensor<f32>
+  return
+}
+
+// -----
+// CHECK-LABEL:   @test_variable_tensor_no_initial_value(
+// CHECK-SAME:                        %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) {
+func.func @test_variable_tensor_no_initial_value(%arg0: tensor<2x4x8xi32>) -> () {
+  // CHECK:           tosa.variable @stored_var : tensor<2x4x8xi32>
+  tosa.variable @stored_var : tensor<2x4x8xi32>
+  // CHECK:           %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+  %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+  // CHECK:           %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+  %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+  // CHECK:           tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
+  tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32>
+  return
+}
+
+// -----
+// CHECK-LABEL:   @test_variable_tensor_with_unknowns(
+// CHECK-SAME:                        %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) {
+func.func @test_variable_tensor_with_unknowns(%arg0: tensor<2x4x8xi32>) -> () {
+  // CHECK:           tosa.variable @stored_var : tensor<2x?x8xi32>
+  tosa.variable @stored_var : tensor<2x?x8xi32>
+  // CHECK:           %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+  %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+  // CHECK:           %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+  %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+  // CHECK:           tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
+  tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32>
+  return
+}



More information about the Mlir-commits mailing list