[Mlir-commits] [mlir] 940d3e0 - [mlir][tosa] Create a profile validation pass for TOSA dialect
Rob Suderman
llvmlistbot at llvm.org
Mon Nov 14 17:43:37 PST 2022
Author: TatWai Chong
Date: 2022-11-14T17:29:50-08:00
New Revision: 940d3e08cf05bcd2779f6d3879850b2606274e3f
URL: https://github.com/llvm/llvm-project/commit/940d3e08cf05bcd2779f6d3879850b2606274e3f
DIFF: https://github.com/llvm/llvm-project/commit/940d3e08cf05bcd2779f6d3879850b2606274e3f.diff
LOG: [mlir][tosa] Create a profile validation pass for TOSA dialect
Add a separate validation pass to check if TOSA operations match with
the specification against given requirement. Perform profile type
checking as the initial feature in the pass.
This is an optional pass that can be enabled via command line. e.g.
$mlir-opt --tosa-validate="profile=bi" for validating against the
base inference profile.
Description:
TOSA defines a variety of operator behavior and requirements in the
specification. It would be helpful to have a separate validation pass
to keep TOSA operation input match with TOSA specification for given
criteria, and also diminish the burden of dialect validation during
compilation.
TOSA supports three profiles of which two are for inference purposes.
The main inference profile supports both integer and floating-point
data types, but the base inference profile only supports integers.
In this initial PR, validate the operations against a given profile
of TOSA, so that validation would fail if a floating point tensor is
present when the base inference profile is selected. Afterward, others
checking will be added to the pass if needed. e.g. control flow
operators and custom operators validation.
The pass is expected to be able to run on any point of TOSA dialect
conversion/transformation pipeline, and not depend on a particular
pass run ahead. So that it is can be used to validate the initial tosa
operations just converted from other dialects, the intermediate form,
or the final tosa operations output.
Change-Id: Ib58349c873c783056e89d2ab3b3312b8d2c61863
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D137279
Added:
mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
Modified:
mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
index b1363b5a179df..d4e2661838314 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,5 +1,7 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name TosaOpt)
+mlir_tablegen(PassesEnums.h.inc -gen-enum-decls)
+mlir_tablegen(PassesEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRTosaPassIncGen)
add_dependencies(mlir-headers MLIRTosaPassIncGen)
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 9de328a83f87d..d6ae78196f4cb 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -14,6 +14,7 @@
#define MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/Transforms/PassesEnums.h.inc"
#include "mlir/Pass/Pass.h"
namespace mlir {
@@ -37,6 +38,7 @@ std::unique_ptr<Pass> createTosaInferShapesPass();
std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
std::unique_ptr<Pass> createTosaOptionalDecompositions();
+std::unique_ptr<Pass> createTosaValidationPass();
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 46bd7a4780e00..c1334bebe3cc8 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -13,6 +13,7 @@
#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
#define MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
+include "mlir/IR/EnumAttr.td"
include "mlir/Pass/PassBase.td"
def TosaLayerwiseConstantFoldPass : Pass<"tosa-layerwise-constant-fold", "func::FuncOp"> {
@@ -63,4 +64,28 @@ def TosaOptionalDecompositions
let constructor = "tosa::createTosaOptionalDecompositions()";
}
+def TosaProfileType : I32EnumAttr<"TosaProfileEnum", "Tosa profile",
+ [
+ I32EnumAttrCase<"BaseInference", 0, "bi">,
+ I32EnumAttrCase<"MainInference", 1, "mi">,
+ I32EnumAttrCase<"MainTraining", 2, "mt">,
+ I32EnumAttrCase<"Undefined", 3>
+ ]>{
+ let cppNamespace = "mlir::tosa";
+}
+
+def TosaValidation : Pass<"tosa-validate", "func::FuncOp"> {
+ let summary = "Validates TOSA dialect";
+ let description = [{
+ This pass validates if input TOSA operations match the specification for given
+ criteria, e.g. TOSA profile.
+ }];
+ let constructor = "createTosaValidationPass()";
+
+ let options = [
+ Option<"profileName", "profile", "std::string",
+ /*default=*/"\"undefined\"",
+ "Validation if ops match for given profile">];
+}
+
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 1cb6e20833150..5290923c25b8a 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -84,5 +84,6 @@ void mlir::tosa::addTosaToLinalgPasses(OpPassManager &pm,
// TODO: Remove pass that operates on const tensor and enable optionality
pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass());
pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
+ pm.addNestedPass<func::FuncOp>(tosa::createTosaValidationPass());
pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index ae552693d7310..4f5a54de0c734 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
TosaLayerwiseConstantFoldPass.cpp
TosaMakeBroadcastable.cpp
TosaOptionalDecompositions.cpp
+ TosaValidation.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
new file mode 100644
index 0000000000000..36a7a3c9eadb5
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -0,0 +1,68 @@
+//===- TosaValidation.cpp ------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Validate if TOSA dialect input matchs with the specification for given
+// requirements.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace tosa {
+#define GEN_PASS_DEF_TOSAVALIDATION
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// TOSA Validation Pass.
+//===----------------------------------------------------------------------===//
+
+struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
+public:
+ explicit TosaValidation() {}
+
+private:
+ void runOnOperation() override;
+
+ llvm::Optional<TosaProfileEnum> profile_type;
+};
+
+void TosaValidation::runOnOperation() {
+ profile_type = symbolizeEnum<TosaProfileEnum>(profileName);
+
+ getOperation().walk([&](Operation *op) {
+ for (Value operand : op->getOperands()) {
+ if ((profile_type == TosaProfileEnum::BaseInference) &&
+ getElementTypeOrSelf(operand).isa<FloatType>()) {
+ return signalPassFailure();
+ }
+ }
+ });
+}
+} // namespace
+
+std::unique_ptr<Pass> mlir::tosa::createTosaValidationPass() {
+ return std::make_unique<TosaValidation>();
+}
More information about the Mlir-commits
mailing list