[llvm] [mlir] [MLIR][NVVM] Add NVVMRequiresSM op trait (PR #126886)

Srinivasa Ravi via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 15 01:42:44 PDT 2025


https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/126886

>From 0d542cae636a2e9fa9d873443901f936da8a3bf9 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Mon, 20 Jan 2025 16:41:02 +0530
Subject: [PATCH 1/4] [MLIR][NVVM] Add NVVMRequiresSM op trait

This change adds the NVVMRequiresSM op trait to the NVVM dialect to
allow tagging NVVM Ops with a minimum required SM version. When a
target SM is able to be determined (through NVVMTargetAttr), this
allows the verification of SM compatibility with the Op without needing
to unnecessarily lower any further down.
---
 .../GPU/IR/CompilationAttrInterfaces.td       | 14 +++
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td    |  2 +
 .../mlir/Dialect/LLVMIR/CMakeLists.txt        |  6 ++
 .../include/mlir/Dialect/LLVMIR/NVVMDialect.h |  2 +
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   | 30 ++++---
 mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h | 88 +++++++++++++++++++
 .../include/mlir/Dialect/LLVMIR/NVVMTraits.td | 34 +++++++
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        | 16 ++++
 mlir/lib/Dialect/LLVMIR/CMakeLists.txt        |  3 +
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    | 20 +++++
 mlir/lib/Dialect/LLVMIR/IR/NVVMTraits.cpp     | 15 ++++
 .../Dialect/LLVMIR/nvvm-check-targetSM.mlir   | 46 ++++++++++
 mlir/test/lib/Dialect/Test/CMakeLists.txt     |  1 +
 mlir/test/lib/Dialect/Test/TestOps.h          |  1 +
 mlir/test/lib/Dialect/Test/TestOps.td         | 17 ++++
 .../llvm-project-overlay/mlir/BUILD.bazel     | 30 ++++++-
 16 files changed, 313 insertions(+), 12 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
 create mode 100644 mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.td
 create mode 100644 mlir/lib/Dialect/LLVMIR/IR/NVVMTraits.cpp
 create mode 100644 mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir

diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td b/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td
index 6d5fd01499121..0cb85180f0741 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td
@@ -55,6 +55,20 @@ def GPUTargetAttrInterface : AttrInterface<"TargetAttrInterface"> {
   ];
 }
 
+def GPUTargetAttrVerifyInterface : AttrInterface<"TargetAttrVerifyInterface"> {
+  let description = [{
+    Interface for GPU target attributes that need to verify the target attribute
+    for the given GPU module.
+  }];
+  let cppNamespace = "::mlir::gpu";
+  let methods = [
+    InterfaceMethod<[{
+        Verifies that the target attribute is valid for the given GPU module.
+      }], "::mlir::LogicalResult", "verifyTarget",
+      (ins "::mlir::Operation *":$module)>
+  ];
+}
+
 def GPUTargetAttr :
     ConfinedAttr<AnyAttr, [PromisedAttrInterface<GPUTargetAttrInterface>]> {
   let description = [{
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 2b1ce573effd0..414825e7c634f 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1460,6 +1460,8 @@ def GPU_GPUModuleOp : GPU_Op<"module", [
     /// Sets the targets of the module.
     void setTargets(ArrayRef<TargetAttrInterface> targets);
   }];
+  
+  let hasVerifier = 1;
 }
 
 def GPU_BinaryOp : GPU_Op<"binary", [Symbol]>, Arguments<(ins
diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index 759de745440c2..0efdcbbf9e469 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -54,6 +54,12 @@ mlir_tablegen(BasicPtxBuilderInterface.cpp.inc -gen-op-interface-defs)
 add_public_tablegen_target(MLIRBasicPtxBuilderInterfaceIncGen)
 add_dependencies(mlir-headers MLIRBasicPtxBuilderInterfaceIncGen)
 
+set(LLVM_TARGET_DEFINITIONS NVVMTraits.td)
+mlir_tablegen(NVVMTraits.h.inc -gen-op-interface-decls)
+mlir_tablegen(NVVMTraits.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRNVVMTraitsIncGen)
+add_dependencies(mlir-headers MLIRNVVMTraitsIncGen)
+
 add_mlir_dialect(NVVMOps nvvm)
 add_mlir_doc(NVVMOps NVVMDialect Dialects/ -gen-dialect-doc -dialect=nvvm)
 set(LLVM_TARGET_DEFINITIONS NVVMOps.td)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
index a9270c6f52344..7ebbaa6aa15eb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
@@ -16,7 +16,9 @@
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/NVVMTraits.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/InferIntRangeInterface.h"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index fe15a524ec3b5..4cb431e023441 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -16,6 +16,7 @@
 include "mlir/IR/EnumAttr.td"
 include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/Dialect/LLVMIR/NVVMTraits.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
 include "mlir/Interfaces/InferIntRangeInterface.td"
@@ -136,8 +137,10 @@ class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
   let assemblyFormat = "attr-dict `:` type($res)";
 }
 
-class NVVM_SpecialRangeableRegisterOp<string mnemonic> :
-  NVVM_SpecialRegisterOp<mnemonic, [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
+class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
+  NVVM_SpecialRegisterOp<mnemonic,
+    !listconcat(traits,
+      [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>])> {
   let arguments = (ins OptionalAttr<LLVM_ConstantRangeAttr>:$range);
   let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)";
   let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda;
@@ -167,14 +170,14 @@ class NVVM_SpecialRangeableRegisterOp<string mnemonic> :
 def NVVM_LaneIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.laneid">;
 def NVVM_WarpSizeOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpsize">;
 def NVVM_WarpIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpid">;
-def NVVM_WarpDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid">;
+def NVVM_WarpDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid", [NVVMRequiresSM<20>]>;
 def NVVM_SmIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.smid">;
 def NVVM_SmDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nsmid">;
 def NVVM_GridIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.gridid">;
 
 //===----------------------------------------------------------------------===//
 // Lane Mask Comparison Ops
-def NVVM_LaneMaskEqOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.eq">;
+def NVVM_LaneMaskEqOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.eq", [NVVMRequiresSM<20>]>;
 def NVVM_LaneMaskLeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.le">;
 def NVVM_LaneMaskLtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.lt">;
 def NVVM_LaneMaskGeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.ge">;
@@ -200,7 +203,7 @@ def NVVM_GridDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.z">;
 
 //===----------------------------------------------------------------------===//
 // CTA Cluster index and range
-def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x">;
+def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x", [NVVMRequiresSM<90>]>;
 def NVVM_ClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.y">;
 def NVVM_ClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.z">;
 def NVVM_ClusterDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.x">;
@@ -210,7 +213,7 @@ def NVVM_ClusterDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ncluster
 
 //===----------------------------------------------------------------------===//
 // CTA index and range within Cluster
-def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x">;
+def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x", [NVVMRequiresSM<90>]>;
 def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y">;
 def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z">;
 def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x">;
@@ -269,7 +272,7 @@ def ReduxKind : I32EnumAttr<"ReduxKind", "NVVM redux kind",
 def ReduxKindAttr : EnumAttr<NVVM_Dialect, ReduxKind, "redux_kind">;
 
 def NVVM_ReduxOp :
-  NVVM_Op<"redux.sync">,
+  NVVM_Op<"redux.sync", [NVVMRequiresSM<80>]>,
   Results<(outs LLVM_Type:$res)>,
   Arguments<(ins LLVM_Type:$val,
                  ReduxKindAttr:$kind,
@@ -2327,7 +2330,8 @@ def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
 // NVVM Wgmma Ops
 //===----------------------------------------------------------------------===//
 
-def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> {
+def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned",
+                              [NVVMRequiresSM<90, /*ArchAccelerated*/"true">]> {
   let arguments = (ins);
   let description = [{
     Enforce an ordering of register accesses between warpgroup level matrix 
@@ -2341,8 +2345,8 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> {
   }];
 }
 
-def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned">,
-  Arguments<(ins )> {
+def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned",  
+                              [NVVMRequiresSM<90, /*ArchAccelerated*/"true">]> {
   let assemblyFormat = "attr-dict";
   let description = [{
     Commits all prior uncommitted warpgroup level matrix multiplication operations.
@@ -2814,7 +2818,8 @@ def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit"> {
 // NVVM target attribute.
 //===----------------------------------------------------------------------===//
 
-def NVVM_TargettAttr : NVVM_Attr<"NVVMTarget", "target"> {
+def NVVM_TargettAttr : NVVM_Attr<"NVVMTarget", "target", 
+  [DeclareAttrInterfaceMethods<GPUTargetAttrVerifyInterface>]> {
   let description = [{
     GPU target attribute for controlling compilation of NVIDIA targets. All
     parameters decay into default values if not present.
@@ -2862,6 +2867,9 @@ def NVVM_TargettAttr : NVVM_Attr<"NVVMTarget", "target"> {
     bool hasFlag(StringRef flag) const;
     bool hasFastMath() const;
     bool hasFtz() const;
+    bool hasCmdOptions() const;
+    std::optional<mlir::NamedAttribute> getCmdOptions() const;
+    LogicalResult verifyTarget(Operation *gpuModule);
   }];
   let extraClassDefinition = [{
     bool $cppClass::hasFlag(StringRef flag) const {
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
new file mode 100644
index 0000000000000..6174309c46185
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
@@ -0,0 +1,88 @@
+//===--- NVVMTraits.h - NVVM Traits -----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines op traits for the NVVM Dialect in MLIR
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef NVVM_DIALECT_NVVM_IR_NVVMTRAITS_H_
+#define NVVM_DIALECT_NVVM_IR_NVVMTRAITS_H_
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StorageUniquerSupport.h"
+#include "llvm/ADT/StringExtras.h"
+
+namespace mlir {
+
+namespace NVVM {
+
+struct NVVMCheckSMVersion {
+  int archVersion;
+  bool archAccelerated;
+  std::string archString;
+
+  NVVMCheckSMVersion() {}
+  NVVMCheckSMVersion(StringRef SMVersion) : archString(SMVersion) {
+    parse(SMVersion);
+  }
+  NVVMCheckSMVersion(int archVersion, bool archAccelerated)
+      : archVersion(archVersion), archAccelerated(archAccelerated) {
+    archString = (llvm::Twine("sm_") + llvm::Twine(archVersion) +
+                  (archAccelerated ? "a" : "\0"))
+                     .str();
+  }
+
+  const StringRef getArchString() const { return archString; }
+
+  // Parses the SM version string and sets the archVersion (integer) and
+  // the archAccelerated flag.
+  void parse(StringRef SMVersion) {
+    archAccelerated = (SMVersion.back() == 'a');
+    SMVersion.drop_front(3)
+        .take_while([](char c) { return llvm::isDigit(c); })
+        .getAsInteger(10, archVersion);
+  }
+
+  bool isCompatible(const NVVMCheckSMVersion &TargetSM) const {
+    // for arch-conditional SMs, they should exactly match to be valid
+    if (archAccelerated || TargetSM.archAccelerated)
+      return (*this) == TargetSM;
+
+    return archVersion <= TargetSM.archVersion;
+  }
+
+  bool operator==(const NVVMCheckSMVersion &Other) const {
+    return archVersion == Other.archVersion &&
+           archAccelerated == Other.archAccelerated;
+  }
+};
+} // namespace NVVM
+} // namespace mlir
+
+#include "mlir/Dialect/LLVMIR/NVVMTraits.h.inc"
+
+namespace mlir {
+
+namespace OpTrait {
+
+template <int Version, bool ArchAccelerated = false>
+class NVVMRequiresSM {
+public:
+  template <typename ConcreteOp>
+  class Impl : public OpTrait::TraitBase<
+                   ConcreteOp, NVVMRequiresSM<Version, ArchAccelerated>::Impl>,
+               public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
+  public:
+    const NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
+      return NVVM::NVVMCheckSMVersion(Version, ArchAccelerated);
+    }
+  };
+};
+} // namespace OpTrait
+} // namespace mlir
+#endif // NVVM_DIALECT_NVVM_IR_NVVMTRAITS_H_
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.td
new file mode 100644
index 0000000000000..e8ec74eef2cb3
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.td
@@ -0,0 +1,34 @@
+//===-- NVVMTraits.td - NVVM Traits ------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines traits for the NVVM Dialect in MLIR
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef NVVM_TRAITS
+#define NVVM_TRAITS
+
+include "mlir/IR/OpBase.td"
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+
+// Interface for NVVM Ops with the NVVMRequiresSM parametric trait
+def RequiresSMInterface: OpInterface<"RequiresSMInterface"> {
+  let cppNamespace = "::mlir::NVVM";
+  let methods = [
+    InterfaceMethod<
+      "Get the SM version required by the op from the trait", 
+      "const mlir::NVVM::NVVMCheckSMVersion", "getRequiredMinSMVersion"
+    >
+  ];
+}
+
+class NVVMRequiresSM<int Version, string ArchAccelerated = "false"> :
+  ParamNativeOpTrait<"NVVMRequiresSM",
+                    !cast<string>(Version) # "," # ArchAccelerated>;
+
+#endif //NVVM_TRAITS
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index d06f10d3137a1..d4c3e65c707cd 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1793,6 +1793,22 @@ void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
   targetsAttr = ArrayAttr::get(getContext(), targetsVector);
 }
 
+LogicalResult GPUModuleOp::verify() {
+  auto targets = getOperation()->getAttrOfType<ArrayAttr>("targets");
+
+  if (!targets)
+    return success();
+
+  for (auto target : targets) {
+    if (auto verifyTargetAttr =
+            llvm::dyn_cast<TargetAttrVerifyInterface>(target)) {
+      if (verifyTargetAttr.verifyTarget(getOperation()).failed())
+        return failure();
+    }
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // GPUBinaryOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index c9a3b97294562..7030f0f5fd6e2 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -42,6 +42,7 @@ add_mlir_dialect_library(MLIRLLVMDialect
 add_mlir_dialect_library(MLIRNVVMDialect
   IR/NVVMDialect.cpp
   IR/BasicPtxBuilderInterface.cpp
+  IR/NVVMTraits.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR
@@ -51,6 +52,7 @@ add_mlir_dialect_library(MLIRNVVMDialect
   MLIRNVVMOpsIncGen
   MLIRNVVMConversionsIncGen
   MLIRBasicPtxBuilderInterfaceIncGen
+  MLIRNVVMTraitsIncGen
   intrinsics_gen
 
   LINK_COMPONENTS
@@ -60,6 +62,7 @@ add_mlir_dialect_library(MLIRNVVMDialect
   LINK_LIBS PUBLIC
   MLIRIR
   MLIRLLVMDialect
+  MLIRGPUDialect
   MLIRSideEffectInterfaces
   MLIRInferIntRangeInterface
   )
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 62f0c21338111..b2da955a0909c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -18,6 +18,7 @@
 
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -1439,6 +1440,25 @@ NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
+  auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
+  if (!gpuModuleOp)
+    return emitError(gpuModule->getLoc(),
+                     "NVVM target attribute must be attached to a GPU module");
+  gpuModuleOp->walk([&](Operation *op) {
+    if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
+      auto requirement = reqOp.getRequiredMinSMVersion();
+      if (!requirement.isCompatible(NVVMCheckSMVersion(getChip()))) {
+        op->emitOpError() << "is not supported on " << getChip();
+        return WalkResult::interrupt();
+      }
+    }
+    return WalkResult::advance();
+  });
+  
+  return success();
+}
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMTraits.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMTraits.cpp
new file mode 100644
index 0000000000000..c774bac2400f3
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMTraits.cpp
@@ -0,0 +1,15 @@
+//===--- NVVMTraits.h - NVVM Traits -----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines op traits for the NVVM Dialect in MLIR
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/NVVMTraits.h"
+
+#include "mlir/Dialect/LLVMIR/NVVMTraits.cpp.inc"
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
new file mode 100644
index 0000000000000..bf5c349a9aa7b
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
@@ -0,0 +1,46 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// Just check these don't emit errors.
+gpu.module @check_valid_SM_exact [#nvvm.target<chip = "sm_80">] {
+  test.nvvm_requires_sm_80
+}
+
+gpu.module @check_valid_SM_greater_1 [#nvvm.target<chip = "sm_86">] {
+  test.nvvm_requires_sm_80
+}
+
+gpu.module @check_valid_SM_greater_2 [#nvvm.target<chip = "sm_90">] {
+  test.nvvm_requires_sm_80
+}
+
+gpu.module @check_valid_SM_arch_acc [#nvvm.target<chip = "sm_90a">] {
+  test.nvvm_requires_sm_90a
+}
+
+// -----
+
+gpu.module @check_invalid_SM_lesser_1 [#nvvm.target<chip = "sm_70">] {
+  // expected-error @below {{is not supported on sm_70}}
+  test.nvvm_requires_sm_80
+}
+
+// -----
+
+gpu.module @check_invalid_SM_lesser_2 [#nvvm.target<chip = "sm_75">] {
+  // expected-error @below {{is not supported on sm_75}}
+  test.nvvm_requires_sm_80
+}
+
+// -----
+
+gpu.module @check_invalid_SM_arch_acc_1 [#nvvm.target<chip = "sm_90">] {
+  // expected-error @below {{is not supported on sm_90}}
+  test.nvvm_requires_sm_90a
+}
+
+// -----
+
+gpu.module @check_invalid_SM_arch_acc_2 [#nvvm.target<chip = "sm_80">] {
+  // expected-error @below {{is not supported on sm_80}}
+  test.nvvm_requires_sm_90a
+}
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index 618b13da9899f..b1ffbcc7df9a2 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -85,6 +85,7 @@ mlir_target_link_libraries(MLIRTestDialect PUBLIC
   MLIRLinalgDialect
   MLIRLinalgTransforms
   MLIRLLVMDialect
+  MLIRNVVMDialect
   MLIRPass
   MLIRPolynomialDialect
   MLIRReduce
diff --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h
index f070c3bedd92c..1adb86ecf7b2a 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.h
+++ b/mlir/test/lib/Dialect/Test/TestOps.h
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/DLTI/Traits.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/NVVMTraits.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/Traits.h"
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 2aa0658ab0e5d..4f922e62e2b1a 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -13,6 +13,7 @@ include "TestDialect.td"
 include "TestInterfaces.td"
 include "mlir/Dialect/DLTI/DLTIBase.td"
 include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
+include "mlir/Dialect/LLVMIR/NVVMTraits.td"
 include "mlir/IR/EnumAttr.td"
 include "mlir/Interfaces/FunctionInterfaces.td"
 include "mlir/IR/OpBase.td"
@@ -2698,6 +2699,22 @@ def TestLinalgFillOp :
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Test NVVM RequiresSM trait.
+//===----------------------------------------------------------------------===//
+
+def TestNVVMRequiresSMOp : TEST_Op<"nvvm_requires_sm_80",
+                                    [NVVMRequiresSM<80>]> {
+  let arguments = (ins );
+  let assemblyFormat = "attr-dict";
+}
+
+def TestNVVMRequiresSMArchCondOp : TEST_Op<"nvvm_requires_sm_90a",
+                                          [NVVMRequiresSM<90, "true">]> {
+  let arguments = (ins );
+  let assemblyFormat = "attr-dict";
+}
+
 //===----------------------------------------------------------------------===//
 // Test Ops with Default-Valued String Attributes
 //===----------------------------------------------------------------------===//
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index e07891f004850..cc5f18acf125b 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6527,15 +6527,18 @@ cc_library(
     name = "NVVMDialect",
     srcs = [
         "lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp",
+        "lib/Dialect/LLVMIR/IR/NVVMTraits.cpp",
         "lib/Dialect/LLVMIR/IR/NVVMDialect.cpp",
     ],
     hdrs = [
         "include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h",
+        "include/mlir/Dialect/LLVMIR/NVVMTraits.h",
         "include/mlir/Dialect/LLVMIR/NVVMDialect.h",
-    ],
+],
     includes = ["include"],
     deps = [
         ":BasicPtxBuilderIntGen",
+        ":NVVMTraitsIntGen",
         ":BytecodeOpInterface",
         ":ConvertToLLVMInterface",
         ":DialectUtils",
@@ -6608,12 +6611,20 @@ td_library(
     ],
 )
 
+td_library(
+    name = "NVVMTraitsIntTdFiles",
+    srcs = ["include/mlir/Dialect/LLVMIR/NVVMTraits.td"],
+    includes = ["include"],
+    deps = [":OpBaseTdFiles"]
+)
+
 td_library(
     name = "NVVMOpsTdFiles",
     srcs = ["include/mlir/Dialect/LLVMIR/NVVMOps.td"],
     includes = ["include"],
     deps = [
         ":BasicPtxBuilderIntTdFiles",
+        ":NVVMTraitsIntTdFiles",
         ":GPUOpsTdFiles",
         ":LLVMOpsTdFiles",
         ":OpBaseTdFiles",
@@ -6646,6 +6657,23 @@ gentbl_cc_library(
     ],
 )
 
+gentbl_cc_library(
+    name = "NVVMTraitsIntGen",
+    tbl_outs = [
+        (
+            ["-gen-op-interface-decls"],
+            "include/mlir/Dialect/LLVMIR/NVVMTraits.h.inc",
+        ),
+        (
+            ["-gen-op-interface-defs"],
+            "include/mlir/Dialect/LLVMIR/NVVMTraits.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/LLVMIR/NVVMTraits.td",
+    deps = [":NVVMTraitsIntTdFiles"],
+)
+
 gentbl_cc_library(
     name = "NVVMOpsIncGen",
     tbl_outs = [

>From 7305d3a3cb6e2e33e732035ff312a642822de1cf Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Fri, 14 Mar 2025 16:01:22 +0530
Subject: [PATCH 2/4] [MLIR][NVVM] Add NVVMRequiresSM op trait

This change adds the NVVMRequiresSM op trait to the NVVM dialect to
allow tagging NVVM Ops with a minimum required SM version. When a
target SM is able to be determined (through NVVMTargetAttr), this
allows the verification of SM compatibility with the Op without needing
to unnecessarily lower any further down.
---
 .../include/mlir/Dialect/LLVMIR/NVVMDialect.h |  2 +-
 mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h | 33 +++++++------------
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    |  8 +++--
 3 files changed, 18 insertions(+), 25 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
index 7ebbaa6aa15eb..ec18b38153103 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
@@ -15,8 +15,8 @@
 #define MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
-#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/NVVMTraits.h"
 #include "mlir/IR/Dialect.h"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
index 6174309c46185..51b868cfe5a7e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
@@ -24,41 +24,32 @@ namespace NVVM {
 struct NVVMCheckSMVersion {
   int archVersion;
   bool archAccelerated;
-  std::string archString;
 
   NVVMCheckSMVersion() {}
-  NVVMCheckSMVersion(StringRef SMVersion) : archString(SMVersion) {
-    parse(SMVersion);
-  }
+  NVVMCheckSMVersion(StringRef smVersion) { parse(smVersion); }
   NVVMCheckSMVersion(int archVersion, bool archAccelerated)
-      : archVersion(archVersion), archAccelerated(archAccelerated) {
-    archString = (llvm::Twine("sm_") + llvm::Twine(archVersion) +
-                  (archAccelerated ? "a" : "\0"))
-                     .str();
-  }
-
-  const StringRef getArchString() const { return archString; }
+      : archVersion(archVersion), archAccelerated(archAccelerated) {}
 
   // Parses the SM version string and sets the archVersion (integer) and
   // the archAccelerated flag.
-  void parse(StringRef SMVersion) {
-    archAccelerated = (SMVersion.back() == 'a');
-    SMVersion.drop_front(3)
+  void parse(StringRef smVersion) {
+    archAccelerated = (smVersion.back() == 'a');
+    smVersion.drop_front(3)
         .take_while([](char c) { return llvm::isDigit(c); })
         .getAsInteger(10, archVersion);
   }
 
-  bool isCompatible(const NVVMCheckSMVersion &TargetSM) const {
+  bool isCompatible(const NVVMCheckSMVersion &targetSM) const {
     // for arch-conditional SMs, they should exactly match to be valid
-    if (archAccelerated || TargetSM.archAccelerated)
-      return (*this) == TargetSM;
+    if (archAccelerated || targetSM.archAccelerated)
+      return (*this) == targetSM;
 
-    return archVersion <= TargetSM.archVersion;
+    return archVersion <= targetSM.archVersion;
   }
 
-  bool operator==(const NVVMCheckSMVersion &Other) const {
-    return archVersion == Other.archVersion &&
-           archAccelerated == Other.archAccelerated;
+  bool operator==(const NVVMCheckSMVersion &other) const {
+    return archVersion == other.archVersion &&
+           archAccelerated == other.archAccelerated;
   }
 };
 } // namespace NVVM
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 82b263debc587..74da5e3762ecd 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1564,17 +1564,19 @@ LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
   if (!gpuModuleOp)
     return emitError(gpuModule->getLoc(),
                      "NVVM target attribute must be attached to a GPU module");
+
+  NVVMCheckSMVersion targetSMVersion(getChip());
   gpuModuleOp->walk([&](Operation *op) {
     if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
-      auto requirement = reqOp.getRequiredMinSMVersion();
-      if (!requirement.isCompatible(NVVMCheckSMVersion(getChip()))) {
+      NVVMCheckSMVersion requirement = reqOp.getRequiredMinSMVersion();
+      if (!requirement.isCompatible(targetSMVersion)) {
         op->emitOpError() << "is not supported on " << getChip();
         return WalkResult::interrupt();
       }
     }
     return WalkResult::advance();
   });
-  
+
   return success();
 }
 

>From c8376db8a3b2ffec99ab4fbec2da572d92b423ac Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Fri, 14 Mar 2025 20:41:37 +0530
Subject: [PATCH 3/4] Add verifyTarget parameter to NVVMTargetAttr

---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td       | 10 ++++++----
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp        |  5 ++++-
 mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir | 10 ++++++++++
 3 files changed, 20 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 98fbdc1f0f9c0..9bea0e7417a63 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3167,10 +3167,11 @@ def NVVM_TargettAttr : NVVM_Attr<"NVVMTarget", "target",
     StringRefParameter<"Target chip.", "\"sm_50\"">:$chip,
     StringRefParameter<"Target chip features.", "\"+ptx60\"">:$features,
     OptionalParameter<"DictionaryAttr", "Target specific flags.">:$flags,
-    OptionalParameter<"ArrayAttr", "Files to link to the LLVM module.">:$link
+    OptionalParameter<"ArrayAttr", "Files to link to the LLVM module.">:$link,
+    DefaultValuedParameter<"bool", "true", "Perform SM version check on Ops.">:$verifyTarget
   );
   let assemblyFormat = [{
-    (`<` struct($O, $triple, $chip, $features, $flags, $link)^ `>`)?
+    (`<` struct($O, $triple, $chip, $features, $flags, $link, $verifyTarget)^ `>`)?
   }];
   let builders = [
     AttrBuilder<(ins CArg<"int", "2">:$optLevel,
@@ -3178,8 +3179,9 @@ def NVVM_TargettAttr : NVVM_Attr<"NVVMTarget", "target",
                      CArg<"StringRef", "\"sm_50\"">:$chip,
                      CArg<"StringRef", "\"+ptx60\"">:$features,
                      CArg<"DictionaryAttr", "nullptr">:$targetFlags,
-                     CArg<"ArrayAttr", "nullptr">:$linkFiles), [{
-      return Base::get($_ctxt, optLevel, triple, chip, features, targetFlags, linkFiles);
+                     CArg<"ArrayAttr", "nullptr">:$linkFiles,
+                     CArg<"bool", "true">:$verifyTarget), [{
+      return Base::get($_ctxt, optLevel, triple, chip, features, targetFlags, linkFiles, verifyTarget);
     }]>
   ];
   let skipDefaultBuilders = 1;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 74da5e3762ecd..1316f396c9436 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1537,7 +1537,7 @@ LogicalResult
 NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
                        int optLevel, StringRef triple, StringRef chip,
                        StringRef features, DictionaryAttr flags,
-                       ArrayAttr files) {
+                       ArrayAttr files, bool verifyTarget) {
   if (optLevel < 0 || optLevel > 3) {
     emitError() << "The optimization level must be a number between 0 and 3.";
     return failure();
@@ -1560,6 +1560,9 @@ NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
 }
 
 LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
+  if (!getVerifyTarget())
+    return success();
+
   auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
   if (!gpuModuleOp)
     return emitError(gpuModule->getLoc(),
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
index bf5c349a9aa7b..97adeb000c55f 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
@@ -17,6 +17,16 @@ gpu.module @check_valid_SM_arch_acc [#nvvm.target<chip = "sm_90a">] {
   test.nvvm_requires_sm_90a
 }
 
+gpu.module @disable_verify_target1 [#nvvm.target<chip = "sm_90", verifyTarget = false>] {
+  test.nvvm_requires_sm_90a
+}
+
+gpu.module @disable_verify_target2 [#nvvm.target<chip = "sm_70", verifyTarget = false>] {
+  test.nvvm_requires_sm_80
+}
+
+
+
 // -----
 
 gpu.module @check_invalid_SM_lesser_1 [#nvvm.target<chip = "sm_70">] {

>From 0a33861c81796aa0a634716323acf511ead62aae Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 15 Apr 2025 14:12:16 +0530
Subject: [PATCH 4/4] address comments

---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   | 52 +++++++++----------
 mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h | 36 ++++++++-----
 .../include/mlir/Dialect/LLVMIR/NVVMTraits.td |  8 ++-
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    |  6 ++-
 .../Dialect/LLVMIR/nvvm-check-targetSM.mlir   | 33 ++++++++++--
 mlir/test/lib/Dialect/Test/TestOps.td         | 10 +++-
 6 files changed, 96 insertions(+), 49 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 9bea0e7417a63..8eac9dc7ac8a1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -214,15 +214,15 @@ def NVVM_ClusterDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ncluster
 //===----------------------------------------------------------------------===//
 // CTA index and range within Cluster
 def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x", [NVVMRequiresSM<90>]>;
-def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y">;
-def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z">;
-def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x">;
-def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.y">;
+def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y", [NVVMRequiresSM<90>]>;
+def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z", [NVVMRequiresSM<90>]>;
+def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x", [NVVMRequiresSM<90>]>;
+def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.y", [NVVMRequiresSM<90>]>;
 def NVVM_ClusterDimBlocksZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.z">;
 
 //===----------------------------------------------------------------------===//
 // CTA index and across Cluster dimensions
-def NVVM_ClusterId : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank">;
+def NVVM_ClusterId : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank", [NVVMRequiresSM<90>]>;
 def NVVM_ClusterDim : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctarank">;
 
 //===----------------------------------------------------------------------===//
@@ -323,7 +323,7 @@ def NVVM_MBarrierInitOp : NVVM_PTXBuilder_Op<"mbarrier.init">,
 }
 
 /// mbarrier.init instruction with shared pointer type
-def NVVM_MBarrierInitSharedOp : NVVM_PTXBuilder_Op<"mbarrier.init.shared">,
+def NVVM_MBarrierInitSharedOp : NVVM_PTXBuilder_Op<"mbarrier.init.shared", [NVVMRequiresSM<80>, DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
   Arguments<(ins LLVM_PointerShared:$addr, I32:$count, PtxPredicate:$predicate)> {
   string llvmBuilder = [{
       createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init_shared, {$addr, $count});
@@ -545,7 +545,7 @@ def NVVM_ClusterArriveOp : NVVM_Op<"cluster.arrive"> {
   let assemblyFormat = "attr-dict";
 }
 
-def NVVM_ClusterArriveRelaxedOp : NVVM_Op<"cluster.arrive.relaxed"> {
+def NVVM_ClusterArriveRelaxedOp : NVVM_Op<"cluster.arrive.relaxed", [NVVMRequiresSM<90>]> {
   let arguments = (ins OptionalAttr<UnitAttr>:$aligned);
 
   let summary = "Cluster Barrier Relaxed Arrive Op";
@@ -571,7 +571,7 @@ def NVVM_ClusterArriveRelaxedOp : NVVM_Op<"cluster.arrive.relaxed"> {
   let assemblyFormat = "attr-dict";
 }
 
-def NVVM_ClusterWaitOp : NVVM_Op<"cluster.wait"> {
+def NVVM_ClusterWaitOp : NVVM_Op<"cluster.wait", [NVVMRequiresSM<90>]> {
   let arguments = (ins OptionalAttr<UnitAttr>:$aligned);
 
   let summary = "Cluster Barrier Wait Op";
@@ -776,7 +776,7 @@ def ShflKind : I32EnumAttr<"ShflKind", "NVVM shuffle kind",
 def ShflKindAttr : EnumAttr<NVVM_Dialect, ShflKind, "shfl_kind">;
 
 def NVVM_ShflOp :
-  NVVM_Op<"shfl.sync">,
+  NVVM_Op<"shfl.sync", [NVVMRequiresSM<30>]>,
   Results<(outs LLVM_Type:$res)>,
   Arguments<(ins I32:$thread_mask,
                  LLVM_Type:$val,
@@ -1880,7 +1880,7 @@ def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
   }];
 }
 
-def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group">,
+def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group", [NVVMRequiresSM<90>]>,
   Arguments<(ins 
     ConfinedAttr<I32Attr, [IntMinValue<0>]>:$group, 
     OptionalAttr<UnitAttr>:$read)> {
@@ -1910,7 +1910,7 @@ def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group">,
 def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : 
   NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", 
   [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 
-  AttrSizedOperandSegments]>,
+  AttrSizedOperandSegments, NVVMRequiresSM<90>]>,
   Arguments<(ins  LLVM_PointerShared:$dstMem,
                   LLVM_AnyPointer:$tmaDescriptor,
                   Variadic<I32>:$coordinates,
@@ -2347,8 +2347,7 @@ def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
 // NVVM Wgmma Ops
 //===----------------------------------------------------------------------===//
 
-def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned",
-                              [NVVMRequiresSM<90, /*ArchAccelerated*/"true">]> {
+def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", [NVVMRequiresSM90a]> {
   let arguments = (ins);
   let description = [{
     Enforce an ordering of register accesses between warpgroup level matrix 
@@ -2362,8 +2361,7 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned",
   }];
 }
 
-def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned",  
-                              [NVVMRequiresSM<90, /*ArchAccelerated*/"true">]> {
+def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned", [NVVMRequiresSM90a]> {
   let assemblyFormat = "attr-dict";
   let description = [{
     Commits all prior uncommitted warpgroup level matrix multiplication operations.
@@ -2375,7 +2373,7 @@ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned",
   }];
 }
 
-def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned">{
+def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned", [NVVMRequiresSM90a]> {
   let arguments = (ins I64Attr:$group);
   let assemblyFormat = "attr-dict $group";
   let description = [{
@@ -2571,7 +2569,7 @@ def NVVM_GriddepcontrolLaunchDependentsOp
 
 def NVVM_MapaOp: NVVM_Op<"mapa",
     [TypesMatchWith<"`res` and `a` should have the same type",
-                    "a", "res", "$_self">]> {
+                    "a", "res", "$_self">, NVVMRequiresSM<90>]> {
   let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$res);
   let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b);
 
@@ -2662,7 +2660,7 @@ def Tcgen05WaitKindAttr :
   let assemblyFormat = "`<` $value `>`";
 }
 
-def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc"> {
+def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc", [NVVMRequiresSM<100, "true", "false">]> {
   let summary = "Tcgen05 alloc operation";
   let description = [{
     The `tcgen05.alloc` Op allocates tensor core memory for
@@ -2692,7 +2690,7 @@ def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc"> {
   }];
 }
 
-def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc"> {
+def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc", [NVVMRequiresSM<100, "true", "false">]> {
   let summary = "Tcgen05 dealloc operation";
   let description = [{
     The `tcgen05.dealloc` Op de-allocates the tensor core memory
@@ -2720,7 +2718,7 @@ def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc"> {
   }];
 }
 
-def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_permit"> {
+def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_permit", [NVVMRequiresSM<100, "true", "false">]> {
   let summary = "Tcgen05 Op to relinquish the right to allocate";
   let description = [{
     The `tcgen05.relinquish_alloc_permit` Op specifies that the CTA
@@ -2743,7 +2741,7 @@ def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_perm
   }];
 }
 
-def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence"> {
+def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence", [NVVMRequiresSM<100, "true", "false">]> {
   let summary = "Tcgen05 fence operations";
   let description = [{
     The `tcgen05.fence<before>` orders all prior async tcgen05 operations
@@ -2765,7 +2763,7 @@ def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence"> {
   }];
 }
 
-def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait"> {
+def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait", [NVVMRequiresSM<100, "true", "false">]> {
   let summary = "Tcgen05 wait operations";
   let description = [{
     The `tcgen05.wait<load>` causes the executing thread to block until
@@ -2787,7 +2785,7 @@ def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait"> {
   }];
 }
 
-def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit"> {
+def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit", [NVVMRequiresSM<100, "true", "false">]> {
   let summary = "Tcgen05 commit operations";
   let description = [{
     The `tcgen05.commit` makes the mbarrier object, specified by
@@ -2825,7 +2823,7 @@ def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit"> {
   }];
 }
 
-def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift"> {
+def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift", [NVVMRequiresSM<100, "true", "false">]> {
   let summary = "Tcgen05 shift operation";
   let description = [{
     The `tcgen05.shift` is an asynchronous instruction which initiates
@@ -2891,7 +2889,7 @@ def Tcgen05CpSrcFormatAttr : EnumAttr<NVVM_Dialect, Tcgen05CpSrcFormat, "tcgen05
   let assemblyFormat = "`<` $value `>`";
 }
 
-def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp"> {
+def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp", [NVVMRequiresSM<100, "true", "false">]> {
   let summary = "Tcgen05 copy operation";
   let description = [{
     Instruction tcgen05.cp initiates an asynchronous copy operation from
@@ -2961,7 +2959,7 @@ def Tcgen05LdStShapeAttr: EnumAttr<NVVM_Dialect, Tcgen05LdStShape, "tcgen05_ldst
 // NVVM tcgen05.ld Op
 //===----------------------------------------------------------------------===//
 
-def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld"> {
+def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld", [NVVMRequiresSM<100, "true", "false">]> {
   let summary = "tensor memory load instructions";
   let arguments = (ins
     // Attributes
@@ -3051,7 +3049,7 @@ def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld"> {
 // NVVM tcgen05.st Op
 //===----------------------------------------------------------------------===//
 
-def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st"> {
+def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSM<100, "true", "false">]> {
   let summary = "tensor memory store instructions";
   let arguments = (ins
     // Attributes
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
index 51b868cfe5a7e..04b408846e820 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
@@ -21,14 +21,21 @@ namespace mlir {
 
 namespace NVVM {
 
+// Structure to store and check compatibility of SM versions.
 struct NVVMCheckSMVersion {
   int archVersion;
   bool archAccelerated;
+  bool exactMatch;
 
-  NVVMCheckSMVersion() {}
-  NVVMCheckSMVersion(StringRef smVersion) { parse(smVersion); }
-  NVVMCheckSMVersion(int archVersion, bool archAccelerated)
-      : archVersion(archVersion), archAccelerated(archAccelerated) {}
+  NVVMCheckSMVersion()
+      : archVersion(0), archAccelerated(false), exactMatch(false) {}
+  NVVMCheckSMVersion(StringRef smVersion, bool exactMatch = false)
+      : exactMatch(exactMatch) {
+    parse(smVersion);
+  }
+  NVVMCheckSMVersion(int archVersion, bool archAccelerated, bool exactMatch)
+      : archVersion(archVersion), archAccelerated(archAccelerated),
+        exactMatch(exactMatch) {}
 
   // Parses the SM version string and sets the archVersion (integer) and
   // the archAccelerated flag.
@@ -40,11 +47,12 @@ struct NVVMCheckSMVersion {
   }
 
   bool isCompatible(const NVVMCheckSMVersion &targetSM) const {
-    // for arch-conditional SMs, they should exactly match to be valid
-    if (archAccelerated || targetSM.archAccelerated)
+    if (exactMatch)
       return (*this) == targetSM;
-
-    return archVersion <= targetSM.archVersion;
+    
+    return archAccelerated ? 
+      archVersion <= targetSM.archVersion && targetSM.archAccelerated :
+      archVersion <= targetSM.archVersion;
   }
 
   bool operator==(const NVVMCheckSMVersion &other) const {
@@ -61,16 +69,18 @@ namespace mlir {
 
 namespace OpTrait {
 
-template <int Version, bool ArchAccelerated = false>
+template <int MinVersion, bool ArchAccelerated = false, bool ExactMatch = false>
 class NVVMRequiresSM {
 public:
   template <typename ConcreteOp>
-  class Impl : public OpTrait::TraitBase<
-                   ConcreteOp, NVVMRequiresSM<Version, ArchAccelerated>::Impl>,
-               public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
+  class Impl
+      : public OpTrait::TraitBase<
+            ConcreteOp,
+            NVVMRequiresSM<MinVersion, ArchAccelerated, ExactMatch>::Impl>,
+        public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
   public:
     const NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
-      return NVVM::NVVMCheckSMVersion(Version, ArchAccelerated);
+      return NVVM::NVVMCheckSMVersion(MinVersion, ArchAccelerated, ExactMatch);
     }
   };
 };
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.td
index e8ec74eef2cb3..5e013b13dc26b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.td
@@ -27,8 +27,12 @@ def RequiresSMInterface: OpInterface<"RequiresSMInterface"> {
   ];
 }
 
-class NVVMRequiresSM<int Version, string ArchAccelerated = "false"> :
+class NVVMRequiresSM<int minVersion, string isArchAccelerated = "false",
+                    string exactMatch = "false"> :
   ParamNativeOpTrait<"NVVMRequiresSM",
-                    !cast<string>(Version) # "," # ArchAccelerated>;
+                    !cast<string>(minVersion) # "," # isArchAccelerated # ","
+                      # exactMatch>;
+                      
+def NVVMRequiresSM90a : NVVMRequiresSM<90, "true", "true">;
 
 #endif //NVVM_TRAITS
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 1316f396c9436..a16584b05c02e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1567,8 +1567,12 @@ LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
   if (!gpuModuleOp)
     return emitError(gpuModule->getLoc(),
                      "NVVM target attribute must be attached to a GPU module");
-
+  
   NVVMCheckSMVersion targetSMVersion(getChip());
+  if (targetSMVersion.archVersion < 20)
+    return emitError(gpuModule->getLoc(),
+                     "Minimum NVVM target SM version is sm_20");
+
   gpuModuleOp->walk([&](Operation *op) {
     if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
       NVVMCheckSMVersion requirement = reqOp.getRequiredMinSMVersion();
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
index 97adeb000c55f..13a21d1e2156f 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
@@ -13,10 +13,19 @@ gpu.module @check_valid_SM_greater_2 [#nvvm.target<chip = "sm_90">] {
   test.nvvm_requires_sm_80
 }
 
-gpu.module @check_valid_SM_arch_acc [#nvvm.target<chip = "sm_90a">] {
+gpu.module @check_valid_SM_arch_acc_exact_1 [#nvvm.target<chip = "sm_90a">] {
   test.nvvm_requires_sm_90a
 }
 
+gpu.module @check_valid_SM_arch_acc_atleast_1 [#nvvm.target<chip = "sm_90a">] {
+  test.nvvm_requires_sm_atleast_90_aa
+}
+
+gpu.module @check_valid_SM_arch_acc_atleast_2 [#nvvm.target<chip = "sm_100a">] {
+  test.nvvm_requires_sm_atleast_90_aa
+}
+
+
 gpu.module @disable_verify_target1 [#nvvm.target<chip = "sm_90", verifyTarget = false>] {
   test.nvvm_requires_sm_90a
 }
@@ -25,7 +34,9 @@ gpu.module @disable_verify_target2 [#nvvm.target<chip = "sm_70", verifyTarget =
   test.nvvm_requires_sm_80
 }
 
-
+gpu.module @disable_verify_target3 [#nvvm.target<chip = "sm_90", verifyTarget = false>] {
+  test.nvvm_requires_sm_atleast_90_aa
+}
 
 // -----
 
@@ -43,14 +54,28 @@ gpu.module @check_invalid_SM_lesser_2 [#nvvm.target<chip = "sm_75">] {
 
 // -----
 
-gpu.module @check_invalid_SM_arch_acc_1 [#nvvm.target<chip = "sm_90">] {
+gpu.module @check_invalid_SM_arch_acc_exact_1 [#nvvm.target<chip = "sm_90">] {
   // expected-error @below {{is not supported on sm_90}}
   test.nvvm_requires_sm_90a
 }
 
 // -----
 
-gpu.module @check_invalid_SM_arch_acc_2 [#nvvm.target<chip = "sm_80">] {
+gpu.module @check_invalid_SM_arch_acc_exact_2 [#nvvm.target<chip = "sm_80">] {
   // expected-error @below {{is not supported on sm_80}}
   test.nvvm_requires_sm_90a
 }
+
+// -----
+
+gpu.module @check_invalid_SM_arch_acc_atleast_1 [#nvvm.target<chip = "sm_80">] {
+  // expected-error @below {{is not supported on sm_80}}
+  test.nvvm_requires_sm_atleast_90_aa
+}
+
+// -----
+
+gpu.module @check_invalid_SM_arch_acc_atleast_2 [#nvvm.target<chip = "sm_90">] {
+  // expected-error @below {{is not supported on sm_90}}
+  test.nvvm_requires_sm_atleast_90_aa
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 2ef9b21a16258..cda84dfc89ef3 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2739,8 +2739,14 @@ def TestNVVMRequiresSMOp : TEST_Op<"nvvm_requires_sm_80",
   let assemblyFormat = "attr-dict";
 }
 
-def TestNVVMRequiresSMArchCondOp : TEST_Op<"nvvm_requires_sm_90a",
-                                          [NVVMRequiresSM<90, "true">]> {
+def TestNVVMRequiresAtleastSMArchCondOp : 
+    TEST_Op<"nvvm_requires_sm_atleast_90_aa", [NVVMRequiresSM<90, "true">]> {
+  let arguments = (ins );
+  let assemblyFormat = "attr-dict";
+}
+
+def TestNVVMRequiresExactSMArchCondOp : TEST_Op<"nvvm_requires_sm_90a",
+                                          [NVVMRequiresSM90a]> {
   let arguments = (ins );
   let assemblyFormat = "attr-dict";
 }



More information about the llvm-commits mailing list