[Mlir-commits] [mlir] [mlir][nvvm] Add attributes for cluster dimension PTX directives (PR #116973)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 20 05:53:52 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: None (arthurqiu)

<details>
<summary>Changes</summary>

PTX programming models provides cluster dimension directives, which are leveraged by the downstream `ptxas` compiler. See https://docs.nvidia.com/cuda/nvvm-ir-spec/#supported-properties and https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cluster-dimension-directives

This PR introduces the cluster dimension directives to MLIR's NVVM dialect as listed below:
```
cluster_dim_{x,y,z}    ->    exact number of CTAs per cluster
cluster_max_blocks     ->    max number of CTAs per cluster
```

---
Full diff: https://github.com/llvm/llvm-project/pull/116973.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+12) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+8-4) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp (+14) 
- (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (+22) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 6b462de144d1ff..296a3c305e5bf4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -53,6 +53,18 @@ def NVVM_Dialect : Dialect {
     static StringRef getReqntidYName() { return "reqntidy"; }
     static StringRef getReqntidZName() { return "reqntidz"; }
 
+    /// Get the name of the attribute used to annotate exact CTAs required
+    /// per cluster for kernel functions.
+    static StringRef getClusterDimAttrName() { return "nvvm.cluster_dim"; }
+    /// Get the name of the metadata names for each dimension
+    static StringRef getClusterDimXName() { return "cluster_dim_x"; }
+    static StringRef getClusterDimYName() { return "cluster_dim_y"; }
+    static StringRef getClusterDimZName() { return "cluster_dim_z"; }
+
+    /// Get the name of the attribute used to annotate maximum number of
+    /// CTAs per cluster for kernel functions.
+    static StringRef getClusterMaxBlocksAttrName() {  return "nvvm.cluster_max_blocks"; }
+
     /// Get the name of the attribute used to annotate min CTA required
     /// per SM for kernel functions.
     static StringRef getMinctasmAttrName() { return "nvvm.minctasm"; }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index d28194d5c00298..ca04af0b060b4f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1126,18 +1126,22 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
                              << "' attribute attached to unexpected op";
     }
   }
-  // If maxntid and reqntid exist, it must be an array with max 3 dim
+  // If maxntid / reqntid / cluster_dim exist, it must be an array with max 3
+  // dim
   if (attrName == NVVMDialect::getMaxntidAttrName() ||
-      attrName == NVVMDialect::getReqntidAttrName()) {
+      attrName == NVVMDialect::getReqntidAttrName() ||
+      attrName == NVVMDialect::getClusterDimAttrName()) {
     auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
     if (!values || values.empty() || values.size() > 3)
       return op->emitError()
              << "'" << attrName
              << "' attribute must be integer array with maximum 3 index";
   }
-  // If minctasm and maxnreg exist, it must be an integer attribute
+  // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer
+  // attribute
   if (attrName == NVVMDialect::getMinctasmAttrName() ||
-      attrName == NVVMDialect::getMaxnregAttrName()) {
+      attrName == NVVMDialect::getMaxnregAttrName() ||
+      attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
     if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))
       return op->emitError()
              << "'" << attrName << "' attribute must be integer constant";
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 9cc66207660f64..cf58bc5d8f475a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -214,6 +214,20 @@ class NVVMDialectLLVMIRTranslationInterface
         generateMetadata(values[1], NVVM::NVVMDialect::getReqntidYName());
       if (values.size() > 2)
         generateMetadata(values[2], NVVM::NVVMDialect::getReqntidZName());
+    } else if (attribute.getName() ==
+               NVVM::NVVMDialect::getClusterDimAttrName()) {
+      if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
+        return failure();
+      auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
+      generateMetadata(values[0], NVVM::NVVMDialect::getClusterDimXName());
+      if (values.size() > 1)
+        generateMetadata(values[1], NVVM::NVVMDialect::getClusterDimYName());
+      if (values.size() > 2)
+        generateMetadata(values[2], NVVM::NVVMDialect::getClusterDimZName());
+    } else if (attribute.getName() ==
+               NVVM::NVVMDialect::getClusterMaxBlocksAttrName()) {
+      auto value = dyn_cast<IntegerAttr>(attribute.getValue());
+      generateMetadata(value.getInt(), "cluster_max_blocks");
     } else if (attribute.getName() ==
                NVVM::NVVMDialect::getMinctasmAttrName()) {
       auto value = dyn_cast<IntegerAttr>(attribute.getValue());
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index e5ea03ff7e0017..a4a3581d6b7594 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -586,6 +586,28 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.reqntid = array<i32: 1, 2
 // CHECK:     {ptr @kernel_func, !"reqntidz", i32 32}
 // -----
 
+llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.cluster_dim = array<i32: 3, 5, 7>} {
+  llvm.return
+}
+
+// CHECK:     !nvvm.annotations =
+// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1}
+// CHECK:     {ptr @kernel_func, !"cluster_dim_x", i32 3}
+// CHECK:     {ptr @kernel_func, !"cluster_dim_y", i32 5}
+// CHECK:     {ptr @kernel_func, !"cluster_dim_z", i32 7}
+// CHECK:     {ptr @kernel_func, !"kernel", i32 1}
+// -----
+
+llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.cluster_max_blocks = 8} {
+  llvm.return
+}
+
+// CHECK:     !nvvm.annotations =
+// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1}
+// CHECK:     {ptr @kernel_func, !"cluster_max_blocks", i32 8}
+// CHECK:     {ptr @kernel_func, !"kernel", i32 1}
+// -----
+
 llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.minctasm = 16} {
   llvm.return
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/116973


More information about the Mlir-commits mailing list