[Mlir-commits] [mlir] [MLIR][NVVM] Add family-specific support to NVVMRequiresSM traits (PR #185909)
Srinivasa Ravi
llvmlistbot at llvm.org
Wed Apr 1 07:56:21 PDT 2026
https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/185909
>From f3c48e3cdebb4530ec3b403de09cc1eaac06f741 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 10 Mar 2026 10:14:39 +0000
Subject: [PATCH 01/10] [MLIR][NVVM] Add family-conditional support to
NVVMRequiresSM traits
This change adds support for family-conditional SM version requirements
to the `NVVMRequiresSM` traits. The following new traits are added:
- `NVVMRequiresSMf` - Op requires an SM version belonging to one of the
given families.
- `NVVMRequiresSMaOrf` - Op requires one of the supported arch-accelerated
versions or an SM version belonging to one of the given families.
---
.../Dialect/LLVMIR/NVVMRequiresSMTraits.h | 117 +++++++++++++-----
.../Dialect/LLVMIR/NVVMRequiresSMTraits.td | 20 ++-
.../Dialect/LLVMIR/nvvm-check-targetSM.mlir | 64 ++++++++++
mlir/test/lib/Dialect/Test/TestOps.td | 18 +++
4 files changed, 187 insertions(+), 32 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
index 36fcaee8ec3a2..87c488eda7b62 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
@@ -23,55 +23,79 @@ namespace NVVM {
// Struct to store and check compatibility of SM versions.
struct NVVMCheckSMVersion {
- // Set to true if the SM version is accelerated (e.g., sm_90a).
- bool archAccelerated;
+ struct SMVersion {
+ unsigned version;
+ // Set to true if the SM version is accelerated (e.g., sm_90a).
+ bool archAccelerated;
+ // Set to true if the SM version is family-specific (e.g., sm_100f).
+ bool familySpecific;
+
+ unsigned getSmFamilyVersion() const {
+ return version / 10;
+ }
+
+ bool hasFamilySpecificFeatures() const {
+ return familySpecific || archAccelerated;
+ }
+ };
// List of SM versions.
// Typically only has one version except for cases where multiple
// arch-accelerated versions are supported.
// For example, tcgen05.shift is supported on sm_100a, sm_101a, and sm_103a.
- llvm::SmallVector<int, 1> smVersionList;
-
- template <typename... Ints>
- NVVMCheckSMVersion(bool archAccelerated, Ints... smVersions)
- : archAccelerated(archAccelerated), smVersionList({smVersions...}) {
- assert((archAccelerated || smVersionList.size() == 1) &&
- "non arch-accelerated SM version list must be a single version!");
+ llvm::SmallVector<SMVersion, 1> smVersionList;
+
+ template <typename... Versions>
+ NVVMCheckSMVersion(bool archAccelerated, bool familySpecific,
+ Versions... smVersions)
+ : smVersionList({SMVersion{static_cast<unsigned>(smVersions),
+ archAccelerated, familySpecific}...}) {
+ assert(
+ !(archAccelerated && familySpecific) &&
+ "archAccelerated and familySpecific cannot be true at the same time!");
}
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const {
assert(targetSM.smVersionList.size() == 1 &&
"target SM version list must be a single version!");
- if (archAccelerated) {
- if (!targetSM.archAccelerated)
- return false;
-
- for (auto version : smVersionList) {
- if (version == targetSM.smVersionList[0])
- return true;
+ SMVersion targetSMVersion = targetSM.smVersionList[0];
+
+ return llvm::any_of(smVersionList, [&](const SMVersion &RequiredSMVersion) {
+ if (RequiredSMVersion.archAccelerated) {
+ return targetSMVersion.archAccelerated &&
+ (RequiredSMVersion.version == targetSMVersion.version);
+ } else if (RequiredSMVersion.familySpecific) {
+ return targetSMVersion.hasFamilySpecificFeatures() &&
+ (RequiredSMVersion.getSmFamilyVersion() ==
+ targetSMVersion.getSmFamilyVersion()) &&
+ (targetSMVersion.version >= RequiredSMVersion.version);
+ } else {
+ return targetSMVersion.version >= RequiredSMVersion.version;
}
- } else {
- return targetSM.smVersionList[0] >= smVersionList[0];
- }
-
- return false;
+ });
}
- bool isMinimumSMVersion() const { return smVersionList[0] >= 20; }
+ bool isMinimumSMVersion() const { return smVersionList[0].version >= 20; }
// Parses an SM version string and returns an equivalent NVVMCheckSMVersion
// object.
- static const NVVMCheckSMVersion
+ static NVVMCheckSMVersion
getTargetSMVersionFromStr(StringRef smVersionString) {
bool isAA = smVersionString.back() == 'a';
+ bool isFS = smVersionString.back() == 'f';
int smVersionInt;
smVersionString.drop_front(3)
.take_while([](char c) { return llvm::isDigit(c); })
.getAsInteger(10, smVersionInt);
- return NVVMCheckSMVersion(isAA, smVersionInt);
+ return NVVMCheckSMVersion(isAA, isFS, smVersionInt);
+ }
+
+ NVVMCheckSMVersion &append(const NVVMCheckSMVersion &other) {
+ smVersionList.append(other.smVersionList);
+ return *this;
}
};
@@ -92,8 +116,8 @@ class NVVMRequiresSM {
: public OpTrait::TraitBase<ConcreteOp, NVVMRequiresSM<MinVersion>::Impl>,
public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
public:
- const NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
- return NVVM::NVVMCheckSMVersion(false, MinVersion);
+ NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
+ return NVVM::NVVMCheckSMVersion(false, false, MinVersion);
}
};
};
@@ -106,12 +130,49 @@ class NVVMRequiresSMa {
NVVMRequiresSMa<SMVersions...>::Impl>,
public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
public:
- const NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
- return NVVM::NVVMCheckSMVersion(true, SMVersions...);
+ NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
+ return NVVM::NVVMCheckSMVersion(true, false, SMVersions...);
+ }
+ };
+
+ static NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() {
+ return NVVM::NVVMCheckSMVersion(true, false, SMVersions...);
+ }
+};
+
+template <int... SMVersions>
+class NVVMRequiresSMf {
+public:
+ template <typename ConcreteOp>
+ class Impl : public OpTrait::TraitBase<ConcreteOp,
+ NVVMRequiresSMf<SMVersions...>::Impl>,
+ public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
+ public:
+ NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
+ return NVVM::NVVMCheckSMVersion(false, true, SMVersions...);
}
};
+
+ static NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() {
+ return NVVM::NVVMCheckSMVersion(false, true, SMVersions...);
+ }
};
+template <typename T, typename U>
+class NVVMRequiresSMaOrf {
+public:
+ template <typename ConcreteOp>
+ class Impl
+ : public OpTrait::TraitBase<ConcreteOp, NVVMRequiresSMaOrf<T, U>::Impl>,
+ public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
+ public:
+ NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
+ auto result = T::getRequiredMinSMVersion();
+ result.append(U::getRequiredMinSMVersion());
+ return result;
+ }
+ };
+};
} // namespace OpTrait
} // namespace mlir
#endif // NVVM_DIALECT_NVVM_IR_NVVMREQUIRESSMTRAITS_H_
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
index 34c0d6b78d5b2..8b35cc48bcfdb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
@@ -22,12 +22,12 @@ def RequiresSMInterface: OpInterface<"RequiresSMInterface"> {
let methods = [
InterfaceMethod<
"Get the SM version required by the op from the trait",
- "const mlir::NVVM::NVVMCheckSMVersion", "getRequiredMinSMVersion"
+ "mlir::NVVM::NVVMCheckSMVersion", "getRequiredMinSMVersion"
>
];
}
-// OP requires a specified minimum SM value or higher;
+// Op requires a specified minimum SM value or higher;
// it is not architecture-specific.
class NVVMRequiresSM<int minVersion> :
ParamNativeOpTrait<"NVVMRequiresSM", !cast<string>(minVersion)>;
@@ -37,11 +37,23 @@ class StrJoin<string sep, list<string> str_list> {
!if(!eq(a, ""), b, !if(!eq(b, ""), a, !strconcat(a, sep, b))));
}
-// OP requires an exact SM match along with
-// architecture acceleration.
+// Op requires an exact SM match along with architecture acceleration.
class NVVMRequiresSMa<list<int> smVersions> :
ParamNativeOpTrait<"NVVMRequiresSMa",
StrJoin<",", !foreach(vers, smVersions,
!cast<string>(vers))>.ret>;
+// Op requires an SM version belonging to the family.
+class NVVMRequiresSMf<list<int> smVersions> :
+ ParamNativeOpTrait<"NVVMRequiresSMf",
+ StrJoin<",", !foreach(vers, smVersions,
+ !cast<string>(vers))>.ret>;
+
+// Op supported on some combination of architecture acceleration and family-specific SM versions.
+class NVVMRequiresSMaOrf<list<int> smVersionsA, list<int> smVersionsF> :
+ ParamNativeOpTrait<"NVVMRequiresSMaOrf",
+ "mlir::OpTrait::NVVMRequiresSMa<" # StrJoin<",",
+ !foreach(vers, smVersionsA, !cast<string>(vers))>.ret # ">" # "," #
+ "mlir::OpTrait::NVVMRequiresSMf<" # StrJoin<",",
+ !foreach(vers, smVersionsF, !cast<string>(vers))>.ret # ">">;
#endif //NVVM_REQUIRES_SM_TRAITS
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
index e469d336dc1ae..7876f57a4d9ce 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
@@ -29,6 +29,37 @@ gpu.module @check_valid_SM_arch_acc_multi_2 [#nvvm.target<chip = "sm_100a">] {
test.nvvm_requires_sm_90a_or_sm_100a
}
+gpu.module @check_valid_SM_family_1 [#nvvm.target<chip = "sm_100f">] {
+ test.nvvm_requires_sm_100f
+}
+
+gpu.module @check_valid_SM_family_2 [#nvvm.target<chip = "sm_100a">] {
+ test.nvvm_requires_sm_100f
+}
+
+gpu.module @check_valid_SM_family_3 [#nvvm.target<chip = "sm_103a">] {
+ test.nvvm_requires_sm_100f
+}
+
+gpu.module @check_valid_SM_family_multi_1 [#nvvm.target<chip = "sm_100f">] {
+ test.nvvm_requires_sm_100f_or_sm_120f
+}
+
+gpu.module @check_valid_SM_family_multi_2 [#nvvm.target<chip = "sm_120f">] {
+ test.nvvm_requires_sm_100f_or_sm_120f
+}
+
+gpu.module @check_valid_SM_arch_or_family_1 [#nvvm.target<chip = "sm_90a">] {
+ test.nvvm_requires_sm_90a_or_sm_100f
+}
+
+gpu.module @check_valid_SM_arch_or_family_2 [#nvvm.target<chip = "sm_100f">] {
+ test.nvvm_requires_sm_90a_or_sm_100f
+}
+
+gpu.module @check_valid_SM_arch_or_family_3 [#nvvm.target<chip = "sm_103a">] {
+ test.nvvm_requires_sm_90a_or_sm_100f
+}
gpu.module @disable_verify_target1 [#nvvm.target<chip = "sm_90", verifyTarget = false>] {
test.nvvm_requires_sm_90a
@@ -42,6 +73,18 @@ gpu.module @disable_verify_target3 [#nvvm.target<chip = "sm_90", verifyTarget =
test.nvvm_requires_sm_90a_or_sm_100a
}
+gpu.module @disable_verify_target4 [#nvvm.target<chip = "sm_120f", verifyTarget = false>] {
+ test.nvvm_requires_sm_100f
+}
+
+gpu.module @disable_verify_target5 [#nvvm.target<chip = "sm_100", verifyTarget = false>] {
+ test.nvvm_requires_sm_100f_or_sm_120f
+}
+
+gpu.module @disable_verify_target6 [#nvvm.target<chip = "sm_90", verifyTarget = false>] {
+ test.nvvm_requires_sm_90a_or_sm_100f
+}
+
// -----
gpu.module @check_invalid_SM_lesser_1 [#nvvm.target<chip = "sm_70">] {
@@ -83,3 +126,24 @@ gpu.module @check_invalid_SM_arch_acc_multi_2 [#nvvm.target<chip = "sm_90">] {
// expected-error @below {{is not supported on sm_90}}
test.nvvm_requires_sm_90a_or_sm_100a
}
+
+// -----
+
+gpu.module @check_invalid_SM_family [#nvvm.target<chip = "sm_110a">] {
+ // expected-error @below {{is not supported on sm_110a}}
+ test.nvvm_requires_sm_100f
+}
+
+// -----
+
+gpu.module @check_invalid_SM_family_multi [#nvvm.target<chip = "sm_110a">] {
+ // expected-error @below {{is not supported on sm_110a}}
+ test.nvvm_requires_sm_100f_or_sm_120f
+}
+
+// -----
+
+gpu.module @check_invalid_SM_arch_or_family [#nvvm.target<chip = "sm_100">] {
+ // expected-error @below {{is not supported on sm_100}}
+ test.nvvm_requires_sm_90a_or_sm_100f
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 02bac016eeed1..49fedc643c62a 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3095,6 +3095,24 @@ def TestNVVMRequirestSMArchCondMultiOp :
let assemblyFormat = "attr-dict";
}
+def TestNVVMRequiresSMFamilyCondOp :
+ TEST_Op<"nvvm_requires_sm_100f", [NVVMRequiresSMf<[100]>]> {
+ let arguments = (ins );
+ let assemblyFormat = "attr-dict";
+}
+
+def TestNVVMRequiresSMFamilyCondMultiOp :
+ TEST_Op<"nvvm_requires_sm_100f_or_sm_120f", [NVVMRequiresSMf<[100, 120]>]> {
+ let arguments = (ins );
+ let assemblyFormat = "attr-dict";
+}
+
+def TestNVVMRequiresSMArchOrFamilyCondOp :
+ TEST_Op<"nvvm_requires_sm_90a_or_sm_100f", [NVVMRequiresSMaOrf<[90], [100]>]> {
+ let arguments = (ins );
+ let assemblyFormat = "attr-dict";
+}
+
//===----------------------------------------------------------------------===//
// Test Ops with Default-Valued String Attributes
//===----------------------------------------------------------------------===//
>From 3cf3e1513808fe6439c0241abb84d3d67c7554cc Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Thu, 12 Mar 2026 08:03:24 +0000
Subject: [PATCH 02/10] fix formatting
---
mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h | 8 +++-----
1 file changed, 3 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
index 87c488eda7b62..dbe4dfd17e84d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
@@ -30,10 +30,8 @@ struct NVVMCheckSMVersion {
// Set to true if the SM version is family-specific (e.g., sm_100f).
bool familySpecific;
- unsigned getSmFamilyVersion() const {
- return version / 10;
- }
-
+ unsigned getSmFamilyVersion() const { return version / 10; }
+
bool hasFamilySpecificFeatures() const {
return familySpecific || archAccelerated;
}
@@ -92,7 +90,7 @@ struct NVVMCheckSMVersion {
return NVVMCheckSMVersion(isAA, isFS, smVersionInt);
}
-
+
NVVMCheckSMVersion &append(const NVVMCheckSMVersion &other) {
smVersionList.append(other.smVersionList);
return *this;
>From b0dc8e4a9b003560409395e7bf64a9818dc2225c Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Thu, 12 Mar 2026 12:04:58 +0000
Subject: [PATCH 03/10] replace StrJoin with interleave
---
.../Dialect/LLVMIR/NVVMRequiresSMTraits.td | 19 ++++---------------
1 file changed, 4 insertions(+), 15 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
index 8b35cc48bcfdb..8c3d34fe033e9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
@@ -32,28 +32,17 @@ def RequiresSMInterface: OpInterface<"RequiresSMInterface"> {
class NVVMRequiresSM<int minVersion> :
ParamNativeOpTrait<"NVVMRequiresSM", !cast<string>(minVersion)>;
-class StrJoin<string sep, list<string> str_list> {
- string ret = !foldl("", str_list, a, b,
- !if(!eq(a, ""), b, !if(!eq(b, ""), a, !strconcat(a, sep, b))));
-}
-
// Op requires an exact SM match along with architecture acceleration.
class NVVMRequiresSMa<list<int> smVersions> :
- ParamNativeOpTrait<"NVVMRequiresSMa",
- StrJoin<",", !foreach(vers, smVersions,
- !cast<string>(vers))>.ret>;
+ ParamNativeOpTrait<"NVVMRequiresSMa", !interleave(smVersions, ",")>;
// Op requires an SM version belonging to the family.
class NVVMRequiresSMf<list<int> smVersions> :
- ParamNativeOpTrait<"NVVMRequiresSMf",
- StrJoin<",", !foreach(vers, smVersions,
- !cast<string>(vers))>.ret>;
+ ParamNativeOpTrait<"NVVMRequiresSMf", !interleave(smVersions, ",")>;
// Op supported on some combination of architecture acceleration and family-specific SM versions.
class NVVMRequiresSMaOrf<list<int> smVersionsA, list<int> smVersionsF> :
ParamNativeOpTrait<"NVVMRequiresSMaOrf",
- "mlir::OpTrait::NVVMRequiresSMa<" # StrJoin<",",
- !foreach(vers, smVersionsA, !cast<string>(vers))>.ret # ">" # "," #
- "mlir::OpTrait::NVVMRequiresSMf<" # StrJoin<",",
- !foreach(vers, smVersionsF, !cast<string>(vers))>.ret # ">">;
+ "mlir::OpTrait::NVVMRequiresSMa<" # !interleave(smVersionsA, ",") # ">" # "," #
+ "mlir::OpTrait::NVVMRequiresSMf<" # !interleave(smVersionsF, ",") # ">">;
#endif //NVVM_REQUIRES_SM_TRAITS
>From 3020c5af597710d06a2da3d9f0b475d098deb6bb Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Mon, 16 Mar 2026 06:14:17 +0000
Subject: [PATCH 04/10] address comments
---
.../Dialect/LLVMIR/NVVMRequiresSMTraits.h | 31 ++++++++++---------
1 file changed, 16 insertions(+), 15 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
index dbe4dfd17e84d..f6718ad0e29f6 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
@@ -24,8 +24,9 @@ namespace NVVM {
// Struct to store and check compatibility of SM versions.
struct NVVMCheckSMVersion {
struct SMVersion {
+ // Base SM version (e.g., sm_100)
unsigned version;
- // Set to true if the SM version is accelerated (e.g., sm_90a).
+ // Set to true if the SM version is accelerated (e.g., sm_100a).
bool archAccelerated;
// Set to true if the SM version is family-specific (e.g., sm_100f).
bool familySpecific;
@@ -39,9 +40,9 @@ struct NVVMCheckSMVersion {
// List of SM versions.
// Typically only has one version except for cases where multiple
- // arch-accelerated versions are supported.
+ // arch-accelerated or family-conditional versions are supported.
// For example, tcgen05.shift is supported on sm_100a, sm_101a, and sm_103a.
- llvm::SmallVector<SMVersion, 1> smVersionList;
+ llvm::SmallVector<SMVersion> smVersionList;
template <typename... Versions>
NVVMCheckSMVersion(bool archAccelerated, bool familySpecific,
@@ -55,22 +56,22 @@ struct NVVMCheckSMVersion {
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const {
assert(targetSM.smVersionList.size() == 1 &&
- "target SM version list must be a single version!");
+ "target SM version list must have a single version!");
SMVersion targetSMVersion = targetSM.smVersionList[0];
return llvm::any_of(smVersionList, [&](const SMVersion &RequiredSMVersion) {
- if (RequiredSMVersion.archAccelerated) {
+ if (RequiredSMVersion.archAccelerated)
return targetSMVersion.archAccelerated &&
(RequiredSMVersion.version == targetSMVersion.version);
- } else if (RequiredSMVersion.familySpecific) {
+
+ if (RequiredSMVersion.familySpecific)
return targetSMVersion.hasFamilySpecificFeatures() &&
(RequiredSMVersion.getSmFamilyVersion() ==
targetSMVersion.getSmFamilyVersion()) &&
(targetSMVersion.version >= RequiredSMVersion.version);
- } else {
- return targetSMVersion.version >= RequiredSMVersion.version;
- }
+
+ return targetSMVersion.version >= RequiredSMVersion.version;
});
}
@@ -91,9 +92,8 @@ struct NVVMCheckSMVersion {
return NVVMCheckSMVersion(isAA, isFS, smVersionInt);
}
- NVVMCheckSMVersion &append(const NVVMCheckSMVersion &other) {
+ void append(const NVVMCheckSMVersion &other) {
smVersionList.append(other.smVersionList);
- return *this;
}
};
@@ -156,17 +156,18 @@ class NVVMRequiresSMf {
}
};
-template <typename T, typename U>
+template <typename SMVersionsA, typename SMVersionsF>
class NVVMRequiresSMaOrf {
public:
template <typename ConcreteOp>
class Impl
- : public OpTrait::TraitBase<ConcreteOp, NVVMRequiresSMaOrf<T, U>::Impl>,
+ : public OpTrait::TraitBase<
+ ConcreteOp, NVVMRequiresSMaOrf<SMVersionsA, SMVersionsF>::Impl>,
public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
public:
NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
- auto result = T::getRequiredMinSMVersion();
- result.append(U::getRequiredMinSMVersion());
+ auto result = SMVersionsA::getRequiredMinSMVersion();
+ result.append(SMVersionsF::getRequiredMinSMVersion());
return result;
}
};
>From 6535e13f5ce6e3aa2029a2934fae5157de163508 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Wed, 18 Mar 2026 08:06:30 +0000
Subject: [PATCH 05/10] address comments
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 42 +++++++++----------
.../Dialect/LLVMIR/NVVMRequiresSMTraits.h | 24 +++++++----
.../Dialect/LLVMIR/NVVMRequiresSMTraits.td | 16 +++----
mlir/test/lib/Dialect/Test/TestOps.td | 10 ++---
4 files changed, 49 insertions(+), 43 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index f8e1ab38e80d4..33442189ff76f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2247,7 +2247,7 @@ def NVVM_ConvertF32x2ToBF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"BF16x2", "conve
// (E4M3x4, E5M2x4, E2M3x4, E3M2x4, E2M1x4)
// These operations always use RS (stochastic rounding) mode with SATFINITE saturation.
class NVVM_ConvertF32x4ToFPx4OpBase<string dstFormat, string mnemonic, Type dstType> :
- NVVM_Op<mnemonic, [Pure, NVVMRequiresSMa<[100, 103]>]>,
+ NVVM_Op<mnemonic, [Pure, NVVMRequiresSMAA<[100, 103]>]>,
Results<(outs dstType:$dst)>,
Arguments<(ins VectorOfLengthAndType<[4], [F32]>:$src, I32:$rbits,
DefaultValuedAttr<BoolAttr, "false">:$relu,
@@ -4514,7 +4514,7 @@ def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
// NVVM Wgmma Ops
//===----------------------------------------------------------------------===//
-def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", [NVVMRequiresSMa<[90]>]> {
+def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", [NVVMRequiresSMAA<[90]>]> {
let arguments = (ins);
let description = [{
Enforce an ordering of register accesses between warpgroup level matrix
@@ -4528,7 +4528,7 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", [NVVMRequiresSMa<[
}];
}
-def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned", [NVVMRequiresSMa<[90]>]> {
+def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned", [NVVMRequiresSMAA<[90]>]> {
let assemblyFormat = "attr-dict";
let description = [{
Commits all prior uncommitted warpgroup level matrix multiplication operations.
@@ -4540,7 +4540,7 @@ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned", [N
}];
}
-def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned", [NVVMRequiresSMa<[90]>]> {
+def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned", [NVVMRequiresSMAA<[90]>]> {
let arguments = (ins I64Attr:$group);
let assemblyFormat = "attr-dict $group";
let description = [{
@@ -4902,7 +4902,7 @@ def Tcgen05WaitKindAttr :
let assemblyFormat = "`<` $value `>`";
}
-def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc", [NVVMRequiresSMa<[100, 101]>]> {
+def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc", [NVVMRequiresSMAA<[100, 101]>]> {
let summary = "Tcgen05 alloc operation";
let description = [{
The `tcgen05.alloc` Op allocates tensor core memory for
@@ -4932,7 +4932,7 @@ def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc", [NVVMRequiresSMa<[100, 101]>]
}];
}
-def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc", [NVVMRequiresSMa<[100, 101]>]> {
+def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc", [NVVMRequiresSMAA<[100, 101]>]> {
let summary = "Tcgen05 dealloc operation";
let description = [{
The `tcgen05.dealloc` Op de-allocates the tensor core memory
@@ -4960,7 +4960,7 @@ def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc", [NVVMRequiresSMa<[100, 10
}];
}
-def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_permit", [NVVMRequiresSMa<[100, 101]>]> {
+def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_permit", [NVVMRequiresSMAA<[100, 101]>]> {
let summary = "Tcgen05 Op to relinquish the right to allocate";
let description = [{
The `tcgen05.relinquish_alloc_permit` Op specifies that the CTA
@@ -4983,7 +4983,7 @@ def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_perm
}];
}
-def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence", [NVVMRequiresSMa<[100, 101]>]> {
+def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence", [NVVMRequiresSMAA<[100, 101]>]> {
let summary = "Tcgen05 fence operations";
let description = [{
The `tcgen05.fence<before>` orders all prior async tcgen05 operations
@@ -5005,7 +5005,7 @@ def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence", [NVVMRequiresSMa<[100, 101]>]
}];
}
-def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait", [NVVMRequiresSMa<[100, 101]>]> {
+def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait", [NVVMRequiresSMAA<[100, 101]>]> {
let summary = "Tcgen05 wait operations";
let description = [{
The `tcgen05.wait<load>` causes the executing thread to block until
@@ -5027,7 +5027,7 @@ def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait", [NVVMRequiresSMa<[100, 101]>]>
}];
}
-def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit", [NVVMRequiresSMa<[100, 101]>]> {
+def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit", [NVVMRequiresSMAA<[100, 101]>]> {
let summary = "Tcgen05 commit operations";
let description = [{
The `tcgen05.commit` makes the *mbarrier object*, specified by
@@ -5065,7 +5065,7 @@ def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit", [NVVMRequiresSMa<[100, 101]
}];
}
-def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift", [NVVMRequiresSMa<[100, 101, 103]>]> {
+def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift", [NVVMRequiresSMAA<[100, 101, 103]>]> {
let summary = "Tcgen05 shift operation";
let description = [{
The `tcgen05.shift` is an asynchronous instruction which initiates
@@ -5131,7 +5131,7 @@ def Tcgen05CpSrcFormatAttr : EnumAttr<NVVM_Dialect, Tcgen05CpSrcFormat, "tcgen05
let assemblyFormat = "`<` $value `>`";
}
-def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp", [NVVMRequiresSMa<[100, 101]>]> {
+def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp", [NVVMRequiresSMAA<[100, 101]>]> {
let summary = "Tcgen05 copy operation";
let description = [{
Instruction tcgen05.cp initiates an asynchronous copy operation from
@@ -5267,7 +5267,7 @@ def Tcgen05LdStShapeAttr: EnumAttr<NVVM_Dialect, Tcgen05LdStShape, "tcgen05_ldst
// NVVM tcgen05.ld Op
//===----------------------------------------------------------------------===//
-def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld", [NVVMRequiresSMa<[100, 101]>]> {
+def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld", [NVVMRequiresSMAA<[100, 101]>]> {
let summary = "tensor memory load instructions";
let arguments = (ins
// Attributes
@@ -5360,7 +5360,7 @@ def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld", [NVVMRequiresSMa<[100, 101]>]> {
//===----------------------------------------------------------------------===//
def NVVM_Tcgen05LdRedOp : NVVM_Op<"tcgen05.ld.red",
- [NVVMRequiresSMa<[101, 110]>]> {
+ [NVVMRequiresSMAA<[101, 110]>]> {
let summary = "Tcgen05 tensor memory load and reduce instructions";
let arguments = (ins
Tcgen05LdStShapeAttr:$shape,
@@ -5449,7 +5449,7 @@ def NVVM_Tcgen05LdRedOp : NVVM_Op<"tcgen05.ld.red",
// NVVM tcgen05.st Op
//===----------------------------------------------------------------------===//
-def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSMa<[100, 101]>]> {
+def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSMAA<[100, 101]>]> {
let summary = "tensor memory store instructions";
let arguments = (ins
// Attributes
@@ -5852,7 +5852,7 @@ defvar Tcgen05MMABlockScaleKindAttr =
def NVVM_Tcgen05MMAOp : NVVM_Op<"tcgen05.mma",
[AttrSizedOperandSegments,
- NVVMRequiresSMa<[100, 110]>]> {
+ NVVMRequiresSMAA<[100, 110]>]> {
let summary = "Performs MMA operation on 5th-gen tensor cores";
let description = [{
@@ -5936,7 +5936,7 @@ def NVVM_Tcgen05MMAOp : NVVM_Op<"tcgen05.mma",
def NVVM_Tcgen05MMASparseOp : NVVM_Op<"tcgen05.mma.sp",
[AttrSizedOperandSegments,
- NVVMRequiresSMa<[100, 110]>]> {
+ NVVMRequiresSMAA<[100, 110]>]> {
let summary = "Performs MMA operation with sparse A matrix on 5th-gen tensor cores";
let description = [{
@@ -6017,7 +6017,7 @@ def Tcgen05MMABlockScaleAttr : EnumAttr<NVVM_Dialect, Tcgen05MMABlockScale,
}
def NVVM_Tcgen05MMABlockScaleOp : NVVM_Op<"tcgen05.mma.block_scale",
- [NVVMRequiresSMa<[100, 110]>]> {
+ [NVVMRequiresSMAA<[100, 110]>]> {
let summary = "Performs block scaled MMA operation on 5th-gen tensor cores";
let description = [{
@@ -6090,7 +6090,7 @@ def NVVM_Tcgen05MMABlockScaleOp : NVVM_Op<"tcgen05.mma.block_scale",
}
def NVVM_Tcgen05MMASparseBlockScaleOp : NVVM_Op<"tcgen05.mma.sp.block_scale",
- [NVVMRequiresSMa<[100, 110]>]> {
+ [NVVMRequiresSMAA<[100, 110]>]> {
let summary = "Performs block scaled MMA operation with sparse A matrix on 5th-gen tensor cores";
let description = [{
@@ -6172,7 +6172,7 @@ def Tcgen05MMACollectorBBufferAttr : EnumAttr<NVVM_Dialect, Tcgen05MMACollectorB
}
def NVVM_Tcgen05MMAWsOp : NVVM_Op<"tcgen05.mma.ws",
- [NVVMRequiresSMa<[100, 110]>]> {
+ [NVVMRequiresSMAA<[100, 110]>]> {
let summary = "Performs weight stationary convolution MMA operation on 5th-gen tensor cores";
let description = [{
@@ -6242,7 +6242,7 @@ def NVVM_Tcgen05MMAWsOp : NVVM_Op<"tcgen05.mma.ws",
}
def NVVM_Tcgen05MMAWsSparseOp : NVVM_Op<"tcgen05.mma.ws.sp",
- [NVVMRequiresSMa<[100, 110]>]> {
+ [NVVMRequiresSMAA<[100, 110]>]> {
let summary = "Performs weight stationary convolution MMA with sparse A matrix on 5th-gen tensor cores";
let description = [{
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
index f6718ad0e29f6..e0bb3c5be8ab7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
@@ -120,12 +120,14 @@ class NVVMRequiresSM {
};
};
+// SMVersions is a template parameter pack of the supported
+// architecture-accelerated SM versions.
template <int... SMVersions>
-class NVVMRequiresSMa {
+class NVVMRequiresSMAA {
public:
template <typename ConcreteOp>
class Impl : public OpTrait::TraitBase<ConcreteOp,
- NVVMRequiresSMa<SMVersions...>::Impl>,
+ NVVMRequiresSMAA<SMVersions...>::Impl>,
public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
public:
NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
@@ -138,12 +140,14 @@ class NVVMRequiresSMa {
}
};
+// SMVersions is a template parameter pack of the supported family-specific SM
+// versions.
template <int... SMVersions>
-class NVVMRequiresSMf {
+class NVVMRequiresSMFS {
public:
template <typename ConcreteOp>
class Impl : public OpTrait::TraitBase<ConcreteOp,
- NVVMRequiresSMf<SMVersions...>::Impl>,
+ NVVMRequiresSMFS<SMVersions...>::Impl>,
public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
public:
NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
@@ -156,18 +160,20 @@ class NVVMRequiresSMf {
}
};
-template <typename SMVersionsA, typename SMVersionsF>
-class NVVMRequiresSMaOrf {
+// SMVersionsAA (SMVersionsFS) is a template parameter pack of the supported
+// architecture-accelerated (family-specific) SM versions.
+template <typename SMVersionsAA, typename SMVersionsFS>
+class NVVMRequiresSMAAOrFS {
public:
template <typename ConcreteOp>
class Impl
: public OpTrait::TraitBase<
- ConcreteOp, NVVMRequiresSMaOrf<SMVersionsA, SMVersionsF>::Impl>,
+ ConcreteOp, NVVMRequiresSMAAOrFS<SMVersionsAA, SMVersionsFS>::Impl>,
public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
public:
NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
- auto result = SMVersionsA::getRequiredMinSMVersion();
- result.append(SMVersionsF::getRequiredMinSMVersion());
+ auto result = SMVersionsAA::getRequiredMinSMVersion();
+ result.append(SMVersionsFS::getRequiredMinSMVersion());
return result;
}
};
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
index 8c3d34fe033e9..0b009819fc2c3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
@@ -33,16 +33,16 @@ class NVVMRequiresSM<int minVersion> :
ParamNativeOpTrait<"NVVMRequiresSM", !cast<string>(minVersion)>;
// Op requires an exact SM match along with architecture acceleration.
-class NVVMRequiresSMa<list<int> smVersions> :
- ParamNativeOpTrait<"NVVMRequiresSMa", !interleave(smVersions, ",")>;
+class NVVMRequiresSMAA<list<int> smVersions> :
+ ParamNativeOpTrait<"NVVMRequiresSMAA", !interleave(smVersions, ",")>;
// Op requires an SM version belonging to the family.
-class NVVMRequiresSMf<list<int> smVersions> :
- ParamNativeOpTrait<"NVVMRequiresSMf", !interleave(smVersions, ",")>;
+class NVVMRequiresSMFS<list<int> smVersions> :
+ ParamNativeOpTrait<"NVVMRequiresSMFS", !interleave(smVersions, ",")>;
// Op supported on some combination of architecture acceleration and family-specific SM versions.
-class NVVMRequiresSMaOrf<list<int> smVersionsA, list<int> smVersionsF> :
- ParamNativeOpTrait<"NVVMRequiresSMaOrf",
- "mlir::OpTrait::NVVMRequiresSMa<" # !interleave(smVersionsA, ",") # ">" # "," #
- "mlir::OpTrait::NVVMRequiresSMf<" # !interleave(smVersionsF, ",") # ">">;
+class NVVMRequiresSMAAOrFS<list<int> smVersionsAA, list<int> smVersionsFS> :
+ ParamNativeOpTrait<"NVVMRequiresSMAAOrFS",
+ "mlir::OpTrait::NVVMRequiresSMAA<" # !interleave(smVersionsAA, ",") # ">" # "," #
+ "mlir::OpTrait::NVVMRequiresSMFS<" # !interleave(smVersionsFS, ",") # ">">;
#endif //NVVM_REQUIRES_SM_TRAITS
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 49fedc643c62a..38df7d5b13e50 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3084,31 +3084,31 @@ def TestNVVMRequiresSMOp :
}
def TestNVVMRequiresSMArchCondOp :
- TEST_Op<"nvvm_requires_sm_90a", [NVVMRequiresSMa<[90]>]> {
+ TEST_Op<"nvvm_requires_sm_90a", [NVVMRequiresSMAA<[90]>]> {
let arguments = (ins );
let assemblyFormat = "attr-dict";
}
def TestNVVMRequirestSMArchCondMultiOp :
- TEST_Op<"nvvm_requires_sm_90a_or_sm_100a", [NVVMRequiresSMa<[90, 100]>]> {
+ TEST_Op<"nvvm_requires_sm_90a_or_sm_100a", [NVVMRequiresSMAA<[90, 100]>]> {
let arguments = (ins );
let assemblyFormat = "attr-dict";
}
def TestNVVMRequiresSMFamilyCondOp :
- TEST_Op<"nvvm_requires_sm_100f", [NVVMRequiresSMf<[100]>]> {
+ TEST_Op<"nvvm_requires_sm_100f", [NVVMRequiresSMFS<[100]>]> {
let arguments = (ins );
let assemblyFormat = "attr-dict";
}
def TestNVVMRequiresSMFamilyCondMultiOp :
- TEST_Op<"nvvm_requires_sm_100f_or_sm_120f", [NVVMRequiresSMf<[100, 120]>]> {
+ TEST_Op<"nvvm_requires_sm_100f_or_sm_120f", [NVVMRequiresSMFS<[100, 120]>]> {
let arguments = (ins );
let assemblyFormat = "attr-dict";
}
def TestNVVMRequiresSMArchOrFamilyCondOp :
- TEST_Op<"nvvm_requires_sm_90a_or_sm_100f", [NVVMRequiresSMaOrf<[90], [100]>]> {
+ TEST_Op<"nvvm_requires_sm_90a_or_sm_100f", [NVVMRequiresSMAAOrFS<[90], [100]>]> {
let arguments = (ins );
let assemblyFormat = "attr-dict";
}
>From 6185c9b7edab8940ea4e2332a862da869398eaba Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Wed, 25 Mar 2026 10:34:59 +0000
Subject: [PATCH 06/10] address comments and switch to use FullSMVersion
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 42 ++---
.../Dialect/LLVMIR/NVVMRequiresSMTraits.h | 164 +++++-------------
.../Dialect/LLVMIR/NVVMRequiresSMTraits.td | 26 +--
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 8 +-
.../Dialect/LLVMIR/nvvm-check-targetSM.mlir | 72 ++++----
mlir/test/lib/Dialect/Test/TestOps.td | 10 +-
6 files changed, 127 insertions(+), 195 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 33442189ff76f..f8e1ab38e80d4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2247,7 +2247,7 @@ def NVVM_ConvertF32x2ToBF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"BF16x2", "conve
// (E4M3x4, E5M2x4, E2M3x4, E3M2x4, E2M1x4)
// These operations always use RS (stochastic rounding) mode with SATFINITE saturation.
class NVVM_ConvertF32x4ToFPx4OpBase<string dstFormat, string mnemonic, Type dstType> :
- NVVM_Op<mnemonic, [Pure, NVVMRequiresSMAA<[100, 103]>]>,
+ NVVM_Op<mnemonic, [Pure, NVVMRequiresSMa<[100, 103]>]>,
Results<(outs dstType:$dst)>,
Arguments<(ins VectorOfLengthAndType<[4], [F32]>:$src, I32:$rbits,
DefaultValuedAttr<BoolAttr, "false">:$relu,
@@ -4514,7 +4514,7 @@ def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
// NVVM Wgmma Ops
//===----------------------------------------------------------------------===//
-def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", [NVVMRequiresSMAA<[90]>]> {
+def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", [NVVMRequiresSMa<[90]>]> {
let arguments = (ins);
let description = [{
Enforce an ordering of register accesses between warpgroup level matrix
@@ -4528,7 +4528,7 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", [NVVMRequiresSMAA<
}];
}
-def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned", [NVVMRequiresSMAA<[90]>]> {
+def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned", [NVVMRequiresSMa<[90]>]> {
let assemblyFormat = "attr-dict";
let description = [{
Commits all prior uncommitted warpgroup level matrix multiplication operations.
@@ -4540,7 +4540,7 @@ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned", [N
}];
}
-def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned", [NVVMRequiresSMAA<[90]>]> {
+def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned", [NVVMRequiresSMa<[90]>]> {
let arguments = (ins I64Attr:$group);
let assemblyFormat = "attr-dict $group";
let description = [{
@@ -4902,7 +4902,7 @@ def Tcgen05WaitKindAttr :
let assemblyFormat = "`<` $value `>`";
}
-def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc", [NVVMRequiresSMAA<[100, 101]>]> {
+def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc", [NVVMRequiresSMa<[100, 101]>]> {
let summary = "Tcgen05 alloc operation";
let description = [{
The `tcgen05.alloc` Op allocates tensor core memory for
@@ -4932,7 +4932,7 @@ def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc", [NVVMRequiresSMAA<[100, 101]>
}];
}
-def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc", [NVVMRequiresSMAA<[100, 101]>]> {
+def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc", [NVVMRequiresSMa<[100, 101]>]> {
let summary = "Tcgen05 dealloc operation";
let description = [{
The `tcgen05.dealloc` Op de-allocates the tensor core memory
@@ -4960,7 +4960,7 @@ def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc", [NVVMRequiresSMAA<[100, 1
}];
}
-def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_permit", [NVVMRequiresSMAA<[100, 101]>]> {
+def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_permit", [NVVMRequiresSMa<[100, 101]>]> {
let summary = "Tcgen05 Op to relinquish the right to allocate";
let description = [{
The `tcgen05.relinquish_alloc_permit` Op specifies that the CTA
@@ -4983,7 +4983,7 @@ def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_perm
}];
}
-def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence", [NVVMRequiresSMAA<[100, 101]>]> {
+def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence", [NVVMRequiresSMa<[100, 101]>]> {
let summary = "Tcgen05 fence operations";
let description = [{
The `tcgen05.fence<before>` orders all prior async tcgen05 operations
@@ -5005,7 +5005,7 @@ def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence", [NVVMRequiresSMAA<[100, 101]>
}];
}
-def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait", [NVVMRequiresSMAA<[100, 101]>]> {
+def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait", [NVVMRequiresSMa<[100, 101]>]> {
let summary = "Tcgen05 wait operations";
let description = [{
The `tcgen05.wait<load>` causes the executing thread to block until
@@ -5027,7 +5027,7 @@ def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait", [NVVMRequiresSMAA<[100, 101]>]>
}];
}
-def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit", [NVVMRequiresSMAA<[100, 101]>]> {
+def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit", [NVVMRequiresSMa<[100, 101]>]> {
let summary = "Tcgen05 commit operations";
let description = [{
The `tcgen05.commit` makes the *mbarrier object*, specified by
@@ -5065,7 +5065,7 @@ def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit", [NVVMRequiresSMAA<[100, 101
}];
}
-def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift", [NVVMRequiresSMAA<[100, 101, 103]>]> {
+def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift", [NVVMRequiresSMa<[100, 101, 103]>]> {
let summary = "Tcgen05 shift operation";
let description = [{
The `tcgen05.shift` is an asynchronous instruction which initiates
@@ -5131,7 +5131,7 @@ def Tcgen05CpSrcFormatAttr : EnumAttr<NVVM_Dialect, Tcgen05CpSrcFormat, "tcgen05
let assemblyFormat = "`<` $value `>`";
}
-def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp", [NVVMRequiresSMAA<[100, 101]>]> {
+def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp", [NVVMRequiresSMa<[100, 101]>]> {
let summary = "Tcgen05 copy operation";
let description = [{
Instruction tcgen05.cp initiates an asynchronous copy operation from
@@ -5267,7 +5267,7 @@ def Tcgen05LdStShapeAttr: EnumAttr<NVVM_Dialect, Tcgen05LdStShape, "tcgen05_ldst
// NVVM tcgen05.ld Op
//===----------------------------------------------------------------------===//
-def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld", [NVVMRequiresSMAA<[100, 101]>]> {
+def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld", [NVVMRequiresSMa<[100, 101]>]> {
let summary = "tensor memory load instructions";
let arguments = (ins
// Attributes
@@ -5360,7 +5360,7 @@ def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld", [NVVMRequiresSMAA<[100, 101]>]> {
//===----------------------------------------------------------------------===//
def NVVM_Tcgen05LdRedOp : NVVM_Op<"tcgen05.ld.red",
- [NVVMRequiresSMAA<[101, 110]>]> {
+ [NVVMRequiresSMa<[101, 110]>]> {
let summary = "Tcgen05 tensor memory load and reduce instructions";
let arguments = (ins
Tcgen05LdStShapeAttr:$shape,
@@ -5449,7 +5449,7 @@ def NVVM_Tcgen05LdRedOp : NVVM_Op<"tcgen05.ld.red",
// NVVM tcgen05.st Op
//===----------------------------------------------------------------------===//
-def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSMAA<[100, 101]>]> {
+def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSMa<[100, 101]>]> {
let summary = "tensor memory store instructions";
let arguments = (ins
// Attributes
@@ -5852,7 +5852,7 @@ defvar Tcgen05MMABlockScaleKindAttr =
def NVVM_Tcgen05MMAOp : NVVM_Op<"tcgen05.mma",
[AttrSizedOperandSegments,
- NVVMRequiresSMAA<[100, 110]>]> {
+ NVVMRequiresSMa<[100, 110]>]> {
let summary = "Performs MMA operation on 5th-gen tensor cores";
let description = [{
@@ -5936,7 +5936,7 @@ def NVVM_Tcgen05MMAOp : NVVM_Op<"tcgen05.mma",
def NVVM_Tcgen05MMASparseOp : NVVM_Op<"tcgen05.mma.sp",
[AttrSizedOperandSegments,
- NVVMRequiresSMAA<[100, 110]>]> {
+ NVVMRequiresSMa<[100, 110]>]> {
let summary = "Performs MMA operation with sparse A matrix on 5th-gen tensor cores";
let description = [{
@@ -6017,7 +6017,7 @@ def Tcgen05MMABlockScaleAttr : EnumAttr<NVVM_Dialect, Tcgen05MMABlockScale,
}
def NVVM_Tcgen05MMABlockScaleOp : NVVM_Op<"tcgen05.mma.block_scale",
- [NVVMRequiresSMAA<[100, 110]>]> {
+ [NVVMRequiresSMa<[100, 110]>]> {
let summary = "Performs block scaled MMA operation on 5th-gen tensor cores";
let description = [{
@@ -6090,7 +6090,7 @@ def NVVM_Tcgen05MMABlockScaleOp : NVVM_Op<"tcgen05.mma.block_scale",
}
def NVVM_Tcgen05MMASparseBlockScaleOp : NVVM_Op<"tcgen05.mma.sp.block_scale",
- [NVVMRequiresSMAA<[100, 110]>]> {
+ [NVVMRequiresSMa<[100, 110]>]> {
let summary = "Performs block scaled MMA operation with sparse A matrix on 5th-gen tensor cores";
let description = [{
@@ -6172,7 +6172,7 @@ def Tcgen05MMACollectorBBufferAttr : EnumAttr<NVVM_Dialect, Tcgen05MMACollectorB
}
def NVVM_Tcgen05MMAWsOp : NVVM_Op<"tcgen05.mma.ws",
- [NVVMRequiresSMAA<[100, 110]>]> {
+ [NVVMRequiresSMa<[100, 110]>]> {
let summary = "Performs weight stationary convolution MMA operation on 5th-gen tensor cores";
let description = [{
@@ -6242,7 +6242,7 @@ def NVVM_Tcgen05MMAWsOp : NVVM_Op<"tcgen05.mma.ws",
}
def NVVM_Tcgen05MMAWsSparseOp : NVVM_Op<"tcgen05.mma.ws.sp",
- [NVVMRequiresSMAA<[100, 110]>]> {
+ [NVVMRequiresSMa<[100, 110]>]> {
let summary = "Performs weight stationary convolution MMA with sparse A matrix on 5th-gen tensor cores";
let description = [{
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
index e0bb3c5be8ab7..82f41d88c1037 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
@@ -23,77 +23,61 @@ namespace NVVM {
// Struct to store and check compatibility of SM versions.
struct NVVMCheckSMVersion {
- struct SMVersion {
- // Base SM version (e.g., sm_100)
- unsigned version;
- // Set to true if the SM version is accelerated (e.g., sm_100a).
- bool archAccelerated;
- // Set to true if the SM version is family-specific (e.g., sm_100f).
- bool familySpecific;
-
- unsigned getSmFamilyVersion() const { return version / 10; }
-
- bool hasFamilySpecificFeatures() const {
- return familySpecific || archAccelerated;
- }
- };
-
- // List of SM versions.
- // Typically only has one version except for cases where multiple
- // arch-accelerated or family-conditional versions are supported.
- // For example, tcgen05.shift is supported on sm_100a, sm_101a, and sm_103a.
- llvm::SmallVector<SMVersion> smVersionList;
+ // List of supported full SM versions.
+ // This is used to check compatibility with a target SM version.
+ // The full SM version is encoded as SM * 10 + ArchSuffixOffset where:
+ // - SM is the SM version (e.g., 100)
+ // - ArchSuffixOffset is 0 for base, 2 for family-specific, and 3 for
+ // architecture-accelerated
+ //
+ // For example, sm_100 is encoded as 1000 (100 * 10 + 0), sm_100f is encoded
+ // as 1002 (100 * 10 + 2) and sm_100a is encoded as 1003 (100 * 10 + 3).
+ llvm::SmallVector<unsigned> fullSmVersionList;
template <typename... Versions>
- NVVMCheckSMVersion(bool archAccelerated, bool familySpecific,
- Versions... smVersions)
- : smVersionList({SMVersion{static_cast<unsigned>(smVersions),
- archAccelerated, familySpecific}...}) {
- assert(
- !(archAccelerated && familySpecific) &&
- "archAccelerated and familySpecific cannot be true at the same time!");
+ NVVMCheckSMVersion(Versions... fullSmVersions)
+ : fullSmVersionList({fullSmVersions...}) {}
+
+ bool isCompatibleWith(const unsigned &targetFullSmVersion) const {
+ return llvm::any_of(
+ fullSmVersionList, [&](const unsigned &requiredFullSmVersion) {
+ if (hasArchAcceleratedFeatures(requiredFullSmVersion))
+ return targetFullSmVersion == requiredFullSmVersion;
+
+ if (hasFamilySpecificFeatures(requiredFullSmVersion))
+ return hasFamilySpecificFeatures(targetFullSmVersion) &&
+ ((targetFullSmVersion / 100) ==
+ (requiredFullSmVersion / 100));
+
+ return targetFullSmVersion >= requiredFullSmVersion;
+ });
}
- bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const {
- assert(targetSM.smVersionList.size() == 1 &&
- "target SM version list must have a single version!");
-
- SMVersion targetSMVersion = targetSM.smVersionList[0];
-
- return llvm::any_of(smVersionList, [&](const SMVersion &RequiredSMVersion) {
- if (RequiredSMVersion.archAccelerated)
- return targetSMVersion.archAccelerated &&
- (RequiredSMVersion.version == targetSMVersion.version);
-
- if (RequiredSMVersion.familySpecific)
- return targetSMVersion.hasFamilySpecificFeatures() &&
- (RequiredSMVersion.getSmFamilyVersion() ==
- targetSMVersion.getSmFamilyVersion()) &&
- (targetSMVersion.version >= RequiredSMVersion.version);
-
- return targetSMVersion.version >= RequiredSMVersion.version;
- });
+ static bool isMinimumSMVersion(unsigned targetFullSmVersion) {
+ return targetFullSmVersion >= 200;
}
- bool isMinimumSMVersion() const { return smVersionList[0].version >= 20; }
-
- // Parses an SM version string and returns an equivalent NVVMCheckSMVersion
- // object.
- static NVVMCheckSMVersion
- getTargetSMVersionFromStr(StringRef smVersionString) {
+ // Parses an SM version string and returns an equivalent full SM version
+ // integer.
+ static unsigned getTargetFullSmVersionFromStr(StringRef smVersionString) {
bool isAA = smVersionString.back() == 'a';
bool isFS = smVersionString.back() == 'f';
- int smVersionInt;
+ unsigned smVersion;
smVersionString.drop_front(3)
.take_while([](char c) { return llvm::isDigit(c); })
- .getAsInteger(10, smVersionInt);
+ .getAsInteger(10, smVersion);
+
+ return smVersion * 10 + (isAA ? 3 : 0) + (isFS ? 2 : 0);
+ }
- return NVVMCheckSMVersion(isAA, isFS, smVersionInt);
+private:
+ bool hasFamilySpecificFeatures(unsigned fullSmVersion) const {
+ return (fullSmVersion % 10) != 0;
}
- void append(const NVVMCheckSMVersion &other) {
- smVersionList.append(other.smVersionList);
+ bool hasArchAcceleratedFeatures(unsigned fullSmVersion) const {
+ return (fullSmVersion % 10) == 3;
}
};
@@ -106,75 +90,17 @@ namespace mlir {
namespace OpTrait {
-template <int MinVersion>
+template <unsigned... FullSMVersions>
class NVVMRequiresSM {
public:
template <typename ConcreteOp>
class Impl
- : public OpTrait::TraitBase<ConcreteOp, NVVMRequiresSM<MinVersion>::Impl>,
- public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
- public:
- NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
- return NVVM::NVVMCheckSMVersion(false, false, MinVersion);
- }
- };
-};
-
-// SMVersions is a template parameter pack of the supported
-// architecture-accelerated SM versions.
-template <int... SMVersions>
-class NVVMRequiresSMAA {
-public:
- template <typename ConcreteOp>
- class Impl : public OpTrait::TraitBase<ConcreteOp,
- NVVMRequiresSMAA<SMVersions...>::Impl>,
- public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
- public:
- NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
- return NVVM::NVVMCheckSMVersion(true, false, SMVersions...);
- }
- };
-
- static NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() {
- return NVVM::NVVMCheckSMVersion(true, false, SMVersions...);
- }
-};
-
-// SMVersions is a template parameter pack of the supported family-specific SM
-// versions.
-template <int... SMVersions>
-class NVVMRequiresSMFS {
-public:
- template <typename ConcreteOp>
- class Impl : public OpTrait::TraitBase<ConcreteOp,
- NVVMRequiresSMFS<SMVersions...>::Impl>,
- public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
- public:
- NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
- return NVVM::NVVMCheckSMVersion(false, true, SMVersions...);
- }
- };
-
- static NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() {
- return NVVM::NVVMCheckSMVersion(false, true, SMVersions...);
- }
-};
-
-// SMVersionsAA (SMVersionsFS) is a template parameter pack of the supported
-// architecture-accelerated (family-specific) SM versions.
-template <typename SMVersionsAA, typename SMVersionsFS>
-class NVVMRequiresSMAAOrFS {
-public:
- template <typename ConcreteOp>
- class Impl
- : public OpTrait::TraitBase<
- ConcreteOp, NVVMRequiresSMAAOrFS<SMVersionsAA, SMVersionsFS>::Impl>,
+ : public OpTrait::TraitBase<ConcreteOp,
+ NVVMRequiresSM<FullSMVersions...>::Impl>,
public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
public:
NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
- auto result = SMVersionsAA::getRequiredMinSMVersion();
- result.append(SMVersionsFS::getRequiredMinSMVersion());
- return result;
+ return NVVM::NVVMCheckSMVersion(FullSMVersions...);
}
};
};
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
index 0b009819fc2c3..0e06085803b8d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
@@ -21,7 +21,7 @@ def RequiresSMInterface: OpInterface<"RequiresSMInterface"> {
let cppNamespace = "::mlir::NVVM";
let methods = [
InterfaceMethod<
- "Get the SM version required by the op from the trait",
+ "Get the SM version require by the op from the trait",
"mlir::NVVM::NVVMCheckSMVersion", "getRequiredMinSMVersion"
>
];
@@ -30,19 +30,25 @@ def RequiresSMInterface: OpInterface<"RequiresSMInterface"> {
// Op requires a specified minimum SM value or higher;
// it is not architecture-specific.
class NVVMRequiresSM<int minVersion> :
- ParamNativeOpTrait<"NVVMRequiresSM", !cast<string>(minVersion)>;
+ ParamNativeOpTrait<"NVVMRequiresSM", !cast<string>(minVersion) # "0">;
// Op requires an exact SM match along with architecture acceleration.
-class NVVMRequiresSMAA<list<int> smVersions> :
- ParamNativeOpTrait<"NVVMRequiresSMAA", !interleave(smVersions, ",")>;
+class NVVMRequiresSMa<list<int> smVersions> :
+ ParamNativeOpTrait<"NVVMRequiresSM",
+ !interleave(!foreach(smVersion, smVersions,
+ !add(!mul(smVersion, 10), 3)), ",")>;
// Op requires an SM version belonging to the family.
-class NVVMRequiresSMFS<list<int> smVersions> :
- ParamNativeOpTrait<"NVVMRequiresSMFS", !interleave(smVersions, ",")>;
+class NVVMRequiresSMf<list<int> smVersions> :
+ ParamNativeOpTrait<"NVVMRequiresSM",
+ !interleave(!foreach(smVersion, smVersions,
+ !add(!mul(smVersion, 10), 2)), ",")>;
// Op supported on some combination of architecture acceleration and family-specific SM versions.
-class NVVMRequiresSMAAOrFS<list<int> smVersionsAA, list<int> smVersionsFS> :
- ParamNativeOpTrait<"NVVMRequiresSMAAOrFS",
- "mlir::OpTrait::NVVMRequiresSMAA<" # !interleave(smVersionsAA, ",") # ">" # "," #
- "mlir::OpTrait::NVVMRequiresSMFS<" # !interleave(smVersionsFS, ",") # ">">;
+class NVVMRequiresSMaOrf<list<int> smVersionsA, list<int> smVersionsF> :
+ ParamNativeOpTrait<"NVVMRequiresSM",
+ !interleave(!foreach(smVersion, smVersionsA,
+ !add(!mul(smVersion, 10), 3)), ",") # "," #
+ !interleave(!foreach(smVersion, smVersionsF,
+ !add(!mul(smVersion, 10), 2)), ",")>;
#endif //NVVM_REQUIRES_SM_TRAITS
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 6ccd59cec65bc..bf3931075e08b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -5822,9 +5822,9 @@ LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
"NVVM target attribute must be attached to a GPU module");
}
- const NVVMCheckSMVersion targetSMVersion =
- NVVMCheckSMVersion::getTargetSMVersionFromStr(getChip());
- if (!targetSMVersion.isMinimumSMVersion()) {
+ const unsigned targetFullSmVersion =
+ NVVMCheckSMVersion::getTargetFullSmVersionFromStr(getChip());
+ if (!NVVMCheckSMVersion::isMinimumSMVersion(targetFullSmVersion)) {
return emitError(gpuModule->getLoc(),
"Minimum NVVM target SM version is sm_20");
}
@@ -5834,7 +5834,7 @@ LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
const NVVMCheckSMVersion requirement =
reqOp.getRequiredMinSMVersion();
- if (!requirement.isCompatibleWith(targetSMVersion)) {
+ if (!requirement.isCompatibleWith(targetFullSmVersion)) {
op->emitOpError() << "is not supported on " << getChip();
return WalkResult::interrupt();
}
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
index 7876f57a4d9ce..43522a6ae14c1 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
@@ -1,53 +1,53 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
// Just check these don't emit errors.
-gpu.module @check_valid_SM_exact [#nvvm.target<chip = "sm_80">] {
- test.nvvm_requires_sm_80
-}
+// gpu.module @check_valid_SM_exact [#nvvm.target<chip = "sm_80">] {
+// test.nvvm_requires_sm_80
+// }
-gpu.module @check_valid_SM_greater_1 [#nvvm.target<chip = "sm_86">] {
- test.nvvm_requires_sm_80
-}
+// gpu.module @check_valid_SM_greater_1 [#nvvm.target<chip = "sm_86">] {
+// test.nvvm_requires_sm_80
+// }
-gpu.module @check_valid_SM_greater_2 [#nvvm.target<chip = "sm_90">] {
- test.nvvm_requires_sm_80
-}
+// gpu.module @check_valid_SM_greater_2 [#nvvm.target<chip = "sm_90">] {
+// test.nvvm_requires_sm_80
+// }
-gpu.module @check_valid_SM_arch_acc_1 [#nvvm.target<chip = "sm_90a">] {
- test.nvvm_requires_sm_90a
-}
+// gpu.module @check_valid_SM_arch_acc_1 [#nvvm.target<chip = "sm_90a">] {
+// test.nvvm_requires_sm_90a
+// }
-gpu.module @check_valid_SM_arch_acc_2 [#nvvm.target<chip = "sm_90a">] {
- test.nvvm_requires_sm_80
-}
+// gpu.module @check_valid_SM_arch_acc_2 [#nvvm.target<chip = "sm_90a">] {
+// test.nvvm_requires_sm_80
+// }
-gpu.module @check_valid_SM_arch_acc_multi_1 [#nvvm.target<chip = "sm_90a">] {
- test.nvvm_requires_sm_90a_or_sm_100a
-}
+// gpu.module @check_valid_SM_arch_acc_multi_1 [#nvvm.target<chip = "sm_90a">] {
+// test.nvvm_requires_sm_90a_or_sm_100a
+// }
-gpu.module @check_valid_SM_arch_acc_multi_2 [#nvvm.target<chip = "sm_100a">] {
- test.nvvm_requires_sm_90a_or_sm_100a
-}
+// gpu.module @check_valid_SM_arch_acc_multi_2 [#nvvm.target<chip = "sm_100a">] {
+// test.nvvm_requires_sm_90a_or_sm_100a
+// }
-gpu.module @check_valid_SM_family_1 [#nvvm.target<chip = "sm_100f">] {
- test.nvvm_requires_sm_100f
-}
+// gpu.module @check_valid_SM_family_1 [#nvvm.target<chip = "sm_100f">] {
+// test.nvvm_requires_sm_100f
+// }
-gpu.module @check_valid_SM_family_2 [#nvvm.target<chip = "sm_100a">] {
- test.nvvm_requires_sm_100f
-}
+// gpu.module @check_valid_SM_family_2 [#nvvm.target<chip = "sm_100a">] {
+// test.nvvm_requires_sm_100f
+// }
-gpu.module @check_valid_SM_family_3 [#nvvm.target<chip = "sm_103a">] {
- test.nvvm_requires_sm_100f
-}
+// gpu.module @check_valid_SM_family_3 [#nvvm.target<chip = "sm_103a">] {
+// test.nvvm_requires_sm_100f
+// }
-gpu.module @check_valid_SM_family_multi_1 [#nvvm.target<chip = "sm_100f">] {
- test.nvvm_requires_sm_100f_or_sm_120f
-}
+// gpu.module @check_valid_SM_family_multi_1 [#nvvm.target<chip = "sm_100f">] {
+// test.nvvm_requires_sm_100f_or_sm_120f
+// }
-gpu.module @check_valid_SM_family_multi_2 [#nvvm.target<chip = "sm_120f">] {
- test.nvvm_requires_sm_100f_or_sm_120f
-}
+// gpu.module @check_valid_SM_family_multi_2 [#nvvm.target<chip = "sm_120f">] {
+// test.nvvm_requires_sm_100f_or_sm_120f
+// }
gpu.module @check_valid_SM_arch_or_family_1 [#nvvm.target<chip = "sm_90a">] {
test.nvvm_requires_sm_90a_or_sm_100f
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 38df7d5b13e50..49fedc643c62a 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3084,31 +3084,31 @@ def TestNVVMRequiresSMOp :
}
def TestNVVMRequiresSMArchCondOp :
- TEST_Op<"nvvm_requires_sm_90a", [NVVMRequiresSMAA<[90]>]> {
+ TEST_Op<"nvvm_requires_sm_90a", [NVVMRequiresSMa<[90]>]> {
let arguments = (ins );
let assemblyFormat = "attr-dict";
}
def TestNVVMRequirestSMArchCondMultiOp :
- TEST_Op<"nvvm_requires_sm_90a_or_sm_100a", [NVVMRequiresSMAA<[90, 100]>]> {
+ TEST_Op<"nvvm_requires_sm_90a_or_sm_100a", [NVVMRequiresSMa<[90, 100]>]> {
let arguments = (ins );
let assemblyFormat = "attr-dict";
}
def TestNVVMRequiresSMFamilyCondOp :
- TEST_Op<"nvvm_requires_sm_100f", [NVVMRequiresSMFS<[100]>]> {
+ TEST_Op<"nvvm_requires_sm_100f", [NVVMRequiresSMf<[100]>]> {
let arguments = (ins );
let assemblyFormat = "attr-dict";
}
def TestNVVMRequiresSMFamilyCondMultiOp :
- TEST_Op<"nvvm_requires_sm_100f_or_sm_120f", [NVVMRequiresSMFS<[100, 120]>]> {
+ TEST_Op<"nvvm_requires_sm_100f_or_sm_120f", [NVVMRequiresSMf<[100, 120]>]> {
let arguments = (ins );
let assemblyFormat = "attr-dict";
}
def TestNVVMRequiresSMArchOrFamilyCondOp :
- TEST_Op<"nvvm_requires_sm_90a_or_sm_100f", [NVVMRequiresSMAAOrFS<[90], [100]>]> {
+ TEST_Op<"nvvm_requires_sm_90a_or_sm_100f", [NVVMRequiresSMaOrf<[90], [100]>]> {
let arguments = (ins );
let assemblyFormat = "attr-dict";
}
>From 3b295e4dfa70f8e8851d1dc555a29e0810bcba3c Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Wed, 25 Mar 2026 10:42:52 +0000
Subject: [PATCH 07/10] rename to NVVMRequiresSMaOrSMf
---
mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td | 2 +-
mlir/test/lib/Dialect/Test/TestOps.td | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
index 0e06085803b8d..e559eaa7e89dc 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
@@ -45,7 +45,7 @@ class NVVMRequiresSMf<list<int> smVersions> :
!add(!mul(smVersion, 10), 2)), ",")>;
// Op supported on some combination of architecture acceleration and family-specific SM versions.
-class NVVMRequiresSMaOrf<list<int> smVersionsA, list<int> smVersionsF> :
+class NVVMRequiresSMaOrSMf<list<int> smVersionsA, list<int> smVersionsF> :
ParamNativeOpTrait<"NVVMRequiresSM",
!interleave(!foreach(smVersion, smVersionsA,
!add(!mul(smVersion, 10), 3)), ",") # "," #
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 49fedc643c62a..abd325b5c3a80 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3108,7 +3108,7 @@ def TestNVVMRequiresSMFamilyCondMultiOp :
}
def TestNVVMRequiresSMArchOrFamilyCondOp :
- TEST_Op<"nvvm_requires_sm_90a_or_sm_100f", [NVVMRequiresSMaOrf<[90], [100]>]> {
+ TEST_Op<"nvvm_requires_sm_90a_or_sm_100f", [NVVMRequiresSMaOrSMf<[90], [100]>]> {
let arguments = (ins );
let assemblyFormat = "attr-dict";
}
>From e783902612fd5e3d4a9f0ef918a04783269708f5 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Wed, 25 Mar 2026 14:30:25 +0000
Subject: [PATCH 08/10] uncomment test cases
---
.../Dialect/LLVMIR/nvvm-check-targetSM.mlir | 72 +++++++++----------
1 file changed, 36 insertions(+), 36 deletions(-)
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
index 43522a6ae14c1..7876f57a4d9ce 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
@@ -1,53 +1,53 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
// Just check these don't emit errors.
-// gpu.module @check_valid_SM_exact [#nvvm.target<chip = "sm_80">] {
-// test.nvvm_requires_sm_80
-// }
+gpu.module @check_valid_SM_exact [#nvvm.target<chip = "sm_80">] {
+ test.nvvm_requires_sm_80
+}
-// gpu.module @check_valid_SM_greater_1 [#nvvm.target<chip = "sm_86">] {
-// test.nvvm_requires_sm_80
-// }
+gpu.module @check_valid_SM_greater_1 [#nvvm.target<chip = "sm_86">] {
+ test.nvvm_requires_sm_80
+}
-// gpu.module @check_valid_SM_greater_2 [#nvvm.target<chip = "sm_90">] {
-// test.nvvm_requires_sm_80
-// }
+gpu.module @check_valid_SM_greater_2 [#nvvm.target<chip = "sm_90">] {
+ test.nvvm_requires_sm_80
+}
-// gpu.module @check_valid_SM_arch_acc_1 [#nvvm.target<chip = "sm_90a">] {
-// test.nvvm_requires_sm_90a
-// }
+gpu.module @check_valid_SM_arch_acc_1 [#nvvm.target<chip = "sm_90a">] {
+ test.nvvm_requires_sm_90a
+}
-// gpu.module @check_valid_SM_arch_acc_2 [#nvvm.target<chip = "sm_90a">] {
-// test.nvvm_requires_sm_80
-// }
+gpu.module @check_valid_SM_arch_acc_2 [#nvvm.target<chip = "sm_90a">] {
+ test.nvvm_requires_sm_80
+}
-// gpu.module @check_valid_SM_arch_acc_multi_1 [#nvvm.target<chip = "sm_90a">] {
-// test.nvvm_requires_sm_90a_or_sm_100a
-// }
+gpu.module @check_valid_SM_arch_acc_multi_1 [#nvvm.target<chip = "sm_90a">] {
+ test.nvvm_requires_sm_90a_or_sm_100a
+}
-// gpu.module @check_valid_SM_arch_acc_multi_2 [#nvvm.target<chip = "sm_100a">] {
-// test.nvvm_requires_sm_90a_or_sm_100a
-// }
+gpu.module @check_valid_SM_arch_acc_multi_2 [#nvvm.target<chip = "sm_100a">] {
+ test.nvvm_requires_sm_90a_or_sm_100a
+}
-// gpu.module @check_valid_SM_family_1 [#nvvm.target<chip = "sm_100f">] {
-// test.nvvm_requires_sm_100f
-// }
+gpu.module @check_valid_SM_family_1 [#nvvm.target<chip = "sm_100f">] {
+ test.nvvm_requires_sm_100f
+}
-// gpu.module @check_valid_SM_family_2 [#nvvm.target<chip = "sm_100a">] {
-// test.nvvm_requires_sm_100f
-// }
+gpu.module @check_valid_SM_family_2 [#nvvm.target<chip = "sm_100a">] {
+ test.nvvm_requires_sm_100f
+}
-// gpu.module @check_valid_SM_family_3 [#nvvm.target<chip = "sm_103a">] {
-// test.nvvm_requires_sm_100f
-// }
+gpu.module @check_valid_SM_family_3 [#nvvm.target<chip = "sm_103a">] {
+ test.nvvm_requires_sm_100f
+}
-// gpu.module @check_valid_SM_family_multi_1 [#nvvm.target<chip = "sm_100f">] {
-// test.nvvm_requires_sm_100f_or_sm_120f
-// }
+gpu.module @check_valid_SM_family_multi_1 [#nvvm.target<chip = "sm_100f">] {
+ test.nvvm_requires_sm_100f_or_sm_120f
+}
-// gpu.module @check_valid_SM_family_multi_2 [#nvvm.target<chip = "sm_120f">] {
-// test.nvvm_requires_sm_100f_or_sm_120f
-// }
+gpu.module @check_valid_SM_family_multi_2 [#nvvm.target<chip = "sm_120f">] {
+ test.nvvm_requires_sm_100f_or_sm_120f
+}
gpu.module @check_valid_SM_arch_or_family_1 [#nvvm.target<chip = "sm_90a">] {
test.nvvm_requires_sm_90a_or_sm_100f
>From c4d6c209d21d3bc3afb43e9411f1bc9801260f02 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 31 Mar 2026 06:37:02 +0000
Subject: [PATCH 09/10] address comments
---
.../mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h | 14 ++++++++++----
.../mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td | 4 ++--
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 2 +-
mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir | 11 +++++++++++
mlir/test/lib/Dialect/Test/TestOps.td | 6 ++++++
5 files changed, 30 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
index 82f41d88c1037..09a641b8d3220 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
@@ -46,15 +46,21 @@ struct NVVMCheckSMVersion {
if (hasFamilySpecificFeatures(requiredFullSmVersion))
return hasFamilySpecificFeatures(targetFullSmVersion) &&
- ((targetFullSmVersion / 100) ==
- (requiredFullSmVersion / 100));
+ (getSMFamily(targetFullSmVersion) ==
+ getSMFamily(requiredFullSmVersion)) &&
+ (getSMVersion(targetFullSmVersion) >=
+ getSMVersion(requiredFullSmVersion));
return targetFullSmVersion >= requiredFullSmVersion;
});
}
- static bool isMinimumSMVersion(unsigned targetFullSmVersion) {
- return targetFullSmVersion >= 200;
+ static unsigned getSMVersion(unsigned fullSmVersion) {
+ return fullSmVersion / 10;
+ }
+
+ static unsigned getSMFamily(unsigned fullSmVersion) {
+ return fullSmVersion / 100;
}
// Parses an SM version string and returns an equivalent full SM version
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
index e559eaa7e89dc..2657f482d0212 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
@@ -48,7 +48,7 @@ class NVVMRequiresSMf<list<int> smVersions> :
class NVVMRequiresSMaOrSMf<list<int> smVersionsA, list<int> smVersionsF> :
ParamNativeOpTrait<"NVVMRequiresSM",
!interleave(!foreach(smVersion, smVersionsA,
- !add(!mul(smVersion, 10), 3)), ",") # "," #
- !interleave(!foreach(smVersion, smVersionsF,
+ !add(!mul(smVersion, 10), 3)) #
+ !foreach(smVersion, smVersionsF,
!add(!mul(smVersion, 10), 2)), ",")>;
#endif //NVVM_REQUIRES_SM_TRAITS
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index bf3931075e08b..669f1e8d62f25 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -5824,7 +5824,7 @@ LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
const unsigned targetFullSmVersion =
NVVMCheckSMVersion::getTargetFullSmVersionFromStr(getChip());
- if (!NVVMCheckSMVersion::isMinimumSMVersion(targetFullSmVersion)) {
+ if (NVVMCheckSMVersion::getSMVersion(targetFullSmVersion) < 20) {
return emitError(gpuModule->getLoc(),
"Minimum NVVM target SM version is sm_20");
}
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
index 7876f57a4d9ce..686b9671fb4ee 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
@@ -41,6 +41,10 @@ gpu.module @check_valid_SM_family_3 [#nvvm.target<chip = "sm_103a">] {
test.nvvm_requires_sm_100f
}
+gpu.module @check_valid_SM_family_4[#nvvm.target<chip = "sm_103f">] {
+ test.nvvm_requires_sm_100f
+}
+
gpu.module @check_valid_SM_family_multi_1 [#nvvm.target<chip = "sm_100f">] {
test.nvvm_requires_sm_100f_or_sm_120f
}
@@ -136,6 +140,13 @@ gpu.module @check_invalid_SM_family [#nvvm.target<chip = "sm_110a">] {
// -----
+gpu.module @check_invalid_SM_family_higher [#nvvm.target<chip = "sm_100f">] {
+ // expected-error @below {{is not supported on sm_100f}}
+ test.nvvm_requires_sm_103f
+}
+
+// -----
+
gpu.module @check_invalid_SM_family_multi [#nvvm.target<chip = "sm_110a">] {
// expected-error @below {{is not supported on sm_110a}}
test.nvvm_requires_sm_100f_or_sm_120f
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index abd325b5c3a80..7400eb7ff33f7 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3101,6 +3101,12 @@ def TestNVVMRequiresSMFamilyCondOp :
let assemblyFormat = "attr-dict";
}
+def TestNVVMRequiresSMFamilyCondHigherOp :
+ TEST_Op<"nvvm_requires_sm_103f", [NVVMRequiresSMf<[103]>]> {
+ let arguments = (ins );
+ let assemblyFormat = "attr-dict";
+}
+
def TestNVVMRequiresSMFamilyCondMultiOp :
TEST_Op<"nvvm_requires_sm_100f_or_sm_120f", [NVVMRequiresSMf<[100, 120]>]> {
let arguments = (ins );
>From da83ff2bbc5df93f5e1c1fc34f799cac3149154e Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Wed, 1 Apr 2026 14:55:05 +0000
Subject: [PATCH 10/10] address comments
---
.../Dialect/LLVMIR/NVVMRequiresSMTraits.h | 30 +++++++++++--------
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 2 +-
2 files changed, 19 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
index 09a641b8d3220..e67cf999a06e7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
@@ -42,7 +42,9 @@ struct NVVMCheckSMVersion {
return llvm::any_of(
fullSmVersionList, [&](const unsigned &requiredFullSmVersion) {
if (hasArchAcceleratedFeatures(requiredFullSmVersion))
- return targetFullSmVersion == requiredFullSmVersion;
+ return hasArchAcceleratedFeatures(targetFullSmVersion) &&
+ (getSMVersion(targetFullSmVersion) ==
+ getSMVersion(requiredFullSmVersion));
if (hasFamilySpecificFeatures(requiredFullSmVersion))
return hasFamilySpecificFeatures(targetFullSmVersion) &&
@@ -55,14 +57,6 @@ struct NVVMCheckSMVersion {
});
}
- static unsigned getSMVersion(unsigned fullSmVersion) {
- return fullSmVersion / 10;
- }
-
- static unsigned getSMFamily(unsigned fullSmVersion) {
- return fullSmVersion / 100;
- }
-
// Parses an SM version string and returns an equivalent full SM version
// integer.
static unsigned getTargetFullSmVersionFromStr(StringRef smVersionString) {
@@ -76,15 +70,27 @@ struct NVVMCheckSMVersion {
return smVersion * 10 + (isAA ? 3 : 0) + (isFS ? 2 : 0);
}
+
+ static bool isMinimumSMVersion(unsigned fullSmVersion) {
+ return getSMVersion(fullSmVersion) >= 20;
+ }
private:
- bool hasFamilySpecificFeatures(unsigned fullSmVersion) const {
- return (fullSmVersion % 10) != 0;
+ static bool hasFamilySpecificFeatures(unsigned fullSmVersion) {
+ return (fullSmVersion % 10) >= 2;
}
- bool hasArchAcceleratedFeatures(unsigned fullSmVersion) const {
+ static bool hasArchAcceleratedFeatures(unsigned fullSmVersion) {
return (fullSmVersion % 10) == 3;
}
+
+ static unsigned getSMVersion(unsigned fullSmVersion) {
+ return fullSmVersion / 10;
+ }
+
+ static unsigned getSMFamily(unsigned fullSmVersion) {
+ return fullSmVersion / 100;
+ }
};
} // namespace NVVM
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 669f1e8d62f25..bf3931075e08b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -5824,7 +5824,7 @@ LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
const unsigned targetFullSmVersion =
NVVMCheckSMVersion::getTargetFullSmVersionFromStr(getChip());
- if (NVVMCheckSMVersion::getSMVersion(targetFullSmVersion) < 20) {
+ if (!NVVMCheckSMVersion::isMinimumSMVersion(targetFullSmVersion)) {
return emitError(gpuModule->getLoc(),
"Minimum NVVM target SM version is sm_20");
}
More information about the Mlir-commits
mailing list