[Mlir-commits] [mlir] 32b7c1f - [mlir][TOSA] Set default TOSA validation level to 'None' for TOSA -> linalg

Benjamin Maxwell llvmlistbot at llvm.org
Tue Aug 8 03:19:47 PDT 2023


Author: Benjamin Maxwell
Date: 2023-08-08T10:19:24Z
New Revision: 32b7c1ff3ed890bf3c7e56edbd4892400a915e1f

URL: https://github.com/llvm/llvm-project/commit/32b7c1ff3ed890bf3c7e56edbd4892400a915e1f
DIFF: https://github.com/llvm/llvm-project/commit/32b7c1ff3ed890bf3c7e56edbd4892400a915e1f.diff

LOG: [mlir][TOSA] Set default TOSA validation level to 'None' for TOSA -> linalg

Unless otherwise specified this pass should not assume a level, as this
rejects otherwise valid TOSA. This has caused build failures in IREE.

The level (and other validation options) have now been made configurable.

The pass options have been converted to enums to make them more type
safe in C++.

Reviewed By: Tai78641

Differential Revision: https://reviews.llvm.org/D157282

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
    mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
    mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index 789f632c7f547c..818d43ffe4e572 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_CONVERSION_TOSATOLINALG_TOSATOLINALG_H
 #define MLIR_CONVERSION_TOSATOLINALG_TOSATOLINALG_H
 
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
@@ -31,8 +32,11 @@ std::unique_ptr<Pass> createTosaToLinalgNamed();
 /// the pass, the function will only contain linalg ops or standard ops if the
 /// pipeline succeeds.  The option to disable decompositions is available for
 /// benchmarking performance improvements from the canonicalizations.
-void addTosaToLinalgPasses(OpPassManager &pm,
-                           bool disableTosaDecompositions = false);
+void addTosaToLinalgPasses(
+    OpPassManager &pm, bool disableTosaDecompositions = false,
+    // Note: Default to 'none' level unless otherwise specified.
+    tosa::ValidationOptions const &validationOptions =
+        tosa::ValidationOptions().setLevel(tosa::TosaLevelEnum::None));
 
 /// Populates conversion passes from TOSA dialect to Linalg dialect.
 void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);

diff  --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index c81f59b3d5d36a..72846d5dbe4890 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -40,7 +40,31 @@ std::unique_ptr<Pass> createTosaInferShapesPass();
 std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
 std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
 std::unique_ptr<Pass> createTosaOptionalDecompositions();
-std::unique_ptr<Pass> createTosaValidationPass();
+
+struct ValidationOptions {
+  /// Validate if operations match for the given profile.
+  TosaProfileEnum profile = TosaProfileEnum::Undefined;
+  ValidationOptions &setProfile(TosaProfileEnum profile) {
+    this->profile = profile;
+    return *this;
+  }
+  /// Verify if the properties of certain operations align the spec requirement.
+  bool strictOperationSpecAlignment = false;
+  ValidationOptions &enableStrictOperationSpecAlignment(bool enable = true) {
+    strictOperationSpecAlignment = enable;
+    return *this;
+  }
+  /// Validate if operator parameters are within specfication for the given
+  /// level.
+  TosaLevelEnum level = TosaLevelEnum::EightK;
+  ValidationOptions &setLevel(TosaLevelEnum level) {
+    this->level = level;
+    return *this;
+  }
+};
+
+std::unique_ptr<Pass> createTosaValidationPass(
+    ValidationOptions const &options = ValidationOptions());
 
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"

diff  --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 11f17dc5e66b75..bc30b88ea2af6a 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -91,15 +91,32 @@ def TosaValidation : Pass<"tosa-validate", "func::FuncOp"> {
   let constructor = "createTosaValidationPass()";
 
   let options = [
-      Option<"profileName", "profile", "std::string",
-             /*default=*/"\"undefined\"",
-             "Validate if operations match for the given profile">,
+      Option<"profile", "profile", "mlir::tosa::TosaProfileEnum",
+             /*default=*/"mlir::tosa::TosaProfileEnum::Undefined",
+             "Validate if operations match for the given profile",
+             [{::llvm::cl::values(
+               clEnumValN(mlir::tosa::TosaProfileEnum::BaseInference, "bi",
+                "Use Base Inference profile."),
+               clEnumValN(mlir::tosa::TosaProfileEnum::MainInference, "mi",
+                "Use Main Inference profile."),
+               clEnumValN(mlir::tosa::TosaProfileEnum::MainTraining, "mt",
+                "Use Main Training profile."),
+               clEnumValN(mlir::tosa::TosaProfileEnum::MainTraining, "undefined",
+                "Do not define a profile.")
+              )}]>,
       Option<"StrictOperationSpecAlignment", "strict-op-spec-alignment", "bool",
              /*default=*/"false",
              "Verify if the properties of certain operations align the spec requirement">,
-      Option<"levelName", "level", "std::string",
-             /*default=*/"\"8k\"",
-             "Validate if operator parameters are within specfication for the given level">,
+      Option<"level", "level", "mlir::tosa::TosaLevelEnum",
+             /*default=*/"mlir::tosa::TosaLevelEnum::EightK",
+             "Validate if operator parameters are within specfication for the given level",
+             [{::llvm::cl::values(
+               clEnumValN(mlir::tosa::TosaLevelEnum::EightK, "8k",
+                "Ranges are expected to be sufficient for applications with frame sizes up to 8K."),
+               clEnumValN(mlir::tosa::TosaLevelEnum::None, "none",
+                "Allows the full range of arguments specified by the operations according "
+                "to the operation data types.")
+              )}]>
    ];
 }
 

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index c1b6d1c60c73b1..d7e867d9228239 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -74,8 +74,9 @@ std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
   return std::make_unique<TosaToLinalg>();
 }
 
-void mlir::tosa::addTosaToLinalgPasses(OpPassManager &pm,
-                                       bool disableTosaDecompositions) {
+void mlir::tosa::addTosaToLinalgPasses(
+    OpPassManager &pm, bool disableTosaDecompositions,
+    tosa::ValidationOptions const &validationOptions) {
   // Optional decompositions are designed to benefit linalg.
   if (!disableTosaDecompositions)
     pm.addNestedPass<func::FuncOp>(tosa::createTosaOptionalDecompositions());
@@ -88,6 +89,7 @@ void mlir::tosa::addTosaToLinalgPasses(OpPassManager &pm,
   // TODO: Remove pass that operates on const tensor and enable optionality
   pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass());
   pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
-  pm.addNestedPass<func::FuncOp>(tosa::createTosaValidationPass());
+  pm.addNestedPass<func::FuncOp>(
+      tosa::createTosaValidationPass(validationOptions));
   pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());
 }

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 05da6ee6cad95d..6a5f2bd467dab5 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -96,6 +96,11 @@ static constexpr tosa_level_t TOSA_LEVEL_NONE = {0, 0, 0, 0};
 struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
 public:
   explicit TosaValidation() { populateConstantOperandChecks(); }
+  explicit TosaValidation(const ValidationOptions &options) : TosaValidation() {
+    this->profile = options.profile;
+    this->StrictOperationSpecAlignment = options.strictOperationSpecAlignment;
+    this->level = options.level;
+  }
   void runOnOperation() override;
 
   LogicalResult applyConstantOperandCheck(Operation *op) {
@@ -387,18 +392,13 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
   // configure profile and level values from pass options profileName and
   // levelName
   void configLevelAndProfile() {
-    profileType = symbolizeEnum<TosaProfileEnum>(profileName);
-
-    auto levelType = symbolizeEnum<TosaLevelEnum>(levelName);
-
     tosa_level = TOSA_LEVEL_NONE;
-    if (levelType == TosaLevelEnum::EightK) {
+    if (level == TosaLevelEnum::EightK) {
       tosa_level = TOSA_LEVEL_EIGHTK;
     }
   }
 
   SmallVector<std::function<LogicalResult(Operation *)>> const_checkers;
-  std::optional<TosaProfileEnum> profileType;
   tosa_level_t tosa_level;
 };
 
@@ -431,7 +431,7 @@ void TosaValidation::runOnOperation() {
   configLevelAndProfile();
   getOperation().walk([&](Operation *op) {
     for (Value operand : op->getOperands()) {
-      if ((profileType == TosaProfileEnum::BaseInference) &&
+      if ((profile == TosaProfileEnum::BaseInference) &&
           isa<FloatType>(getElementTypeOrSelf(operand))) {
         return signalPassFailure();
       }
@@ -451,6 +451,7 @@ void TosaValidation::runOnOperation() {
 }
 } // namespace
 
-std::unique_ptr<Pass> mlir::tosa::createTosaValidationPass() {
-  return std::make_unique<TosaValidation>();
+std::unique_ptr<Pass>
+mlir::tosa::createTosaValidationPass(ValidationOptions const &options) {
+  return std::make_unique<TosaValidation>(options);
 }


        


More information about the Mlir-commits mailing list