[Mlir-commits] [mlir] [tosa] Change the type of profile option to a list (PR #111214)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 4 14:33:52 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: TatWai Chong (tatwaichong)

<details>
<summary>Changes</summary>

In tosa valiation pass, change the type of profile option to ListOption. Now TOSA profiles is turned from hierarchical to composable. Each profile is an independent set, i.e. an target can implement multiple profiles.

Set the option to none by default, and limit to profiles if requested.

Change-Id: I1fb8d0c1b27eccd768349b6eb4234093313efb57

---
Full diff: https://github.com/llvm/llvm-project/pull/111214.diff


6 Files Affected:

- (modified) mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h (+2-2) 
- (modified) mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td (+3-14) 
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+15-1) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+7-1) 
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+5-1) 


``````````diff
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index c84e4f17c38d88..f5ae71cad15243 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -39,8 +39,8 @@ void addTosaToLinalgPasses(
         TosaToLinalgNamedOptions(),
     // Note: Default to 'none' level unless otherwise specified.
     std::optional<tosa::TosaValidationOptions> validationOptions =
-        tosa::TosaValidationOptions{tosa::TosaProfileEnum::Undefined, false,
-                                    tosa::TosaLevelEnum::None});
+        tosa::TosaValidationOptions{
+            {"none"}, false, tosa::TosaLevelEnum::None});
 
 /// Populates TOSA to linalg pipelines
 /// Currently, this includes only the "tosa-to-linalg-pipeline".
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index c0352fa88fe08d..dac67633769c76 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -76,7 +76,7 @@ def TosaProfileType : I32EnumAttr<"TosaProfileEnum", "Tosa profile",
       I32EnumAttrCase<"BaseInference", 0, "bi">,
       I32EnumAttrCase<"MainInference", 1, "mi">,
       I32EnumAttrCase<"MainTraining", 2, "mt">,
-      I32EnumAttrCase<"Undefined", 3>
+      I32EnumAttrCase<"Undefined", 3, "none">
     ]>{
   let cppNamespace = "mlir::tosa";
 }
@@ -97,19 +97,8 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
   }];
 
   let options = [
-      Option<"profile", "profile", "mlir::tosa::TosaProfileEnum",
-             /*default=*/"mlir::tosa::TosaProfileEnum::Undefined",
-             "Validate if operations match for the given profile",
-             [{::llvm::cl::values(
-               clEnumValN(mlir::tosa::TosaProfileEnum::BaseInference, "bi",
-                "Use Base Inference profile."),
-               clEnumValN(mlir::tosa::TosaProfileEnum::MainInference, "mi",
-                "Use Main Inference profile."),
-               clEnumValN(mlir::tosa::TosaProfileEnum::MainTraining, "mt",
-                "Use Main Training profile."),
-               clEnumValN(mlir::tosa::TosaProfileEnum::Undefined, "undefined",
-                "Do not define a profile.")
-              )}]>,
+      ListOption<"profile", "profile", "std::string",
+             "Validate if operations match for the given profile set">,
       Option<"StrictOperationSpecAlignment", "strict-op-spec-alignment", "bool",
              /*default=*/"false",
              "Verify if the properties of certain operations align the spec requirement">,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 44036d7c31a912..06a7262c467421 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -115,7 +115,7 @@ void mlir::tosa::registerTosaToLinalgPipelines() {
         TosaToLinalgOptions tosaToLinalgOptions;
         TosaToLinalgNamedOptions tosaToLinalgNamedOptions;
         TosaValidationOptions validationOptions;
-        validationOptions.profile = tosa::TosaProfileEnum::BaseInference;
+        validationOptions.profile = {"none"};
         validationOptions.StrictOperationSpecAlignment = true;
         validationOptions.level = tosa::TosaLevelEnum::EightK;
         tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index b78c372af77e64..e390a613b58077 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -405,14 +405,28 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     if (level == TosaLevelEnum::EightK) {
       tosaLevel = TOSA_LEVEL_EIGHTK;
     }
+
+    if (!profile.empty()) {
+      for (std::string &prof : profile) {
+        auto profSymbol = symbolizeTosaProfileEnum(prof);
+        if (profSymbol) {
+          enabled_profiles.push_back(profSymbol.value());
+        }
+      }
+    }
   }
 
   bool CheckVariable(Operation *op);
   bool CheckVariableReadOrWrite(Operation *op);
 
   bool isValidElementType(Type type);
+  bool isEnabledProfile(TosaProfileEnum prof) {
+    return std::find(enabled_profiles.begin(), enabled_profiles.end(), prof) !=
+           std::end(enabled_profiles);
+  }
 
   SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
+  SmallVector<TosaProfileEnum, 3> enabled_profiles;
   TosaLevel tosaLevel;
   DenseMap<StringAttr, mlir::Type> variablesMap;
 };
@@ -507,7 +521,7 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
 
 bool TosaValidation::isValidElementType(Type type) {
   if (isa<FloatType>(type)) {
-    if (profile == TosaProfileEnum::BaseInference)
+    if (!isEnabledProfile(TosaProfileEnum::MainInference))
       return false;
     return type.isF32() || type.isF16() || type.isBF16();
   }
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index e5c5b9b3663903..b9298b66643538 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1,4 +1,10 @@
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate=strict-op-spec-alignment
+//--------------------------------------------------------------------------------------------------
+// Test expected errors in terms of the shape and type of tensor, and the argument type of
+// operation. Excludes the profile compilance checking since it is performed earlier in the
+// validation flow.
+//--------------------------------------------------------------------------------------------------
+
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=bi,mi,mt strict-op-spec-alignment"
 
 
 func.func @test_const() -> tensor<1xf32> {
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 9b652f2d0bd142..e851019362958f 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1,4 +1,8 @@
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate
+//--------------------------------------------------------------------------------------------------
+// Enable all supported profiles to focus the verification of expected level errors.
+//--------------------------------------------------------------------------------------------------
+
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=bi,mi,mt"
 
 
 func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {

``````````

</details>


https://github.com/llvm/llvm-project/pull/111214


More information about the Mlir-commits mailing list