[Mlir-commits] [mlir] 574b423 - [MLIR][NVVM] Introduce special registers for CTA Cluster
Guray Ozen
llvmlistbot at llvm.org
Fri Sep 1 07:38:35 PDT 2023
Author: Guray Ozen
Date: 2023-09-01T16:38:29+02:00
New Revision: 574b423a80acd252dcf031f5a5ea0d6ec76bf0de
URL: https://github.com/llvm/llvm-project/commit/574b423a80acd252dcf031f5a5ea0d6ec76bf0de
DIFF: https://github.com/llvm/llvm-project/commit/574b423a80acd252dcf031f5a5ea0d6ec76bf0de.diff
LOG: [MLIR][NVVM] Introduce special registers for CTA Cluster
This work introduces special registers such as cluster ID, dimensions, and more for managing CTA clusters, which are groups of CTAsthat can synchronize and communicate through shared memory. This is for Nvidia's sm_90 capability.
Differential Revision: https://reviews.llvm.org/D158588
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/test/Target/LLVMIR/nvvmir.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 6daeb93eb4cd3c..55aa820314e89b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -249,6 +249,30 @@ def NVVM_GridDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.x">;
def NVVM_GridDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">;
def NVVM_GridDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">;
+//===----------------------------------------------------------------------===//
+// CTA Cluster index and range
+def NVVM_ClusterIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clusterid.x">;
+def NVVM_ClusterIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clusterid.y">;
+def NVVM_ClusterIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clusterid.z">;
+def NVVM_ClusterDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nclusterid.x">;
+def NVVM_ClusterDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nclusterid.y">;
+def NVVM_ClusterDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nclusterid.z">;
+
+
+//===----------------------------------------------------------------------===//
+// CTA index and range within Cluster
+def NVVM_BlockInClusterIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctaid.x">;
+def NVVM_BlockInClusterIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctaid.y">;
+def NVVM_BlockInClusterIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctaid.z">;
+def NVVM_GridInClusterDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctaid.x">;
+def NVVM_GridInClusterDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctaid.y">;
+def NVVM_GridInClusterDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctaid.z">;
+
+//===----------------------------------------------------------------------===//
+// CTA index and across Cluster dimensions
+def NVVM_ClusterId : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctarank">;
+def NVVM_ClusterDim : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctarank">;
+
//===----------------------------------------------------------------------===//
// NVVM approximate op definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 39af5895387d32..24ef1198577937 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -30,6 +30,35 @@ llvm.func @nvvm_special_regs() -> i32 {
%13 = nvvm.read.ptx.sreg.warpsize : i32
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.laneid()
%14 = nvvm.read.ptx.sreg.laneid : i32
+ // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.clusterid.x
+ %15 = nvvm.read.ptx.sreg.clusterid.x : i32
+ // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.clusterid.y
+ %16 = nvvm.read.ptx.sreg.clusterid.y : i32
+ // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.clusterid.z
+ %17 = nvvm.read.ptx.sreg.clusterid.z : i32
+ // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.nclusterid.x
+ %18 = nvvm.read.ptx.sreg.nclusterid.x : i32
+ // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.nclusterid.y
+ %19 = nvvm.read.ptx.sreg.nclusterid.y : i32
+ // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.nclusterid.z
+ %20 = nvvm.read.ptx.sreg.nclusterid.z : i32
+ // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.cluster.ctaid
+ %21 = nvvm.read.ptx.sreg.cluster.ctaid.x : i32
+ // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.cluster.ctaid
+ %22 = nvvm.read.ptx.sreg.cluster.ctaid.y : i32
+ // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.cluster.ctaid
+ %23 = nvvm.read.ptx.sreg.cluster.ctaid.z : i32
+ // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.cluster.nctaid
+ %24 = nvvm.read.ptx.sreg.cluster.nctaid.x : i32
+ // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.cluster.nctaid
+ %25 = nvvm.read.ptx.sreg.cluster.nctaid.y : i32
+ // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.cluster.nctaid
+ %26 = nvvm.read.ptx.sreg.cluster.nctaid.z : i32
+ // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.cluster.ctarank
+ %27 = nvvm.read.ptx.sreg.cluster.ctarank : i32
+ // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.cluster.nctarank
+ %28 = nvvm.read.ptx.sreg.cluster.nctarank : i32
+
llvm.return %1 : i32
}
More information about the Mlir-commits
mailing list