[Mlir-commits] [llvm] [mlir] [MLIR][NVVM] Add NVVMRequiresSM op trait (PR #126886)
Srinivasa Ravi
llvmlistbot at llvm.org
Wed May 7 23:38:15 PDT 2025
https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/126886
>From 749bcff7859135ca6f8a2643ec1c9eddf698f504 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Thu, 8 May 2025 10:50:33 +0530
Subject: [PATCH 1/8] [MLIR][NVVM] Add NVVMRequiresSM op trait
This change adds the NVVMRequiresSM op trait to the NVVM dialect to
allow tagging NVVM Ops with a minimum required SM version. When a
target SM is able to be determined (through NVVMTargetAttr), this
allows the verification of SM compatibility with the Op without needing
to unnecessarily lower any further down.
---
.../GPU/IR/CompilationAttrInterfaces.td | 14 +++
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | 2 +
.../mlir/Dialect/LLVMIR/CMakeLists.txt | 6 ++
.../include/mlir/Dialect/LLVMIR/NVVMDialect.h | 2 +
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 28 +++---
mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h | 88 +++++++++++++++++++
.../include/mlir/Dialect/LLVMIR/NVVMTraits.td | 34 +++++++
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 16 ++++
mlir/lib/Dialect/LLVMIR/CMakeLists.txt | 3 +
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 20 +++++
mlir/lib/Dialect/LLVMIR/IR/NVVMTraits.cpp | 15 ++++
.../Dialect/LLVMIR/nvvm-check-targetSM.mlir | 46 ++++++++++
mlir/test/lib/Dialect/Test/CMakeLists.txt | 1 +
mlir/test/lib/Dialect/Test/TestOps.h | 1 +
mlir/test/lib/Dialect/Test/TestOps.td | 17 ++++
.../llvm-project-overlay/mlir/BUILD.bazel | 30 ++++++-
16 files changed, 311 insertions(+), 12 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
create mode 100644 mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.td
create mode 100644 mlir/lib/Dialect/LLVMIR/IR/NVVMTraits.cpp
create mode 100644 mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td b/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td
index 6d5fd01499121..0cb85180f0741 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td
@@ -55,6 +55,20 @@ def GPUTargetAttrInterface : AttrInterface<"TargetAttrInterface"> {
];
}
+def GPUTargetAttrVerifyInterface : AttrInterface<"TargetAttrVerifyInterface"> {
+ let description = [{
+ Interface for GPU target attributes that need to verify the target attribute
+ for the given GPU module.
+ }];
+ let cppNamespace = "::mlir::gpu";
+ let methods = [
+ InterfaceMethod<[{
+ Verifies that the target attribute is valid for the given GPU module.
+ }], "::mlir::LogicalResult", "verifyTarget",
+ (ins "::mlir::Operation *":$module)>
+ ];
+}
+
def GPUTargetAttr :
ConfinedAttr<AnyAttr, [PromisedAttrInterface<GPUTargetAttrInterface>]> {
let description = [{
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 68095b7bf5c59..8d83d02e27c33 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1460,6 +1460,8 @@ def GPU_GPUModuleOp : GPU_Op<"module", [
/// Sets the targets of the module.
void setTargets(ArrayRef<TargetAttrInterface> targets);
}];
+
+ let hasVerifier = 1;
}
def GPU_BinaryOp : GPU_Op<"binary", [Symbol]>, Arguments<(ins
diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index 759de745440c2..0efdcbbf9e469 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -54,6 +54,12 @@ mlir_tablegen(BasicPtxBuilderInterface.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(MLIRBasicPtxBuilderInterfaceIncGen)
add_dependencies(mlir-headers MLIRBasicPtxBuilderInterfaceIncGen)
+set(LLVM_TARGET_DEFINITIONS NVVMTraits.td)
+mlir_tablegen(NVVMTraits.h.inc -gen-op-interface-decls)
+mlir_tablegen(NVVMTraits.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRNVVMTraitsIncGen)
+add_dependencies(mlir-headers MLIRNVVMTraitsIncGen)
+
add_mlir_dialect(NVVMOps nvvm)
add_mlir_doc(NVVMOps NVVMDialect Dialects/ -gen-dialect-doc -dialect=nvvm)
set(LLVM_TARGET_DEFINITIONS NVVMOps.td)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
index f1eae15d6bf18..fe01f29b8dfd0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
@@ -16,7 +16,9 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/NVVMTraits.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 6540273b216e3..75e4a3c68762c 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -16,6 +16,7 @@
include "mlir/IR/EnumAttr.td"
include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/Dialect/LLVMIR/NVVMTraits.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
@@ -138,8 +139,10 @@ class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
let assemblyFormat = "attr-dict `:` type($res)";
}
-class NVVM_SpecialRangeableRegisterOp<string mnemonic> :
- NVVM_SpecialRegisterOp<mnemonic, [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
+class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
+ NVVM_SpecialRegisterOp<mnemonic,
+ !listconcat(traits,
+ [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>])> {
let arguments = (ins OptionalAttr<LLVM_ConstantRangeAttr>:$range);
let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)";
let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda;
@@ -169,14 +172,14 @@ class NVVM_SpecialRangeableRegisterOp<string mnemonic> :
def NVVM_LaneIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.laneid">;
def NVVM_WarpSizeOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpsize">;
def NVVM_WarpIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpid">;
-def NVVM_WarpDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid">;
+def NVVM_WarpDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid", [NVVMRequiresSM<20>]>;
def NVVM_SmIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.smid">;
def NVVM_SmDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nsmid">;
def NVVM_GridIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.gridid">;
//===----------------------------------------------------------------------===//
// Lane Mask Comparison Ops
-def NVVM_LaneMaskEqOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.eq">;
+def NVVM_LaneMaskEqOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.eq", [NVVMRequiresSM<20>]>;
def NVVM_LaneMaskLeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.le">;
def NVVM_LaneMaskLtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.lt">;
def NVVM_LaneMaskGeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.ge">;
@@ -202,7 +205,7 @@ def NVVM_GridDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.z">;
//===----------------------------------------------------------------------===//
// CTA Cluster index and range
-def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x">;
+def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x", [NVVMRequiresSM<90>]>;
def NVVM_ClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.y">;
def NVVM_ClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.z">;
def NVVM_ClusterDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.x">;
@@ -212,7 +215,7 @@ def NVVM_ClusterDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ncluster
//===----------------------------------------------------------------------===//
// CTA index and range within Cluster
-def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x">;
+def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x", [NVVMRequiresSM<90>]>;
def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y">;
def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z">;
def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x">;
@@ -273,7 +276,7 @@ def ReduxKind : I32EnumAttr<"ReduxKind", "NVVM redux kind",
def ReduxKindAttr : EnumAttr<NVVM_Dialect, ReduxKind, "redux_kind">;
def NVVM_ReduxOp :
- NVVM_Op<"redux.sync">,
+ NVVM_Op<"redux.sync", [NVVMRequiresSM<80>]>,
Results<(outs LLVM_Type:$res)>,
Arguments<(ins LLVM_Type:$val,
ReduxKindAttr:$kind,
@@ -2581,7 +2584,8 @@ def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
// NVVM Wgmma Ops
//===----------------------------------------------------------------------===//
-def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> {
+def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned",
+ [NVVMRequiresSM<90, /*ArchAccelerated*/"true">]> {
let arguments = (ins);
let description = [{
Enforce an ordering of register accesses between warpgroup level matrix
@@ -2595,8 +2599,8 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> {
}];
}
-def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned">,
- Arguments<(ins )> {
+def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned",
+ [NVVMRequiresSM<90, /*ArchAccelerated*/"true">]> {
let assemblyFormat = "attr-dict";
let description = [{
Commits all prior uncommitted warpgroup level matrix multiplication operations.
@@ -3448,7 +3452,8 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st"> {
// NVVM target attribute.
//===----------------------------------------------------------------------===//
-def NVVM_TargetAttr : NVVM_Attr<"NVVMTarget", "target"> {
+def NVVM_TargettAttr : NVVM_Attr<"NVVMTarget", "target",
+ [DeclareAttrInterfaceMethods<GPUTargetAttrVerifyInterface>]> {
let description = [{
GPU target attribute for controlling compilation of NVIDIA targets. All
parameters decay into default values if not present.
@@ -3498,6 +3503,7 @@ def NVVM_TargetAttr : NVVM_Attr<"NVVMTarget", "target"> {
bool hasFtz() const;
bool hasCmdOptions() const;
std::optional<mlir::NamedAttribute> getCmdOptions() const;
+ LogicalResult verifyTarget(Operation *gpuModule);
}];
let extraClassDefinition = [{
bool $cppClass::hasFlag(StringRef flag) const {
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
new file mode 100644
index 0000000000000..6174309c46185
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
@@ -0,0 +1,88 @@
+//===--- NVVMTraits.h - NVVM Traits -----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines op traits for the NVVM Dialect in MLIR
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef NVVM_DIALECT_NVVM_IR_NVVMTRAITS_H_
+#define NVVM_DIALECT_NVVM_IR_NVVMTRAITS_H_
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StorageUniquerSupport.h"
+#include "llvm/ADT/StringExtras.h"
+
+namespace mlir {
+
+namespace NVVM {
+
+struct NVVMCheckSMVersion {
+ int archVersion;
+ bool archAccelerated;
+ std::string archString;
+
+ NVVMCheckSMVersion() {}
+ NVVMCheckSMVersion(StringRef SMVersion) : archString(SMVersion) {
+ parse(SMVersion);
+ }
+ NVVMCheckSMVersion(int archVersion, bool archAccelerated)
+ : archVersion(archVersion), archAccelerated(archAccelerated) {
+ archString = (llvm::Twine("sm_") + llvm::Twine(archVersion) +
+ (archAccelerated ? "a" : "\0"))
+ .str();
+ }
+
+ const StringRef getArchString() const { return archString; }
+
+ // Parses the SM version string and sets the archVersion (integer) and
+ // the archAccelerated flag.
+ void parse(StringRef SMVersion) {
+ archAccelerated = (SMVersion.back() == 'a');
+ SMVersion.drop_front(3)
+ .take_while([](char c) { return llvm::isDigit(c); })
+ .getAsInteger(10, archVersion);
+ }
+
+ bool isCompatible(const NVVMCheckSMVersion &TargetSM) const {
+ // for arch-conditional SMs, they should exactly match to be valid
+ if (archAccelerated || TargetSM.archAccelerated)
+ return (*this) == TargetSM;
+
+ return archVersion <= TargetSM.archVersion;
+ }
+
+ bool operator==(const NVVMCheckSMVersion &Other) const {
+ return archVersion == Other.archVersion &&
+ archAccelerated == Other.archAccelerated;
+ }
+};
+} // namespace NVVM
+} // namespace mlir
+
+#include "mlir/Dialect/LLVMIR/NVVMTraits.h.inc"
+
+namespace mlir {
+
+namespace OpTrait {
+
+template <int Version, bool ArchAccelerated = false>
+class NVVMRequiresSM {
+public:
+ template <typename ConcreteOp>
+ class Impl : public OpTrait::TraitBase<
+ ConcreteOp, NVVMRequiresSM<Version, ArchAccelerated>::Impl>,
+ public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
+ public:
+ const NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
+ return NVVM::NVVMCheckSMVersion(Version, ArchAccelerated);
+ }
+ };
+};
+} // namespace OpTrait
+} // namespace mlir
+#endif // NVVM_DIALECT_NVVM_IR_NVVMTRAITS_H_
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.td
new file mode 100644
index 0000000000000..e8ec74eef2cb3
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.td
@@ -0,0 +1,34 @@
+//===-- NVVMTraits.td - NVVM Traits ------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines traits for the NVVM Dialect in MLIR
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef NVVM_TRAITS
+#define NVVM_TRAITS
+
+include "mlir/IR/OpBase.td"
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+
+// Interface for NVVM Ops with the NVVMRequiresSM parametric trait
+def RequiresSMInterface: OpInterface<"RequiresSMInterface"> {
+ let cppNamespace = "::mlir::NVVM";
+ let methods = [
+ InterfaceMethod<
+ "Get the SM version required by the op from the trait",
+ "const mlir::NVVM::NVVMCheckSMVersion", "getRequiredMinSMVersion"
+ >
+ ];
+}
+
+class NVVMRequiresSM<int Version, string ArchAccelerated = "false"> :
+ ParamNativeOpTrait<"NVVMRequiresSM",
+ !cast<string>(Version) # "," # ArchAccelerated>;
+
+#endif //NVVM_TRAITS
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 84e3071946f59..39f626b558294 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1790,6 +1790,22 @@ void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
targetsAttr = ArrayAttr::get(getContext(), targetsVector);
}
+LogicalResult GPUModuleOp::verify() {
+ auto targets = getOperation()->getAttrOfType<ArrayAttr>("targets");
+
+ if (!targets)
+ return success();
+
+ for (auto target : targets) {
+ if (auto verifyTargetAttr =
+ llvm::dyn_cast<TargetAttrVerifyInterface>(target)) {
+ if (verifyTargetAttr.verifyTarget(getOperation()).failed())
+ return failure();
+ }
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// GPUBinaryOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index c9a3b97294562..7030f0f5fd6e2 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -42,6 +42,7 @@ add_mlir_dialect_library(MLIRLLVMDialect
add_mlir_dialect_library(MLIRNVVMDialect
IR/NVVMDialect.cpp
IR/BasicPtxBuilderInterface.cpp
+ IR/NVVMTraits.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR
@@ -51,6 +52,7 @@ add_mlir_dialect_library(MLIRNVVMDialect
MLIRNVVMOpsIncGen
MLIRNVVMConversionsIncGen
MLIRBasicPtxBuilderInterfaceIncGen
+ MLIRNVVMTraitsIncGen
intrinsics_gen
LINK_COMPONENTS
@@ -60,6 +62,7 @@ add_mlir_dialect_library(MLIRNVVMDialect
LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMDialect
+ MLIRGPUDialect
MLIRSideEffectInterfaces
MLIRInferIntRangeInterface
)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 3c3731a63e268..75733d3d7afe6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -18,6 +18,7 @@
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -1704,6 +1705,25 @@ NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
+ auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
+ if (!gpuModuleOp)
+ return emitError(gpuModule->getLoc(),
+ "NVVM target attribute must be attached to a GPU module");
+ gpuModuleOp->walk([&](Operation *op) {
+ if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
+ auto requirement = reqOp.getRequiredMinSMVersion();
+ if (!requirement.isCompatible(NVVMCheckSMVersion(getChip()))) {
+ op->emitOpError() << "is not supported on " << getChip();
+ return WalkResult::interrupt();
+ }
+ }
+ return WalkResult::advance();
+ });
+
+ return success();
+}
+
#define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMTraits.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMTraits.cpp
new file mode 100644
index 0000000000000..c774bac2400f3
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMTraits.cpp
@@ -0,0 +1,15 @@
+//===--- NVVMTraits.h - NVVM Traits -----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines op traits for the NVVM Dialect in MLIR
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/NVVMTraits.h"
+
+#include "mlir/Dialect/LLVMIR/NVVMTraits.cpp.inc"
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
new file mode 100644
index 0000000000000..bf5c349a9aa7b
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
@@ -0,0 +1,46 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// Just check these don't emit errors.
+gpu.module @check_valid_SM_exact [#nvvm.target<chip = "sm_80">] {
+ test.nvvm_requires_sm_80
+}
+
+gpu.module @check_valid_SM_greater_1 [#nvvm.target<chip = "sm_86">] {
+ test.nvvm_requires_sm_80
+}
+
+gpu.module @check_valid_SM_greater_2 [#nvvm.target<chip = "sm_90">] {
+ test.nvvm_requires_sm_80
+}
+
+gpu.module @check_valid_SM_arch_acc [#nvvm.target<chip = "sm_90a">] {
+ test.nvvm_requires_sm_90a
+}
+
+// -----
+
+gpu.module @check_invalid_SM_lesser_1 [#nvvm.target<chip = "sm_70">] {
+ // expected-error @below {{is not supported on sm_70}}
+ test.nvvm_requires_sm_80
+}
+
+// -----
+
+gpu.module @check_invalid_SM_lesser_2 [#nvvm.target<chip = "sm_75">] {
+ // expected-error @below {{is not supported on sm_75}}
+ test.nvvm_requires_sm_80
+}
+
+// -----
+
+gpu.module @check_invalid_SM_arch_acc_1 [#nvvm.target<chip = "sm_90">] {
+ // expected-error @below {{is not supported on sm_90}}
+ test.nvvm_requires_sm_90a
+}
+
+// -----
+
+gpu.module @check_invalid_SM_arch_acc_2 [#nvvm.target<chip = "sm_80">] {
+ // expected-error @below {{is not supported on sm_80}}
+ test.nvvm_requires_sm_90a
+}
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index 6e608e4772391..f179b05ca08ad 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -86,6 +86,7 @@ mlir_target_link_libraries(MLIRTestDialect PUBLIC
MLIRLinalgTransforms
MLIRPtrDialect
MLIRLLVMDialect
+ MLIRNVVMDialect
MLIRPass
MLIRPolynomialDialect
MLIRReduce
diff --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h
index f070c3bedd92c..1adb86ecf7b2a 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.h
+++ b/mlir/test/lib/Dialect/Test/TestOps.h
@@ -16,6 +16,7 @@
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/DLTI/Traits.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/NVVMTraits.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Traits.h"
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 3e461999e2730..bdf53a2eae047 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -13,6 +13,7 @@ include "TestDialect.td"
include "TestInterfaces.td"
include "mlir/Dialect/DLTI/DLTIBase.td"
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
+include "mlir/Dialect/LLVMIR/NVVMTraits.td"
include "mlir/IR/EnumAttr.td"
include "mlir/Interfaces/FunctionInterfaces.td"
include "mlir/IR/OpBase.td"
@@ -2806,6 +2807,22 @@ def TestLinalgFillOp :
}];
}
+//===----------------------------------------------------------------------===//
+// Test NVVM RequiresSM trait.
+//===----------------------------------------------------------------------===//
+
+def TestNVVMRequiresSMOp : TEST_Op<"nvvm_requires_sm_80",
+ [NVVMRequiresSM<80>]> {
+ let arguments = (ins );
+ let assemblyFormat = "attr-dict";
+}
+
+def TestNVVMRequiresSMArchCondOp : TEST_Op<"nvvm_requires_sm_90a",
+ [NVVMRequiresSM<90, "true">]> {
+ let arguments = (ins );
+ let assemblyFormat = "attr-dict";
+}
+
//===----------------------------------------------------------------------===//
// Test Ops with Default-Valued String Attributes
//===----------------------------------------------------------------------===//
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index da7b783e98ba8..8c21206a83d08 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -5771,15 +5771,18 @@ cc_library(
name = "NVVMDialect",
srcs = [
"lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp",
+ "lib/Dialect/LLVMIR/IR/NVVMTraits.cpp",
"lib/Dialect/LLVMIR/IR/NVVMDialect.cpp",
],
hdrs = [
"include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h",
+ "include/mlir/Dialect/LLVMIR/NVVMTraits.h",
"include/mlir/Dialect/LLVMIR/NVVMDialect.h",
- ],
+],
includes = ["include"],
deps = [
":BasicPtxBuilderIntGen",
+ ":NVVMTraitsIntGen",
":BytecodeOpInterface",
":ConvertToLLVMInterface",
":DialectUtils",
@@ -5852,12 +5855,20 @@ td_library(
],
)
+td_library(
+ name = "NVVMTraitsIntTdFiles",
+ srcs = ["include/mlir/Dialect/LLVMIR/NVVMTraits.td"],
+ includes = ["include"],
+ deps = [":OpBaseTdFiles"]
+)
+
td_library(
name = "NVVMOpsTdFiles",
srcs = ["include/mlir/Dialect/LLVMIR/NVVMOps.td"],
includes = ["include"],
deps = [
":BasicPtxBuilderIntTdFiles",
+ ":NVVMTraitsIntTdFiles",
":GPUOpsTdFiles",
":LLVMOpsTdFiles",
":OpBaseTdFiles",
@@ -5884,6 +5895,23 @@ gentbl_cc_library(
],
)
+gentbl_cc_library(
+ name = "NVVMTraitsIntGen",
+ tbl_outs = [
+ (
+ ["-gen-op-interface-decls"],
+ "include/mlir/Dialect/LLVMIR/NVVMTraits.h.inc",
+ ),
+ (
+ ["-gen-op-interface-defs"],
+ "include/mlir/Dialect/LLVMIR/NVVMTraits.cpp.inc",
+ ),
+ ],
+ tblgen = ":mlir-tblgen",
+ td_file = "include/mlir/Dialect/LLVMIR/NVVMTraits.td",
+ deps = [":NVVMTraitsIntTdFiles"],
+)
+
gentbl_cc_library(
name = "NVVMOpsIncGen",
tbl_outs = {
>From 7b551160e9db56c72b3704129052ef24f5f4f7b7 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Fri, 14 Mar 2025 16:01:22 +0530
Subject: [PATCH 2/8] [MLIR][NVVM] Add NVVMRequiresSM op trait
This change adds the NVVMRequiresSM op trait to the NVVM dialect to
allow tagging NVVM Ops with a minimum required SM version. When a
target SM is able to be determined (through NVVMTargetAttr), this
allows the verification of SM compatibility with the Op without needing
to unnecessarily lower any further down.
---
.../include/mlir/Dialect/LLVMIR/NVVMDialect.h | 2 +-
mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h | 33 +++++++------------
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 8 +++--
3 files changed, 18 insertions(+), 25 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
index fe01f29b8dfd0..add03efa3bdc9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
@@ -15,8 +15,8 @@
#define MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_
#include "mlir/Bytecode/BytecodeOpInterface.h"
-#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMTraits.h"
#include "mlir/IR/Dialect.h"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
index 6174309c46185..51b868cfe5a7e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
@@ -24,41 +24,32 @@ namespace NVVM {
struct NVVMCheckSMVersion {
int archVersion;
bool archAccelerated;
- std::string archString;
NVVMCheckSMVersion() {}
- NVVMCheckSMVersion(StringRef SMVersion) : archString(SMVersion) {
- parse(SMVersion);
- }
+ NVVMCheckSMVersion(StringRef smVersion) { parse(smVersion); }
NVVMCheckSMVersion(int archVersion, bool archAccelerated)
- : archVersion(archVersion), archAccelerated(archAccelerated) {
- archString = (llvm::Twine("sm_") + llvm::Twine(archVersion) +
- (archAccelerated ? "a" : "\0"))
- .str();
- }
-
- const StringRef getArchString() const { return archString; }
+ : archVersion(archVersion), archAccelerated(archAccelerated) {}
// Parses the SM version string and sets the archVersion (integer) and
// the archAccelerated flag.
- void parse(StringRef SMVersion) {
- archAccelerated = (SMVersion.back() == 'a');
- SMVersion.drop_front(3)
+ void parse(StringRef smVersion) {
+ archAccelerated = (smVersion.back() == 'a');
+ smVersion.drop_front(3)
.take_while([](char c) { return llvm::isDigit(c); })
.getAsInteger(10, archVersion);
}
- bool isCompatible(const NVVMCheckSMVersion &TargetSM) const {
+ bool isCompatible(const NVVMCheckSMVersion &targetSM) const {
// for arch-conditional SMs, they should exactly match to be valid
- if (archAccelerated || TargetSM.archAccelerated)
- return (*this) == TargetSM;
+ if (archAccelerated || targetSM.archAccelerated)
+ return (*this) == targetSM;
- return archVersion <= TargetSM.archVersion;
+ return archVersion <= targetSM.archVersion;
}
- bool operator==(const NVVMCheckSMVersion &Other) const {
- return archVersion == Other.archVersion &&
- archAccelerated == Other.archAccelerated;
+ bool operator==(const NVVMCheckSMVersion &other) const {
+ return archVersion == other.archVersion &&
+ archAccelerated == other.archAccelerated;
}
};
} // namespace NVVM
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 75733d3d7afe6..00446327a2244 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1710,17 +1710,19 @@ LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
if (!gpuModuleOp)
return emitError(gpuModule->getLoc(),
"NVVM target attribute must be attached to a GPU module");
+
+ NVVMCheckSMVersion targetSMVersion(getChip());
gpuModuleOp->walk([&](Operation *op) {
if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
- auto requirement = reqOp.getRequiredMinSMVersion();
- if (!requirement.isCompatible(NVVMCheckSMVersion(getChip()))) {
+ NVVMCheckSMVersion requirement = reqOp.getRequiredMinSMVersion();
+ if (!requirement.isCompatible(targetSMVersion)) {
op->emitOpError() << "is not supported on " << getChip();
return WalkResult::interrupt();
}
}
return WalkResult::advance();
});
-
+
return success();
}
>From 302e92ee75ae7444b82d7f40c8e0ba8d4d311df5 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Fri, 14 Mar 2025 20:41:37 +0530
Subject: [PATCH 3/8] Add verifyTarget parameter to NVVMTargetAttr
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 10 ++++++----
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 5 ++++-
mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir | 10 ++++++++++
3 files changed, 20 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 75e4a3c68762c..d0060b24b7484 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3480,10 +3480,11 @@ def NVVM_TargettAttr : NVVM_Attr<"NVVMTarget", "target",
StringRefParameter<"Target chip.", "\"sm_50\"">:$chip,
StringRefParameter<"Target chip features.", "\"+ptx60\"">:$features,
OptionalParameter<"DictionaryAttr", "Target specific flags.">:$flags,
- OptionalParameter<"ArrayAttr", "Files to link to the LLVM module.">:$link
+ OptionalParameter<"ArrayAttr", "Files to link to the LLVM module.">:$link,
+ DefaultValuedParameter<"bool", "true", "Perform SM version check on Ops.">:$verifyTarget
);
let assemblyFormat = [{
- (`<` struct($O, $triple, $chip, $features, $flags, $link)^ `>`)?
+ (`<` struct($O, $triple, $chip, $features, $flags, $link, $verifyTarget)^ `>`)?
}];
let builders = [
AttrBuilder<(ins CArg<"int", "2">:$optLevel,
@@ -3491,8 +3492,9 @@ def NVVM_TargettAttr : NVVM_Attr<"NVVMTarget", "target",
CArg<"StringRef", "\"sm_50\"">:$chip,
CArg<"StringRef", "\"+ptx60\"">:$features,
CArg<"DictionaryAttr", "nullptr">:$targetFlags,
- CArg<"ArrayAttr", "nullptr">:$linkFiles), [{
- return Base::get($_ctxt, optLevel, triple, chip, features, targetFlags, linkFiles);
+ CArg<"ArrayAttr", "nullptr">:$linkFiles,
+ CArg<"bool", "true">:$verifyTarget), [{
+ return Base::get($_ctxt, optLevel, triple, chip, features, targetFlags, linkFiles, verifyTarget);
}]>
];
let skipDefaultBuilders = 1;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 00446327a2244..9b222eca4d5a1 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1683,7 +1683,7 @@ LogicalResult
NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
int optLevel, StringRef triple, StringRef chip,
StringRef features, DictionaryAttr flags,
- ArrayAttr files) {
+ ArrayAttr files, bool verifyTarget) {
if (optLevel < 0 || optLevel > 3) {
emitError() << "The optimization level must be a number between 0 and 3.";
return failure();
@@ -1706,6 +1706,9 @@ NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
}
LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
+ if (!getVerifyTarget())
+ return success();
+
auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
if (!gpuModuleOp)
return emitError(gpuModule->getLoc(),
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
index bf5c349a9aa7b..97adeb000c55f 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
@@ -17,6 +17,16 @@ gpu.module @check_valid_SM_arch_acc [#nvvm.target<chip = "sm_90a">] {
test.nvvm_requires_sm_90a
}
+gpu.module @disable_verify_target1 [#nvvm.target<chip = "sm_90", verifyTarget = false>] {
+ test.nvvm_requires_sm_90a
+}
+
+gpu.module @disable_verify_target2 [#nvvm.target<chip = "sm_70", verifyTarget = false>] {
+ test.nvvm_requires_sm_80
+}
+
+
+
// -----
gpu.module @check_invalid_SM_lesser_1 [#nvvm.target<chip = "sm_70">] {
>From 3aace32e07bbfd1428dc56fd3bc36f78e5748179 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 15 Apr 2025 14:12:16 +0530
Subject: [PATCH 4/8] address comments
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 52 +++++++++----------
mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h | 36 ++++++++-----
.../include/mlir/Dialect/LLVMIR/NVVMTraits.td | 8 ++-
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 6 ++-
.../Dialect/LLVMIR/nvvm-check-targetSM.mlir | 33 ++++++++++--
mlir/test/lib/Dialect/Test/TestOps.td | 10 +++-
6 files changed, 96 insertions(+), 49 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index d0060b24b7484..5ed79a5d422f2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -216,15 +216,15 @@ def NVVM_ClusterDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ncluster
//===----------------------------------------------------------------------===//
// CTA index and range within Cluster
def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x", [NVVMRequiresSM<90>]>;
-def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y">;
-def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z">;
-def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x">;
-def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.y">;
+def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y", [NVVMRequiresSM<90>]>;
+def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z", [NVVMRequiresSM<90>]>;
+def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x", [NVVMRequiresSM<90>]>;
+def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.y", [NVVMRequiresSM<90>]>;
def NVVM_ClusterDimBlocksZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.z">;
//===----------------------------------------------------------------------===//
// CTA index and across Cluster dimensions
-def NVVM_ClusterId : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank">;
+def NVVM_ClusterId : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank", [NVVMRequiresSM<90>]>;
def NVVM_ClusterDim : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctarank">;
//===----------------------------------------------------------------------===//
@@ -325,7 +325,7 @@ def NVVM_MBarrierInitOp : NVVM_PTXBuilder_Op<"mbarrier.init">,
}
/// mbarrier.init instruction with shared pointer type
-def NVVM_MBarrierInitSharedOp : NVVM_PTXBuilder_Op<"mbarrier.init.shared">,
+def NVVM_MBarrierInitSharedOp : NVVM_PTXBuilder_Op<"mbarrier.init.shared", [NVVMRequiresSM<80>, DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
Arguments<(ins LLVM_PointerShared:$addr, I32:$count, PtxPredicate:$predicate)> {
string llvmBuilder = [{
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init_shared, {$addr, $count});
@@ -547,7 +547,7 @@ def NVVM_ClusterArriveOp : NVVM_Op<"cluster.arrive"> {
let assemblyFormat = "attr-dict";
}
-def NVVM_ClusterArriveRelaxedOp : NVVM_Op<"cluster.arrive.relaxed"> {
+def NVVM_ClusterArriveRelaxedOp : NVVM_Op<"cluster.arrive.relaxed", [NVVMRequiresSM<90>]> {
let arguments = (ins OptionalAttr<UnitAttr>:$aligned);
let summary = "Cluster Barrier Relaxed Arrive Op";
@@ -573,7 +573,7 @@ def NVVM_ClusterArriveRelaxedOp : NVVM_Op<"cluster.arrive.relaxed"> {
let assemblyFormat = "attr-dict";
}
-def NVVM_ClusterWaitOp : NVVM_Op<"cluster.wait"> {
+def NVVM_ClusterWaitOp : NVVM_Op<"cluster.wait", [NVVMRequiresSM<90>]> {
let arguments = (ins OptionalAttr<UnitAttr>:$aligned);
let summary = "Cluster Barrier Wait Op";
@@ -778,7 +778,7 @@ def ShflKind : I32EnumAttr<"ShflKind", "NVVM shuffle kind",
def ShflKindAttr : EnumAttr<NVVM_Dialect, ShflKind, "shfl_kind">;
def NVVM_ShflOp :
- NVVM_Op<"shfl.sync">,
+ NVVM_Op<"shfl.sync", [NVVMRequiresSM<30>]>,
Results<(outs LLVM_Type:$res)>,
Arguments<(ins I32:$thread_mask,
LLVM_Type:$val,
@@ -2117,7 +2117,7 @@ def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
}];
}
-def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group">,
+def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group", [NVVMRequiresSM<90>]>,
Arguments<(ins
ConfinedAttr<I32Attr, [IntMinValue<0>]>:$group,
OptionalAttr<UnitAttr>:$read)> {
@@ -2147,7 +2147,7 @@ def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group">,
def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
- AttrSizedOperandSegments]>,
+ AttrSizedOperandSegments, NVVMRequiresSM<90>]>,
Arguments<(ins LLVM_PointerShared:$dstMem,
LLVM_AnyPointer:$tmaDescriptor,
Variadic<I32>:$coordinates,
@@ -2584,8 +2584,7 @@ def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
// NVVM Wgmma Ops
//===----------------------------------------------------------------------===//
-def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned",
- [NVVMRequiresSM<90, /*ArchAccelerated*/"true">]> {
+def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", [NVVMRequiresSM90a]> {
let arguments = (ins);
let description = [{
Enforce an ordering of register accesses between warpgroup level matrix
@@ -2599,8 +2598,7 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned",
}];
}
-def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned",
- [NVVMRequiresSM<90, /*ArchAccelerated*/"true">]> {
+def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned", [NVVMRequiresSM90a]> {
let assemblyFormat = "attr-dict";
let description = [{
Commits all prior uncommitted warpgroup level matrix multiplication operations.
@@ -2612,7 +2610,7 @@ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned",
}];
}
-def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned">{
+def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned", [NVVMRequiresSM90a]> {
let arguments = (ins I64Attr:$group);
let assemblyFormat = "attr-dict $group";
let description = [{
@@ -2808,7 +2806,7 @@ def NVVM_GriddepcontrolLaunchDependentsOp
def NVVM_MapaOp: NVVM_Op<"mapa",
[TypesMatchWith<"`res` and `a` should have the same type",
- "a", "res", "$_self">]> {
+ "a", "res", "$_self">, NVVMRequiresSM<90>]> {
let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$res);
let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b);
@@ -2975,7 +2973,7 @@ def Tcgen05WaitKindAttr :
let assemblyFormat = "`<` $value `>`";
}
-def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc"> {
+def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc", [NVVMRequiresSM<100, "true", "false">]> {
let summary = "Tcgen05 alloc operation";
let description = [{
The `tcgen05.alloc` Op allocates tensor core memory for
@@ -3005,7 +3003,7 @@ def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc"> {
}];
}
-def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc"> {
+def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc", [NVVMRequiresSM<100, "true", "false">]> {
let summary = "Tcgen05 dealloc operation";
let description = [{
The `tcgen05.dealloc` Op de-allocates the tensor core memory
@@ -3033,7 +3031,7 @@ def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc"> {
}];
}
-def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_permit"> {
+def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_permit", [NVVMRequiresSM<100, "true", "false">]> {
let summary = "Tcgen05 Op to relinquish the right to allocate";
let description = [{
The `tcgen05.relinquish_alloc_permit` Op specifies that the CTA
@@ -3056,7 +3054,7 @@ def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_perm
}];
}
-def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence"> {
+def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence", [NVVMRequiresSM<100, "true", "false">]> {
let summary = "Tcgen05 fence operations";
let description = [{
The `tcgen05.fence<before>` orders all prior async tcgen05 operations
@@ -3078,7 +3076,7 @@ def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence"> {
}];
}
-def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait"> {
+def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait", [NVVMRequiresSM<100, "true", "false">]> {
let summary = "Tcgen05 wait operations";
let description = [{
The `tcgen05.wait<load>` causes the executing thread to block until
@@ -3100,7 +3098,7 @@ def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait"> {
}];
}
-def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit"> {
+def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit", [NVVMRequiresSM<100, "true", "false">]> {
let summary = "Tcgen05 commit operations";
let description = [{
The `tcgen05.commit` makes the mbarrier object, specified by
@@ -3138,7 +3136,7 @@ def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit"> {
}];
}
-def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift"> {
+def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift", [NVVMRequiresSM<100, "true", "false">]> {
let summary = "Tcgen05 shift operation";
let description = [{
The `tcgen05.shift` is an asynchronous instruction which initiates
@@ -3204,7 +3202,7 @@ def Tcgen05CpSrcFormatAttr : EnumAttr<NVVM_Dialect, Tcgen05CpSrcFormat, "tcgen05
let assemblyFormat = "`<` $value `>`";
}
-def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp"> {
+def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp", [NVVMRequiresSM<100, "true", "false">]> {
let summary = "Tcgen05 copy operation";
let description = [{
Instruction tcgen05.cp initiates an asynchronous copy operation from
@@ -3274,7 +3272,7 @@ def Tcgen05LdStShapeAttr: EnumAttr<NVVM_Dialect, Tcgen05LdStShape, "tcgen05_ldst
// NVVM tcgen05.ld Op
//===----------------------------------------------------------------------===//
-def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld"> {
+def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld", [NVVMRequiresSM<100, "true", "false">]> {
let summary = "tensor memory load instructions";
let arguments = (ins
// Attributes
@@ -3364,7 +3362,7 @@ def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld"> {
// NVVM tcgen05.st Op
//===----------------------------------------------------------------------===//
-def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st"> {
+def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSM<100, "true", "false">]> {
let summary = "tensor memory store instructions";
let arguments = (ins
// Attributes
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
index 51b868cfe5a7e..04b408846e820 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
@@ -21,14 +21,21 @@ namespace mlir {
namespace NVVM {
+// Structure to store and check compatibility of SM versions.
struct NVVMCheckSMVersion {
int archVersion;
bool archAccelerated;
+ bool exactMatch;
- NVVMCheckSMVersion() {}
- NVVMCheckSMVersion(StringRef smVersion) { parse(smVersion); }
- NVVMCheckSMVersion(int archVersion, bool archAccelerated)
- : archVersion(archVersion), archAccelerated(archAccelerated) {}
+ NVVMCheckSMVersion()
+ : archVersion(0), archAccelerated(false), exactMatch(false) {}
+ NVVMCheckSMVersion(StringRef smVersion, bool exactMatch = false)
+ : exactMatch(exactMatch) {
+ parse(smVersion);
+ }
+ NVVMCheckSMVersion(int archVersion, bool archAccelerated, bool exactMatch)
+ : archVersion(archVersion), archAccelerated(archAccelerated),
+ exactMatch(exactMatch) {}
// Parses the SM version string and sets the archVersion (integer) and
// the archAccelerated flag.
@@ -40,11 +47,12 @@ struct NVVMCheckSMVersion {
}
bool isCompatible(const NVVMCheckSMVersion &targetSM) const {
- // for arch-conditional SMs, they should exactly match to be valid
- if (archAccelerated || targetSM.archAccelerated)
+ if (exactMatch)
return (*this) == targetSM;
-
- return archVersion <= targetSM.archVersion;
+
+ return archAccelerated ?
+ archVersion <= targetSM.archVersion && targetSM.archAccelerated :
+ archVersion <= targetSM.archVersion;
}
bool operator==(const NVVMCheckSMVersion &other) const {
@@ -61,16 +69,18 @@ namespace mlir {
namespace OpTrait {
-template <int Version, bool ArchAccelerated = false>
+template <int MinVersion, bool ArchAccelerated = false, bool ExactMatch = false>
class NVVMRequiresSM {
public:
template <typename ConcreteOp>
- class Impl : public OpTrait::TraitBase<
- ConcreteOp, NVVMRequiresSM<Version, ArchAccelerated>::Impl>,
- public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
+ class Impl
+ : public OpTrait::TraitBase<
+ ConcreteOp,
+ NVVMRequiresSM<MinVersion, ArchAccelerated, ExactMatch>::Impl>,
+ public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
public:
const NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
- return NVVM::NVVMCheckSMVersion(Version, ArchAccelerated);
+ return NVVM::NVVMCheckSMVersion(MinVersion, ArchAccelerated, ExactMatch);
}
};
};
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.td
index e8ec74eef2cb3..5e013b13dc26b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.td
@@ -27,8 +27,12 @@ def RequiresSMInterface: OpInterface<"RequiresSMInterface"> {
];
}
-class NVVMRequiresSM<int Version, string ArchAccelerated = "false"> :
+class NVVMRequiresSM<int minVersion, string isArchAccelerated = "false",
+ string exactMatch = "false"> :
ParamNativeOpTrait<"NVVMRequiresSM",
- !cast<string>(Version) # "," # ArchAccelerated>;
+ !cast<string>(minVersion) # "," # isArchAccelerated # ","
+ # exactMatch>;
+
+def NVVMRequiresSM90a : NVVMRequiresSM<90, "true", "true">;
#endif //NVVM_TRAITS
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 9b222eca4d5a1..d814ba2ae96b4 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1713,8 +1713,12 @@ LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
if (!gpuModuleOp)
return emitError(gpuModule->getLoc(),
"NVVM target attribute must be attached to a GPU module");
-
+
NVVMCheckSMVersion targetSMVersion(getChip());
+ if (targetSMVersion.archVersion < 20)
+ return emitError(gpuModule->getLoc(),
+ "Minimum NVVM target SM version is sm_20");
+
gpuModuleOp->walk([&](Operation *op) {
if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
NVVMCheckSMVersion requirement = reqOp.getRequiredMinSMVersion();
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
index 97adeb000c55f..13a21d1e2156f 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
@@ -13,10 +13,19 @@ gpu.module @check_valid_SM_greater_2 [#nvvm.target<chip = "sm_90">] {
test.nvvm_requires_sm_80
}
-gpu.module @check_valid_SM_arch_acc [#nvvm.target<chip = "sm_90a">] {
+gpu.module @check_valid_SM_arch_acc_exact_1 [#nvvm.target<chip = "sm_90a">] {
test.nvvm_requires_sm_90a
}
+gpu.module @check_valid_SM_arch_acc_atleast_1 [#nvvm.target<chip = "sm_90a">] {
+ test.nvvm_requires_sm_atleast_90_aa
+}
+
+gpu.module @check_valid_SM_arch_acc_atleast_2 [#nvvm.target<chip = "sm_100a">] {
+ test.nvvm_requires_sm_atleast_90_aa
+}
+
+
gpu.module @disable_verify_target1 [#nvvm.target<chip = "sm_90", verifyTarget = false>] {
test.nvvm_requires_sm_90a
}
@@ -25,7 +34,9 @@ gpu.module @disable_verify_target2 [#nvvm.target<chip = "sm_70", verifyTarget =
test.nvvm_requires_sm_80
}
-
+gpu.module @disable_verify_target3 [#nvvm.target<chip = "sm_90", verifyTarget = false>] {
+ test.nvvm_requires_sm_atleast_90_aa
+}
// -----
@@ -43,14 +54,28 @@ gpu.module @check_invalid_SM_lesser_2 [#nvvm.target<chip = "sm_75">] {
// -----
-gpu.module @check_invalid_SM_arch_acc_1 [#nvvm.target<chip = "sm_90">] {
+gpu.module @check_invalid_SM_arch_acc_exact_1 [#nvvm.target<chip = "sm_90">] {
// expected-error @below {{is not supported on sm_90}}
test.nvvm_requires_sm_90a
}
// -----
-gpu.module @check_invalid_SM_arch_acc_2 [#nvvm.target<chip = "sm_80">] {
+gpu.module @check_invalid_SM_arch_acc_exact_2 [#nvvm.target<chip = "sm_80">] {
// expected-error @below {{is not supported on sm_80}}
test.nvvm_requires_sm_90a
}
+
+// -----
+
+gpu.module @check_invalid_SM_arch_acc_atleast_1 [#nvvm.target<chip = "sm_80">] {
+ // expected-error @below {{is not supported on sm_80}}
+ test.nvvm_requires_sm_atleast_90_aa
+}
+
+// -----
+
+gpu.module @check_invalid_SM_arch_acc_atleast_2 [#nvvm.target<chip = "sm_90">] {
+ // expected-error @below {{is not supported on sm_90}}
+ test.nvvm_requires_sm_atleast_90_aa
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index bdf53a2eae047..19e5e6b179de1 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2817,8 +2817,14 @@ def TestNVVMRequiresSMOp : TEST_Op<"nvvm_requires_sm_80",
let assemblyFormat = "attr-dict";
}
-def TestNVVMRequiresSMArchCondOp : TEST_Op<"nvvm_requires_sm_90a",
- [NVVMRequiresSM<90, "true">]> {
+def TestNVVMRequiresAtleastSMArchCondOp :
+ TEST_Op<"nvvm_requires_sm_atleast_90_aa", [NVVMRequiresSM<90, "true">]> {
+ let arguments = (ins );
+ let assemblyFormat = "attr-dict";
+}
+
+def TestNVVMRequiresExactSMArchCondOp : TEST_Op<"nvvm_requires_sm_90a",
+ [NVVMRequiresSM90a]> {
let arguments = (ins );
let assemblyFormat = "attr-dict";
}
>From afc3466e992c5f12c8af185d439d2d5edbdc7430 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 15 Apr 2025 14:15:56 +0530
Subject: [PATCH 5/8] address comments
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 5ed79a5d422f2..894e9f55a1f4c 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -172,14 +172,14 @@ class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []>
def NVVM_LaneIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.laneid">;
def NVVM_WarpSizeOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpsize">;
def NVVM_WarpIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpid">;
-def NVVM_WarpDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid", [NVVMRequiresSM<20>]>;
+def NVVM_WarpDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid">;
def NVVM_SmIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.smid">;
def NVVM_SmDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nsmid">;
def NVVM_GridIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.gridid">;
//===----------------------------------------------------------------------===//
// Lane Mask Comparison Ops
-def NVVM_LaneMaskEqOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.eq", [NVVMRequiresSM<20>]>;
+def NVVM_LaneMaskEqOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.eq">;
def NVVM_LaneMaskLeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.le">;
def NVVM_LaneMaskLtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.lt">;
def NVVM_LaneMaskGeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.ge">;
>From 56012877d5bffea6f20c558c5509daba9c53988d Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 15 Apr 2025 14:25:34 +0530
Subject: [PATCH 6/8] fix formatting
---
mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h | 8 ++++----
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 2 +-
2 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
index 04b408846e820..867d4d527e905 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
@@ -49,10 +49,10 @@ struct NVVMCheckSMVersion {
bool isCompatible(const NVVMCheckSMVersion &targetSM) const {
if (exactMatch)
return (*this) == targetSM;
-
- return archAccelerated ?
- archVersion <= targetSM.archVersion && targetSM.archAccelerated :
- archVersion <= targetSM.archVersion;
+
+ return archAccelerated
+ ? archVersion <= targetSM.archVersion && targetSM.archAccelerated
+ : archVersion <= targetSM.archVersion;
}
bool operator==(const NVVMCheckSMVersion &other) const {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index d814ba2ae96b4..33e0f71a93725 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1713,7 +1713,7 @@ LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
if (!gpuModuleOp)
return emitError(gpuModule->getLoc(),
"NVVM target attribute must be attached to a GPU module");
-
+
NVVMCheckSMVersion targetSMVersion(getChip());
if (targetSMVersion.archVersion < 20)
return emitError(gpuModule->getLoc(),
>From ad5110e8cd6da3a76868759c30cfd7f929e2b0ac Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Thu, 8 May 2025 10:46:41 +0530
Subject: [PATCH 7/8] address comments
---
mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h | 9 ++++++++-
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 6 ++++--
2 files changed, 12 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
index 867d4d527e905..1717bb952c93b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
@@ -24,8 +24,15 @@ namespace NVVM {
// Structure to store and check compatibility of SM versions.
struct NVVMCheckSMVersion {
int archVersion;
+
+ // true if the SM version is accelerated (Ex. sm_90a vs sm_90)
bool archAccelerated;
- bool exactMatch;
+
+ // true if the target SM version must exactly match this one
+ // (both archVersion and archAccelerated)
+ // Ex. sm_90a with exactMatch = false will also match with
+ // sm_100a, sm_120a, etc...
+ bool exactMatch;
NVVMCheckSMVersion()
: archVersion(0), archAccelerated(false), exactMatch(false) {}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 33e0f71a93725..3a77456054c83 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1710,14 +1710,16 @@ LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
return success();
auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
- if (!gpuModuleOp)
+ if (!gpuModuleOp) {
return emitError(gpuModule->getLoc(),
"NVVM target attribute must be attached to a GPU module");
+ }
NVVMCheckSMVersion targetSMVersion(getChip());
- if (targetSMVersion.archVersion < 20)
+ if (targetSMVersion.archVersion < 20) {
return emitError(gpuModule->getLoc(),
"Minimum NVVM target SM version is sm_20");
+ }
gpuModuleOp->walk([&](Operation *op) {
if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
>From 3be83ece30b5e097541a928ffe251c8e34668f6c Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Thu, 8 May 2025 12:07:34 +0530
Subject: [PATCH 8/8] fix formatting
---
mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
index 1717bb952c93b..3773b1c85ebed 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
@@ -24,15 +24,15 @@ namespace NVVM {
// Structure to store and check compatibility of SM versions.
struct NVVMCheckSMVersion {
int archVersion;
-
+
// true if the SM version is accelerated (Ex. sm_90a vs sm_90)
bool archAccelerated;
-
+
// true if the target SM version must exactly match this one
// (both archVersion and archAccelerated)
// Ex. sm_90a with exactMatch = false will also match with
// sm_100a, sm_120a, etc...
- bool exactMatch;
+ bool exactMatch;
NVVMCheckSMVersion()
: archVersion(0), archAccelerated(false), exactMatch(false) {}
More information about the Mlir-commits
mailing list