[Mlir-commits] [mlir] [MLIR][NVVM] Add family-conditional support to NVVMRequiresSM traits (PR #185909)
Srinivasa Ravi
llvmlistbot at llvm.org
Sun Mar 15 23:18:38 PDT 2026
================
@@ -23,55 +23,77 @@ namespace NVVM {
// Struct to store and check compatibility of SM versions.
struct NVVMCheckSMVersion {
- // Set to true if the SM version is accelerated (e.g., sm_90a).
- bool archAccelerated;
+ struct SMVersion {
+ unsigned version;
+ // Set to true if the SM version is accelerated (e.g., sm_90a).
+ bool archAccelerated;
+ // Set to true if the SM version is family-specific (e.g., sm_100f).
+ bool familySpecific;
+
+ unsigned getSmFamilyVersion() const { return version / 10; }
+
+ bool hasFamilySpecificFeatures() const {
+ return familySpecific || archAccelerated;
+ }
+ };
// List of SM versions.
// Typically only has one version except for cases where multiple
// arch-accelerated versions are supported.
// For example, tcgen05.shift is supported on sm_100a, sm_101a, and sm_103a.
- llvm::SmallVector<int, 1> smVersionList;
-
- template <typename... Ints>
- NVVMCheckSMVersion(bool archAccelerated, Ints... smVersions)
- : archAccelerated(archAccelerated), smVersionList({smVersions...}) {
- assert((archAccelerated || smVersionList.size() == 1) &&
- "non arch-accelerated SM version list must be a single version!");
+ llvm::SmallVector<SMVersion, 1> smVersionList;
+
+ template <typename... Versions>
+ NVVMCheckSMVersion(bool archAccelerated, bool familySpecific,
+ Versions... smVersions)
+ : smVersionList({SMVersion{static_cast<unsigned>(smVersions),
+ archAccelerated, familySpecific}...}) {
+ assert(
+ !(archAccelerated && familySpecific) &&
+ "archAccelerated and familySpecific cannot be true at the same time!");
}
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const {
assert(targetSM.smVersionList.size() == 1 &&
"target SM version list must be a single version!");
- if (archAccelerated) {
- if (!targetSM.archAccelerated)
- return false;
-
- for (auto version : smVersionList) {
- if (version == targetSM.smVersionList[0])
- return true;
+ SMVersion targetSMVersion = targetSM.smVersionList[0];
+
+ return llvm::any_of(smVersionList, [&](const SMVersion &RequiredSMVersion) {
+ if (RequiredSMVersion.archAccelerated) {
+ return targetSMVersion.archAccelerated &&
+ (RequiredSMVersion.version == targetSMVersion.version);
+ } else if (RequiredSMVersion.familySpecific) {
+ return targetSMVersion.hasFamilySpecificFeatures() &&
+ (RequiredSMVersion.getSmFamilyVersion() ==
+ targetSMVersion.getSmFamilyVersion()) &&
+ (targetSMVersion.version >= RequiredSMVersion.version);
+ } else {
+ return targetSMVersion.version >= RequiredSMVersion.version;
}
- } else {
- return targetSM.smVersionList[0] >= smVersionList[0];
- }
-
- return false;
+ });
}
- bool isMinimumSMVersion() const { return smVersionList[0] >= 20; }
+ bool isMinimumSMVersion() const { return smVersionList[0].version >= 20; }
// Parses an SM version string and returns an equivalent NVVMCheckSMVersion
// object.
- static const NVVMCheckSMVersion
+ static NVVMCheckSMVersion
getTargetSMVersionFromStr(StringRef smVersionString) {
bool isAA = smVersionString.back() == 'a';
+ bool isFS = smVersionString.back() == 'f';
int smVersionInt;
smVersionString.drop_front(3)
.take_while([](char c) { return llvm::isDigit(c); })
.getAsInteger(10, smVersionInt);
- return NVVMCheckSMVersion(isAA, smVersionInt);
+ return NVVMCheckSMVersion(isAA, isFS, smVersionInt);
+ }
+
+ NVVMCheckSMVersion &append(const NVVMCheckSMVersion &other) {
+ smVersionList.append(other.smVersionList);
+ return *this;
----------------
Wolfram70 wrote:
As we don't really need a `get`, I've just changed `append` to a `void` function since we don't use the return value anyways.
https://github.com/llvm/llvm-project/pull/185909
More information about the Mlir-commits
mailing list