[Mlir-commits] [mlir] [NFC][MLIR][NVVM] Restructure NVVM dialect (PR #195811)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 5 02:09:46 PDT 2026
llvmorg-github-actions[bot] wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Srinivasa Ravi (Wolfram70)
<details>
<summary>Changes</summary>
Moves the declarations of the NVVM dialect and some widely used enums to separate files to make them easier to maintain and also use in the NVGPU dialect.
---
Full diff: https://github.com/llvm/llvm-project/pull/195811.diff
3 Files Affected:
- (added) mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.td (+94)
- (added) mlir/include/mlir/Dialect/LLVMIR/NVVMEnums.td (+72)
- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+7-136)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.td
new file mode 100644
index 0000000000000..025e093ebd8b6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.td
@@ -0,0 +1,94 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// This file contains the declaration of the NVVM IR dialect.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef NVVMIR_DIALECT
+#define NVVMIR_DIALECT
+
+include "mlir/IR/DialectBase.td"
+
+def NVVM_Dialect : Dialect {
+ let name = "nvvm";
+ let cppNamespace = "::mlir::NVVM";
+ let dependentDialects = ["LLVM::LLVMDialect"];
+ let hasOperationAttrVerify = 1;
+
+ let extraClassDeclaration = [{
+ /// Get the name of the attribute used to annotate external kernel
+ /// functions.
+ static StringRef getKernelFuncAttrName() { return "nvvm.kernel"; }
+ /// Get the name of the attribute used to annotate max threads required
+ /// per CTA for kernel functions.
+ static StringRef getMaxntidAttrName() { return "nvvm.maxntid"; }
+ /// Get the name of the metadata names for each dimension
+ static StringRef getMaxntidXName() { return "maxntidx"; }
+ static StringRef getMaxntidYName() { return "maxntidy"; }
+ static StringRef getMaxntidZName() { return "maxntidz"; }
+
+ /// Get the name of the attribute used to annotate exact threads required
+ /// per CTA for kernel functions.
+ static StringRef getReqntidAttrName() { return "nvvm.reqntid"; }
+ /// Get the name of the metadata names for each dimension
+ static StringRef getReqntidXName() { return "reqntidx"; }
+ static StringRef getReqntidYName() { return "reqntidy"; }
+ static StringRef getReqntidZName() { return "reqntidz"; }
+
+ /// Get the name of the attribute used to annotate exact CTAs required
+ /// per cluster for kernel functions.
+ static StringRef getClusterDimAttrName() { return "nvvm.cluster_dim"; }
+ /// Get the name of the metadata names for each dimension
+ static StringRef getClusterDimXName() { return "cluster_dim_x"; }
+ static StringRef getClusterDimYName() { return "cluster_dim_y"; }
+ static StringRef getClusterDimZName() { return "cluster_dim_z"; }
+
+ /// Get the name of the attribute used to annotate maximum number of
+ /// CTAs per cluster for kernel functions.
+ static StringRef getClusterMaxBlocksAttrName() { return "nvvm.cluster_max_blocks"; }
+
+ /// Get the name of the attribute used to annotate min CTA required
+ /// per SM for kernel functions.
+ static StringRef getMinctasmAttrName() { return "nvvm.minctasm"; }
+
+ /// Get the name of the attribute used to annotate max number of
+ /// registers that can be allocated per thread.
+ static StringRef getMaxnregAttrName() { return "nvvm.maxnreg"; }
+
+ /// Get the name of the attribute used to annotate kernel arguments that
+ /// are grid constants.
+ static StringRef getGridConstantAttrName() { return "nvvm.grid_constant"; }
+
+ /// Get the name of the attribute used to annotate the `.blocksareclusters`
+ /// PTX directive for kernel functions.
+ /// This attribute implies that the grid launch configuration for the
+ /// corresponding kernel function is specifying the number of clusters
+ /// instead of the number of thread blocks. This attribute is only
+ /// allowed for kernel functions and requires nvvm.reqntid and
+ /// nvvm.cluster_dim attributes.
+ static StringRef getBlocksAreClustersAttrName() { return "nvvm.blocksareclusters"; }
+
+ /// Get the name of the attribute used to annotate managed global variables.
+ static StringRef getManagedAttrName() { return "nvvm.managed"; }
+
+ /// Verify an attribute from this dialect on the argument at 'argIndex' for
+ /// the region at 'regionIndex' on the given operation. Returns failure if
+ /// the verification failed, success otherwise. This hook may optionally be
+ /// invoked from any operation containing a region.
+ LogicalResult verifyRegionArgAttribute(Operation *op,
+ unsigned regionIndex,
+ unsigned argIndex,
+ NamedAttribute argAttr) override;
+ }];
+
+ let useDefaultAttributePrinterParser = 1;
+}
+
+#endif // NVVMIR_DIALECT
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMEnums.td
new file mode 100644
index 0000000000000..42d196c5662d1
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMEnums.td
@@ -0,0 +1,72 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// This file contains the declaration of the NVVM IR enum attributes.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef NVVMIR_ENUMS
+#define NVVMIR_ENUMS
+
+include "mlir/Dialect/LLVMIR/NVVMDialect.td"
+include "mlir/IR/EnumAttr.td"
+
+// Attributes for the floating point rounding modes supported by PTX
+def FPRoundingModeNone : I32EnumAttrCase<"NONE", 0, "none">;
+def FPRoundingModeRN : I32EnumAttrCase<"RN", 1, "rn">;
+def FPRoundingModeRM : I32EnumAttrCase<"RM", 2, "rm">;
+def FPRoundingModeRP : I32EnumAttrCase<"RP", 3, "rp">;
+def FPRoundingModeRZ : I32EnumAttrCase<"RZ", 4, "rz">;
+def FPRoundingModeRNA : I32EnumAttrCase<"RNA", 5, "rna">;
+def FPRoundingModeRS : I32EnumAttrCase<"RS", 6, "rs">;
+
+def FPRoundingMode : I32EnumAttr<"FPRoundingMode", "NVVM FPRoundingMode kind",
+ [FPRoundingModeNone, FPRoundingModeRN, FPRoundingModeRM,
+ FPRoundingModeRP, FPRoundingModeRZ, FPRoundingModeRNA, FPRoundingModeRS]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def FPRoundingModeAttr : EnumAttr<NVVM_Dialect, FPRoundingMode, "fp_rnd_mode"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def SaturationModeNone : I32EnumAttrCase<"NONE", 0, "none">;
+def SaturationModeFinite : I32EnumAttrCase<"SATFINITE", 1, "satfinite">;
+def SaturationModeSat : I32EnumAttrCase<"SAT", 2, "sat">;
+
+def SaturationMode : I32EnumAttr<"SaturationMode", "NVVM SaturationMode kind",
+ [SaturationModeNone, SaturationModeFinite, SaturationModeSat]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def SaturationModeAttr : EnumAttr<NVVM_Dialect, SaturationMode, "sat_mode"> {
+ let summary = "Describes the saturation mode";
+ let description = [{
+ A `nvvm.sat_mode` attribute specifies the saturation mode for instructions
+ involving floating points or integers. It can be one of the following
+ values:
+ - `none`: No saturation is applied.
+ - `satfinite`: If the absolute value of input (ignoring sign) is greater
+ than the `MAX_NORM` of the specified destination format, then the result
+ is the sign-preserved `MAX_NORM` of the destination format and a positive
+ `MAX_NORM` in unsigned datatypes for which the destination sign is not
+ supported. If the input is `NaN`, then the result can be `NaN` or the
+ `MAX_NORM` of the destination format, depending on the format.
+ - `sat`: For integer destination types, this limits the value to `MININT..
+ MAXINT` and applies to both signed and unsigned integer datatypes. For
+ floating point destination types (applies to only `F16`, `F32`, and `F64`
+ types), this limits the value to the range `[0.0, 1.0]` and flushes NaN
+ results to positive zero.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt)
+}];
+ let assemblyFormat = "`<` $value `>`";
+}
+
+#endif // NVVMIR_ENUMS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 51ff22dfdc65c..0d271acd862ba 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -13,17 +13,19 @@
#ifndef NVVMIR_OPS
#define NVVMIR_OPS
-include "mlir/IR/EnumAttr.td"
include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
+include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/Dialect/LLVMIR/LLVMTypes.td"
+include "mlir/Dialect/LLVMIR/NVVMDialect.td"
+include "mlir/Dialect/LLVMIR/NVVMEnums.td"
include "mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td"
include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td"
-include "mlir/Interfaces/SideEffectInterfaces.td"
-include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
+include "mlir/IR/CommonAttrConstraints.td"
+include "mlir/IR/EnumAttr.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
-include "mlir/Dialect/LLVMIR/LLVMTypes.td"
-include "mlir/IR/CommonAttrConstraints.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>;
def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
@@ -33,85 +35,6 @@ def LLVM_PointerLocal : LLVM_PointerInAddressSpace<5>;
def LLVM_PointerTensor : LLVM_PointerInAddressSpace<6>;
def LLVM_PointerSharedCluster : LLVM_PointerInAddressSpace<7>;
-//===----------------------------------------------------------------------===//
-// NVVM dialect definitions
-//===----------------------------------------------------------------------===//
-
-def NVVM_Dialect : Dialect {
- let name = "nvvm";
- let cppNamespace = "::mlir::NVVM";
- let dependentDialects = ["LLVM::LLVMDialect"];
- let hasOperationAttrVerify = 1;
-
- let extraClassDeclaration = [{
- /// Get the name of the attribute used to annotate external kernel
- /// functions.
- static StringRef getKernelFuncAttrName() { return "nvvm.kernel"; }
- /// Get the name of the attribute used to annotate max threads required
- /// per CTA for kernel functions.
- static StringRef getMaxntidAttrName() { return "nvvm.maxntid"; }
- /// Get the name of the metadata names for each dimension
- static StringRef getMaxntidXName() { return "maxntidx"; }
- static StringRef getMaxntidYName() { return "maxntidy"; }
- static StringRef getMaxntidZName() { return "maxntidz"; }
-
- /// Get the name of the attribute used to annotate exact threads required
- /// per CTA for kernel functions.
- static StringRef getReqntidAttrName() { return "nvvm.reqntid"; }
- /// Get the name of the metadata names for each dimension
- static StringRef getReqntidXName() { return "reqntidx"; }
- static StringRef getReqntidYName() { return "reqntidy"; }
- static StringRef getReqntidZName() { return "reqntidz"; }
-
- /// Get the name of the attribute used to annotate exact CTAs required
- /// per cluster for kernel functions.
- static StringRef getClusterDimAttrName() { return "nvvm.cluster_dim"; }
- /// Get the name of the metadata names for each dimension
- static StringRef getClusterDimXName() { return "cluster_dim_x"; }
- static StringRef getClusterDimYName() { return "cluster_dim_y"; }
- static StringRef getClusterDimZName() { return "cluster_dim_z"; }
-
- /// Get the name of the attribute used to annotate maximum number of
- /// CTAs per cluster for kernel functions.
- static StringRef getClusterMaxBlocksAttrName() { return "nvvm.cluster_max_blocks"; }
-
- /// Get the name of the attribute used to annotate min CTA required
- /// per SM for kernel functions.
- static StringRef getMinctasmAttrName() { return "nvvm.minctasm"; }
-
- /// Get the name of the attribute used to annotate max number of
- /// registers that can be allocated per thread.
- static StringRef getMaxnregAttrName() { return "nvvm.maxnreg"; }
-
- /// Get the name of the attribute used to annotate kernel arguments that
- /// are grid constants.
- static StringRef getGridConstantAttrName() { return "nvvm.grid_constant"; }
-
- /// Get the name of the attribute used to annotate the `.blocksareclusters`
- /// PTX directive for kernel functions.
- /// This attribute implies that the grid launch configuration for the
- /// corresponding kernel function is specifying the number of clusters
- /// instead of the number of thread blocks. This attribute is only
- /// allowed for kernel functions and requires nvvm.reqntid and
- /// nvvm.cluster_dim attributes.
- static StringRef getBlocksAreClustersAttrName() { return "nvvm.blocksareclusters"; }
-
- /// Get the name of the attribute used to annotate managed global variables.
- static StringRef getManagedAttrName() { return "nvvm.managed"; }
-
- /// Verify an attribute from this dialect on the argument at 'argIndex' for
- /// the region at 'regionIndex' on the given operation. Returns failure if
- /// the verification failed, success otherwise. This hook may optionally be
- /// invoked from any operation containing a region.
- LogicalResult verifyRegionArgAttribute(Operation *op,
- unsigned regionIndex,
- unsigned argIndex,
- NamedAttribute argAttr) override;
- }];
-
- let useDefaultAttributePrinterParser = 1;
-}
-
//===----------------------------------------------------------------------===//
// NVVM op definitions
//===----------------------------------------------------------------------===//
@@ -1917,58 +1840,6 @@ def NVVM_CpAsyncMBarrierArriveOp : NVVM_VoidIntrinsicOp<"cp.async.mbarrier.arriv
// NVVM Conversion Ops (for "cvt.*" family of PTX instructions)
//===----------------------------------------------------------------------===//
-// Attributes for the floating point rounding modes supported by PTX
-def FPRoundingModeNone : I32EnumAttrCase<"NONE", 0, "none">;
-def FPRoundingModeRN : I32EnumAttrCase<"RN", 1, "rn">;
-def FPRoundingModeRM : I32EnumAttrCase<"RM", 2, "rm">;
-def FPRoundingModeRP : I32EnumAttrCase<"RP", 3, "rp">;
-def FPRoundingModeRZ : I32EnumAttrCase<"RZ", 4, "rz">;
-def FPRoundingModeRNA : I32EnumAttrCase<"RNA", 5, "rna">;
-def FPRoundingModeRS : I32EnumAttrCase<"RS", 6, "rs">;
-
-def FPRoundingMode : I32EnumAttr<"FPRoundingMode", "NVVM FPRoundingMode kind",
- [FPRoundingModeNone, FPRoundingModeRN, FPRoundingModeRM,
- FPRoundingModeRP, FPRoundingModeRZ, FPRoundingModeRNA, FPRoundingModeRS]> {
- let genSpecializedAttr = 0;
- let cppNamespace = "::mlir::NVVM";
-}
-def FPRoundingModeAttr : EnumAttr<NVVM_Dialect, FPRoundingMode, "fp_rnd_mode"> {
- let assemblyFormat = "`<` $value `>`";
-}
-
-def SaturationModeNone : I32EnumAttrCase<"NONE", 0, "none">;
-def SaturationModeFinite : I32EnumAttrCase<"SATFINITE", 1, "satfinite">;
-def SaturationModeSat : I32EnumAttrCase<"SAT", 2, "sat">;
-
-def SaturationMode : I32EnumAttr<"SaturationMode", "NVVM SaturationMode kind",
- [SaturationModeNone, SaturationModeFinite, SaturationModeSat]> {
- let genSpecializedAttr = 0;
- let cppNamespace = "::mlir::NVVM";
-}
-def SaturationModeAttr : EnumAttr<NVVM_Dialect, SaturationMode, "sat_mode"> {
- let summary = "Describes the saturation mode";
- let description = [{
- A `nvvm.sat_mode` attribute specifies the saturation mode for instructions
- involving floating points or integers. It can be one of the following
- values:
- - `none`: No saturation is applied.
- - `satfinite`: If the absolute value of input (ignoring sign) is greater
- than the `MAX_NORM` of the specified destination format, then the result
- is the sign-preserved `MAX_NORM` of the destination format and a positive
- `MAX_NORM` in unsigned datatypes for which the destination sign is not
- supported. If the input is `NaN`, then the result can be `NaN` or th
- `MAX_NORM` of the destination format, depending on the format.
- - `sat`: For integer destination types, this limits the value to `MININT..
- MAXINT` and applies to both signed and unsigned integer datatypes. For
- floating point destination types (applies to only `F16`, `F32`, and `F64`
- types), this limits the value to the range `[0.0, 1.0]` and flushes NaN
- results to positive zero.
-
- [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt)
-}];
- let assemblyFormat = "`<` $value `>`";
-}
-
def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> {
let summary = "Convert the given float input to TF32";
let description = [{
``````````
</details>
https://github.com/llvm/llvm-project/pull/195811
More information about the Mlir-commits
mailing list