[Mlir-commits] [mlir] [mlir][tosa] Check extension cooperative profiles in target environment (PR #185476)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 9 10:52:21 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
Author: Luke Hutton (lhutton1)
<details>
<summary>Changes</summary>
This commit moves checks on extension cooperative profiles out of profile conformance and into target environment verification. This allows the checks to be enforced when the target is created, not during profile conformance validation.
---
Full diff: https://github.com/llvm/llvm-project/pull/185476.diff
10 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h (-26)
- (modified) mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp (+47-1)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp (+4)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (-15)
- (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+7)
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+1-1)
- (modified) mlir/test/Dialect/Tosa/profile_all_unsupported.mlir (+4-4)
- (modified) mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir (+4-4)
- (modified) mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir (+4-11)
- (added) mlir/test/Dialect/Tosa/tosa-attach-target-non-cooperative-profile.mlir (+9)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index d9de79e415292..53372f56f2bf0 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -137,32 +137,6 @@ class TosaProfileCompliance {
OpComplianceInfo<T>
findMatchedEntry(Operation *op, SmallVector<OpComplianceInfo<T>> compInfo);
- SmallVector<Profile> getCooperativeProfiles(Extension ext) {
- switch (ext) {
- case Extension::int16:
- case Extension::int4:
- case Extension::doubleround:
- case Extension::inexactround:
- return {Profile::pro_int};
- case Extension::bf16:
- case Extension::fp8e4m3:
- case Extension::fp8e5m2:
- case Extension::fft:
- case Extension::mxfp:
- case Extension::mxfp_conv:
- return {Profile::pro_fp};
- case Extension::variable:
- case Extension::controlflow:
- case Extension::dynamic:
- case Extension::int64:
- case Extension::shape:
- return {Profile::pro_fp, Profile::pro_int};
- case Extension::none:
- return {};
- };
- llvm_unreachable("bad Extension type");
- }
-
// Debug utilites.
template <typename T>
SmallVector<StringRef> stringifyProfile(ArrayRef<T> profiles);
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
index 9f616b223bf92..118d8c6443bee 100644
--- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -52,6 +52,32 @@ TosaSpecificationVersion getMinVersion(const Extension &extension) {
llvm_unreachable("Unknown TOSA extension");
}
+SmallVector<Profile, 2> getCooperativeProfiles(Extension ext) {
+ switch (ext) {
+ case Extension::int16:
+ case Extension::int4:
+ case Extension::doubleround:
+ case Extension::inexactround:
+ return {Profile::pro_int};
+ case Extension::bf16:
+ case Extension::fp8e4m3:
+ case Extension::fp8e5m2:
+ case Extension::fft:
+ case Extension::mxfp:
+ case Extension::mxfp_conv:
+ return {Profile::pro_fp};
+ case Extension::variable:
+ case Extension::controlflow:
+ case Extension::dynamic:
+ case Extension::int64:
+ case Extension::shape:
+ return {Profile::pro_fp, Profile::pro_int};
+ case Extension::none:
+ return {};
+ };
+ llvm_unreachable("bad Extension type");
+}
+
TosaSpecificationVersion getMinVersion(const Level &level) {
switch (level) {
case Level::eightK:
@@ -90,14 +116,34 @@ LogicalResult TargetEnv::verifyTargetInformation(TargetEnvAttr targetAttr,
return success();
};
+ const auto isExtensionCooperativeWithProfile =
+ [&](Extension ext) -> LogicalResult {
+ const auto cooperativeProfiles = getCooperativeProfiles(ext);
+
+ const ArrayRef<Profile> targetProfiles = targetAttr.getProfiles();
+ if (!llvm::any_of(cooperativeProfiles,
+ [&targetProfiles](const auto &profile) {
+ return llvm::is_contained(targetProfiles, profile);
+ }))
+ return emitError(targetAttrLoc)
+ << "use of extension '" << stringifyEnum(ext)
+ << "' requires any of profiles: [" << cooperativeProfiles
+ << "] to be enabled in the target";
+
+ return success();
+ };
+
for (const auto &profile : targetAttr.getProfiles())
if (failed(
isCompatibleWithTargetVersion(profile, targetAttrLoc, "profile")))
return failure();
- for (const auto &extension : targetAttr.getExtensions())
+ for (const auto &extension : targetAttr.getExtensions()) {
if (failed(isCompatibleWithTargetVersion(extension, targetAttrLoc,
"extension")))
return failure();
+ if (failed(isExtensionCooperativeWithProfile(extension)))
+ return failure();
+ }
if (failed(isCompatibleWithTargetVersion(targetAttr.getLevel(), targetAttrLoc,
"level")))
return failure();
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
index a0661e4ee0bd2..410d55d63e5fd 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
@@ -63,6 +63,10 @@ class TosaAttachTarget
MLIRContext *ctx = &getContext();
const auto targetEnvAttr = TargetEnvAttr::get(
ctx, specificationVersion, level, selectedProfiles, selectedExtensions);
+
+ if (failed(TargetEnv::verifyTargetInformation(targetEnvAttr, mod.getLoc())))
+ return signalPassFailure();
+
mod->setAttr(TargetEnvAttr::name, targetEnvAttr);
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 5037d0176d68d..1b824a4d32586 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -435,21 +435,6 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
return failure();
}
- // Each extension can contain a list of profiles that it works with, usually
- // have the same data type.
- if constexpr (std::is_same_v<T, Extension>) {
- for (const auto &mode : opRequiredMode) {
- SmallVector<Profile> coProfs = getCooperativeProfiles(mode);
- if (!targetEnv.allowsAnyOf(coProfs)) {
- op->emitOpError() << "illegal: requires ["
- << llvm::join(stringifyProfile<Profile>(coProfs),
- ", ")
- << "] to work with but not enabled in target\n";
- return failure();
- }
- }
- }
-
// Ensure the profile inference match the profile knowledge of the
// specification.
for (const auto &cands : specRequiredModeSet) {
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 16190bb69411c..f0f001ec8511b 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -186,6 +186,13 @@ func.func @test_concat(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x21x3xbf16>)
return %0 : tensor<26x21x3xbf16>
}
+// -----
+func.func @test_concat(%arg0: tensor<13x21x3xi16>, %arg1: tensor<13x21x3xi16>) -> tensor<26x21x3xi16> {
+ // expected-error at +1 {{'tosa.concat' op illegal: requires [int16] but not enabled in target}}
+ %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xi16>, tensor<13x21x3xi16>) -> tensor<26x21x3xi16>
+ return %0 : tensor<26x21x3xi16>
+}
+
// -----
func.func @test_pad(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16> {
%padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index d9de2e0d37c25..d061da14bb109 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="extensions=dynamic" -tosa-validate
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=dynamic" -tosa-validate
func.func @test_argmax_rank_invalid(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {
// expected-error at +1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}}
diff --git a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
index d1e61345ff313..8e56c9c54446b 100644
--- a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
@@ -1,8 +1,8 @@
-//--------------------------------------------------------------------------------------------------
-// Enable all supported extensions to focus the verification of expected profile requirement errors.
-//--------------------------------------------------------------------------------------------------
+//-----------------------------------------------------------------------------
+// Check validation of operations when no profiles are specified in the target.
+//-----------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="" -tosa-validate="strict-op-spec-alignment"
// -----
func.func @test_add_i32(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
index d7330cb763fc0..fb0ce19dfc5b0 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
@@ -1,8 +1,8 @@
-//--------------------------------------------------------------------------------------------------
-// Enable all supported extensions to focus the verification of expected profile requirement errors.
-//--------------------------------------------------------------------------------------------------
+//-----------------------------------------------------------------------------
+// Check operations fail to validate when pro_fp is not provided in the target.
+//-----------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround,mxfp,mxfp_conv" -tosa-validate="strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int" -tosa-validate="strict-op-spec-alignment"
// -----
func.func @test_const_f16() -> tensor<3x11x11x3xf16> {
diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
index cb760956fbd68..99b602e48febb 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
@@ -1,8 +1,8 @@
-//--------------------------------------------------------------------------------------------------
-// Enable all supported extensions to focus the verification of expected profile requirement errors.
-//--------------------------------------------------------------------------------------------------
+//--------------------------------------------------------------------------------
+// Check operations fail to validation when pro_int is not provided in the target.
+//--------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_fp" -tosa-validate="strict-op-spec-alignment"
// -----
func.func @test_const_i1() -> tensor<3x11x11x3xi1> {
@@ -179,13 +179,6 @@ func.func @test_reduce_sum(%arg0: tensor<13x21x3xi32>) -> tensor<1x21x3xi32> {
return %0 : tensor<1x21x3xi32>
}
-// -----
-func.func @test_concat(%arg0: tensor<13x21x3xi16>, %arg1: tensor<13x21x3xi16>) -> tensor<26x21x3xi16> {
- // expected-error at +1 {{'tosa.concat' op illegal: requires [pro_int] to work with but not enabled in target}}
- %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xi16>, tensor<13x21x3xi16>) -> tensor<26x21x3xi16>
- return %0 : tensor<26x21x3xi16>
-}
-
// -----
func.func @test_pad(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
%padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
diff --git a/mlir/test/Dialect/Tosa/tosa-attach-target-non-cooperative-profile.mlir b/mlir/test/Dialect/Tosa/tosa-attach-target-non-cooperative-profile.mlir
new file mode 100644
index 0000000000000..a72228896cd0a
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-attach-target-non-cooperative-profile.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt %s -verify-diagnostics -tosa-attach-target="profiles=pro_fp extensions=int16"
+
+// expected-error at below {{use of extension 'int16' requires any of profiles: [pro_int] to be enabled in the target}}
+module {
+ func.func @test_simple(%arg0 : tensor<1x1x1x1xf32>, %arg1 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> {
+ %1 = tosa.add %arg0, %arg1 : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
+ return %1 : tensor<1x1x1x1xf32>
+ }
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/185476
More information about the Mlir-commits
mailing list