[Mlir-commits] [mlir] [MLIR][NVVM] Add PTX predefined special registers (PR #112343)

Pradeep Kumar llvmlistbot at llvm.org
Tue Oct 15 03:09:13 PDT 2024


https://github.com/schwarzschild-radius created https://github.com/llvm/llvm-project/pull/112343

This commit adds support for the following PTX predefined special registers
* warpid
* nwarpid
* smid
* nsmid
* gridid
* lanemask.*
* globaltimer
* envreg* And added lit tests under nvvmir.mlir

>From bcab57ab18cbf3b85c86c865a6c79362704b5419 Mon Sep 17 00:00:00 2001
From: pradeepku <pradeepku at nvidia.com>
Date: Thu, 10 Oct 2024 23:11:49 +0530
Subject: [PATCH] [MLIR][NVVM] Add PTX predefined special registers

This commit adds support for the following PTX predefined special
registers
* warpid
* nwarpid
* smid
* nsmid
* gridid
* lanemask.*
* globaltimer
* envreg*
And added lit tests under nvvmir.mlir
---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 22 ++++-
 mlir/test/Target/LLVMIR/nvvmir.mlir         | 92 ++++++++++++++++++++-
 2 files changed, 109 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 152715f281088e..e67f5fc8f9347b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -139,9 +139,22 @@ class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []>
 }
 
 //===----------------------------------------------------------------------===//
-// Lane index and range
+// Lane, Warp, SM, Grid index and range
 def NVVM_LaneIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.laneid">;
 def NVVM_WarpSizeOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpsize">;
+def NVVM_WarpIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpid">;
+def NVVM_WarpDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid">;
+def NVVM_SmIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.smid">;
+def NVVM_SmDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nsmid">;
+def NVVM_GridIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.gridid">;
+
+//===----------------------------------------------------------------------===//
+// Lane Mask Comparison Ops
+def NVVM_LaneMaskEqOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.lanemask.eq">;
+def NVVM_LaneMaskLeOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.lanemask.le">;
+def NVVM_LaneMaskLtOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.lanemask.lt">;
+def NVVM_LaneMaskGeOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.lanemask.ge">;
+def NVVM_LaneMaskGtOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.lanemask.gt">;
 
 //===----------------------------------------------------------------------===//
 // Thread index and range
@@ -189,6 +202,13 @@ def NVVM_ClusterDim : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nct
 // Clock registers
 def NVVM_ClockOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clock">;
 def NVVM_Clock64Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.clock64">;
+def NVVM_GlobalTimerOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.globaltimer">;
+
+//===----------------------------------------------------------------------===//
+// envreg registers
+foreach index = !range(0, 32) in {
+  def NVVM_EnvReg # index # Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.envreg" # index>;
+}
 
 //===----------------------------------------------------------------------===//
 // NVVM approximate op definitions
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 7fd082a5eb3c75..0471e5faf84578 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -62,10 +62,94 @@ llvm.func @nvvm_special_regs() -> i32 {
   %29 = nvvm.read.ptx.sreg.clock : i32
   // CHECK: call i64 @llvm.nvvm.read.ptx.sreg.clock64
   %30 = nvvm.read.ptx.sreg.clock64 : i64
-
-  // CHECK: %31 = call range(i32 0, 64) i32 @llvm.nvvm.read.ptx.sreg.tid.x()
-  %31 = nvvm.read.ptx.sreg.tid.x range <i32, 0, 64> : i32
-
+  // CHECK: call i64 @llvm.nvvm.read.ptx.sreg.globaltimer
+  %31 = nvvm.read.ptx.sreg.globaltimer : i64
+  // CHECK: %32 = call range(i32 0, 64) i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+  %32 = nvvm.read.ptx.sreg.tid.x range <i32, 0, 64> : i32
+  // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.warpid
+  %33 = nvvm.read.ptx.sreg.warpid : i32
+  // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.nwarpid
+  %34 = nvvm.read.ptx.sreg.nwarpid : i32
+  // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.smid
+  %35 = nvvm.read.ptx.sreg.smid : i32
+  // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.nsmid
+  %36 = nvvm.read.ptx.sreg.nsmid : i32
+  // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.gridid
+  %37 = nvvm.read.ptx.sreg.gridid : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg0
+  %38 = nvvm.read.ptx.sreg.envreg0 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg1
+  %39 = nvvm.read.ptx.sreg.envreg1 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg2
+  %40 = nvvm.read.ptx.sreg.envreg2 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg3
+  %41 = nvvm.read.ptx.sreg.envreg3 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg4
+  %42 = nvvm.read.ptx.sreg.envreg4 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg5
+  %43 = nvvm.read.ptx.sreg.envreg5 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg6
+  %44 = nvvm.read.ptx.sreg.envreg6 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg7
+  %45 = nvvm.read.ptx.sreg.envreg7 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg8
+  %46 = nvvm.read.ptx.sreg.envreg8 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg9
+  %47 = nvvm.read.ptx.sreg.envreg9 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg10
+  %48 = nvvm.read.ptx.sreg.envreg10 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg11
+  %49 = nvvm.read.ptx.sreg.envreg11 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg12
+  %50 = nvvm.read.ptx.sreg.envreg12 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg13
+  %51 = nvvm.read.ptx.sreg.envreg13 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg14
+  %52 = nvvm.read.ptx.sreg.envreg14 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg15
+  %53 = nvvm.read.ptx.sreg.envreg15 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg16
+  %54 = nvvm.read.ptx.sreg.envreg16 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg17
+  %55 = nvvm.read.ptx.sreg.envreg17 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg18
+  %56 = nvvm.read.ptx.sreg.envreg18 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg19
+  %57 = nvvm.read.ptx.sreg.envreg19 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg20
+  %58 = nvvm.read.ptx.sreg.envreg20 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg21
+  %59 = nvvm.read.ptx.sreg.envreg21 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg22
+  %60 = nvvm.read.ptx.sreg.envreg22 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg23
+  %61 = nvvm.read.ptx.sreg.envreg23 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg24
+  %62 = nvvm.read.ptx.sreg.envreg24 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg25
+  %63 = nvvm.read.ptx.sreg.envreg25 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg26
+  %64 = nvvm.read.ptx.sreg.envreg26 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg27
+  %65 = nvvm.read.ptx.sreg.envreg27 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg28
+  %66 = nvvm.read.ptx.sreg.envreg28 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg29
+  %67 = nvvm.read.ptx.sreg.envreg29 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg30
+  %68 = nvvm.read.ptx.sreg.envreg30 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg31
+  %69 = nvvm.read.ptx.sreg.envreg31 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.lanemask.eq
+  %70 = nvvm.read.ptx.sreg.lanemask.eq : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.lanemask.le
+  %71 = nvvm.read.ptx.sreg.lanemask.le : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.lanemask.lt
+  %72 = nvvm.read.ptx.sreg.lanemask.lt : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.lanemask.ge
+  %73 = nvvm.read.ptx.sreg.lanemask.ge : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.lanemask.gt
+  %74 = nvvm.read.ptx.sreg.lanemask.gt : i32
   llvm.return %1 : i32
 }
 



More information about the Mlir-commits mailing list