[Mlir-commits] [mlir] [TOSA] Add StatefulOps to TOSA Dialect (PR #66843)

Tai Ly llvmlistbot at llvm.org
Fri Sep 22 09:12:41 PDT 2023


https://github.com/Tai78641 updated https://github.com/llvm/llvm-project/pull/66843

>From 98c12b744232f45f1b14b233ba305b34553821e3 Mon Sep 17 00:00:00 2001
From: Jerry Ge <jerry.ge at arm.com>
Date: Sun, 18 Sep 2022 19:38:14 -0700
Subject: [PATCH 1/2] Add StatefulOps to TOSA Dialect

This patch adds tosa.variable, tosa.variable.read and
tosa.variable.write operators and tests.

Signed-off-by: Jerry Ge <jerry.ge at arm.com>
Change-Id: I647e2e5c3762d7890b03f6aa7c09a29198b7d355
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h   |  5 +
 .../mlir/Dialect/Tosa/IR/TosaUtilOps.td       | 67 +++++++++++++
 .../mlir/Dialect/Tosa/Transforms/Passes.h     |  2 +-
 .../mlir/Dialect/Tosa/Transforms/Passes.td    |  2 +-
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 43 ++++++++
 .../Tosa/Transforms/TosaValidation.cpp        | 98 ++++++++++++++++++-
 mlir/test/Dialect/Tosa/invalid.mlir           | 45 +++++++++
 mlir/test/Dialect/Tosa/variables.mlir         | 33 +++++++
 8 files changed, 290 insertions(+), 5 deletions(-)
 create mode 100644 mlir/test/Dialect/Tosa/variables.mlir

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 555d9bea18ba4dc..a9bc3351f4cff05 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -34,6 +34,11 @@ class PatternRewriter;
 
 namespace tosa {
 
+ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
+                            Attribute &attr);
+void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
+                     Attribute attr);
+
 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc"
 
 } // namespace tosa
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index d75f5dffa8716c9..9731ae210d5a7d2 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -79,4 +79,71 @@ def Tosa_YieldOp : Tosa_Op<"yield", [
   let assemblyFormat = "$inputs attr-dict `:` type($inputs)";
 }
 
+//===----------------------------------------------------------------------===//
+// Operator: variable
+//===----------------------------------------------------------------------===//
+def Tosa_VariableOp : Tosa_Op<"variable", []> {
+  let summary = "Defines a variable";
+
+  let description = [{
+    Defines a new TOSA variable. This is a mutable value.
+    Modifications are expressed using read/write semantics.
+  }];
+
+  let arguments = (ins
+    SymbolNameAttr:$name,
+    TypeAttr:$type,
+    OptionalAttr<AnyAttr>:$initial_value
+  );
+
+  let assemblyFormat = [{
+    $name
+    attr-dict
+    custom<TypeOrAttr>($type, $initial_value)
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: variable.write
+//===----------------------------------------------------------------------===//
+def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
+  let summary = "write_buffer operator";
+
+  let description = [{
+    Assigns a value to pseudo-buffer resource holding a mutable tensor.
+  }];
+
+  let arguments = (ins
+    FlatSymbolRefAttr:$name,
+    AnyType:$value
+  );
+
+  let assemblyFormat = [{
+    $name attr-dict `,` $value `:` type($value)
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: variable.read
+//===----------------------------------------------------------------------===//
+def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
+  let summary = "read_buffer operator";
+
+  let description = [{
+    Reads the value from a pseudo-buffer resource holding a mutable tensor.
+  }];
+
+  let arguments = (ins
+    FlatSymbolRefAttr:$name
+  );
+
+  let results = (outs
+    AnyType:$value
+  );
+
+  let assemblyFormat = [{
+    $name attr-dict `:` type($value)
+  }];
+}
+
 #endif // TOSA_UTIL_OPS
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 6b5dd9c970703ee..409a3311c002c10 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -65,7 +65,7 @@ struct ValidationOptions {
   }
 };
 
-std::unique_ptr<Pass> createTosaValidationPass(
+std::unique_ptr<OperationPass<ModuleOp>> createTosaValidationPass(
     ValidationOptions const &options = ValidationOptions());
 
 #define GEN_PASS_REGISTRATION
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 18402b3e70647a9..0c313640052b494 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -82,7 +82,7 @@ def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level",
   let cppNamespace = "mlir::tosa";
 }
 
-def TosaValidation : Pass<"tosa-validate", "func::FuncOp"> {
+def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
   let summary = "Validates TOSA dialect";
   let description = [{
     This pass validates if input TOSA operations match the specification for given
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 616aad8c4aaf08f..f731049cb7fff36 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -97,6 +97,49 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
   return nullptr;
 }
 
+//===----------------------------------------------------------------------===//
+// Parsers and printers
+//===----------------------------------------------------------------------===//
+
+ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
+                                        Attribute &attr) {
+  if (succeeded(parser.parseOptionalEqual())) {
+    if (failed(parser.parseAttribute(attr))) {
+      return parser.emitError(parser.getCurrentLocation())
+             << "expected attribute";
+    }
+    if (auto typedAttr = attr.dyn_cast<TypedAttr>()) {
+      typeAttr = TypeAttr::get(typedAttr.getType());
+    }
+    return success();
+  }
+
+  Type type;
+  if (failed(parser.parseColonType(type))) {
+    return parser.emitError(parser.getCurrentLocation()) << "expected type";
+  }
+  typeAttr = TypeAttr::get(type);
+
+  return success();
+}
+
+void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
+                                 Attribute attr) {
+  bool needsSpace = false;
+  auto typedAttr = attr.dyn_cast_or_null<TypedAttr>();
+  if (!typedAttr || typedAttr.getType() != type.getValue()) {
+    p << ": ";
+    p.printAttribute(type);
+    needsSpace = true; // subsequent attr value needs a space separator
+  }
+  if (attr) {
+    if (needsSpace)
+      p << ' ';
+    p << "= ";
+    p.printAttribute(attr);
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Operator Verifiers.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 52885e69c3924f2..6fea37364f49ecb 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -14,6 +14,9 @@
 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
 #include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
 
+#include <string>
+#include <unordered_map>
+
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/IR/Builders.h"
@@ -101,7 +104,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     this->StrictOperationSpecAlignment = options.strictOperationSpecAlignment;
     this->level = options.level;
   }
-  void runOnOperation() override;
+  void runOnOperation() final;
 
   LogicalResult applyConstantOperandCheck(Operation *op) {
     for (auto &checker : const_checkers) {
@@ -113,6 +116,9 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
 
   LogicalResult applyLevelCheck(Operation *op);
 
+  // check variable read/write data types against variable declarations
+  LogicalResult applyVariableCheck(Operation *op);
+
 private:
   void populateConstantOperandChecks() {
     const_checkers.emplace_back(checkConstantOperandPad);
@@ -398,8 +404,12 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     }
   }
 
+  bool CheckVariable(Operation *op);
+  bool CheckVariableReadOrWrite(Operation *op);
+
   SmallVector<std::function<LogicalResult(Operation *)>> const_checkers;
   tosa_level_t tosa_level;
+  std::unordered_map<std::string, mlir::Type> variables_map;
 };
 
 LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
@@ -427,6 +437,83 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
   return success();
 }
 
+inline bool CompatibleTypes(const mlir::Type &type,
+                            const mlir::Type &declared_type) {
+  // for now, simply use type equality comparison
+  return type == declared_type;
+}
+
+bool TosaValidation::CheckVariable(Operation *op) {
+  if (isa<mlir::tosa::VariableOp>(op)) {
+    auto name_attr = dyn_cast<mlir::StringAttr>(op->getAttr("name"));
+    if (!name_attr) {
+      op->emitOpError() << "Name attribute is not StringAttr";
+      return false;
+    }
+    std::string name = name_attr.getValue().str();
+
+    if (variables_map.count(name)) {
+      op->emitOpError() << "name has already been declared";
+      return false;
+    }
+
+    auto type_attr = dyn_cast<mlir::TypeAttr>(op->getAttr("type"));
+    if (!type_attr) {
+      op->emitOpError() << "type attribute is not TypeAttr";
+      return false;
+    }
+    mlir::Type type = type_attr.getValue();
+
+    variables_map[name] = type;
+  }
+
+  return true;
+}
+
+bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
+  if (isa<mlir::tosa::VariableReadOp>(op) ||
+      isa<mlir::tosa::VariableWriteOp>(op)) {
+    auto name_attr = dyn_cast<mlir::FlatSymbolRefAttr>(op->getAttr("name"));
+    if (!name_attr) {
+      op->emitOpError() << "name attribute is not FlatSymbolRefAttr";
+      return false;
+    }
+    std::string name = name_attr.getValue().str();
+
+    if (!variables_map.count(name)) {
+      op->emitOpError() << "name has not been declared";
+      return false;
+    }
+
+    auto var_type = variables_map[name];
+
+    for (auto v : op->getOperands()) {
+      auto type = v.getType();
+      if (!CompatibleTypes(type, var_type)) {
+        op->emitOpError() << "operand type does not equal variable type";
+        return false;
+      }
+    }
+
+    for (auto v : op->getResults()) {
+      auto type = v.getType();
+      if (!CompatibleTypes(type, var_type)) {
+        op->emitOpError() << "result type does not equal variable type";
+        return false;
+      }
+    }
+  }
+
+  return true;
+}
+
+LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
+  if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
+    return failure();
+  }
+  return success();
+}
+
 void TosaValidation::runOnOperation() {
   configLevelAndProfile();
   getOperation().walk([&](Operation *op) {
@@ -440,18 +527,23 @@ void TosaValidation::runOnOperation() {
       }
     }
 
-    // Some uses of TOSA rely on the constant operands of particular operations.
+    // Some uses of TOSA rely on the constant operands of particular
+    // operations.
     if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op)))
       signalPassFailure();
 
     // do level checks
     if (failed(applyLevelCheck(op)))
       signalPassFailure();
+
+    // do variable type checks
+    if (failed(applyVariableCheck(op)))
+      signalPassFailure();
   });
 }
 } // namespace
 
-std::unique_ptr<Pass>
+std::unique_ptr<OperationPass<ModuleOp>>
 mlir::tosa::createTosaValidationPass(ValidationOptions const &options) {
   return std::make_unique<TosaValidation>(options);
 }
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 7c58bb10b9c5ed6..9233662e88db902 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -203,3 +203,48 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<
       : (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
     return %0 : tensor<1x7x7x9xf32>
 }
+
+// -----
+
+func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi32>) -> () {
+  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+  // expected-error at +1 {{'tosa.variable' op name has already been declared}}
+  tosa.variable @stored_var : tensor<1x4x8xi32>
+  return
+}
+
+// -----
+
+func.func @test_variable_read_type(%arg0: tensor<2x4x8xi32>) -> () {
+  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+  // expected-error at +1 {{'tosa.variable.read' op result type does not equal variable type}}
+  %0 = tosa.variable.read @stored_var : tensor<2x4x8xi16>
+  return
+}
+
+// -----
+
+func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi32>) -> () {
+  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+  // expected-error at +1 {{'tosa.variable.read' op result type does not equal variable type}}
+  %0 = tosa.variable.read @stored_var : tensor<1x4x8xi32>
+  return
+}
+
+// -----
+
+func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
+  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+  // expected-error at +1 {{'tosa.variable.write' op operand type does not equal variable type}}
+  tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16>
+  return
+}
+
+// -----
+
+func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
+  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+  // expected-error at +1 {{'tosa.variable.write' op operand type does not equal variable type}}
+  tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi32>
+  return
+}
diff --git a/mlir/test/Dialect/Tosa/variables.mlir b/mlir/test/Dialect/Tosa/variables.mlir
new file mode 100644
index 000000000000000..9a26aa0bc8bf4d5
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/variables.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
+
+
+// -----
+// CHECK-LABEL:   @test_variable_scalar(
+// CHECK-SAME:                        %[[ADD_VAL:.*]]: tensor<f32>) {
+func.func @test_variable_scalar(%arg0: tensor<f32>) -> () {
+  // CHECK:           tosa.variable @stored_var = dense<3.140000e+00> : tensor<f32>
+  tosa.variable @stored_var = dense<3.14> : tensor<f32>
+  // CHECK:           %[[STORED_VAL:.*]] = tosa.variable.read @stored_var : tensor<f32>
+  %0 = tosa.variable.read @stored_var : tensor<f32>
+  // CHECK:           %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  %1 = "tosa.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  // CHECK:           tosa.variable.write @stored_var, %[[RESULT_ADD]] : tensor<f32>
+  tosa.variable.write @stored_var, %1 : tensor<f32>
+  return
+}
+
+// -----
+// CHECK-LABEL:   @test_variable_tensor(
+// CHECK-SAME:                        %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) {
+func.func @test_variable_tensor(%arg0: tensor<2x4x8xi32>) -> () {
+  // CHECK:           tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+  // CHECK:           %[[STORED_VAL:.*]] = tosa.variable.read @stored_var : tensor<2x4x8xi32>
+  %0 = tosa.variable.read @stored_var : tensor<2x4x8xi32>
+  // CHECK:           %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+  %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+  // CHECK:           tosa.variable.write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
+  tosa.variable.write @stored_var, %1 : tensor<2x4x8xi32>
+  return
+}

>From 645f2bab965b60fff66b6af97f947b265237a7a9 Mon Sep 17 00:00:00 2001
From: Jerry Ge <jerry.ge at arm.com>
Date: Sun, 18 Sep 2022 19:38:14 -0700
Subject: [PATCH 2/2] Add StatefulOps to TOSA Dialect

This patch adds tosa.variable, tosa.variable.read and
tosa.variable.write operators and tests.

Signed-off-by: Jerry Ge <jerry.ge at arm.com>
Change-Id: I647e2e5c3762d7890b03f6aa7c09a29198b7d355
---
 .../Conversion/TosaToLinalg/TosaToLinalg.h    |  4 +-
 .../mlir/Dialect/Tosa/IR/TosaUtilOps.td       |  4 +-
 .../mlir/Dialect/Tosa/Transforms/Passes.h     |  3 --
 .../mlir/Dialect/Tosa/Transforms/Passes.td    |  1 -
 .../TosaToLinalg/TosaToLinalgPass.cpp         |  5 +--
 .../Tosa/Transforms/TosaValidation.cpp        | 40 +++++--------------
 6 files changed, 17 insertions(+), 40 deletions(-)

diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index 818d43ffe4e572e..0cb389d87859b74 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -35,8 +35,8 @@ std::unique_ptr<Pass> createTosaToLinalgNamed();
 void addTosaToLinalgPasses(
     OpPassManager &pm, bool disableTosaDecompositions = false,
     // Note: Default to 'none' level unless otherwise specified.
-    tosa::ValidationOptions const &validationOptions =
-        tosa::ValidationOptions().setLevel(tosa::TosaLevelEnum::None));
+    tosa::TosaValidationOptions const &validationOptions = {
+        tosa::TosaProfileEnum::Undefined, false, tosa::TosaLevelEnum::None});
 
 /// Populates conversion passes from TOSA dialect to Linalg dialect.
 void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index 9731ae210d5a7d2..f9f25da1b649dea 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -114,7 +114,7 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
   }];
 
   let arguments = (ins
-    FlatSymbolRefAttr:$name,
+    SymbolNameAttr:$name,
     AnyType:$value
   );
 
@@ -134,7 +134,7 @@ def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
   }];
 
   let arguments = (ins
-    FlatSymbolRefAttr:$name
+    SymbolNameAttr:$name
   );
 
   let results = (outs
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 409a3311c002c10..955d6182e81af3e 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -65,9 +65,6 @@ struct ValidationOptions {
   }
 };
 
-std::unique_ptr<OperationPass<ModuleOp>> createTosaValidationPass(
-    ValidationOptions const &options = ValidationOptions());
-
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
 
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 0c313640052b494..736795951b34326 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -88,7 +88,6 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
     This pass validates if input TOSA operations match the specification for given
     criteria, e.g. TOSA profile.
   }];
-  let constructor = "createTosaValidationPass()";
 
   let options = [
       Option<"profile", "profile", "mlir::tosa::TosaProfileEnum",
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index d7e867d92282395..0a9eaf8882d713d 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -76,7 +76,7 @@ std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
 
 void mlir::tosa::addTosaToLinalgPasses(
     OpPassManager &pm, bool disableTosaDecompositions,
-    tosa::ValidationOptions const &validationOptions) {
+    tosa::TosaValidationOptions const &validationOptions) {
   // Optional decompositions are designed to benefit linalg.
   if (!disableTosaDecompositions)
     pm.addNestedPass<func::FuncOp>(tosa::createTosaOptionalDecompositions());
@@ -89,7 +89,6 @@ void mlir::tosa::addTosaToLinalgPasses(
   // TODO: Remove pass that operates on const tensor and enable optionality
   pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass());
   pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
-  pm.addNestedPass<func::FuncOp>(
-      tosa::createTosaValidationPass(validationOptions));
+  pm.addNestedPass<func::FuncOp>(tosa::createTosaValidation(validationOptions));
   pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());
 }
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 6fea37364f49ecb..d686ce125c13516 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -99,9 +99,10 @@ static constexpr tosa_level_t TOSA_LEVEL_NONE = {0, 0, 0, 0};
 struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
 public:
   explicit TosaValidation() { populateConstantOperandChecks(); }
-  explicit TosaValidation(const ValidationOptions &options) : TosaValidation() {
+  explicit TosaValidation(const TosaValidationOptions &options)
+      : TosaValidation() {
     this->profile = options.profile;
-    this->StrictOperationSpecAlignment = options.strictOperationSpecAlignment;
+    this->StrictOperationSpecAlignment = options.StrictOperationSpecAlignment;
     this->level = options.level;
   }
   void runOnOperation() final;
@@ -409,7 +410,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
 
   SmallVector<std::function<LogicalResult(Operation *)>> const_checkers;
   tosa_level_t tosa_level;
-  std::unordered_map<std::string, mlir::Type> variables_map;
+  DenseMap<const mlir::StringAttr *, mlir::Type> variables_map;
 };
 
 LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
@@ -445,26 +446,17 @@ inline bool CompatibleTypes(const mlir::Type &type,
 
 bool TosaValidation::CheckVariable(Operation *op) {
   if (isa<mlir::tosa::VariableOp>(op)) {
-    auto name_attr = dyn_cast<mlir::StringAttr>(op->getAttr("name"));
-    if (!name_attr) {
-      op->emitOpError() << "Name attribute is not StringAttr";
-      return false;
-    }
-    std::string name = name_attr.getValue().str();
+    auto name_attr = cast<mlir::StringAttr>(op->getAttr("name"));
 
-    if (variables_map.count(name)) {
+    if (variables_map.count(&name_attr)) {
       op->emitOpError() << "name has already been declared";
       return false;
     }
 
-    auto type_attr = dyn_cast<mlir::TypeAttr>(op->getAttr("type"));
-    if (!type_attr) {
-      op->emitOpError() << "type attribute is not TypeAttr";
-      return false;
-    }
+    auto type_attr = cast<mlir::TypeAttr>(op->getAttr("type"));
     mlir::Type type = type_attr.getValue();
 
-    variables_map[name] = type;
+    variables_map[&name_attr] = type;
   }
 
   return true;
@@ -473,19 +465,14 @@ bool TosaValidation::CheckVariable(Operation *op) {
 bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
   if (isa<mlir::tosa::VariableReadOp>(op) ||
       isa<mlir::tosa::VariableWriteOp>(op)) {
-    auto name_attr = dyn_cast<mlir::FlatSymbolRefAttr>(op->getAttr("name"));
-    if (!name_attr) {
-      op->emitOpError() << "name attribute is not FlatSymbolRefAttr";
-      return false;
-    }
-    std::string name = name_attr.getValue().str();
+    auto name_attr = cast<mlir::StringAttr>(op->getAttr("name"));
 
-    if (!variables_map.count(name)) {
+    if (!variables_map.count(&name_attr)) {
       op->emitOpError() << "name has not been declared";
       return false;
     }
 
-    auto var_type = variables_map[name];
+    auto var_type = variables_map[&name_attr];
 
     for (auto v : op->getOperands()) {
       auto type = v.getType();
@@ -542,8 +529,3 @@ void TosaValidation::runOnOperation() {
   });
 }
 } // namespace
-
-std::unique_ptr<OperationPass<ModuleOp>>
-mlir::tosa::createTosaValidationPass(ValidationOptions const &options) {
-  return std::make_unique<TosaValidation>(options);
-}



More information about the Mlir-commits mailing list