[Mlir-commits] [mlir] [mlir][tosa] Align Variable ops to match with TOSA v1.0 spec (PR #130680)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 9 13:52:19 PDT 2025


https://github.com/Jerry-Ge updated https://github.com/llvm/llvm-project/pull/130680

>From 5cd5f7c914fd41ee437c0df5be2bc5c243d02330 Mon Sep 17 00:00:00 2001
From: Jerry Ge <jerry.ge at arm.com>
Date: Wed, 4 Dec 2024 00:29:10 +0000
Subject: [PATCH] [mlir][tosa] Align Variable ops to match with TOSA v1.0 spec

* updated AnyType:$value to Tosa_Tensor:$input1 and Tosa_Tensor:$output1 for VariableWrite and VriableRead Operators
* updated description discrepancies
* note: in the TOSA spec, we had var_shape attr, but it's already included
  in the TypeAttr:$type in MLIR

Signed-off-by: Jerry Ge <jerry.ge at arm.com>
Change-Id: I4cd0348cd4e306dbc2e0e53a89a9404d91fb44d4
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td | 16 ++++++++--------
 .../TosaToMLProgram/TosaToMLProgram.cpp          |  2 +-
 .../Tosa/Transforms/TosaProfileCompliance.cpp    |  8 +++++++-
 .../Dialect/Tosa/Transforms/TosaValidation.cpp   |  5 ++---
 .../TosaToLinalg/tosa-to-linalg-pipeline.mlir    |  4 ++--
 .../TosaToMLProgram/tosa-to-mlprogram.mlir       |  4 ++--
 mlir/test/Dialect/Tosa/invalid.mlir              | 16 ++++++++--------
 mlir/test/Dialect/Tosa/invalid_extension.mlir    | 10 +++++-----
 mlir/test/Dialect/Tosa/level_check.mlir          | 16 ++++++++--------
 mlir/test/Dialect/Tosa/variables.mlir            | 16 ++++++++--------
 10 files changed, 51 insertions(+), 46 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index 3b2ede1b1a1a2..0ab0a62f1cf11 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -109,9 +109,9 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
 }
 
 //===----------------------------------------------------------------------===//
-// Operator: variable.write
+// Operator: variable_write
 //===----------------------------------------------------------------------===//
-def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
+def Tosa_VariableWriteOp : Tosa_Op<"variable_write", []> {
   let summary = "write_buffer operator";
 
   let description = [{
@@ -120,7 +120,7 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
 
   let arguments = (ins
     SymbolNameAttr:$name,
-    AnyType:$value
+    Tosa_Tensor:$input1
   );
 
   list<Availability> availability = [
@@ -129,14 +129,14 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
   ];
 
   let assemblyFormat = [{
-    $name attr-dict `,` $value `:` type($value)
+    $name attr-dict `,` $input1 `:` type($input1)
   }];
 }
 
 //===----------------------------------------------------------------------===//
-// Operator: variable.read
+// Operator: variable_read
 //===----------------------------------------------------------------------===//
-def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
+def Tosa_VariableReadOp : Tosa_Op<"variable_read", []> {
   let summary = "read_buffer operator";
 
   let description = [{
@@ -148,7 +148,7 @@ def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
   );
 
   let results = (outs
-    AnyType:$value
+    Tosa_Tensor:$output1
   );
 
   list<Availability> availability = [
@@ -157,7 +157,7 @@ def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
   ];
 
   let assemblyFormat = [{
-    $name attr-dict `:` type($value)
+    $name attr-dict `:` type($output1)
   }];
 }
 
diff --git a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
index d134d8cdf485e..310566e692202 100644
--- a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
+++ b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
@@ -45,7 +45,7 @@ class VariableWriteOpConverter
     auto globalSymbolRef =
         SymbolRefAttr::get(rewriter.getContext(), op.getName());
     auto newVariableWrite = rewriter.create<ml_program::GlobalStoreOp>(
-        op.getLoc(), globalSymbolRef, op.getValue());
+        op.getLoc(), globalSymbolRef, op.getInput1());
     rewriter.replaceOp(op, newVariableWrite);
     return success();
   }
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index eb7981b313d1d..f8fd6d2365a05 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -226,6 +226,12 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
   return failure();
 }
 
+template <>
+LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {
+  addValue(op.getInput1());
+  return success();
+}
+
 template <>
 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::IfOp op) {
   addValue(op.getCondition());
@@ -280,6 +286,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
   POPULATE_PROFILE_INFO_CUSTOM(Rescale)
   POPULATE_PROFILE_INFO_CUSTOM(MatMul)
   POPULATE_PROFILE_INFO_CUSTOM(Variable)
+  POPULATE_PROFILE_INFO_CUSTOM(VariableWrite)
   POPULATE_PROFILE_INFO_CUSTOM(If)
   POPULATE_PROFILE_INFO_CUSTOM(While)
 
@@ -334,7 +341,6 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
   POPULATE_PROFILE_INFO_COMMON(Reverse)
   POPULATE_PROFILE_INFO_COMMON(Identity)
   POPULATE_PROFILE_INFO_COMMON(VariableRead)
-  POPULATE_PROFILE_INFO_COMMON(VariableWrite)
 
   // Type Invariant Extension, a capability extension that is independent
   // of the data type, meaning any compatible type can be used. No type
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 28e562c813eb3..e7d7d4bcd3d68 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -767,7 +767,7 @@ inline bool CompatibleTypes(const mlir::Type &type,
 
 bool TosaValidation::CheckVariable(Operation *op) {
   if (isa<mlir::tosa::VariableOp>(op)) {
-    auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
+    mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
 
     if (variablesMap.count(nameAttr)) {
       op->emitOpError() << "name has already been declared";
@@ -786,8 +786,7 @@ bool TosaValidation::CheckVariable(Operation *op) {
 bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
   if (isa<mlir::tosa::VariableReadOp>(op) ||
       isa<mlir::tosa::VariableWriteOp>(op)) {
-    auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
-
+    mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
     if (!variablesMap.count(nameAttr)) {
       op->emitOpError() << "name has not been declared";
       return false;
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
index 731e134ed1a07..37ed5cec00a0d 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
@@ -6,8 +6,8 @@
 // check that -tosa-validate of stateful ops kick in
 func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
-  // expected-error at +1 {{'tosa.variable.write' op operand type does not equal variable type}}
-  tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi8>
+  // expected-error at +1 {{'tosa.variable_write' op operand type does not equal variable type}}
+  tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
   return
 }
 
diff --git a/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir b/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir
index 69b6875987daf..365b05ff084da 100644
--- a/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir
+++ b/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir
@@ -5,9 +5,9 @@ module {
   tosa.variable @var_x = dense<7.000000e+00> : tensor<1xf32>
   func.func @test_stateful_ops(%arg0: tensor<1xf32>) -> (tensor<1xf32>) {
     // CHECK: ml_program.global_store @var_x = %arg0 : tensor<1xf32>
-    tosa.variable.write @var_x, %arg0 : tensor<1xf32>
+    tosa.variable_write @var_x, %arg0 : tensor<1xf32>
     // CHECK: %[[LOAD:.+]] = ml_program.global_load @var_x : tensor<1xf32>
-    %0 = tosa.variable.read @var_x : tensor<1xf32>
+    %0 = tosa.variable_read @var_x : tensor<1xf32>
     return %0 : tensor<1xf32>
   }
 }
\ No newline at end of file
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 12b2379a592c3..f93a10645471c 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -699,8 +699,8 @@ func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
 
 func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
-  // expected-error at +1 {{'tosa.variable.read' op result type does not equal variable type}}
-  %0 = tosa.variable.read @stored_var : tensor<2x4x8xi16>
+  // expected-error at +1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
+  %0 = tosa.variable_read @stored_var : tensor<2x4x8xi16>
   return
 }
 
@@ -708,8 +708,8 @@ func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
 
 func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
-  // expected-error at +1 {{'tosa.variable.read' op result type does not equal variable type}}
-  %0 = tosa.variable.read @stored_var : tensor<1x4x8xi32>
+  // expected-error at +1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
+  %0 = tosa.variable_read @stored_var : tensor<1x4x8xi32>
   return
 }
 
@@ -717,8 +717,8 @@ func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
 
 func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
-  // expected-error at +1 {{'tosa.variable.write' op operand type does not equal variable type}}
-  tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16>
+  // expected-error at +1 {{'tosa.variable_write' op illegal: operand/result data types not supported}}
+  tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16>
   return
 }
 
@@ -726,8 +726,8 @@ func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
 
 func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
-  // expected-error at +1 {{'tosa.variable.write' op operand type does not equal variable type}}
-  tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi8>
+  // expected-error at +1 {{'tosa.variable_write' op operand type does not equal variable type}}
+  tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
   return
 }
 
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 241e603e91c61..28962f3b9c262 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -313,17 +313,17 @@ func.func @test_identity(%arg0: tensor<13x21x3xi4>) -> tensor<13x21x3xi4> {
 func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
   // expected-error at +1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}}
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
-  // expected-error at +1 {{'tosa.variable.read' op illegal: requires [variable]}}
-  %0 = tosa.variable.read @stored_var : tensor<2x4x8xi16>
+  // expected-error at +1 {{'tosa.variable_read' op illegal: requires [variable]}}
+  %0 = tosa.variable_read @stored_var : tensor<2x4x8xi8>
   return
 }
 
 // -----
-func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
+func.func @test_variable_write_type(%arg0: tensor<2x4x8xi8>) -> () {
   // expected-error at +1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}}
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
-  // expected-error at +1 {{'tosa.variable.write' op illegal: requires [variable]}}
-  tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16>
+  // expected-error at +1 {{'tosa.variable_write' op illegal: requires [variable]}}
+  tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi8>
   return
 }
 
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 8c3b2e526e444..a35a1665bc24f 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1089,10 +1089,10 @@ func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x210000000x3xf32>, %
 
 func.func @test_variable_read_write_tensor_size_invalid() -> () {
   tosa.variable @stored_var = dense<3.14> : tensor<536870912xf32>
-  // expected-error at +1 {{'tosa.variable.read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
-  %0 = tosa.variable.read @stored_var : tensor<536870912xf32>
-  // expected-error at +1 {{'tosa.variable.write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
-  tosa.variable.write @stored_var, %0 : tensor<536870912xf32>
+  // expected-error at +1 {{'tosa.variable_read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
+  %0 = tosa.variable_read @stored_var : tensor<536870912xf32>
+  // expected-error at +1 {{'tosa.variable_write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
+  tosa.variable_write @stored_var, %0 : tensor<536870912xf32>
   return
 }
 
@@ -1157,10 +1157,10 @@ func.func @test_cond_if_rank_invalid(%arg0: tensor<1x1x1x1x1x1x1x1xf32>, %arg1:
 func.func @test_variable_read_write_rank_invalid() -> () {
   // expected-error at +1 {{'tosa.variable' op failed level check: attribute rank(shape) <= MAX_RANK}}
   tosa.variable @stored_var = dense<3.14> : tensor<1x1x1x1x1x1x1x1xf32>
-  // expected-error at +1 {{'tosa.variable.read' op failed level check: result rank(shape) <= MAX_RANK}}
-  %0 = tosa.variable.read @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
-  // expected-error at +1 {{'tosa.variable.write' op failed level check: operand rank(shape) <= MAX_RANK}}
-  tosa.variable.write @stored_var, %0 : tensor<1x1x1x1x1x1x1x1xf32>
+  // expected-error at +1 {{'tosa.variable_read' op failed level check: result rank(shape) <= MAX_RANK}}
+  %0 = tosa.variable_read @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
+  // expected-error at +1 {{'tosa.variable_write' op failed level check: operand rank(shape) <= MAX_RANK}}
+  tosa.variable_write @stored_var, %0 : tensor<1x1x1x1x1x1x1x1xf32>
   return
 }
 
diff --git a/mlir/test/Dialect/Tosa/variables.mlir b/mlir/test/Dialect/Tosa/variables.mlir
index 9a26aa0bc8bf4..6fa6b26155461 100644
--- a/mlir/test/Dialect/Tosa/variables.mlir
+++ b/mlir/test/Dialect/Tosa/variables.mlir
@@ -8,12 +8,12 @@
 func.func @test_variable_scalar(%arg0: tensor<f32>) -> () {
   // CHECK:           tosa.variable @stored_var = dense<3.140000e+00> : tensor<f32>
   tosa.variable @stored_var = dense<3.14> : tensor<f32>
-  // CHECK:           %[[STORED_VAL:.*]] = tosa.variable.read @stored_var : tensor<f32>
-  %0 = tosa.variable.read @stored_var : tensor<f32>
+  // CHECK:           %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<f32>
+  %0 = tosa.variable_read @stored_var : tensor<f32>
   // CHECK:           %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
   %1 = "tosa.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
-  // CHECK:           tosa.variable.write @stored_var, %[[RESULT_ADD]] : tensor<f32>
-  tosa.variable.write @stored_var, %1 : tensor<f32>
+  // CHECK:           tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<f32>
+  tosa.variable_write @stored_var, %1 : tensor<f32>
   return
 }
 
@@ -23,11 +23,11 @@ func.func @test_variable_scalar(%arg0: tensor<f32>) -> () {
 func.func @test_variable_tensor(%arg0: tensor<2x4x8xi32>) -> () {
   // CHECK:           tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
-  // CHECK:           %[[STORED_VAL:.*]] = tosa.variable.read @stored_var : tensor<2x4x8xi32>
-  %0 = tosa.variable.read @stored_var : tensor<2x4x8xi32>
+  // CHECK:           %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+  %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
   // CHECK:           %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
   %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
-  // CHECK:           tosa.variable.write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
-  tosa.variable.write @stored_var, %1 : tensor<2x4x8xi32>
+  // CHECK:           tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
+  tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32>
   return
 }



More information about the Mlir-commits mailing list