[Mlir-commits] [mlir] 940d3e0 - [mlir][tosa] Create a profile validation pass for TOSA dialect

Rob Suderman llvmlistbot at llvm.org
Mon Nov 14 17:43:37 PST 2022


Author: TatWai Chong
Date: 2022-11-14T17:29:50-08:00
New Revision: 940d3e08cf05bcd2779f6d3879850b2606274e3f

URL: https://github.com/llvm/llvm-project/commit/940d3e08cf05bcd2779f6d3879850b2606274e3f
DIFF: https://github.com/llvm/llvm-project/commit/940d3e08cf05bcd2779f6d3879850b2606274e3f.diff

LOG: [mlir][tosa] Create a profile validation pass for TOSA dialect

Add a separate validation pass to check if TOSA operations match with
the specification against given requirement. Perform profile type
checking as the initial feature in the pass.

This is an optional pass that can be enabled via command line. e.g.
$mlir-opt --tosa-validate="profile=bi" for validating against the
base inference profile.

Description:
TOSA defines a variety of operator behavior and requirements in the
specification. It would be helpful to have a separate validation pass
to keep TOSA operation input match with TOSA specification for given
criteria, and also diminish the burden of dialect validation during
compilation.

TOSA supports three profiles of which two are for inference purposes.
The main inference profile supports both integer and floating-point
data types, but the base inference profile only supports integers.
In this initial PR, validate the operations against a given profile
of TOSA, so that validation would fail if a floating point tensor is
present when the base inference profile is selected. Afterward, others
checking will be added to the pass if needed. e.g. control flow
operators and custom operators validation.

The pass is expected to be able to run on any point of TOSA dialect
conversion/transformation pipeline, and not depend on a particular
pass run ahead. So that it is can be used to validate the initial tosa
operations just converted from other dialects, the intermediate form,
or the final tosa operations output.

Change-Id: Ib58349c873c783056e89d2ab3b3312b8d2c61863

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D137279

Added: 
    mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Modified: 
    mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
    mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
    mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
    mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
index b1363b5a179df..d4e2661838314 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,5 +1,7 @@
 set(LLVM_TARGET_DEFINITIONS Passes.td)
 mlir_tablegen(Passes.h.inc -gen-pass-decls -name TosaOpt)
+mlir_tablegen(PassesEnums.h.inc -gen-enum-decls)
+mlir_tablegen(PassesEnums.cpp.inc -gen-enum-defs)
 add_public_tablegen_target(MLIRTosaPassIncGen)
 add_dependencies(mlir-headers MLIRTosaPassIncGen)
 

diff  --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 9de328a83f87d..d6ae78196f4cb 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -14,6 +14,7 @@
 #define MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H
 
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/Transforms/PassesEnums.h.inc"
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
@@ -37,6 +38,7 @@ std::unique_ptr<Pass> createTosaInferShapesPass();
 std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
 std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
 std::unique_ptr<Pass> createTosaOptionalDecompositions();
+std::unique_ptr<Pass> createTosaValidationPass();
 
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"

diff  --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 46bd7a4780e00..c1334bebe3cc8 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -13,6 +13,7 @@
 #ifndef MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
 #define MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
 
+include "mlir/IR/EnumAttr.td"
 include "mlir/Pass/PassBase.td"
 
 def TosaLayerwiseConstantFoldPass : Pass<"tosa-layerwise-constant-fold", "func::FuncOp"> {
@@ -63,4 +64,28 @@ def TosaOptionalDecompositions
   let constructor = "tosa::createTosaOptionalDecompositions()";
 }
 
+def TosaProfileType : I32EnumAttr<"TosaProfileEnum", "Tosa profile",
+    [
+      I32EnumAttrCase<"BaseInference", 0, "bi">,
+      I32EnumAttrCase<"MainInference", 1, "mi">,
+      I32EnumAttrCase<"MainTraining", 2, "mt">,
+      I32EnumAttrCase<"Undefined", 3>
+    ]>{
+  let cppNamespace = "mlir::tosa";
+}
+
+def TosaValidation : Pass<"tosa-validate", "func::FuncOp"> {
+  let summary = "Validates TOSA dialect";
+  let description = [{
+    This pass validates if input TOSA operations match the specification for given
+    criteria, e.g. TOSA profile.
+  }];
+  let constructor = "createTosaValidationPass()";
+
+  let options = [
+      Option<"profileName", "profile", "std::string",
+      /*default=*/"\"undefined\"",
+      "Validation if ops match for given profile">];
+}
+
 #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 1cb6e20833150..5290923c25b8a 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -84,5 +84,6 @@ void mlir::tosa::addTosaToLinalgPasses(OpPassManager &pm,
   // TODO: Remove pass that operates on const tensor and enable optionality
   pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass());
   pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
+  pm.addNestedPass<func::FuncOp>(tosa::createTosaValidationPass());
   pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());
 }

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index ae552693d7310..4f5a54de0c734 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
   TosaLayerwiseConstantFoldPass.cpp
   TosaMakeBroadcastable.cpp
   TosaOptionalDecompositions.cpp
+  TosaValidation.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
new file mode 100644
index 0000000000000..36a7a3c9eadb5
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -0,0 +1,68 @@
+//===- TosaValidation.cpp ------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Validate if TOSA dialect input matchs with the specification for given
+// requirements.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace tosa {
+#define GEN_PASS_DEF_TOSAVALIDATION
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// TOSA Validation Pass.
+//===----------------------------------------------------------------------===//
+
+struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
+public:
+  explicit TosaValidation() {}
+
+private:
+  void runOnOperation() override;
+
+  llvm::Optional<TosaProfileEnum> profile_type;
+};
+
+void TosaValidation::runOnOperation() {
+  profile_type = symbolizeEnum<TosaProfileEnum>(profileName);
+
+  getOperation().walk([&](Operation *op) {
+    for (Value operand : op->getOperands()) {
+      if ((profile_type == TosaProfileEnum::BaseInference) &&
+          getElementTypeOrSelf(operand).isa<FloatType>()) {
+        return signalPassFailure();
+      }
+    }
+  });
+}
+} // namespace
+
+std::unique_ptr<Pass> mlir::tosa::createTosaValidationPass() {
+  return std::make_unique<TosaValidation>();
+}


        


More information about the Mlir-commits mailing list