[Mlir-commits] [mlir] af972f0 - [TOSA] Add StatefulOps to TOSA Dialect (#66843)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 16 16:10:21 PDT 2023
Author: Tai Ly
Date: 2023-10-16T16:10:17-07:00
New Revision: af972f01c01843a9ffe41ff496154267fa387a51
URL: https://github.com/llvm/llvm-project/commit/af972f01c01843a9ffe41ff496154267fa387a51
DIFF: https://github.com/llvm/llvm-project/commit/af972f01c01843a9ffe41ff496154267fa387a51.diff
LOG: [TOSA] Add StatefulOps to TOSA Dialect (#66843)
This patch adds tosa.variable, tosa.variable.read and
tosa.variable.write operators and tests.
Change-Id: I647e2e5c3762d7890b03f6aa7c09a29198b7d355
---------
Signed-off-by: Jerry Ge <jerry.ge at arm.com>
Co-authored-by: Jerry Ge <jerry.ge at arm.com>
Added:
mlir/test/Dialect/Tosa/variables.mlir
Modified:
mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
mlir/test/Dialect/Tosa/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index d8d4027500f99c6..c411010603ac61f 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -35,8 +35,8 @@ std::unique_ptr<Pass> createTosaToLinalgNamed();
void addTosaToLinalgPasses(
OpPassManager &pm, const TosaToLinalgOptions &options,
// Note: Default to 'none' level unless otherwise specified.
- tosa::ValidationOptions const &validationOptions =
- tosa::ValidationOptions().setLevel(tosa::TosaLevelEnum::None));
+ tosa::TosaValidationOptions const &validationOptions = {
+ tosa::TosaProfileEnum::Undefined, false, tosa::TosaLevelEnum::None});
/// Populates conversion passes from TOSA dialect to Linalg dialect.
void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 555d9bea18ba4dc..a9bc3351f4cff05 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -34,6 +34,11 @@ class PatternRewriter;
namespace tosa {
+ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
+ Attribute &attr);
+void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
+ Attribute attr);
+
#include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc"
} // namespace tosa
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index d75f5dffa8716c9..f9f25da1b649dea 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -79,4 +79,71 @@ def Tosa_YieldOp : Tosa_Op<"yield", [
let assemblyFormat = "$inputs attr-dict `:` type($inputs)";
}
+//===----------------------------------------------------------------------===//
+// Operator: variable
+//===----------------------------------------------------------------------===//
+def Tosa_VariableOp : Tosa_Op<"variable", []> {
+ let summary = "Defines a variable";
+
+ let description = [{
+ Defines a new TOSA variable. This is a mutable value.
+ Modifications are expressed using read/write semantics.
+ }];
+
+ let arguments = (ins
+ SymbolNameAttr:$name,
+ TypeAttr:$type,
+ OptionalAttr<AnyAttr>:$initial_value
+ );
+
+ let assemblyFormat = [{
+ $name
+ attr-dict
+ custom<TypeOrAttr>($type, $initial_value)
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: variable.write
+//===----------------------------------------------------------------------===//
+def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
+ let summary = "write_buffer operator";
+
+ let description = [{
+ Assigns a value to pseudo-buffer resource holding a mutable tensor.
+ }];
+
+ let arguments = (ins
+ SymbolNameAttr:$name,
+ AnyType:$value
+ );
+
+ let assemblyFormat = [{
+ $name attr-dict `,` $value `:` type($value)
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: variable.read
+//===----------------------------------------------------------------------===//
+def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
+ let summary = "read_buffer operator";
+
+ let description = [{
+ Reads the value from a pseudo-buffer resource holding a mutable tensor.
+ }];
+
+ let arguments = (ins
+ SymbolNameAttr:$name
+ );
+
+ let results = (outs
+ AnyType:$value
+ );
+
+ let assemblyFormat = [{
+ $name attr-dict `:` type($value)
+ }];
+}
+
#endif // TOSA_UTIL_OPS
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 940aed107e2f916..fbfc56dfe2cf4f1 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -68,9 +68,6 @@ struct ValidationOptions {
}
};
-std::unique_ptr<Pass> createTosaValidationPass(
- ValidationOptions const &options = ValidationOptions());
-
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index ac100a6d75c7c08..a0f670de20150fb 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -89,13 +89,12 @@ def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level",
let cppNamespace = "mlir::tosa";
}
-def TosaValidation : Pass<"tosa-validate", "func::FuncOp"> {
+def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
let summary = "Validates TOSA dialect";
let description = [{
This pass validates if input TOSA operations match the specification for given
criteria, e.g. TOSA profile.
}];
- let constructor = "createTosaValidationPass()";
let options = [
Option<"profile", "profile", "mlir::tosa::TosaProfileEnum",
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 718e34ced8d7e70..3c54f85b033b0b6 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -76,7 +76,7 @@ std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
void mlir::tosa::addTosaToLinalgPasses(
OpPassManager &pm, const TosaToLinalgOptions &options,
- tosa::ValidationOptions const &validationOptions) {
+ tosa::TosaValidationOptions const &validationOptions) {
// Optional decompositions are designed to benefit linalg.
if (!options.disableTosaDecompositions)
pm.addNestedPass<func::FuncOp>(tosa::createTosaOptionalDecompositions());
@@ -90,7 +90,6 @@ void mlir::tosa::addTosaToLinalgPasses(
pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass(
{options.aggressiveReduceConstant}));
pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
- pm.addNestedPass<func::FuncOp>(
- tosa::createTosaValidationPass(validationOptions));
+ pm.addNestedPass<func::FuncOp>(tosa::createTosaValidation(validationOptions));
pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 6db04fe38bcd356..ff34183f9a030a8 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -146,6 +146,49 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
return nullptr;
}
+//===----------------------------------------------------------------------===//
+// Parsers and printers
+//===----------------------------------------------------------------------===//
+
+ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
+ Attribute &attr) {
+ if (succeeded(parser.parseOptionalEqual())) {
+ if (failed(parser.parseAttribute(attr))) {
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected attribute";
+ }
+ if (auto typedAttr = attr.dyn_cast<TypedAttr>()) {
+ typeAttr = TypeAttr::get(typedAttr.getType());
+ }
+ return success();
+ }
+
+ Type type;
+ if (failed(parser.parseColonType(type))) {
+ return parser.emitError(parser.getCurrentLocation()) << "expected type";
+ }
+ typeAttr = TypeAttr::get(type);
+
+ return success();
+}
+
+void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
+ Attribute attr) {
+ bool needsSpace = false;
+ auto typedAttr = attr.dyn_cast_or_null<TypedAttr>();
+ if (!typedAttr || typedAttr.getType() != type.getValue()) {
+ p << ": ";
+ p.printAttribute(type);
+ needsSpace = true; // subsequent attr value needs a space separator
+ }
+ if (attr) {
+ if (needsSpace)
+ p << ' ';
+ p << "= ";
+ p.printAttribute(attr);
+ }
+}
+
//===----------------------------------------------------------------------===//
// TOSA Operator Verifiers.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 52885e69c3924f2..d686ce125c13516 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -14,6 +14,9 @@
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
+#include <string>
+#include <unordered_map>
+
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/Builders.h"
@@ -96,12 +99,13 @@ static constexpr tosa_level_t TOSA_LEVEL_NONE = {0, 0, 0, 0};
struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
public:
explicit TosaValidation() { populateConstantOperandChecks(); }
- explicit TosaValidation(const ValidationOptions &options) : TosaValidation() {
+ explicit TosaValidation(const TosaValidationOptions &options)
+ : TosaValidation() {
this->profile = options.profile;
- this->StrictOperationSpecAlignment = options.strictOperationSpecAlignment;
+ this->StrictOperationSpecAlignment = options.StrictOperationSpecAlignment;
this->level = options.level;
}
- void runOnOperation() override;
+ void runOnOperation() final;
LogicalResult applyConstantOperandCheck(Operation *op) {
for (auto &checker : const_checkers) {
@@ -113,6 +117,9 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
LogicalResult applyLevelCheck(Operation *op);
+ // check variable read/write data types against variable declarations
+ LogicalResult applyVariableCheck(Operation *op);
+
private:
void populateConstantOperandChecks() {
const_checkers.emplace_back(checkConstantOperandPad);
@@ -398,8 +405,12 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
}
}
+ bool CheckVariable(Operation *op);
+ bool CheckVariableReadOrWrite(Operation *op);
+
SmallVector<std::function<LogicalResult(Operation *)>> const_checkers;
tosa_level_t tosa_level;
+ DenseMap<const mlir::StringAttr *, mlir::Type> variables_map;
};
LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
@@ -427,6 +438,69 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
return success();
}
+inline bool CompatibleTypes(const mlir::Type &type,
+ const mlir::Type &declared_type) {
+ // for now, simply use type equality comparison
+ return type == declared_type;
+}
+
+bool TosaValidation::CheckVariable(Operation *op) {
+ if (isa<mlir::tosa::VariableOp>(op)) {
+ auto name_attr = cast<mlir::StringAttr>(op->getAttr("name"));
+
+ if (variables_map.count(&name_attr)) {
+ op->emitOpError() << "name has already been declared";
+ return false;
+ }
+
+ auto type_attr = cast<mlir::TypeAttr>(op->getAttr("type"));
+ mlir::Type type = type_attr.getValue();
+
+ variables_map[&name_attr] = type;
+ }
+
+ return true;
+}
+
+bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
+ if (isa<mlir::tosa::VariableReadOp>(op) ||
+ isa<mlir::tosa::VariableWriteOp>(op)) {
+ auto name_attr = cast<mlir::StringAttr>(op->getAttr("name"));
+
+ if (!variables_map.count(&name_attr)) {
+ op->emitOpError() << "name has not been declared";
+ return false;
+ }
+
+ auto var_type = variables_map[&name_attr];
+
+ for (auto v : op->getOperands()) {
+ auto type = v.getType();
+ if (!CompatibleTypes(type, var_type)) {
+ op->emitOpError() << "operand type does not equal variable type";
+ return false;
+ }
+ }
+
+ for (auto v : op->getResults()) {
+ auto type = v.getType();
+ if (!CompatibleTypes(type, var_type)) {
+ op->emitOpError() << "result type does not equal variable type";
+ return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
+ if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
+ return failure();
+ }
+ return success();
+}
+
void TosaValidation::runOnOperation() {
configLevelAndProfile();
getOperation().walk([&](Operation *op) {
@@ -440,18 +514,18 @@ void TosaValidation::runOnOperation() {
}
}
- // Some uses of TOSA rely on the constant operands of particular operations.
+ // Some uses of TOSA rely on the constant operands of particular
+ // operations.
if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op)))
signalPassFailure();
// do level checks
if (failed(applyLevelCheck(op)))
signalPassFailure();
+
+ // do variable type checks
+ if (failed(applyVariableCheck(op)))
+ signalPassFailure();
});
}
} // namespace
-
-std::unique_ptr<Pass>
-mlir::tosa::createTosaValidationPass(ValidationOptions const &options) {
- return std::make_unique<TosaValidation>(options);
-}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 7c58bb10b9c5ed6..9233662e88db902 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -203,3 +203,48 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<
: (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
return %0 : tensor<1x7x7x9xf32>
}
+
+// -----
+
+func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi32>) -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ // expected-error at +1 {{'tosa.variable' op name has already been declared}}
+ tosa.variable @stored_var : tensor<1x4x8xi32>
+ return
+}
+
+// -----
+
+func.func @test_variable_read_type(%arg0: tensor<2x4x8xi32>) -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ // expected-error at +1 {{'tosa.variable.read' op result type does not equal variable type}}
+ %0 = tosa.variable.read @stored_var : tensor<2x4x8xi16>
+ return
+}
+
+// -----
+
+func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi32>) -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ // expected-error at +1 {{'tosa.variable.read' op result type does not equal variable type}}
+ %0 = tosa.variable.read @stored_var : tensor<1x4x8xi32>
+ return
+}
+
+// -----
+
+func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ // expected-error at +1 {{'tosa.variable.write' op operand type does not equal variable type}}
+ tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16>
+ return
+}
+
+// -----
+
+func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ // expected-error at +1 {{'tosa.variable.write' op operand type does not equal variable type}}
+ tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi32>
+ return
+}
diff --git a/mlir/test/Dialect/Tosa/variables.mlir b/mlir/test/Dialect/Tosa/variables.mlir
new file mode 100644
index 000000000000000..9a26aa0bc8bf4d5
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/variables.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
+
+
+// -----
+// CHECK-LABEL: @test_variable_scalar(
+// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<f32>) {
+func.func @test_variable_scalar(%arg0: tensor<f32>) -> () {
+ // CHECK: tosa.variable @stored_var = dense<3.140000e+00> : tensor<f32>
+ tosa.variable @stored_var = dense<3.14> : tensor<f32>
+ // CHECK: %[[STORED_VAL:.*]] = tosa.variable.read @stored_var : tensor<f32>
+ %0 = tosa.variable.read @stored_var : tensor<f32>
+ // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %1 = "tosa.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ // CHECK: tosa.variable.write @stored_var, %[[RESULT_ADD]] : tensor<f32>
+ tosa.variable.write @stored_var, %1 : tensor<f32>
+ return
+}
+
+// -----
+// CHECK-LABEL: @test_variable_tensor(
+// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) {
+func.func @test_variable_tensor(%arg0: tensor<2x4x8xi32>) -> () {
+ // CHECK: tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ // CHECK: %[[STORED_VAL:.*]] = tosa.variable.read @stored_var : tensor<2x4x8xi32>
+ %0 = tosa.variable.read @stored_var : tensor<2x4x8xi32>
+ // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+ %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+ // CHECK: tosa.variable.write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
+ tosa.variable.write @stored_var, %1 : tensor<2x4x8xi32>
+ return
+}
More information about the Mlir-commits
mailing list