[Mlir-commits] [mlir] [TOSA] Add StatefulOps to TOSA Dialect (PR #66843)
Tai Ly
llvmlistbot at llvm.org
Tue Sep 19 18:22:34 PDT 2023
https://github.com/Tai78641 created https://github.com/llvm/llvm-project/pull/66843
This patch adds tosa.variable, tosa.variable.read and tosa.variable.write operators and tests.
Change-Id: I647e2e5c3762d7890b03f6aa7c09a29198b7d355
>From 1cefd08297c76d2a57732fd6eb9aa31463c22783 Mon Sep 17 00:00:00 2001
From: Jerry Ge <jerry.ge at arm.com>
Date: Sun, 18 Sep 2022 19:38:14 -0700
Subject: [PATCH] Add StatefulOps to TOSA Dialect
This patch adds tosa.variable, tosa.variable.read and
tosa.variable.write operators and tests.
Signed-off-by: Jerry Ge <jerry.ge at arm.com>
Change-Id: I647e2e5c3762d7890b03f6aa7c09a29198b7d355
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h | 5 +
.../mlir/Dialect/Tosa/IR/TosaUtilOps.td | 67 +++++++++++++
.../mlir/Dialect/Tosa/Transforms/Passes.h | 2 +-
.../mlir/Dialect/Tosa/Transforms/Passes.td | 2 +-
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 43 ++++++++
.../Tosa/Transforms/TosaValidation.cpp | 98 ++++++++++++++++++-
mlir/test/Dialect/Tosa/invalid.mlir | 45 +++++++++
mlir/test/Dialect/Tosa/variables.mlir | 33 +++++++
8 files changed, 290 insertions(+), 5 deletions(-)
create mode 100644 mlir/test/Dialect/Tosa/variables.mlir
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 555d9bea18ba4dc..a9bc3351f4cff05 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -34,6 +34,11 @@ class PatternRewriter;
namespace tosa {
+ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
+ Attribute &attr);
+void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
+ Attribute attr);
+
#include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc"
} // namespace tosa
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index d75f5dffa8716c9..9731ae210d5a7d2 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -79,4 +79,71 @@ def Tosa_YieldOp : Tosa_Op<"yield", [
let assemblyFormat = "$inputs attr-dict `:` type($inputs)";
}
+//===----------------------------------------------------------------------===//
+// Operator: variable
+//===----------------------------------------------------------------------===//
+def Tosa_VariableOp : Tosa_Op<"variable", []> {
+ let summary = "Defines a variable";
+
+ let description = [{
+ Defines a new TOSA variable. This is a mutable value.
+ Modifications are expressed using read/write semantics.
+ }];
+
+ let arguments = (ins
+ SymbolNameAttr:$name,
+ TypeAttr:$type,
+ OptionalAttr<AnyAttr>:$initial_value
+ );
+
+ let assemblyFormat = [{
+ $name
+ attr-dict
+ custom<TypeOrAttr>($type, $initial_value)
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: variable.write
+//===----------------------------------------------------------------------===//
+def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
+ let summary = "write_buffer operator";
+
+ let description = [{
+ Assigns a value to pseudo-buffer resource holding a mutable tensor.
+ }];
+
+ let arguments = (ins
+ FlatSymbolRefAttr:$name,
+ AnyType:$value
+ );
+
+ let assemblyFormat = [{
+ $name attr-dict `,` $value `:` type($value)
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: variable.read
+//===----------------------------------------------------------------------===//
+def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
+ let summary = "read_buffer operator";
+
+ let description = [{
+ Reads the value from a pseudo-buffer resource holding a mutable tensor.
+ }];
+
+ let arguments = (ins
+ FlatSymbolRefAttr:$name
+ );
+
+ let results = (outs
+ AnyType:$value
+ );
+
+ let assemblyFormat = [{
+ $name attr-dict `:` type($value)
+ }];
+}
+
#endif // TOSA_UTIL_OPS
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 72846d5dbe48908..25e26eade88f1cf 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -63,7 +63,7 @@ struct ValidationOptions {
}
};
-std::unique_ptr<Pass> createTosaValidationPass(
+std::unique_ptr<OperationPass<ModuleOp>> createTosaValidationPass(
ValidationOptions const &options = ValidationOptions());
#define GEN_PASS_REGISTRATION
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 18402b3e70647a9..0c313640052b494 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -82,7 +82,7 @@ def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level",
let cppNamespace = "mlir::tosa";
}
-def TosaValidation : Pass<"tosa-validate", "func::FuncOp"> {
+def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
let summary = "Validates TOSA dialect";
let description = [{
This pass validates if input TOSA operations match the specification for given
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 616aad8c4aaf08f..f731049cb7fff36 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -97,6 +97,49 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
return nullptr;
}
+//===----------------------------------------------------------------------===//
+// Parsers and printers
+//===----------------------------------------------------------------------===//
+
+ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
+ Attribute &attr) {
+ if (succeeded(parser.parseOptionalEqual())) {
+ if (failed(parser.parseAttribute(attr))) {
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected attribute";
+ }
+ if (auto typedAttr = attr.dyn_cast<TypedAttr>()) {
+ typeAttr = TypeAttr::get(typedAttr.getType());
+ }
+ return success();
+ }
+
+ Type type;
+ if (failed(parser.parseColonType(type))) {
+ return parser.emitError(parser.getCurrentLocation()) << "expected type";
+ }
+ typeAttr = TypeAttr::get(type);
+
+ return success();
+}
+
+void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
+ Attribute attr) {
+ bool needsSpace = false;
+ auto typedAttr = attr.dyn_cast_or_null<TypedAttr>();
+ if (!typedAttr || typedAttr.getType() != type.getValue()) {
+ p << ": ";
+ p.printAttribute(type);
+ needsSpace = true; // subsequent attr value needs a space separator
+ }
+ if (attr) {
+ if (needsSpace)
+ p << ' ';
+ p << "= ";
+ p.printAttribute(attr);
+ }
+}
+
//===----------------------------------------------------------------------===//
// TOSA Operator Verifiers.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 52885e69c3924f2..6fea37364f49ecb 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -14,6 +14,9 @@
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
+#include <string>
+#include <unordered_map>
+
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/Builders.h"
@@ -101,7 +104,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
this->StrictOperationSpecAlignment = options.strictOperationSpecAlignment;
this->level = options.level;
}
- void runOnOperation() override;
+ void runOnOperation() final;
LogicalResult applyConstantOperandCheck(Operation *op) {
for (auto &checker : const_checkers) {
@@ -113,6 +116,9 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
LogicalResult applyLevelCheck(Operation *op);
+ // check variable read/write data types against variable declarations
+ LogicalResult applyVariableCheck(Operation *op);
+
private:
void populateConstantOperandChecks() {
const_checkers.emplace_back(checkConstantOperandPad);
@@ -398,8 +404,12 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
}
}
+ bool CheckVariable(Operation *op);
+ bool CheckVariableReadOrWrite(Operation *op);
+
SmallVector<std::function<LogicalResult(Operation *)>> const_checkers;
tosa_level_t tosa_level;
+ std::unordered_map<std::string, mlir::Type> variables_map;
};
LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
@@ -427,6 +437,83 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
return success();
}
+inline bool CompatibleTypes(const mlir::Type &type,
+ const mlir::Type &declared_type) {
+ // for now, simply use type equality comparison
+ return type == declared_type;
+}
+
+bool TosaValidation::CheckVariable(Operation *op) {
+ if (isa<mlir::tosa::VariableOp>(op)) {
+ auto name_attr = dyn_cast<mlir::StringAttr>(op->getAttr("name"));
+ if (!name_attr) {
+ op->emitOpError() << "Name attribute is not StringAttr";
+ return false;
+ }
+ std::string name = name_attr.getValue().str();
+
+ if (variables_map.count(name)) {
+ op->emitOpError() << "name has already been declared";
+ return false;
+ }
+
+ auto type_attr = dyn_cast<mlir::TypeAttr>(op->getAttr("type"));
+ if (!type_attr) {
+ op->emitOpError() << "type attribute is not TypeAttr";
+ return false;
+ }
+ mlir::Type type = type_attr.getValue();
+
+ variables_map[name] = type;
+ }
+
+ return true;
+}
+
+bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
+ if (isa<mlir::tosa::VariableReadOp>(op) ||
+ isa<mlir::tosa::VariableWriteOp>(op)) {
+ auto name_attr = dyn_cast<mlir::FlatSymbolRefAttr>(op->getAttr("name"));
+ if (!name_attr) {
+ op->emitOpError() << "name attribute is not FlatSymbolRefAttr";
+ return false;
+ }
+ std::string name = name_attr.getValue().str();
+
+ if (!variables_map.count(name)) {
+ op->emitOpError() << "name has not been declared";
+ return false;
+ }
+
+ auto var_type = variables_map[name];
+
+ for (auto v : op->getOperands()) {
+ auto type = v.getType();
+ if (!CompatibleTypes(type, var_type)) {
+ op->emitOpError() << "operand type does not equal variable type";
+ return false;
+ }
+ }
+
+ for (auto v : op->getResults()) {
+ auto type = v.getType();
+ if (!CompatibleTypes(type, var_type)) {
+ op->emitOpError() << "result type does not equal variable type";
+ return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
+ if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
+ return failure();
+ }
+ return success();
+}
+
void TosaValidation::runOnOperation() {
configLevelAndProfile();
getOperation().walk([&](Operation *op) {
@@ -440,18 +527,23 @@ void TosaValidation::runOnOperation() {
}
}
- // Some uses of TOSA rely on the constant operands of particular operations.
+ // Some uses of TOSA rely on the constant operands of particular
+ // operations.
if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op)))
signalPassFailure();
// do level checks
if (failed(applyLevelCheck(op)))
signalPassFailure();
+
+ // do variable type checks
+ if (failed(applyVariableCheck(op)))
+ signalPassFailure();
});
}
} // namespace
-std::unique_ptr<Pass>
+std::unique_ptr<OperationPass<ModuleOp>>
mlir::tosa::createTosaValidationPass(ValidationOptions const &options) {
return std::make_unique<TosaValidation>(options);
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 7c58bb10b9c5ed6..9233662e88db902 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -203,3 +203,48 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<
: (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
return %0 : tensor<1x7x7x9xf32>
}
+
+// -----
+
+func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi32>) -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ // expected-error at +1 {{'tosa.variable' op name has already been declared}}
+ tosa.variable @stored_var : tensor<1x4x8xi32>
+ return
+}
+
+// -----
+
+func.func @test_variable_read_type(%arg0: tensor<2x4x8xi32>) -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ // expected-error at +1 {{'tosa.variable.read' op result type does not equal variable type}}
+ %0 = tosa.variable.read @stored_var : tensor<2x4x8xi16>
+ return
+}
+
+// -----
+
+func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi32>) -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ // expected-error at +1 {{'tosa.variable.read' op result type does not equal variable type}}
+ %0 = tosa.variable.read @stored_var : tensor<1x4x8xi32>
+ return
+}
+
+// -----
+
+func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ // expected-error at +1 {{'tosa.variable.write' op operand type does not equal variable type}}
+ tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16>
+ return
+}
+
+// -----
+
+func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ // expected-error at +1 {{'tosa.variable.write' op operand type does not equal variable type}}
+ tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi32>
+ return
+}
diff --git a/mlir/test/Dialect/Tosa/variables.mlir b/mlir/test/Dialect/Tosa/variables.mlir
new file mode 100644
index 000000000000000..9a26aa0bc8bf4d5
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/variables.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
+
+
+// -----
+// CHECK-LABEL: @test_variable_scalar(
+// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<f32>) {
+func.func @test_variable_scalar(%arg0: tensor<f32>) -> () {
+ // CHECK: tosa.variable @stored_var = dense<3.140000e+00> : tensor<f32>
+ tosa.variable @stored_var = dense<3.14> : tensor<f32>
+ // CHECK: %[[STORED_VAL:.*]] = tosa.variable.read @stored_var : tensor<f32>
+ %0 = tosa.variable.read @stored_var : tensor<f32>
+ // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %1 = "tosa.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ // CHECK: tosa.variable.write @stored_var, %[[RESULT_ADD]] : tensor<f32>
+ tosa.variable.write @stored_var, %1 : tensor<f32>
+ return
+}
+
+// -----
+// CHECK-LABEL: @test_variable_tensor(
+// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) {
+func.func @test_variable_tensor(%arg0: tensor<2x4x8xi32>) -> () {
+ // CHECK: tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ // CHECK: %[[STORED_VAL:.*]] = tosa.variable.read @stored_var : tensor<2x4x8xi32>
+ %0 = tosa.variable.read @stored_var : tensor<2x4x8xi32>
+ // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+ %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+ // CHECK: tosa.variable.write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
+ tosa.variable.write @stored_var, %1 : tensor<2x4x8xi32>
+ return
+}
More information about the Mlir-commits
mailing list