[Mlir-commits] [mlir] [NFC][MLIR][NVVM] Restructure NVVM dialect (PR #195811)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 5 02:09:46 PDT 2026


llvmorg-github-actions[bot] wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Srinivasa Ravi (Wolfram70)

<details>
<summary>Changes</summary>

Moves the declarations of the NVVM dialect and some widely used enums to separate files to make them easier to maintain and also use in the NVGPU dialect.

---
Full diff: https://github.com/llvm/llvm-project/pull/195811.diff


3 Files Affected:

- (added) mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.td (+94) 
- (added) mlir/include/mlir/Dialect/LLVMIR/NVVMEnums.td (+72) 
- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+7-136) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.td
new file mode 100644
index 0000000000000..025e093ebd8b6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.td
@@ -0,0 +1,94 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// This file contains the declaration of the NVVM IR dialect.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef NVVMIR_DIALECT
+#define NVVMIR_DIALECT
+
+include "mlir/IR/DialectBase.td"
+
+def NVVM_Dialect : Dialect {
+  let name = "nvvm";
+  let cppNamespace = "::mlir::NVVM";
+  let dependentDialects = ["LLVM::LLVMDialect"];
+  let hasOperationAttrVerify = 1;
+
+  let extraClassDeclaration = [{
+    /// Get the name of the attribute used to annotate external kernel
+    /// functions.
+    static StringRef getKernelFuncAttrName() { return "nvvm.kernel"; }
+    /// Get the name of the attribute used to annotate max threads required
+    /// per CTA for kernel functions.
+    static StringRef getMaxntidAttrName() { return "nvvm.maxntid"; }
+    /// Get the name of the metadata names for each dimension
+    static StringRef getMaxntidXName() { return "maxntidx"; }
+    static StringRef getMaxntidYName() { return "maxntidy"; }
+    static StringRef getMaxntidZName() { return "maxntidz"; }
+
+    /// Get the name of the attribute used to annotate exact threads required
+    /// per CTA for kernel functions.
+    static StringRef getReqntidAttrName() { return "nvvm.reqntid"; }
+    /// Get the name of the metadata names for each dimension
+    static StringRef getReqntidXName() { return "reqntidx"; }
+    static StringRef getReqntidYName() { return "reqntidy"; }
+    static StringRef getReqntidZName() { return "reqntidz"; }
+
+    /// Get the name of the attribute used to annotate exact CTAs required
+    /// per cluster for kernel functions.
+    static StringRef getClusterDimAttrName() { return "nvvm.cluster_dim"; }
+    /// Get the name of the metadata names for each dimension
+    static StringRef getClusterDimXName() { return "cluster_dim_x"; }
+    static StringRef getClusterDimYName() { return "cluster_dim_y"; }
+    static StringRef getClusterDimZName() { return "cluster_dim_z"; }
+
+    /// Get the name of the attribute used to annotate maximum number of
+    /// CTAs per cluster for kernel functions.
+    static StringRef getClusterMaxBlocksAttrName() {  return "nvvm.cluster_max_blocks"; }
+
+    /// Get the name of the attribute used to annotate min CTA required
+    /// per SM for kernel functions.
+    static StringRef getMinctasmAttrName() { return "nvvm.minctasm"; }
+
+    /// Get the name of the attribute used to annotate max number of
+    /// registers that can be allocated per thread.
+    static StringRef getMaxnregAttrName() { return "nvvm.maxnreg"; }
+
+    /// Get the name of the attribute used to annotate kernel arguments that
+    /// are grid constants.
+    static StringRef getGridConstantAttrName() { return "nvvm.grid_constant"; }
+
+    /// Get the name of the attribute used to annotate the `.blocksareclusters`
+    /// PTX directive for kernel functions.
+    /// This attribute implies that the grid launch configuration for the
+    /// corresponding kernel function is specifying the number of clusters
+    /// instead of the number of thread blocks. This attribute is only
+    /// allowed for kernel functions and requires nvvm.reqntid and
+    /// nvvm.cluster_dim attributes.
+    static StringRef getBlocksAreClustersAttrName() { return "nvvm.blocksareclusters"; }
+
+    /// Get the name of the attribute used to annotate managed global variables.
+    static StringRef getManagedAttrName() { return "nvvm.managed"; }
+
+    /// Verify an attribute from this dialect on the argument at 'argIndex' for
+    /// the region at 'regionIndex' on the given operation. Returns failure if
+    /// the verification failed, success otherwise. This hook may optionally be
+    /// invoked from any operation containing a region.
+    LogicalResult verifyRegionArgAttribute(Operation *op,
+                                           unsigned regionIndex,
+                                           unsigned argIndex,
+                                           NamedAttribute argAttr) override;
+  }];
+
+  let useDefaultAttributePrinterParser = 1;
+}
+
+#endif // NVVMIR_DIALECT
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMEnums.td
new file mode 100644
index 0000000000000..42d196c5662d1
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMEnums.td
@@ -0,0 +1,72 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// This file contains the declaration of the NVVM IR enum attributes.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef NVVMIR_ENUMS
+#define NVVMIR_ENUMS
+
+include "mlir/Dialect/LLVMIR/NVVMDialect.td"
+include "mlir/IR/EnumAttr.td"
+
+// Attributes for the floating point rounding modes supported by PTX
+def FPRoundingModeNone : I32EnumAttrCase<"NONE", 0, "none">;
+def FPRoundingModeRN   : I32EnumAttrCase<"RN",   1, "rn">;
+def FPRoundingModeRM   : I32EnumAttrCase<"RM",   2, "rm">;
+def FPRoundingModeRP   : I32EnumAttrCase<"RP",   3, "rp">;
+def FPRoundingModeRZ   : I32EnumAttrCase<"RZ",   4, "rz">;
+def FPRoundingModeRNA  : I32EnumAttrCase<"RNA",  5, "rna">;
+def FPRoundingModeRS   : I32EnumAttrCase<"RS",   6, "rs">;
+
+def FPRoundingMode : I32EnumAttr<"FPRoundingMode", "NVVM FPRoundingMode kind",
+  [FPRoundingModeNone, FPRoundingModeRN, FPRoundingModeRM,
+    FPRoundingModeRP, FPRoundingModeRZ, FPRoundingModeRNA, FPRoundingModeRS]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def FPRoundingModeAttr : EnumAttr<NVVM_Dialect, FPRoundingMode, "fp_rnd_mode"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def SaturationModeNone   : I32EnumAttrCase<"NONE", 0, "none">;
+def SaturationModeFinite : I32EnumAttrCase<"SATFINITE", 1, "satfinite">;
+def SaturationModeSat    : I32EnumAttrCase<"SAT", 2, "sat">;
+
+def SaturationMode : I32EnumAttr<"SaturationMode", "NVVM SaturationMode kind",
+  [SaturationModeNone, SaturationModeFinite, SaturationModeSat]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def SaturationModeAttr : EnumAttr<NVVM_Dialect, SaturationMode, "sat_mode"> {
+  let summary = "Describes the saturation mode";
+  let description = [{
+    A `nvvm.sat_mode` attribute specifies the saturation mode for instructions
+    involving floating points or integers. It can be one of the following
+    values:
+    - `none`: No saturation is applied.
+    - `satfinite`: If the absolute value of input (ignoring sign) is greater
+      than the `MAX_NORM` of the specified destination format, then the result
+      is the sign-preserved `MAX_NORM` of the destination format and a positive
+      `MAX_NORM` in unsigned datatypes for which the destination sign is not
+      supported. If the input is `NaN`, then the result can be `NaN` or the
+      `MAX_NORM` of the destination format, depending on the format.
+    - `sat`: For integer destination types, this limits the value to `MININT..
+      MAXINT` and applies to both signed and unsigned integer datatypes. For
+      floating point destination types (applies to only `F16`, `F32`, and `F64`
+      types), this limits the value to the range `[0.0, 1.0]` and flushes NaN
+      results to positive zero.
+
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt)
+}];
+ let assemblyFormat = "`<` $value `>`";
+}
+
+#endif // NVVMIR_ENUMS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 51ff22dfdc65c..0d271acd862ba 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -13,17 +13,19 @@
 #ifndef NVVMIR_OPS
 #define NVVMIR_OPS
 
-include "mlir/IR/EnumAttr.td"
 include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
+include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/Dialect/LLVMIR/LLVMTypes.td"
+include "mlir/Dialect/LLVMIR/NVVMDialect.td"
+include "mlir/Dialect/LLVMIR/NVVMEnums.td"
 include "mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td"
 include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td"
-include "mlir/Interfaces/SideEffectInterfaces.td"
-include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
+include "mlir/IR/CommonAttrConstraints.td"
+include "mlir/IR/EnumAttr.td"
 include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
-include "mlir/Dialect/LLVMIR/LLVMTypes.td"
-include "mlir/IR/CommonAttrConstraints.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
 
 def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>;
 def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
@@ -33,85 +35,6 @@ def LLVM_PointerLocal : LLVM_PointerInAddressSpace<5>;
 def LLVM_PointerTensor : LLVM_PointerInAddressSpace<6>;
 def LLVM_PointerSharedCluster : LLVM_PointerInAddressSpace<7>;
 
-//===----------------------------------------------------------------------===//
-// NVVM dialect definitions
-//===----------------------------------------------------------------------===//
-
-def NVVM_Dialect : Dialect {
-  let name = "nvvm";
-  let cppNamespace = "::mlir::NVVM";
-  let dependentDialects = ["LLVM::LLVMDialect"];
-  let hasOperationAttrVerify = 1;
-
-  let extraClassDeclaration = [{
-    /// Get the name of the attribute used to annotate external kernel
-    /// functions.
-    static StringRef getKernelFuncAttrName() { return "nvvm.kernel"; }
-    /// Get the name of the attribute used to annotate max threads required
-    /// per CTA for kernel functions.
-    static StringRef getMaxntidAttrName() { return "nvvm.maxntid"; }
-    /// Get the name of the metadata names for each dimension
-    static StringRef getMaxntidXName() { return "maxntidx"; }
-    static StringRef getMaxntidYName() { return "maxntidy"; }
-    static StringRef getMaxntidZName() { return "maxntidz"; }
-
-    /// Get the name of the attribute used to annotate exact threads required
-    /// per CTA for kernel functions.
-    static StringRef getReqntidAttrName() { return "nvvm.reqntid"; }
-    /// Get the name of the metadata names for each dimension
-    static StringRef getReqntidXName() { return "reqntidx"; }
-    static StringRef getReqntidYName() { return "reqntidy"; }
-    static StringRef getReqntidZName() { return "reqntidz"; }
-
-    /// Get the name of the attribute used to annotate exact CTAs required
-    /// per cluster for kernel functions.
-    static StringRef getClusterDimAttrName() { return "nvvm.cluster_dim"; }
-    /// Get the name of the metadata names for each dimension
-    static StringRef getClusterDimXName() { return "cluster_dim_x"; }
-    static StringRef getClusterDimYName() { return "cluster_dim_y"; }
-    static StringRef getClusterDimZName() { return "cluster_dim_z"; }
-
-    /// Get the name of the attribute used to annotate maximum number of
-    /// CTAs per cluster for kernel functions.
-    static StringRef getClusterMaxBlocksAttrName() {  return "nvvm.cluster_max_blocks"; }
-
-    /// Get the name of the attribute used to annotate min CTA required
-    /// per SM for kernel functions.
-    static StringRef getMinctasmAttrName() { return "nvvm.minctasm"; }
-
-    /// Get the name of the attribute used to annotate max number of
-    /// registers that can be allocated per thread.
-    static StringRef getMaxnregAttrName() { return "nvvm.maxnreg"; }
-
-    /// Get the name of the attribute used to annotate kernel arguments that
-    /// are grid constants.
-    static StringRef getGridConstantAttrName() { return "nvvm.grid_constant"; }
-
-    /// Get the name of the attribute used to annotate the `.blocksareclusters`
-    /// PTX directive for kernel functions.
-    /// This attribute implies that the grid launch configuration for the
-    /// corresponding kernel function is specifying the number of clusters
-    /// instead of the number of thread blocks. This attribute is only
-    /// allowed for kernel functions and requires nvvm.reqntid and
-    /// nvvm.cluster_dim attributes.
-    static StringRef getBlocksAreClustersAttrName() { return "nvvm.blocksareclusters"; }
-
-    /// Get the name of the attribute used to annotate managed global variables.
-    static StringRef getManagedAttrName() { return "nvvm.managed"; }
-
-    /// Verify an attribute from this dialect on the argument at 'argIndex' for
-    /// the region at 'regionIndex' on the given operation. Returns failure if
-    /// the verification failed, success otherwise. This hook may optionally be
-    /// invoked from any operation containing a region.
-    LogicalResult verifyRegionArgAttribute(Operation *op,
-                                           unsigned regionIndex,
-                                           unsigned argIndex,
-                                           NamedAttribute argAttr) override;
-  }];
-
-  let useDefaultAttributePrinterParser = 1;
-}
-
 //===----------------------------------------------------------------------===//
 // NVVM op definitions
 //===----------------------------------------------------------------------===//
@@ -1917,58 +1840,6 @@ def NVVM_CpAsyncMBarrierArriveOp : NVVM_VoidIntrinsicOp<"cp.async.mbarrier.arriv
 // NVVM Conversion Ops (for "cvt.*" family of PTX instructions)
 //===----------------------------------------------------------------------===//
 
-// Attributes for the floating point rounding modes supported by PTX
-def FPRoundingModeNone : I32EnumAttrCase<"NONE", 0, "none">;
-def FPRoundingModeRN   : I32EnumAttrCase<"RN",   1, "rn">;
-def FPRoundingModeRM   : I32EnumAttrCase<"RM",   2, "rm">;
-def FPRoundingModeRP   : I32EnumAttrCase<"RP",   3, "rp">;
-def FPRoundingModeRZ   : I32EnumAttrCase<"RZ",   4, "rz">;
-def FPRoundingModeRNA  : I32EnumAttrCase<"RNA",  5, "rna">;
-def FPRoundingModeRS   : I32EnumAttrCase<"RS",   6, "rs">;
-
-def FPRoundingMode : I32EnumAttr<"FPRoundingMode", "NVVM FPRoundingMode kind",
-  [FPRoundingModeNone, FPRoundingModeRN, FPRoundingModeRM,
-    FPRoundingModeRP, FPRoundingModeRZ, FPRoundingModeRNA, FPRoundingModeRS]> {
-  let genSpecializedAttr = 0;
-  let cppNamespace = "::mlir::NVVM";
-}
-def FPRoundingModeAttr : EnumAttr<NVVM_Dialect, FPRoundingMode, "fp_rnd_mode"> {
-  let assemblyFormat = "`<` $value `>`";
-}
-
-def SaturationModeNone   : I32EnumAttrCase<"NONE", 0, "none">;
-def SaturationModeFinite : I32EnumAttrCase<"SATFINITE", 1, "satfinite">;
-def SaturationModeSat    : I32EnumAttrCase<"SAT", 2, "sat">;
-
-def SaturationMode : I32EnumAttr<"SaturationMode", "NVVM SaturationMode kind",
-  [SaturationModeNone, SaturationModeFinite, SaturationModeSat]> {
-  let genSpecializedAttr = 0;
-  let cppNamespace = "::mlir::NVVM";
-}
-def SaturationModeAttr : EnumAttr<NVVM_Dialect, SaturationMode, "sat_mode"> {
-  let summary = "Describes the saturation mode";
-  let description = [{
-    A `nvvm.sat_mode` attribute specifies the saturation mode for instructions 
-    involving floating points or integers. It can be one of the following 
-    values:
-    - `none`: No saturation is applied.
-    - `satfinite`: If the absolute value of input (ignoring sign) is greater 
-      than the `MAX_NORM` of the specified destination format, then the result 
-      is the sign-preserved `MAX_NORM` of the destination format and a positive 
-      `MAX_NORM` in unsigned datatypes for which the destination sign is not 
-      supported. If the input is `NaN`, then the result can be `NaN` or th 
-      `MAX_NORM` of the destination format, depending on the format.
-    - `sat`: For integer destination types, this limits the value to `MININT..
-      MAXINT` and applies to both signed and unsigned integer datatypes. For 
-      floating point destination types (applies to only `F16`, `F32`, and `F64` 
-      types), this limits the value to the range `[0.0, 1.0]` and flushes NaN 
-      results to positive zero.
-
-    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt)
-}];
- let assemblyFormat = "`<` $value `>`";
-}
-
 def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> {
   let summary = "Convert the given float input to TF32";
   let description = [{

``````````

</details>


https://github.com/llvm/llvm-project/pull/195811


More information about the Mlir-commits mailing list