[Mlir-commits] [mlir] [TOSA] Add StatefulOps to TOSA Dialect (PR #66843)
Tai Ly
llvmlistbot at llvm.org
Mon Oct 16 13:48:04 PDT 2023
https://github.com/Tai78641 updated https://github.com/llvm/llvm-project/pull/66843
>From 78c8db81e00a20819bd42def141920699ead7b6a Mon Sep 17 00:00:00 2001
From: Jerry Ge <jerry.ge at arm.com>
Date: Sun, 18 Sep 2022 19:38:14 -0700
Subject: [PATCH 1/2] Add StatefulOps to TOSA Dialect
This patch adds tosa.variable, tosa.variable.read and
tosa.variable.write operators and tests.
Signed-off-by: Jerry Ge <jerry.ge at arm.com>
Change-Id: I647e2e5c3762d7890b03f6aa7c09a29198b7d355
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h | 5 +
.../mlir/Dialect/Tosa/IR/TosaUtilOps.td | 67 +++++++++++++
.../mlir/Dialect/Tosa/Transforms/Passes.h | 2 +-
.../mlir/Dialect/Tosa/Transforms/Passes.td | 2 +-
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 43 ++++++++
.../Tosa/Transforms/TosaValidation.cpp | 98 ++++++++++++++++++-
mlir/test/Dialect/Tosa/invalid.mlir | 45 +++++++++
mlir/test/Dialect/Tosa/variables.mlir | 33 +++++++
8 files changed, 290 insertions(+), 5 deletions(-)
create mode 100644 mlir/test/Dialect/Tosa/variables.mlir
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 555d9bea18ba4dc..a9bc3351f4cff05 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -34,6 +34,11 @@ class PatternRewriter;
namespace tosa {
+ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
+ Attribute &attr);
+void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
+ Attribute attr);
+
#include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc"
} // namespace tosa
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index d75f5dffa8716c9..9731ae210d5a7d2 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -79,4 +79,71 @@ def Tosa_YieldOp : Tosa_Op<"yield", [
let assemblyFormat = "$inputs attr-dict `:` type($inputs)";
}
+//===----------------------------------------------------------------------===//
+// Operator: variable
+//===----------------------------------------------------------------------===//
+def Tosa_VariableOp : Tosa_Op<"variable", []> {
+ let summary = "Defines a variable";
+
+ let description = [{
+ Defines a new TOSA variable. This is a mutable value.
+ Modifications are expressed using read/write semantics.
+ }];
+
+ let arguments = (ins
+ SymbolNameAttr:$name,
+ TypeAttr:$type,
+ OptionalAttr<AnyAttr>:$initial_value
+ );
+
+ let assemblyFormat = [{
+ $name
+ attr-dict
+ custom<TypeOrAttr>($type, $initial_value)
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: variable.write
+//===----------------------------------------------------------------------===//
+def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
+ let summary = "write_buffer operator";
+
+ let description = [{
+ Assigns a value to pseudo-buffer resource holding a mutable tensor.
+ }];
+
+ let arguments = (ins
+ FlatSymbolRefAttr:$name,
+ AnyType:$value
+ );
+
+ let assemblyFormat = [{
+ $name attr-dict `,` $value `:` type($value)
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: variable.read
+//===----------------------------------------------------------------------===//
+def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
+ let summary = "read_buffer operator";
+
+ let description = [{
+ Reads the value from a pseudo-buffer resource holding a mutable tensor.
+ }];
+
+ let arguments = (ins
+ FlatSymbolRefAttr:$name
+ );
+
+ let results = (outs
+ AnyType:$value
+ );
+
+ let assemblyFormat = [{
+ $name attr-dict `:` type($value)
+ }];
+}
+
#endif // TOSA_UTIL_OPS
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 940aed107e2f916..085eabdbd164522 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -68,7 +68,7 @@ struct ValidationOptions {
}
};
-std::unique_ptr<Pass> createTosaValidationPass(
+std::unique_ptr<OperationPass<ModuleOp>> createTosaValidationPass(
ValidationOptions const &options = ValidationOptions());
#define GEN_PASS_REGISTRATION
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index ac100a6d75c7c08..2c4951a4173cf1c 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -89,7 +89,7 @@ def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level",
let cppNamespace = "mlir::tosa";
}
-def TosaValidation : Pass<"tosa-validate", "func::FuncOp"> {
+def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
let summary = "Validates TOSA dialect";
let description = [{
This pass validates if input TOSA operations match the specification for given
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 6db04fe38bcd356..ff34183f9a030a8 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -146,6 +146,49 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
return nullptr;
}
+//===----------------------------------------------------------------------===//
+// Parsers and printers
+//===----------------------------------------------------------------------===//
+
+ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
+ Attribute &attr) {
+ if (succeeded(parser.parseOptionalEqual())) {
+ if (failed(parser.parseAttribute(attr))) {
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected attribute";
+ }
+ if (auto typedAttr = attr.dyn_cast<TypedAttr>()) {
+ typeAttr = TypeAttr::get(typedAttr.getType());
+ }
+ return success();
+ }
+
+ Type type;
+ if (failed(parser.parseColonType(type))) {
+ return parser.emitError(parser.getCurrentLocation()) << "expected type";
+ }
+ typeAttr = TypeAttr::get(type);
+
+ return success();
+}
+
+void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
+ Attribute attr) {
+ bool needsSpace = false;
+ auto typedAttr = attr.dyn_cast_or_null<TypedAttr>();
+ if (!typedAttr || typedAttr.getType() != type.getValue()) {
+ p << ": ";
+ p.printAttribute(type);
+ needsSpace = true; // subsequent attr value needs a space separator
+ }
+ if (attr) {
+ if (needsSpace)
+ p << ' ';
+ p << "= ";
+ p.printAttribute(attr);
+ }
+}
+
//===----------------------------------------------------------------------===//
// TOSA Operator Verifiers.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 52885e69c3924f2..6fea37364f49ecb 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -14,6 +14,9 @@
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
+#include <string>
+#include <unordered_map>
+
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/Builders.h"
@@ -101,7 +104,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
this->StrictOperationSpecAlignment = options.strictOperationSpecAlignment;
this->level = options.level;
}
- void runOnOperation() override;
+ void runOnOperation() final;
LogicalResult applyConstantOperandCheck(Operation *op) {
for (auto &checker : const_checkers) {
@@ -113,6 +116,9 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
LogicalResult applyLevelCheck(Operation *op);
+ // check variable read/write data types against variable declarations
+ LogicalResult applyVariableCheck(Operation *op);
+
private:
void populateConstantOperandChecks() {
const_checkers.emplace_back(checkConstantOperandPad);
@@ -398,8 +404,12 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
}
}
+ bool CheckVariable(Operation *op);
+ bool CheckVariableReadOrWrite(Operation *op);
+
SmallVector<std::function<LogicalResult(Operation *)>> const_checkers;
tosa_level_t tosa_level;
+ std::unordered_map<std::string, mlir::Type> variables_map;
};
LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
@@ -427,6 +437,83 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
return success();
}
+inline bool CompatibleTypes(const mlir::Type &type,
+ const mlir::Type &declared_type) {
+ // for now, simply use type equality comparison
+ return type == declared_type;
+}
+
+bool TosaValidation::CheckVariable(Operation *op) {
+ if (isa<mlir::tosa::VariableOp>(op)) {
+ auto name_attr = dyn_cast<mlir::StringAttr>(op->getAttr("name"));
+ if (!name_attr) {
+ op->emitOpError() << "Name attribute is not StringAttr";
+ return false;
+ }
+ std::string name = name_attr.getValue().str();
+
+ if (variables_map.count(name)) {
+ op->emitOpError() << "name has already been declared";
+ return false;
+ }
+
+ auto type_attr = dyn_cast<mlir::TypeAttr>(op->getAttr("type"));
+ if (!type_attr) {
+ op->emitOpError() << "type attribute is not TypeAttr";
+ return false;
+ }
+ mlir::Type type = type_attr.getValue();
+
+ variables_map[name] = type;
+ }
+
+ return true;
+}
+
+bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
+ if (isa<mlir::tosa::VariableReadOp>(op) ||
+ isa<mlir::tosa::VariableWriteOp>(op)) {
+ auto name_attr = dyn_cast<mlir::FlatSymbolRefAttr>(op->getAttr("name"));
+ if (!name_attr) {
+ op->emitOpError() << "name attribute is not FlatSymbolRefAttr";
+ return false;
+ }
+ std::string name = name_attr.getValue().str();
+
+ if (!variables_map.count(name)) {
+ op->emitOpError() << "name has not been declared";
+ return false;
+ }
+
+ auto var_type = variables_map[name];
+
+ for (auto v : op->getOperands()) {
+ auto type = v.getType();
+ if (!CompatibleTypes(type, var_type)) {
+ op->emitOpError() << "operand type does not equal variable type";
+ return false;
+ }
+ }
+
+ for (auto v : op->getResults()) {
+ auto type = v.getType();
+ if (!CompatibleTypes(type, var_type)) {
+ op->emitOpError() << "result type does not equal variable type";
+ return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
+ if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
+ return failure();
+ }
+ return success();
+}
+
void TosaValidation::runOnOperation() {
configLevelAndProfile();
getOperation().walk([&](Operation *op) {
@@ -440,18 +527,23 @@ void TosaValidation::runOnOperation() {
}
}
- // Some uses of TOSA rely on the constant operands of particular operations.
+ // Some uses of TOSA rely on the constant operands of particular
+ // operations.
if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op)))
signalPassFailure();
// do level checks
if (failed(applyLevelCheck(op)))
signalPassFailure();
+
+ // do variable type checks
+ if (failed(applyVariableCheck(op)))
+ signalPassFailure();
});
}
} // namespace
-std::unique_ptr<Pass>
+std::unique_ptr<OperationPass<ModuleOp>>
mlir::tosa::createTosaValidationPass(ValidationOptions const &options) {
return std::make_unique<TosaValidation>(options);
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 7c58bb10b9c5ed6..9233662e88db902 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -203,3 +203,48 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<
: (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
return %0 : tensor<1x7x7x9xf32>
}
+
+// -----
+
+func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi32>) -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ // expected-error at +1 {{'tosa.variable' op name has already been declared}}
+ tosa.variable @stored_var : tensor<1x4x8xi32>
+ return
+}
+
+// -----
+
+func.func @test_variable_read_type(%arg0: tensor<2x4x8xi32>) -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ // expected-error at +1 {{'tosa.variable.read' op result type does not equal variable type}}
+ %0 = tosa.variable.read @stored_var : tensor<2x4x8xi16>
+ return
+}
+
+// -----
+
+func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi32>) -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ // expected-error at +1 {{'tosa.variable.read' op result type does not equal variable type}}
+ %0 = tosa.variable.read @stored_var : tensor<1x4x8xi32>
+ return
+}
+
+// -----
+
+func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ // expected-error at +1 {{'tosa.variable.write' op operand type does not equal variable type}}
+ tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16>
+ return
+}
+
+// -----
+
+func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ // expected-error at +1 {{'tosa.variable.write' op operand type does not equal variable type}}
+ tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi32>
+ return
+}
diff --git a/mlir/test/Dialect/Tosa/variables.mlir b/mlir/test/Dialect/Tosa/variables.mlir
new file mode 100644
index 000000000000000..9a26aa0bc8bf4d5
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/variables.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
+
+
+// -----
+// CHECK-LABEL: @test_variable_scalar(
+// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<f32>) {
+func.func @test_variable_scalar(%arg0: tensor<f32>) -> () {
+ // CHECK: tosa.variable @stored_var = dense<3.140000e+00> : tensor<f32>
+ tosa.variable @stored_var = dense<3.14> : tensor<f32>
+ // CHECK: %[[STORED_VAL:.*]] = tosa.variable.read @stored_var : tensor<f32>
+ %0 = tosa.variable.read @stored_var : tensor<f32>
+ // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %1 = "tosa.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ // CHECK: tosa.variable.write @stored_var, %[[RESULT_ADD]] : tensor<f32>
+ tosa.variable.write @stored_var, %1 : tensor<f32>
+ return
+}
+
+// -----
+// CHECK-LABEL: @test_variable_tensor(
+// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) {
+func.func @test_variable_tensor(%arg0: tensor<2x4x8xi32>) -> () {
+ // CHECK: tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ // CHECK: %[[STORED_VAL:.*]] = tosa.variable.read @stored_var : tensor<2x4x8xi32>
+ %0 = tosa.variable.read @stored_var : tensor<2x4x8xi32>
+ // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+ %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+ // CHECK: tosa.variable.write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
+ tosa.variable.write @stored_var, %1 : tensor<2x4x8xi32>
+ return
+}
>From ac7c15be514da0419277cb8baea1427b5896cec0 Mon Sep 17 00:00:00 2001
From: Jerry Ge <jerry.ge at arm.com>
Date: Sun, 18 Sep 2022 19:38:14 -0700
Subject: [PATCH 2/2] Add StatefulOps to TOSA Dialect
This patch adds tosa.variable, tosa.variable.read and
tosa.variable.write operators and tests.
Signed-off-by: Jerry Ge <jerry.ge at arm.com>
Change-Id: I647e2e5c3762d7890b03f6aa7c09a29198b7d355
---
.../Conversion/TosaToLinalg/TosaToLinalg.h | 4 +-
.../mlir/Dialect/Tosa/IR/TosaUtilOps.td | 4 +-
.../mlir/Dialect/Tosa/Transforms/Passes.h | 3 --
.../mlir/Dialect/Tosa/Transforms/Passes.td | 1 -
.../TosaToLinalg/TosaToLinalgPass.cpp | 5 +--
.../Tosa/Transforms/TosaValidation.cpp | 40 +++++--------------
6 files changed, 17 insertions(+), 40 deletions(-)
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index d8d4027500f99c6..c411010603ac61f 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -35,8 +35,8 @@ std::unique_ptr<Pass> createTosaToLinalgNamed();
void addTosaToLinalgPasses(
OpPassManager &pm, const TosaToLinalgOptions &options,
// Note: Default to 'none' level unless otherwise specified.
- tosa::ValidationOptions const &validationOptions =
- tosa::ValidationOptions().setLevel(tosa::TosaLevelEnum::None));
+ tosa::TosaValidationOptions const &validationOptions = {
+ tosa::TosaProfileEnum::Undefined, false, tosa::TosaLevelEnum::None});
/// Populates conversion passes from TOSA dialect to Linalg dialect.
void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index 9731ae210d5a7d2..f9f25da1b649dea 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -114,7 +114,7 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
}];
let arguments = (ins
- FlatSymbolRefAttr:$name,
+ SymbolNameAttr:$name,
AnyType:$value
);
@@ -134,7 +134,7 @@ def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
}];
let arguments = (ins
- FlatSymbolRefAttr:$name
+ SymbolNameAttr:$name
);
let results = (outs
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 085eabdbd164522..fbfc56dfe2cf4f1 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -68,9 +68,6 @@ struct ValidationOptions {
}
};
-std::unique_ptr<OperationPass<ModuleOp>> createTosaValidationPass(
- ValidationOptions const &options = ValidationOptions());
-
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 2c4951a4173cf1c..a0f670de20150fb 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -95,7 +95,6 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
This pass validates if input TOSA operations match the specification for given
criteria, e.g. TOSA profile.
}];
- let constructor = "createTosaValidationPass()";
let options = [
Option<"profile", "profile", "mlir::tosa::TosaProfileEnum",
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 718e34ced8d7e70..3c54f85b033b0b6 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -76,7 +76,7 @@ std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
void mlir::tosa::addTosaToLinalgPasses(
OpPassManager &pm, const TosaToLinalgOptions &options,
- tosa::ValidationOptions const &validationOptions) {
+ tosa::TosaValidationOptions const &validationOptions) {
// Optional decompositions are designed to benefit linalg.
if (!options.disableTosaDecompositions)
pm.addNestedPass<func::FuncOp>(tosa::createTosaOptionalDecompositions());
@@ -90,7 +90,6 @@ void mlir::tosa::addTosaToLinalgPasses(
pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass(
{options.aggressiveReduceConstant}));
pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
- pm.addNestedPass<func::FuncOp>(
- tosa::createTosaValidationPass(validationOptions));
+ pm.addNestedPass<func::FuncOp>(tosa::createTosaValidation(validationOptions));
pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 6fea37364f49ecb..d686ce125c13516 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -99,9 +99,10 @@ static constexpr tosa_level_t TOSA_LEVEL_NONE = {0, 0, 0, 0};
struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
public:
explicit TosaValidation() { populateConstantOperandChecks(); }
- explicit TosaValidation(const ValidationOptions &options) : TosaValidation() {
+ explicit TosaValidation(const TosaValidationOptions &options)
+ : TosaValidation() {
this->profile = options.profile;
- this->StrictOperationSpecAlignment = options.strictOperationSpecAlignment;
+ this->StrictOperationSpecAlignment = options.StrictOperationSpecAlignment;
this->level = options.level;
}
void runOnOperation() final;
@@ -409,7 +410,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
SmallVector<std::function<LogicalResult(Operation *)>> const_checkers;
tosa_level_t tosa_level;
- std::unordered_map<std::string, mlir::Type> variables_map;
+ DenseMap<const mlir::StringAttr *, mlir::Type> variables_map;
};
LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
@@ -445,26 +446,17 @@ inline bool CompatibleTypes(const mlir::Type &type,
bool TosaValidation::CheckVariable(Operation *op) {
if (isa<mlir::tosa::VariableOp>(op)) {
- auto name_attr = dyn_cast<mlir::StringAttr>(op->getAttr("name"));
- if (!name_attr) {
- op->emitOpError() << "Name attribute is not StringAttr";
- return false;
- }
- std::string name = name_attr.getValue().str();
+ auto name_attr = cast<mlir::StringAttr>(op->getAttr("name"));
- if (variables_map.count(name)) {
+ if (variables_map.count(&name_attr)) {
op->emitOpError() << "name has already been declared";
return false;
}
- auto type_attr = dyn_cast<mlir::TypeAttr>(op->getAttr("type"));
- if (!type_attr) {
- op->emitOpError() << "type attribute is not TypeAttr";
- return false;
- }
+ auto type_attr = cast<mlir::TypeAttr>(op->getAttr("type"));
mlir::Type type = type_attr.getValue();
- variables_map[name] = type;
+ variables_map[&name_attr] = type;
}
return true;
@@ -473,19 +465,14 @@ bool TosaValidation::CheckVariable(Operation *op) {
bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
if (isa<mlir::tosa::VariableReadOp>(op) ||
isa<mlir::tosa::VariableWriteOp>(op)) {
- auto name_attr = dyn_cast<mlir::FlatSymbolRefAttr>(op->getAttr("name"));
- if (!name_attr) {
- op->emitOpError() << "name attribute is not FlatSymbolRefAttr";
- return false;
- }
- std::string name = name_attr.getValue().str();
+ auto name_attr = cast<mlir::StringAttr>(op->getAttr("name"));
- if (!variables_map.count(name)) {
+ if (!variables_map.count(&name_attr)) {
op->emitOpError() << "name has not been declared";
return false;
}
- auto var_type = variables_map[name];
+ auto var_type = variables_map[&name_attr];
for (auto v : op->getOperands()) {
auto type = v.getType();
@@ -542,8 +529,3 @@ void TosaValidation::runOnOperation() {
});
}
} // namespace
-
-std::unique_ptr<OperationPass<ModuleOp>>
-mlir::tosa::createTosaValidationPass(ValidationOptions const &options) {
- return std::make_unique<TosaValidation>(options);
-}
More information about the Mlir-commits
mailing list