[Mlir-commits] [mlir] [MLIR][NVVM] Add family-conditional support to NVVMRequiresSM traits (PR #185909)
Srinivasa Ravi
llvmlistbot at llvm.org
Wed Mar 11 09:01:13 PDT 2026
https://github.com/Wolfram70 created https://github.com/llvm/llvm-project/pull/185909
This change adds support for family-conditional SM version requirements to the `NVVMRequiresSM` traits. The following new traits are added:
- `NVVMRequiresSMf` - Op requires an SM version belonging to one of the given families.
- `NVVMRequiresSMaOrf` - Op requires one of the supported arch-accelerated versions or an SM version belonging to one of the given families.
>From f3c48e3cdebb4530ec3b403de09cc1eaac06f741 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 10 Mar 2026 10:14:39 +0000
Subject: [PATCH] [MLIR][NVVM] Add family-conditional support to NVVMRequiresSM
traits
This change adds support for family-conditional SM version requirements
to the `NVVMRequiresSM` traits. The following new traits are added:
- `NVVMRequiresSMf` - Op requires an SM version belonging to one of the
given families.
- `NVVMRequiresSMaOrf` - Op requires one of the supported arch-accelerated
versions or an SM version belonging to one of the given families.
---
.../Dialect/LLVMIR/NVVMRequiresSMTraits.h | 117 +++++++++++++-----
.../Dialect/LLVMIR/NVVMRequiresSMTraits.td | 20 ++-
.../Dialect/LLVMIR/nvvm-check-targetSM.mlir | 64 ++++++++++
mlir/test/lib/Dialect/Test/TestOps.td | 18 +++
4 files changed, 187 insertions(+), 32 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
index 36fcaee8ec3a2..87c488eda7b62 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
@@ -23,55 +23,79 @@ namespace NVVM {
// Struct to store and check compatibility of SM versions.
struct NVVMCheckSMVersion {
- // Set to true if the SM version is accelerated (e.g., sm_90a).
- bool archAccelerated;
+ struct SMVersion {
+ unsigned version;
+ // Set to true if the SM version is accelerated (e.g., sm_90a).
+ bool archAccelerated;
+ // Set to true if the SM version is family-specific (e.g., sm_100f).
+ bool familySpecific;
+
+ unsigned getSmFamilyVersion() const {
+ return version / 10;
+ }
+
+ bool hasFamilySpecificFeatures() const {
+ return familySpecific || archAccelerated;
+ }
+ };
// List of SM versions.
// Typically only has one version except for cases where multiple
// arch-accelerated versions are supported.
// For example, tcgen05.shift is supported on sm_100a, sm_101a, and sm_103a.
- llvm::SmallVector<int, 1> smVersionList;
-
- template <typename... Ints>
- NVVMCheckSMVersion(bool archAccelerated, Ints... smVersions)
- : archAccelerated(archAccelerated), smVersionList({smVersions...}) {
- assert((archAccelerated || smVersionList.size() == 1) &&
- "non arch-accelerated SM version list must be a single version!");
+ llvm::SmallVector<SMVersion, 1> smVersionList;
+
+ template <typename... Versions>
+ NVVMCheckSMVersion(bool archAccelerated, bool familySpecific,
+ Versions... smVersions)
+ : smVersionList({SMVersion{static_cast<unsigned>(smVersions),
+ archAccelerated, familySpecific}...}) {
+ assert(
+ !(archAccelerated && familySpecific) &&
+ "archAccelerated and familySpecific cannot be true at the same time!");
}
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const {
assert(targetSM.smVersionList.size() == 1 &&
"target SM version list must be a single version!");
- if (archAccelerated) {
- if (!targetSM.archAccelerated)
- return false;
-
- for (auto version : smVersionList) {
- if (version == targetSM.smVersionList[0])
- return true;
+ SMVersion targetSMVersion = targetSM.smVersionList[0];
+
+ return llvm::any_of(smVersionList, [&](const SMVersion &RequiredSMVersion) {
+ if (RequiredSMVersion.archAccelerated) {
+ return targetSMVersion.archAccelerated &&
+ (RequiredSMVersion.version == targetSMVersion.version);
+ } else if (RequiredSMVersion.familySpecific) {
+ return targetSMVersion.hasFamilySpecificFeatures() &&
+ (RequiredSMVersion.getSmFamilyVersion() ==
+ targetSMVersion.getSmFamilyVersion()) &&
+ (targetSMVersion.version >= RequiredSMVersion.version);
+ } else {
+ return targetSMVersion.version >= RequiredSMVersion.version;
}
- } else {
- return targetSM.smVersionList[0] >= smVersionList[0];
- }
-
- return false;
+ });
}
- bool isMinimumSMVersion() const { return smVersionList[0] >= 20; }
+ bool isMinimumSMVersion() const { return smVersionList[0].version >= 20; }
// Parses an SM version string and returns an equivalent NVVMCheckSMVersion
// object.
- static const NVVMCheckSMVersion
+ static NVVMCheckSMVersion
getTargetSMVersionFromStr(StringRef smVersionString) {
bool isAA = smVersionString.back() == 'a';
+ bool isFS = smVersionString.back() == 'f';
int smVersionInt;
smVersionString.drop_front(3)
.take_while([](char c) { return llvm::isDigit(c); })
.getAsInteger(10, smVersionInt);
- return NVVMCheckSMVersion(isAA, smVersionInt);
+ return NVVMCheckSMVersion(isAA, isFS, smVersionInt);
+ }
+
+ NVVMCheckSMVersion &append(const NVVMCheckSMVersion &other) {
+ smVersionList.append(other.smVersionList);
+ return *this;
}
};
@@ -92,8 +116,8 @@ class NVVMRequiresSM {
: public OpTrait::TraitBase<ConcreteOp, NVVMRequiresSM<MinVersion>::Impl>,
public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
public:
- const NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
- return NVVM::NVVMCheckSMVersion(false, MinVersion);
+ NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
+ return NVVM::NVVMCheckSMVersion(false, false, MinVersion);
}
};
};
@@ -106,12 +130,49 @@ class NVVMRequiresSMa {
NVVMRequiresSMa<SMVersions...>::Impl>,
public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
public:
- const NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
- return NVVM::NVVMCheckSMVersion(true, SMVersions...);
+ NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
+ return NVVM::NVVMCheckSMVersion(true, false, SMVersions...);
+ }
+ };
+
+ static NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() {
+ return NVVM::NVVMCheckSMVersion(true, false, SMVersions...);
+ }
+};
+
+template <int... SMVersions>
+class NVVMRequiresSMf {
+public:
+ template <typename ConcreteOp>
+ class Impl : public OpTrait::TraitBase<ConcreteOp,
+ NVVMRequiresSMf<SMVersions...>::Impl>,
+ public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
+ public:
+ NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
+ return NVVM::NVVMCheckSMVersion(false, true, SMVersions...);
}
};
+
+ static NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() {
+ return NVVM::NVVMCheckSMVersion(false, true, SMVersions...);
+ }
};
+template <typename T, typename U>
+class NVVMRequiresSMaOrf {
+public:
+ template <typename ConcreteOp>
+ class Impl
+ : public OpTrait::TraitBase<ConcreteOp, NVVMRequiresSMaOrf<T, U>::Impl>,
+ public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
+ public:
+ NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
+ auto result = T::getRequiredMinSMVersion();
+ result.append(U::getRequiredMinSMVersion());
+ return result;
+ }
+ };
+};
} // namespace OpTrait
} // namespace mlir
#endif // NVVM_DIALECT_NVVM_IR_NVVMREQUIRESSMTRAITS_H_
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
index 34c0d6b78d5b2..8b35cc48bcfdb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
@@ -22,12 +22,12 @@ def RequiresSMInterface: OpInterface<"RequiresSMInterface"> {
let methods = [
InterfaceMethod<
"Get the SM version required by the op from the trait",
- "const mlir::NVVM::NVVMCheckSMVersion", "getRequiredMinSMVersion"
+ "mlir::NVVM::NVVMCheckSMVersion", "getRequiredMinSMVersion"
>
];
}
-// OP requires a specified minimum SM value or higher;
+// Op requires a specified minimum SM value or higher;
// it is not architecture-specific.
class NVVMRequiresSM<int minVersion> :
ParamNativeOpTrait<"NVVMRequiresSM", !cast<string>(minVersion)>;
@@ -37,11 +37,23 @@ class StrJoin<string sep, list<string> str_list> {
!if(!eq(a, ""), b, !if(!eq(b, ""), a, !strconcat(a, sep, b))));
}
-// OP requires an exact SM match along with
-// architecture acceleration.
+// Op requires an exact SM match along with architecture acceleration.
class NVVMRequiresSMa<list<int> smVersions> :
ParamNativeOpTrait<"NVVMRequiresSMa",
StrJoin<",", !foreach(vers, smVersions,
!cast<string>(vers))>.ret>;
+// Op requires an SM version belonging to the family.
+class NVVMRequiresSMf<list<int> smVersions> :
+ ParamNativeOpTrait<"NVVMRequiresSMf",
+ StrJoin<",", !foreach(vers, smVersions,
+ !cast<string>(vers))>.ret>;
+
+// Op supported on some combination of architecture acceleration and family-specific SM versions.
+class NVVMRequiresSMaOrf<list<int> smVersionsA, list<int> smVersionsF> :
+ ParamNativeOpTrait<"NVVMRequiresSMaOrf",
+ "mlir::OpTrait::NVVMRequiresSMa<" # StrJoin<",",
+ !foreach(vers, smVersionsA, !cast<string>(vers))>.ret # ">" # "," #
+ "mlir::OpTrait::NVVMRequiresSMf<" # StrJoin<",",
+ !foreach(vers, smVersionsF, !cast<string>(vers))>.ret # ">">;
#endif //NVVM_REQUIRES_SM_TRAITS
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
index e469d336dc1ae..7876f57a4d9ce 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
@@ -29,6 +29,37 @@ gpu.module @check_valid_SM_arch_acc_multi_2 [#nvvm.target<chip = "sm_100a">] {
test.nvvm_requires_sm_90a_or_sm_100a
}
+gpu.module @check_valid_SM_family_1 [#nvvm.target<chip = "sm_100f">] {
+ test.nvvm_requires_sm_100f
+}
+
+gpu.module @check_valid_SM_family_2 [#nvvm.target<chip = "sm_100a">] {
+ test.nvvm_requires_sm_100f
+}
+
+gpu.module @check_valid_SM_family_3 [#nvvm.target<chip = "sm_103a">] {
+ test.nvvm_requires_sm_100f
+}
+
+gpu.module @check_valid_SM_family_multi_1 [#nvvm.target<chip = "sm_100f">] {
+ test.nvvm_requires_sm_100f_or_sm_120f
+}
+
+gpu.module @check_valid_SM_family_multi_2 [#nvvm.target<chip = "sm_120f">] {
+ test.nvvm_requires_sm_100f_or_sm_120f
+}
+
+gpu.module @check_valid_SM_arch_or_family_1 [#nvvm.target<chip = "sm_90a">] {
+ test.nvvm_requires_sm_90a_or_sm_100f
+}
+
+gpu.module @check_valid_SM_arch_or_family_2 [#nvvm.target<chip = "sm_100f">] {
+ test.nvvm_requires_sm_90a_or_sm_100f
+}
+
+gpu.module @check_valid_SM_arch_or_family_3 [#nvvm.target<chip = "sm_103a">] {
+ test.nvvm_requires_sm_90a_or_sm_100f
+}
gpu.module @disable_verify_target1 [#nvvm.target<chip = "sm_90", verifyTarget = false>] {
test.nvvm_requires_sm_90a
@@ -42,6 +73,18 @@ gpu.module @disable_verify_target3 [#nvvm.target<chip = "sm_90", verifyTarget =
test.nvvm_requires_sm_90a_or_sm_100a
}
+gpu.module @disable_verify_target4 [#nvvm.target<chip = "sm_120f", verifyTarget = false>] {
+ test.nvvm_requires_sm_100f
+}
+
+gpu.module @disable_verify_target5 [#nvvm.target<chip = "sm_100", verifyTarget = false>] {
+ test.nvvm_requires_sm_100f_or_sm_120f
+}
+
+gpu.module @disable_verify_target6 [#nvvm.target<chip = "sm_90", verifyTarget = false>] {
+ test.nvvm_requires_sm_90a_or_sm_100f
+}
+
// -----
gpu.module @check_invalid_SM_lesser_1 [#nvvm.target<chip = "sm_70">] {
@@ -83,3 +126,24 @@ gpu.module @check_invalid_SM_arch_acc_multi_2 [#nvvm.target<chip = "sm_90">] {
// expected-error @below {{is not supported on sm_90}}
test.nvvm_requires_sm_90a_or_sm_100a
}
+
+// -----
+
+gpu.module @check_invalid_SM_family [#nvvm.target<chip = "sm_110a">] {
+ // expected-error @below {{is not supported on sm_110a}}
+ test.nvvm_requires_sm_100f
+}
+
+// -----
+
+gpu.module @check_invalid_SM_family_multi [#nvvm.target<chip = "sm_110a">] {
+ // expected-error @below {{is not supported on sm_110a}}
+ test.nvvm_requires_sm_100f_or_sm_120f
+}
+
+// -----
+
+gpu.module @check_invalid_SM_arch_or_family [#nvvm.target<chip = "sm_100">] {
+ // expected-error @below {{is not supported on sm_100}}
+ test.nvvm_requires_sm_90a_or_sm_100f
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 02bac016eeed1..49fedc643c62a 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3095,6 +3095,24 @@ def TestNVVMRequirestSMArchCondMultiOp :
let assemblyFormat = "attr-dict";
}
+def TestNVVMRequiresSMFamilyCondOp :
+ TEST_Op<"nvvm_requires_sm_100f", [NVVMRequiresSMf<[100]>]> {
+ let arguments = (ins );
+ let assemblyFormat = "attr-dict";
+}
+
+def TestNVVMRequiresSMFamilyCondMultiOp :
+ TEST_Op<"nvvm_requires_sm_100f_or_sm_120f", [NVVMRequiresSMf<[100, 120]>]> {
+ let arguments = (ins );
+ let assemblyFormat = "attr-dict";
+}
+
+def TestNVVMRequiresSMArchOrFamilyCondOp :
+ TEST_Op<"nvvm_requires_sm_90a_or_sm_100f", [NVVMRequiresSMaOrf<[90], [100]>]> {
+ let arguments = (ins );
+ let assemblyFormat = "attr-dict";
+}
+
//===----------------------------------------------------------------------===//
// Test Ops with Default-Valued String Attributes
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list