[Mlir-commits] [mlir] [MLIR][NVVM] Add family-conditional support to NVVMRequiresSM traits (PR #185909)

Srinivasa Ravi llvmlistbot at llvm.org
Wed Mar 18 01:07:08 PDT 2026


https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/185909

>From f3c48e3cdebb4530ec3b403de09cc1eaac06f741 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 10 Mar 2026 10:14:39 +0000
Subject: [PATCH 1/5] [MLIR][NVVM] Add family-conditional support to
 NVVMRequiresSM traits

This change adds support for family-conditional SM version requirements
to the `NVVMRequiresSM` traits. The following new traits are added:
- `NVVMRequiresSMf` - Op requires an SM version belonging to one of the
  given families.
- `NVVMRequiresSMaOrf` - Op requires one of the supported arch-accelerated
  versions or an SM version belonging to one of the given families.
---
 .../Dialect/LLVMIR/NVVMRequiresSMTraits.h     | 117 +++++++++++++-----
 .../Dialect/LLVMIR/NVVMRequiresSMTraits.td    |  20 ++-
 .../Dialect/LLVMIR/nvvm-check-targetSM.mlir   |  64 ++++++++++
 mlir/test/lib/Dialect/Test/TestOps.td         |  18 +++
 4 files changed, 187 insertions(+), 32 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
index 36fcaee8ec3a2..87c488eda7b62 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
@@ -23,55 +23,79 @@ namespace NVVM {
 
 // Struct to store and check compatibility of SM versions.
 struct NVVMCheckSMVersion {
-  // Set to true if the SM version is accelerated (e.g., sm_90a).
-  bool archAccelerated;
+  struct SMVersion {
+    unsigned version;
+    // Set to true if the SM version is accelerated (e.g., sm_90a).
+    bool archAccelerated;
+    // Set to true if the SM version is family-specific (e.g., sm_100f).
+    bool familySpecific;
+
+    unsigned getSmFamilyVersion() const {
+      return version / 10;
+    }
+    
+    bool hasFamilySpecificFeatures() const {
+      return familySpecific || archAccelerated;
+    }
+  };
 
   // List of SM versions.
   // Typically only has one version except for cases where multiple
   // arch-accelerated versions are supported.
   // For example, tcgen05.shift is supported on sm_100a, sm_101a, and sm_103a.
-  llvm::SmallVector<int, 1> smVersionList;
-
-  template <typename... Ints>
-  NVVMCheckSMVersion(bool archAccelerated, Ints... smVersions)
-      : archAccelerated(archAccelerated), smVersionList({smVersions...}) {
-    assert((archAccelerated || smVersionList.size() == 1) &&
-           "non arch-accelerated SM version list must be a single version!");
+  llvm::SmallVector<SMVersion, 1> smVersionList;
+
+  template <typename... Versions>
+  NVVMCheckSMVersion(bool archAccelerated, bool familySpecific,
+                     Versions... smVersions)
+      : smVersionList({SMVersion{static_cast<unsigned>(smVersions),
+                                 archAccelerated, familySpecific}...}) {
+    assert(
+        !(archAccelerated && familySpecific) &&
+        "archAccelerated and familySpecific cannot be true at the same time!");
   }
 
   bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const {
     assert(targetSM.smVersionList.size() == 1 &&
            "target SM version list must be a single version!");
 
-    if (archAccelerated) {
-      if (!targetSM.archAccelerated)
-        return false;
-
-      for (auto version : smVersionList) {
-        if (version == targetSM.smVersionList[0])
-          return true;
+    SMVersion targetSMVersion = targetSM.smVersionList[0];
+
+    return llvm::any_of(smVersionList, [&](const SMVersion &RequiredSMVersion) {
+      if (RequiredSMVersion.archAccelerated) {
+        return targetSMVersion.archAccelerated &&
+               (RequiredSMVersion.version == targetSMVersion.version);
+      } else if (RequiredSMVersion.familySpecific) {
+        return targetSMVersion.hasFamilySpecificFeatures() &&
+               (RequiredSMVersion.getSmFamilyVersion() ==
+                targetSMVersion.getSmFamilyVersion()) &&
+               (targetSMVersion.version >= RequiredSMVersion.version);
+      } else {
+        return targetSMVersion.version >= RequiredSMVersion.version;
       }
-    } else {
-      return targetSM.smVersionList[0] >= smVersionList[0];
-    }
-
-    return false;
+    });
   }
 
-  bool isMinimumSMVersion() const { return smVersionList[0] >= 20; }
+  bool isMinimumSMVersion() const { return smVersionList[0].version >= 20; }
 
   // Parses an SM version string and returns an equivalent NVVMCheckSMVersion
   // object.
-  static const NVVMCheckSMVersion
+  static NVVMCheckSMVersion
   getTargetSMVersionFromStr(StringRef smVersionString) {
     bool isAA = smVersionString.back() == 'a';
+    bool isFS = smVersionString.back() == 'f';
 
     int smVersionInt;
     smVersionString.drop_front(3)
         .take_while([](char c) { return llvm::isDigit(c); })
         .getAsInteger(10, smVersionInt);
 
-    return NVVMCheckSMVersion(isAA, smVersionInt);
+    return NVVMCheckSMVersion(isAA, isFS, smVersionInt);
+  }
+  
+  NVVMCheckSMVersion &append(const NVVMCheckSMVersion &other) {
+    smVersionList.append(other.smVersionList);
+    return *this;
   }
 };
 
@@ -92,8 +116,8 @@ class NVVMRequiresSM {
       : public OpTrait::TraitBase<ConcreteOp, NVVMRequiresSM<MinVersion>::Impl>,
         public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
   public:
-    const NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
-      return NVVM::NVVMCheckSMVersion(false, MinVersion);
+    NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
+      return NVVM::NVVMCheckSMVersion(false, false, MinVersion);
     }
   };
 };
@@ -106,12 +130,49 @@ class NVVMRequiresSMa {
                                          NVVMRequiresSMa<SMVersions...>::Impl>,
                public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
   public:
-    const NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
-      return NVVM::NVVMCheckSMVersion(true, SMVersions...);
+    NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
+      return NVVM::NVVMCheckSMVersion(true, false, SMVersions...);
+    }
+  };
+
+  static NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() {
+    return NVVM::NVVMCheckSMVersion(true, false, SMVersions...);
+  }
+};
+
+template <int... SMVersions>
+class NVVMRequiresSMf {
+public:
+  template <typename ConcreteOp>
+  class Impl : public OpTrait::TraitBase<ConcreteOp,
+                                         NVVMRequiresSMf<SMVersions...>::Impl>,
+               public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
+  public:
+    NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
+      return NVVM::NVVMCheckSMVersion(false, true, SMVersions...);
     }
   };
+
+  static NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() {
+    return NVVM::NVVMCheckSMVersion(false, true, SMVersions...);
+  }
 };
 
+template <typename T, typename U>
+class NVVMRequiresSMaOrf {
+public:
+  template <typename ConcreteOp>
+  class Impl
+      : public OpTrait::TraitBase<ConcreteOp, NVVMRequiresSMaOrf<T, U>::Impl>,
+        public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
+  public:
+    NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
+      auto result = T::getRequiredMinSMVersion();
+      result.append(U::getRequiredMinSMVersion());
+      return result;
+    }
+  };
+};
 } // namespace OpTrait
 } // namespace mlir
 #endif // NVVM_DIALECT_NVVM_IR_NVVMREQUIRESSMTRAITS_H_
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
index 34c0d6b78d5b2..8b35cc48bcfdb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
@@ -22,12 +22,12 @@ def RequiresSMInterface: OpInterface<"RequiresSMInterface"> {
   let methods = [
     InterfaceMethod<
       "Get the SM version required by the op from the trait", 
-      "const mlir::NVVM::NVVMCheckSMVersion", "getRequiredMinSMVersion"
+      "mlir::NVVM::NVVMCheckSMVersion", "getRequiredMinSMVersion"
     >
   ];
 }
 
-// OP requires a specified minimum SM value or higher; 
+// Op requires a specified minimum SM value or higher; 
 // it is not architecture-specific.
 class NVVMRequiresSM<int minVersion> :
   ParamNativeOpTrait<"NVVMRequiresSM", !cast<string>(minVersion)>;
@@ -37,11 +37,23 @@ class StrJoin<string sep, list<string> str_list> {
                !if(!eq(a, ""), b, !if(!eq(b, ""), a, !strconcat(a, sep, b))));
 }
 
-// OP requires an exact SM match along with 
-// architecture acceleration.
+// Op requires an exact SM match along with architecture acceleration.
 class NVVMRequiresSMa<list<int> smVersions> :
   ParamNativeOpTrait<"NVVMRequiresSMa",
                       StrJoin<",", !foreach(vers, smVersions,
                                             !cast<string>(vers))>.ret>;
 
+// Op requires an SM version belonging to the family.
+class NVVMRequiresSMf<list<int> smVersions> :
+  ParamNativeOpTrait<"NVVMRequiresSMf",
+                      StrJoin<",", !foreach(vers, smVersions,
+                                            !cast<string>(vers))>.ret>;
+
+// Op supported on some combination of architecture acceleration and family-specific SM versions.
+class NVVMRequiresSMaOrf<list<int> smVersionsA, list<int> smVersionsF> :
+  ParamNativeOpTrait<"NVVMRequiresSMaOrf",
+    "mlir::OpTrait::NVVMRequiresSMa<" # StrJoin<",",
+      !foreach(vers, smVersionsA, !cast<string>(vers))>.ret # ">" # "," #
+    "mlir::OpTrait::NVVMRequiresSMf<" # StrJoin<",",
+      !foreach(vers, smVersionsF, !cast<string>(vers))>.ret # ">">;
 #endif //NVVM_REQUIRES_SM_TRAITS
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
index e469d336dc1ae..7876f57a4d9ce 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
@@ -29,6 +29,37 @@ gpu.module @check_valid_SM_arch_acc_multi_2 [#nvvm.target<chip = "sm_100a">] {
   test.nvvm_requires_sm_90a_or_sm_100a
 }
 
+gpu.module @check_valid_SM_family_1 [#nvvm.target<chip = "sm_100f">] {
+  test.nvvm_requires_sm_100f
+}
+
+gpu.module @check_valid_SM_family_2 [#nvvm.target<chip = "sm_100a">] {
+  test.nvvm_requires_sm_100f
+}
+
+gpu.module @check_valid_SM_family_3 [#nvvm.target<chip = "sm_103a">] {
+  test.nvvm_requires_sm_100f
+}
+
+gpu.module @check_valid_SM_family_multi_1 [#nvvm.target<chip = "sm_100f">] {
+  test.nvvm_requires_sm_100f_or_sm_120f
+}
+
+gpu.module @check_valid_SM_family_multi_2 [#nvvm.target<chip = "sm_120f">] {
+  test.nvvm_requires_sm_100f_or_sm_120f
+}
+
+gpu.module @check_valid_SM_arch_or_family_1 [#nvvm.target<chip = "sm_90a">] {
+  test.nvvm_requires_sm_90a_or_sm_100f
+}
+
+gpu.module @check_valid_SM_arch_or_family_2 [#nvvm.target<chip = "sm_100f">] {
+  test.nvvm_requires_sm_90a_or_sm_100f
+}
+
+gpu.module @check_valid_SM_arch_or_family_3 [#nvvm.target<chip = "sm_103a">] {
+  test.nvvm_requires_sm_90a_or_sm_100f
+}
 
 gpu.module @disable_verify_target1 [#nvvm.target<chip = "sm_90", verifyTarget = false>] {
   test.nvvm_requires_sm_90a
@@ -42,6 +73,18 @@ gpu.module @disable_verify_target3 [#nvvm.target<chip = "sm_90", verifyTarget =
   test.nvvm_requires_sm_90a_or_sm_100a
 }
 
+gpu.module @disable_verify_target4 [#nvvm.target<chip = "sm_120f", verifyTarget = false>] {
+  test.nvvm_requires_sm_100f
+}
+
+gpu.module @disable_verify_target5 [#nvvm.target<chip = "sm_100", verifyTarget = false>] {
+  test.nvvm_requires_sm_100f_or_sm_120f
+}
+
+gpu.module @disable_verify_target6 [#nvvm.target<chip = "sm_90", verifyTarget = false>] {
+  test.nvvm_requires_sm_90a_or_sm_100f
+}
+
 // -----
 
 gpu.module @check_invalid_SM_lesser_1 [#nvvm.target<chip = "sm_70">] {
@@ -83,3 +126,24 @@ gpu.module @check_invalid_SM_arch_acc_multi_2 [#nvvm.target<chip = "sm_90">] {
   // expected-error @below {{is not supported on sm_90}}
   test.nvvm_requires_sm_90a_or_sm_100a
 }
+
+// -----
+
+gpu.module @check_invalid_SM_family [#nvvm.target<chip = "sm_110a">] {
+  // expected-error @below {{is not supported on sm_110a}}
+  test.nvvm_requires_sm_100f
+}
+
+// -----
+
+gpu.module @check_invalid_SM_family_multi [#nvvm.target<chip = "sm_110a">] {
+  // expected-error @below {{is not supported on sm_110a}}
+  test.nvvm_requires_sm_100f_or_sm_120f
+}
+
+// -----
+
+gpu.module @check_invalid_SM_arch_or_family [#nvvm.target<chip = "sm_100">] {
+  // expected-error @below {{is not supported on sm_100}}
+  test.nvvm_requires_sm_90a_or_sm_100f
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 02bac016eeed1..49fedc643c62a 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3095,6 +3095,24 @@ def TestNVVMRequirestSMArchCondMultiOp :
   let assemblyFormat = "attr-dict";
 }
 
+def TestNVVMRequiresSMFamilyCondOp :
+    TEST_Op<"nvvm_requires_sm_100f", [NVVMRequiresSMf<[100]>]> {
+  let arguments = (ins );
+  let assemblyFormat = "attr-dict";
+}
+
+def TestNVVMRequiresSMFamilyCondMultiOp :
+    TEST_Op<"nvvm_requires_sm_100f_or_sm_120f", [NVVMRequiresSMf<[100, 120]>]> {
+  let arguments = (ins );
+  let assemblyFormat = "attr-dict";
+}
+
+def TestNVVMRequiresSMArchOrFamilyCondOp :
+    TEST_Op<"nvvm_requires_sm_90a_or_sm_100f", [NVVMRequiresSMaOrf<[90], [100]>]> {
+  let arguments = (ins );
+  let assemblyFormat = "attr-dict";
+}
+
 //===----------------------------------------------------------------------===//
 // Test Ops with Default-Valued String Attributes
 //===----------------------------------------------------------------------===//

>From 3cf3e1513808fe6439c0241abb84d3d67c7554cc Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Thu, 12 Mar 2026 08:03:24 +0000
Subject: [PATCH 2/5] fix formatting

---
 mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h | 8 +++-----
 1 file changed, 3 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
index 87c488eda7b62..dbe4dfd17e84d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
@@ -30,10 +30,8 @@ struct NVVMCheckSMVersion {
     // Set to true if the SM version is family-specific (e.g., sm_100f).
     bool familySpecific;
 
-    unsigned getSmFamilyVersion() const {
-      return version / 10;
-    }
-    
+    unsigned getSmFamilyVersion() const { return version / 10; }
+
     bool hasFamilySpecificFeatures() const {
       return familySpecific || archAccelerated;
     }
@@ -92,7 +90,7 @@ struct NVVMCheckSMVersion {
 
     return NVVMCheckSMVersion(isAA, isFS, smVersionInt);
   }
-  
+
   NVVMCheckSMVersion &append(const NVVMCheckSMVersion &other) {
     smVersionList.append(other.smVersionList);
     return *this;

>From b0dc8e4a9b003560409395e7bf64a9818dc2225c Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Thu, 12 Mar 2026 12:04:58 +0000
Subject: [PATCH 3/5] replace StrJoin with interleave

---
 .../Dialect/LLVMIR/NVVMRequiresSMTraits.td    | 19 ++++---------------
 1 file changed, 4 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
index 8b35cc48bcfdb..8c3d34fe033e9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
@@ -32,28 +32,17 @@ def RequiresSMInterface: OpInterface<"RequiresSMInterface"> {
 class NVVMRequiresSM<int minVersion> :
   ParamNativeOpTrait<"NVVMRequiresSM", !cast<string>(minVersion)>;
 
-class StrJoin<string sep, list<string> str_list> {
-  string ret = !foldl("", str_list, a, b,
-               !if(!eq(a, ""), b, !if(!eq(b, ""), a, !strconcat(a, sep, b))));
-}
-
 // Op requires an exact SM match along with architecture acceleration.
 class NVVMRequiresSMa<list<int> smVersions> :
-  ParamNativeOpTrait<"NVVMRequiresSMa",
-                      StrJoin<",", !foreach(vers, smVersions,
-                                            !cast<string>(vers))>.ret>;
+  ParamNativeOpTrait<"NVVMRequiresSMa", !interleave(smVersions, ",")>;
 
 // Op requires an SM version belonging to the family.
 class NVVMRequiresSMf<list<int> smVersions> :
-  ParamNativeOpTrait<"NVVMRequiresSMf",
-                      StrJoin<",", !foreach(vers, smVersions,
-                                            !cast<string>(vers))>.ret>;
+  ParamNativeOpTrait<"NVVMRequiresSMf", !interleave(smVersions, ",")>;
 
 // Op supported on some combination of architecture acceleration and family-specific SM versions.
 class NVVMRequiresSMaOrf<list<int> smVersionsA, list<int> smVersionsF> :
   ParamNativeOpTrait<"NVVMRequiresSMaOrf",
-    "mlir::OpTrait::NVVMRequiresSMa<" # StrJoin<",",
-      !foreach(vers, smVersionsA, !cast<string>(vers))>.ret # ">" # "," #
-    "mlir::OpTrait::NVVMRequiresSMf<" # StrJoin<",",
-      !foreach(vers, smVersionsF, !cast<string>(vers))>.ret # ">">;
+    "mlir::OpTrait::NVVMRequiresSMa<" # !interleave(smVersionsA, ",") # ">" # "," #
+    "mlir::OpTrait::NVVMRequiresSMf<" # !interleave(smVersionsF, ",") # ">">;
 #endif //NVVM_REQUIRES_SM_TRAITS

>From 3020c5af597710d06a2da3d9f0b475d098deb6bb Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Mon, 16 Mar 2026 06:14:17 +0000
Subject: [PATCH 4/5] address comments

---
 .../Dialect/LLVMIR/NVVMRequiresSMTraits.h     | 31 ++++++++++---------
 1 file changed, 16 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
index dbe4dfd17e84d..f6718ad0e29f6 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
@@ -24,8 +24,9 @@ namespace NVVM {
 // Struct to store and check compatibility of SM versions.
 struct NVVMCheckSMVersion {
   struct SMVersion {
+    // Base SM version (e.g., sm_100)
     unsigned version;
-    // Set to true if the SM version is accelerated (e.g., sm_90a).
+    // Set to true if the SM version is accelerated (e.g., sm_100a).
     bool archAccelerated;
     // Set to true if the SM version is family-specific (e.g., sm_100f).
     bool familySpecific;
@@ -39,9 +40,9 @@ struct NVVMCheckSMVersion {
 
   // List of SM versions.
   // Typically only has one version except for cases where multiple
-  // arch-accelerated versions are supported.
+  // arch-accelerated or family-conditional versions are supported.
   // For example, tcgen05.shift is supported on sm_100a, sm_101a, and sm_103a.
-  llvm::SmallVector<SMVersion, 1> smVersionList;
+  llvm::SmallVector<SMVersion> smVersionList;
 
   template <typename... Versions>
   NVVMCheckSMVersion(bool archAccelerated, bool familySpecific,
@@ -55,22 +56,22 @@ struct NVVMCheckSMVersion {
 
   bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const {
     assert(targetSM.smVersionList.size() == 1 &&
-           "target SM version list must be a single version!");
+           "target SM version list must have a single version!");
 
     SMVersion targetSMVersion = targetSM.smVersionList[0];
 
     return llvm::any_of(smVersionList, [&](const SMVersion &RequiredSMVersion) {
-      if (RequiredSMVersion.archAccelerated) {
+      if (RequiredSMVersion.archAccelerated)
         return targetSMVersion.archAccelerated &&
                (RequiredSMVersion.version == targetSMVersion.version);
-      } else if (RequiredSMVersion.familySpecific) {
+
+      if (RequiredSMVersion.familySpecific)
         return targetSMVersion.hasFamilySpecificFeatures() &&
                (RequiredSMVersion.getSmFamilyVersion() ==
                 targetSMVersion.getSmFamilyVersion()) &&
                (targetSMVersion.version >= RequiredSMVersion.version);
-      } else {
-        return targetSMVersion.version >= RequiredSMVersion.version;
-      }
+
+      return targetSMVersion.version >= RequiredSMVersion.version;
     });
   }
 
@@ -91,9 +92,8 @@ struct NVVMCheckSMVersion {
     return NVVMCheckSMVersion(isAA, isFS, smVersionInt);
   }
 
-  NVVMCheckSMVersion &append(const NVVMCheckSMVersion &other) {
+  void append(const NVVMCheckSMVersion &other) {
     smVersionList.append(other.smVersionList);
-    return *this;
   }
 };
 
@@ -156,17 +156,18 @@ class NVVMRequiresSMf {
   }
 };
 
-template <typename T, typename U>
+template <typename SMVersionsA, typename SMVersionsF>
 class NVVMRequiresSMaOrf {
 public:
   template <typename ConcreteOp>
   class Impl
-      : public OpTrait::TraitBase<ConcreteOp, NVVMRequiresSMaOrf<T, U>::Impl>,
+      : public OpTrait::TraitBase<
+            ConcreteOp, NVVMRequiresSMaOrf<SMVersionsA, SMVersionsF>::Impl>,
         public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
   public:
     NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
-      auto result = T::getRequiredMinSMVersion();
-      result.append(U::getRequiredMinSMVersion());
+      auto result = SMVersionsA::getRequiredMinSMVersion();
+      result.append(SMVersionsF::getRequiredMinSMVersion());
       return result;
     }
   };

>From 6535e13f5ce6e3aa2029a2934fae5157de163508 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Wed, 18 Mar 2026 08:06:30 +0000
Subject: [PATCH 5/5] address comments

---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   | 42 +++++++++----------
 .../Dialect/LLVMIR/NVVMRequiresSMTraits.h     | 24 +++++++----
 .../Dialect/LLVMIR/NVVMRequiresSMTraits.td    | 16 +++----
 mlir/test/lib/Dialect/Test/TestOps.td         | 10 ++---
 4 files changed, 49 insertions(+), 43 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index f8e1ab38e80d4..33442189ff76f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2247,7 +2247,7 @@ def NVVM_ConvertF32x2ToBF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"BF16x2", "conve
 // (E4M3x4, E5M2x4, E2M3x4, E3M2x4, E2M1x4)
 // These operations always use RS (stochastic rounding) mode with SATFINITE saturation.
 class NVVM_ConvertF32x4ToFPx4OpBase<string dstFormat, string mnemonic, Type dstType> :
-  NVVM_Op<mnemonic, [Pure, NVVMRequiresSMa<[100, 103]>]>,
+  NVVM_Op<mnemonic, [Pure, NVVMRequiresSMAA<[100, 103]>]>,
   Results<(outs dstType:$dst)>,
   Arguments<(ins VectorOfLengthAndType<[4], [F32]>:$src, I32:$rbits,
                  DefaultValuedAttr<BoolAttr, "false">:$relu,
@@ -4514,7 +4514,7 @@ def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
 // NVVM Wgmma Ops
 //===----------------------------------------------------------------------===//
 
-def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", [NVVMRequiresSMa<[90]>]> {
+def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", [NVVMRequiresSMAA<[90]>]> {
   let arguments = (ins);
   let description = [{
     Enforce an ordering of register accesses between warpgroup level matrix 
@@ -4528,7 +4528,7 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", [NVVMRequiresSMa<[
   }];
 }
 
-def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned", [NVVMRequiresSMa<[90]>]> {
+def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned", [NVVMRequiresSMAA<[90]>]> {
   let assemblyFormat = "attr-dict";
   let description = [{
     Commits all prior uncommitted warpgroup level matrix multiplication operations.
@@ -4540,7 +4540,7 @@ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned", [N
   }];
 }
 
-def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned", [NVVMRequiresSMa<[90]>]> {
+def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned", [NVVMRequiresSMAA<[90]>]> {
   let arguments = (ins I64Attr:$group);
   let assemblyFormat = "attr-dict $group";
   let description = [{
@@ -4902,7 +4902,7 @@ def Tcgen05WaitKindAttr :
   let assemblyFormat = "`<` $value `>`";
 }
 
-def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc", [NVVMRequiresSMa<[100, 101]>]> {
+def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc", [NVVMRequiresSMAA<[100, 101]>]> {
   let summary = "Tcgen05 alloc operation";
   let description = [{
     The `tcgen05.alloc` Op allocates tensor core memory for
@@ -4932,7 +4932,7 @@ def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc", [NVVMRequiresSMa<[100, 101]>]
   }];
 }
 
-def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc", [NVVMRequiresSMa<[100, 101]>]> {
+def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc", [NVVMRequiresSMAA<[100, 101]>]> {
   let summary = "Tcgen05 dealloc operation";
   let description = [{
     The `tcgen05.dealloc` Op de-allocates the tensor core memory
@@ -4960,7 +4960,7 @@ def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc", [NVVMRequiresSMa<[100, 10
   }];
 }
 
-def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_permit", [NVVMRequiresSMa<[100, 101]>]> {
+def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_permit", [NVVMRequiresSMAA<[100, 101]>]> {
   let summary = "Tcgen05 Op to relinquish the right to allocate";
   let description = [{
     The `tcgen05.relinquish_alloc_permit` Op specifies that the CTA
@@ -4983,7 +4983,7 @@ def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_perm
   }];
 }
 
-def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence", [NVVMRequiresSMa<[100, 101]>]> {
+def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence", [NVVMRequiresSMAA<[100, 101]>]> {
   let summary = "Tcgen05 fence operations";
   let description = [{
     The `tcgen05.fence<before>` orders all prior async tcgen05 operations
@@ -5005,7 +5005,7 @@ def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence", [NVVMRequiresSMa<[100, 101]>]
   }];
 }
 
-def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait", [NVVMRequiresSMa<[100, 101]>]> {
+def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait", [NVVMRequiresSMAA<[100, 101]>]> {
   let summary = "Tcgen05 wait operations";
   let description = [{
     The `tcgen05.wait<load>` causes the executing thread to block until
@@ -5027,7 +5027,7 @@ def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait", [NVVMRequiresSMa<[100, 101]>]>
   }];
 }
 
-def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit", [NVVMRequiresSMa<[100, 101]>]> {
+def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit", [NVVMRequiresSMAA<[100, 101]>]> {
   let summary = "Tcgen05 commit operations";
   let description = [{
     The `tcgen05.commit` makes the *mbarrier object*, specified by
@@ -5065,7 +5065,7 @@ def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit", [NVVMRequiresSMa<[100, 101]
   }];
 }
 
-def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift", [NVVMRequiresSMa<[100, 101, 103]>]> {
+def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift", [NVVMRequiresSMAA<[100, 101, 103]>]> {
   let summary = "Tcgen05 shift operation";
   let description = [{
     The `tcgen05.shift` is an asynchronous instruction which initiates
@@ -5131,7 +5131,7 @@ def Tcgen05CpSrcFormatAttr : EnumAttr<NVVM_Dialect, Tcgen05CpSrcFormat, "tcgen05
   let assemblyFormat = "`<` $value `>`";
 }
 
-def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp", [NVVMRequiresSMa<[100, 101]>]> {
+def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp", [NVVMRequiresSMAA<[100, 101]>]> {
   let summary = "Tcgen05 copy operation";
   let description = [{
     Instruction tcgen05.cp initiates an asynchronous copy operation from
@@ -5267,7 +5267,7 @@ def Tcgen05LdStShapeAttr: EnumAttr<NVVM_Dialect, Tcgen05LdStShape, "tcgen05_ldst
 // NVVM tcgen05.ld Op
 //===----------------------------------------------------------------------===//
 
-def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld", [NVVMRequiresSMa<[100, 101]>]> {
+def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld", [NVVMRequiresSMAA<[100, 101]>]> {
   let summary = "tensor memory load instructions";
   let arguments = (ins
     // Attributes
@@ -5360,7 +5360,7 @@ def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld", [NVVMRequiresSMa<[100, 101]>]> {
 //===----------------------------------------------------------------------===//
 
 def NVVM_Tcgen05LdRedOp : NVVM_Op<"tcgen05.ld.red",
-                          [NVVMRequiresSMa<[101, 110]>]> {
+                          [NVVMRequiresSMAA<[101, 110]>]> {
   let summary = "Tcgen05 tensor memory load and reduce instructions";
   let arguments = (ins
     Tcgen05LdStShapeAttr:$shape,
@@ -5449,7 +5449,7 @@ def NVVM_Tcgen05LdRedOp : NVVM_Op<"tcgen05.ld.red",
 // NVVM tcgen05.st Op
 //===----------------------------------------------------------------------===//
 
-def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSMa<[100, 101]>]> {
+def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSMAA<[100, 101]>]> {
   let summary = "tensor memory store instructions";
   let arguments = (ins
     // Attributes
@@ -5852,7 +5852,7 @@ defvar Tcgen05MMABlockScaleKindAttr =
 
 def NVVM_Tcgen05MMAOp : NVVM_Op<"tcgen05.mma",
                           [AttrSizedOperandSegments,
-                           NVVMRequiresSMa<[100, 110]>]> {
+                           NVVMRequiresSMAA<[100, 110]>]> {
   let summary = "Performs MMA operation on 5th-gen tensor cores";
 
   let description = [{
@@ -5936,7 +5936,7 @@ def NVVM_Tcgen05MMAOp : NVVM_Op<"tcgen05.mma",
 
 def NVVM_Tcgen05MMASparseOp : NVVM_Op<"tcgen05.mma.sp",
                                       [AttrSizedOperandSegments,
-                                       NVVMRequiresSMa<[100, 110]>]> {
+                                       NVVMRequiresSMAA<[100, 110]>]> {
   let summary = "Performs MMA operation with sparse A matrix on 5th-gen tensor cores";
 
   let description = [{
@@ -6017,7 +6017,7 @@ def Tcgen05MMABlockScaleAttr : EnumAttr<NVVM_Dialect, Tcgen05MMABlockScale,
 }
 
 def NVVM_Tcgen05MMABlockScaleOp : NVVM_Op<"tcgen05.mma.block_scale",
-                                          [NVVMRequiresSMa<[100, 110]>]> {
+                                          [NVVMRequiresSMAA<[100, 110]>]> {
   let summary = "Performs block scaled MMA operation on 5th-gen tensor cores";
 
   let description = [{
@@ -6090,7 +6090,7 @@ def NVVM_Tcgen05MMABlockScaleOp : NVVM_Op<"tcgen05.mma.block_scale",
 }
 
 def NVVM_Tcgen05MMASparseBlockScaleOp : NVVM_Op<"tcgen05.mma.sp.block_scale",
-                                                [NVVMRequiresSMa<[100, 110]>]> {
+                                                [NVVMRequiresSMAA<[100, 110]>]> {
   let summary = "Performs block scaled MMA operation with sparse A matrix on 5th-gen tensor cores";
 
   let description = [{
@@ -6172,7 +6172,7 @@ def Tcgen05MMACollectorBBufferAttr : EnumAttr<NVVM_Dialect, Tcgen05MMACollectorB
 }
 
 def NVVM_Tcgen05MMAWsOp : NVVM_Op<"tcgen05.mma.ws",
-                                  [NVVMRequiresSMa<[100, 110]>]> {
+                                  [NVVMRequiresSMAA<[100, 110]>]> {
   let summary = "Performs weight stationary convolution MMA operation on 5th-gen tensor cores";
 
   let description = [{
@@ -6242,7 +6242,7 @@ def NVVM_Tcgen05MMAWsOp : NVVM_Op<"tcgen05.mma.ws",
 }
 
 def NVVM_Tcgen05MMAWsSparseOp : NVVM_Op<"tcgen05.mma.ws.sp",
-                                        [NVVMRequiresSMa<[100, 110]>]> {
+                                        [NVVMRequiresSMAA<[100, 110]>]> {
   let summary = "Performs weight stationary convolution MMA with sparse A matrix on 5th-gen tensor cores";
 
   let description = [{
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
index f6718ad0e29f6..e0bb3c5be8ab7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
@@ -120,12 +120,14 @@ class NVVMRequiresSM {
   };
 };
 
+// SMVersions is a template parameter pack of the supported
+// architecture-accelerated SM versions.
 template <int... SMVersions>
-class NVVMRequiresSMa {
+class NVVMRequiresSMAA {
 public:
   template <typename ConcreteOp>
   class Impl : public OpTrait::TraitBase<ConcreteOp,
-                                         NVVMRequiresSMa<SMVersions...>::Impl>,
+                                         NVVMRequiresSMAA<SMVersions...>::Impl>,
                public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
   public:
     NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
@@ -138,12 +140,14 @@ class NVVMRequiresSMa {
   }
 };
 
+// SMVersions is a template parameter pack of the supported family-specific SM
+// versions.
 template <int... SMVersions>
-class NVVMRequiresSMf {
+class NVVMRequiresSMFS {
 public:
   template <typename ConcreteOp>
   class Impl : public OpTrait::TraitBase<ConcreteOp,
-                                         NVVMRequiresSMf<SMVersions...>::Impl>,
+                                         NVVMRequiresSMFS<SMVersions...>::Impl>,
                public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
   public:
     NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
@@ -156,18 +160,20 @@ class NVVMRequiresSMf {
   }
 };
 
-template <typename SMVersionsA, typename SMVersionsF>
-class NVVMRequiresSMaOrf {
+// SMVersionsAA (SMVersionsFS) is a template parameter pack of the supported
+// architecture-accelerated (family-specific) SM versions.
+template <typename SMVersionsAA, typename SMVersionsFS>
+class NVVMRequiresSMAAOrFS {
 public:
   template <typename ConcreteOp>
   class Impl
       : public OpTrait::TraitBase<
-            ConcreteOp, NVVMRequiresSMaOrf<SMVersionsA, SMVersionsF>::Impl>,
+            ConcreteOp, NVVMRequiresSMAAOrFS<SMVersionsAA, SMVersionsFS>::Impl>,
         public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
   public:
     NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
-      auto result = SMVersionsA::getRequiredMinSMVersion();
-      result.append(SMVersionsF::getRequiredMinSMVersion());
+      auto result = SMVersionsAA::getRequiredMinSMVersion();
+      result.append(SMVersionsFS::getRequiredMinSMVersion());
       return result;
     }
   };
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
index 8c3d34fe033e9..0b009819fc2c3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
@@ -33,16 +33,16 @@ class NVVMRequiresSM<int minVersion> :
   ParamNativeOpTrait<"NVVMRequiresSM", !cast<string>(minVersion)>;
 
 // Op requires an exact SM match along with architecture acceleration.
-class NVVMRequiresSMa<list<int> smVersions> :
-  ParamNativeOpTrait<"NVVMRequiresSMa", !interleave(smVersions, ",")>;
+class NVVMRequiresSMAA<list<int> smVersions> :
+  ParamNativeOpTrait<"NVVMRequiresSMAA", !interleave(smVersions, ",")>;
 
 // Op requires an SM version belonging to the family.
-class NVVMRequiresSMf<list<int> smVersions> :
-  ParamNativeOpTrait<"NVVMRequiresSMf", !interleave(smVersions, ",")>;
+class NVVMRequiresSMFS<list<int> smVersions> :
+  ParamNativeOpTrait<"NVVMRequiresSMFS", !interleave(smVersions, ",")>;
 
 // Op supported on some combination of architecture acceleration and family-specific SM versions.
-class NVVMRequiresSMaOrf<list<int> smVersionsA, list<int> smVersionsF> :
-  ParamNativeOpTrait<"NVVMRequiresSMaOrf",
-    "mlir::OpTrait::NVVMRequiresSMa<" # !interleave(smVersionsA, ",") # ">" # "," #
-    "mlir::OpTrait::NVVMRequiresSMf<" # !interleave(smVersionsF, ",") # ">">;
+class NVVMRequiresSMAAOrFS<list<int> smVersionsAA, list<int> smVersionsFS> :
+  ParamNativeOpTrait<"NVVMRequiresSMAAOrFS",
+    "mlir::OpTrait::NVVMRequiresSMAA<" # !interleave(smVersionsAA, ",") # ">" # "," #
+    "mlir::OpTrait::NVVMRequiresSMFS<" # !interleave(smVersionsFS, ",") # ">">;
 #endif //NVVM_REQUIRES_SM_TRAITS
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 49fedc643c62a..38df7d5b13e50 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3084,31 +3084,31 @@ def TestNVVMRequiresSMOp :
 }
 
 def TestNVVMRequiresSMArchCondOp :
-    TEST_Op<"nvvm_requires_sm_90a", [NVVMRequiresSMa<[90]>]> {
+    TEST_Op<"nvvm_requires_sm_90a", [NVVMRequiresSMAA<[90]>]> {
   let arguments = (ins );
   let assemblyFormat = "attr-dict";
 }
 
 def TestNVVMRequirestSMArchCondMultiOp :
-    TEST_Op<"nvvm_requires_sm_90a_or_sm_100a", [NVVMRequiresSMa<[90, 100]>]> {
+    TEST_Op<"nvvm_requires_sm_90a_or_sm_100a", [NVVMRequiresSMAA<[90, 100]>]> {
   let arguments = (ins );
   let assemblyFormat = "attr-dict";
 }
 
 def TestNVVMRequiresSMFamilyCondOp :
-    TEST_Op<"nvvm_requires_sm_100f", [NVVMRequiresSMf<[100]>]> {
+    TEST_Op<"nvvm_requires_sm_100f", [NVVMRequiresSMFS<[100]>]> {
   let arguments = (ins );
   let assemblyFormat = "attr-dict";
 }
 
 def TestNVVMRequiresSMFamilyCondMultiOp :
-    TEST_Op<"nvvm_requires_sm_100f_or_sm_120f", [NVVMRequiresSMf<[100, 120]>]> {
+    TEST_Op<"nvvm_requires_sm_100f_or_sm_120f", [NVVMRequiresSMFS<[100, 120]>]> {
   let arguments = (ins );
   let assemblyFormat = "attr-dict";
 }
 
 def TestNVVMRequiresSMArchOrFamilyCondOp :
-    TEST_Op<"nvvm_requires_sm_90a_or_sm_100f", [NVVMRequiresSMaOrf<[90], [100]>]> {
+    TEST_Op<"nvvm_requires_sm_90a_or_sm_100f", [NVVMRequiresSMAAOrFS<[90], [100]>]> {
   let arguments = (ins );
   let assemblyFormat = "attr-dict";
 }



More information about the Mlir-commits mailing list