[Mlir-commits] [mlir] af972f0 - [TOSA] Add StatefulOps to TOSA Dialect (#66843)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 16 16:10:21 PDT 2023


Author: Tai Ly
Date: 2023-10-16T16:10:17-07:00
New Revision: af972f01c01843a9ffe41ff496154267fa387a51

URL: https://github.com/llvm/llvm-project/commit/af972f01c01843a9ffe41ff496154267fa387a51
DIFF: https://github.com/llvm/llvm-project/commit/af972f01c01843a9ffe41ff496154267fa387a51.diff

LOG: [TOSA] Add StatefulOps to TOSA Dialect (#66843)

This patch adds tosa.variable, tosa.variable.read and
tosa.variable.write operators and tests.


Change-Id: I647e2e5c3762d7890b03f6aa7c09a29198b7d355

---------

Signed-off-by: Jerry Ge <jerry.ge at arm.com>
Co-authored-by: Jerry Ge <jerry.ge at arm.com>

Added: 
    mlir/test/Dialect/Tosa/variables.mlir

Modified: 
    mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
    mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
    mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
    mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
    mlir/test/Dialect/Tosa/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index d8d4027500f99c6..c411010603ac61f 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -35,8 +35,8 @@ std::unique_ptr<Pass> createTosaToLinalgNamed();
 void addTosaToLinalgPasses(
     OpPassManager &pm, const TosaToLinalgOptions &options,
     // Note: Default to 'none' level unless otherwise specified.
-    tosa::ValidationOptions const &validationOptions =
-        tosa::ValidationOptions().setLevel(tosa::TosaLevelEnum::None));
+    tosa::TosaValidationOptions const &validationOptions = {
+        tosa::TosaProfileEnum::Undefined, false, tosa::TosaLevelEnum::None});
 
 /// Populates conversion passes from TOSA dialect to Linalg dialect.
 void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);

diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 555d9bea18ba4dc..a9bc3351f4cff05 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -34,6 +34,11 @@ class PatternRewriter;
 
 namespace tosa {
 
+ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
+                            Attribute &attr);
+void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
+                     Attribute attr);
+
 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc"
 
 } // namespace tosa

diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index d75f5dffa8716c9..f9f25da1b649dea 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -79,4 +79,71 @@ def Tosa_YieldOp : Tosa_Op<"yield", [
   let assemblyFormat = "$inputs attr-dict `:` type($inputs)";
 }
 
+//===----------------------------------------------------------------------===//
+// Operator: variable
+//===----------------------------------------------------------------------===//
+def Tosa_VariableOp : Tosa_Op<"variable", []> {
+  let summary = "Defines a variable";
+
+  let description = [{
+    Defines a new TOSA variable. This is a mutable value.
+    Modifications are expressed using read/write semantics.
+  }];
+
+  let arguments = (ins
+    SymbolNameAttr:$name,
+    TypeAttr:$type,
+    OptionalAttr<AnyAttr>:$initial_value
+  );
+
+  let assemblyFormat = [{
+    $name
+    attr-dict
+    custom<TypeOrAttr>($type, $initial_value)
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: variable.write
+//===----------------------------------------------------------------------===//
+def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
+  let summary = "write_buffer operator";
+
+  let description = [{
+    Assigns a value to pseudo-buffer resource holding a mutable tensor.
+  }];
+
+  let arguments = (ins
+    SymbolNameAttr:$name,
+    AnyType:$value
+  );
+
+  let assemblyFormat = [{
+    $name attr-dict `,` $value `:` type($value)
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: variable.read
+//===----------------------------------------------------------------------===//
+def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
+  let summary = "read_buffer operator";
+
+  let description = [{
+    Reads the value from a pseudo-buffer resource holding a mutable tensor.
+  }];
+
+  let arguments = (ins
+    SymbolNameAttr:$name
+  );
+
+  let results = (outs
+    AnyType:$value
+  );
+
+  let assemblyFormat = [{
+    $name attr-dict `:` type($value)
+  }];
+}
+
 #endif // TOSA_UTIL_OPS

diff  --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 940aed107e2f916..fbfc56dfe2cf4f1 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -68,9 +68,6 @@ struct ValidationOptions {
   }
 };
 
-std::unique_ptr<Pass> createTosaValidationPass(
-    ValidationOptions const &options = ValidationOptions());
-
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
 

diff  --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index ac100a6d75c7c08..a0f670de20150fb 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -89,13 +89,12 @@ def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level",
   let cppNamespace = "mlir::tosa";
 }
 
-def TosaValidation : Pass<"tosa-validate", "func::FuncOp"> {
+def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
   let summary = "Validates TOSA dialect";
   let description = [{
     This pass validates if input TOSA operations match the specification for given
     criteria, e.g. TOSA profile.
   }];
-  let constructor = "createTosaValidationPass()";
 
   let options = [
       Option<"profile", "profile", "mlir::tosa::TosaProfileEnum",

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 718e34ced8d7e70..3c54f85b033b0b6 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -76,7 +76,7 @@ std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
 
 void mlir::tosa::addTosaToLinalgPasses(
     OpPassManager &pm, const TosaToLinalgOptions &options,
-    tosa::ValidationOptions const &validationOptions) {
+    tosa::TosaValidationOptions const &validationOptions) {
   // Optional decompositions are designed to benefit linalg.
   if (!options.disableTosaDecompositions)
     pm.addNestedPass<func::FuncOp>(tosa::createTosaOptionalDecompositions());
@@ -90,7 +90,6 @@ void mlir::tosa::addTosaToLinalgPasses(
   pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass(
       {options.aggressiveReduceConstant}));
   pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
-  pm.addNestedPass<func::FuncOp>(
-      tosa::createTosaValidationPass(validationOptions));
+  pm.addNestedPass<func::FuncOp>(tosa::createTosaValidation(validationOptions));
   pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());
 }

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 6db04fe38bcd356..ff34183f9a030a8 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -146,6 +146,49 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
   return nullptr;
 }
 
+//===----------------------------------------------------------------------===//
+// Parsers and printers
+//===----------------------------------------------------------------------===//
+
+ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
+                                        Attribute &attr) {
+  if (succeeded(parser.parseOptionalEqual())) {
+    if (failed(parser.parseAttribute(attr))) {
+      return parser.emitError(parser.getCurrentLocation())
+             << "expected attribute";
+    }
+    if (auto typedAttr = attr.dyn_cast<TypedAttr>()) {
+      typeAttr = TypeAttr::get(typedAttr.getType());
+    }
+    return success();
+  }
+
+  Type type;
+  if (failed(parser.parseColonType(type))) {
+    return parser.emitError(parser.getCurrentLocation()) << "expected type";
+  }
+  typeAttr = TypeAttr::get(type);
+
+  return success();
+}
+
+void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
+                                 Attribute attr) {
+  bool needsSpace = false;
+  auto typedAttr = attr.dyn_cast_or_null<TypedAttr>();
+  if (!typedAttr || typedAttr.getType() != type.getValue()) {
+    p << ": ";
+    p.printAttribute(type);
+    needsSpace = true; // subsequent attr value needs a space separator
+  }
+  if (attr) {
+    if (needsSpace)
+      p << ' ';
+    p << "= ";
+    p.printAttribute(attr);
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Operator Verifiers.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 52885e69c3924f2..d686ce125c13516 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -14,6 +14,9 @@
 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
 #include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
 
+#include <string>
+#include <unordered_map>
+
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/IR/Builders.h"
@@ -96,12 +99,13 @@ static constexpr tosa_level_t TOSA_LEVEL_NONE = {0, 0, 0, 0};
 struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
 public:
   explicit TosaValidation() { populateConstantOperandChecks(); }
-  explicit TosaValidation(const ValidationOptions &options) : TosaValidation() {
+  explicit TosaValidation(const TosaValidationOptions &options)
+      : TosaValidation() {
     this->profile = options.profile;
-    this->StrictOperationSpecAlignment = options.strictOperationSpecAlignment;
+    this->StrictOperationSpecAlignment = options.StrictOperationSpecAlignment;
     this->level = options.level;
   }
-  void runOnOperation() override;
+  void runOnOperation() final;
 
   LogicalResult applyConstantOperandCheck(Operation *op) {
     for (auto &checker : const_checkers) {
@@ -113,6 +117,9 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
 
   LogicalResult applyLevelCheck(Operation *op);
 
+  // check variable read/write data types against variable declarations
+  LogicalResult applyVariableCheck(Operation *op);
+
 private:
   void populateConstantOperandChecks() {
     const_checkers.emplace_back(checkConstantOperandPad);
@@ -398,8 +405,12 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     }
   }
 
+  bool CheckVariable(Operation *op);
+  bool CheckVariableReadOrWrite(Operation *op);
+
   SmallVector<std::function<LogicalResult(Operation *)>> const_checkers;
   tosa_level_t tosa_level;
+  DenseMap<const mlir::StringAttr *, mlir::Type> variables_map;
 };
 
 LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
@@ -427,6 +438,69 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
   return success();
 }
 
+inline bool CompatibleTypes(const mlir::Type &type,
+                            const mlir::Type &declared_type) {
+  // for now, simply use type equality comparison
+  return type == declared_type;
+}
+
+bool TosaValidation::CheckVariable(Operation *op) {
+  if (isa<mlir::tosa::VariableOp>(op)) {
+    auto name_attr = cast<mlir::StringAttr>(op->getAttr("name"));
+
+    if (variables_map.count(&name_attr)) {
+      op->emitOpError() << "name has already been declared";
+      return false;
+    }
+
+    auto type_attr = cast<mlir::TypeAttr>(op->getAttr("type"));
+    mlir::Type type = type_attr.getValue();
+
+    variables_map[&name_attr] = type;
+  }
+
+  return true;
+}
+
+bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
+  if (isa<mlir::tosa::VariableReadOp>(op) ||
+      isa<mlir::tosa::VariableWriteOp>(op)) {
+    auto name_attr = cast<mlir::StringAttr>(op->getAttr("name"));
+
+    if (!variables_map.count(&name_attr)) {
+      op->emitOpError() << "name has not been declared";
+      return false;
+    }
+
+    auto var_type = variables_map[&name_attr];
+
+    for (auto v : op->getOperands()) {
+      auto type = v.getType();
+      if (!CompatibleTypes(type, var_type)) {
+        op->emitOpError() << "operand type does not equal variable type";
+        return false;
+      }
+    }
+
+    for (auto v : op->getResults()) {
+      auto type = v.getType();
+      if (!CompatibleTypes(type, var_type)) {
+        op->emitOpError() << "result type does not equal variable type";
+        return false;
+      }
+    }
+  }
+
+  return true;
+}
+
+LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
+  if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
+    return failure();
+  }
+  return success();
+}
+
 void TosaValidation::runOnOperation() {
   configLevelAndProfile();
   getOperation().walk([&](Operation *op) {
@@ -440,18 +514,18 @@ void TosaValidation::runOnOperation() {
       }
     }
 
-    // Some uses of TOSA rely on the constant operands of particular operations.
+    // Some uses of TOSA rely on the constant operands of particular
+    // operations.
     if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op)))
       signalPassFailure();
 
     // do level checks
     if (failed(applyLevelCheck(op)))
       signalPassFailure();
+
+    // do variable type checks
+    if (failed(applyVariableCheck(op)))
+      signalPassFailure();
   });
 }
 } // namespace
-
-std::unique_ptr<Pass>
-mlir::tosa::createTosaValidationPass(ValidationOptions const &options) {
-  return std::make_unique<TosaValidation>(options);
-}

diff  --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 7c58bb10b9c5ed6..9233662e88db902 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -203,3 +203,48 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<
       : (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
     return %0 : tensor<1x7x7x9xf32>
 }
+
+// -----
+
+func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi32>) -> () {
+  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+  // expected-error at +1 {{'tosa.variable' op name has already been declared}}
+  tosa.variable @stored_var : tensor<1x4x8xi32>
+  return
+}
+
+// -----
+
+func.func @test_variable_read_type(%arg0: tensor<2x4x8xi32>) -> () {
+  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+  // expected-error at +1 {{'tosa.variable.read' op result type does not equal variable type}}
+  %0 = tosa.variable.read @stored_var : tensor<2x4x8xi16>
+  return
+}
+
+// -----
+
+func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi32>) -> () {
+  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+  // expected-error at +1 {{'tosa.variable.read' op result type does not equal variable type}}
+  %0 = tosa.variable.read @stored_var : tensor<1x4x8xi32>
+  return
+}
+
+// -----
+
+func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
+  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+  // expected-error at +1 {{'tosa.variable.write' op operand type does not equal variable type}}
+  tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16>
+  return
+}
+
+// -----
+
+func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
+  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+  // expected-error at +1 {{'tosa.variable.write' op operand type does not equal variable type}}
+  tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi32>
+  return
+}

diff  --git a/mlir/test/Dialect/Tosa/variables.mlir b/mlir/test/Dialect/Tosa/variables.mlir
new file mode 100644
index 000000000000000..9a26aa0bc8bf4d5
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/variables.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
+
+
+// -----
+// CHECK-LABEL:   @test_variable_scalar(
+// CHECK-SAME:                        %[[ADD_VAL:.*]]: tensor<f32>) {
+func.func @test_variable_scalar(%arg0: tensor<f32>) -> () {
+  // CHECK:           tosa.variable @stored_var = dense<3.140000e+00> : tensor<f32>
+  tosa.variable @stored_var = dense<3.14> : tensor<f32>
+  // CHECK:           %[[STORED_VAL:.*]] = tosa.variable.read @stored_var : tensor<f32>
+  %0 = tosa.variable.read @stored_var : tensor<f32>
+  // CHECK:           %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  %1 = "tosa.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  // CHECK:           tosa.variable.write @stored_var, %[[RESULT_ADD]] : tensor<f32>
+  tosa.variable.write @stored_var, %1 : tensor<f32>
+  return
+}
+
+// -----
+// CHECK-LABEL:   @test_variable_tensor(
+// CHECK-SAME:                        %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) {
+func.func @test_variable_tensor(%arg0: tensor<2x4x8xi32>) -> () {
+  // CHECK:           tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+  // CHECK:           %[[STORED_VAL:.*]] = tosa.variable.read @stored_var : tensor<2x4x8xi32>
+  %0 = tosa.variable.read @stored_var : tensor<2x4x8xi32>
+  // CHECK:           %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+  %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+  // CHECK:           tosa.variable.write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
+  tosa.variable.write @stored_var, %1 : tensor<2x4x8xi32>
+  return
+}


        


More information about the Mlir-commits mailing list