[Mlir-commits] [mlir] [mlir][NVVM] Tighten result-type predicate on special-register ops (PR #195030)

Bastian Hagedorn llvmlistbot at llvm.org
Thu Apr 30 03:17:36 PDT 2026


https://github.com/bastianhagedorn updated https://github.com/llvm/llvm-project/pull/195030

>From 5d22055b882ae0edbfc7e8a93171254e0a56de0e Mon Sep 17 00:00:00 2001
From: Bastian Hagedorn <bhagedorn at nvidia.com>
Date: Thu, 30 Apr 2026 09:34:25 +0000
Subject: [PATCH 1/2] [mlir][NVVM] Tighten result-type predicate on
 special-register ops

Use concrete `I32` (default) and `I64` (clock64, globaltimer) instead
of generic `LLVM_Type` for the special-register op result type.

Assisted-by: Claude Opus 4.7 (Anthropic)
---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td |  9 ++++++---
 mlir/test/Dialect/LLVMIR/invalid.mlir       | 16 ++++++++++++++++
 2 files changed, 22 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 95fe1e0535843..73afdb29b6149 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -309,12 +309,15 @@ class NVVM_SingleResultIntrinsicOp<string mnemonic, list<Trait> traits = [], str
 class NVVM_PureSpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
   NVVM_IntrOp<mnemonic, !listconcat(traits, [Pure]), 1> {
   let arguments = (ins);
+  let results = (outs I32:$res);
   let assemblyFormat = "attr-dict `:` type($res)";
 }
 
-class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
+class NVVM_SpecialRegisterOp<string mnemonic, Type resultType = I32,
+                                              list<Trait> traits = []> :
   NVVM_IntrOp<mnemonic, traits, 1> {
   let arguments = (ins);
+  let results = (outs resultType:$res);
   let assemblyFormat = "attr-dict `:` type($res)";
 }
 
@@ -421,8 +424,8 @@ def NVVM_AggrSmemSize    : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.aggr.smem.s
 //===----------------------------------------------------------------------===//
 // Clock registers
 def NVVM_ClockOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clock">;
-def NVVM_Clock64Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.clock64">;
-def NVVM_GlobalTimerOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.globaltimer">;
+def NVVM_Clock64Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.clock64", I64>;
+def NVVM_GlobalTimerOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.globaltimer", I64>;
 def NVVM_GlobalTimerLoOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.globaltimer.lo">;
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index e849b59b846f7..0e3357992be18 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -2115,3 +2115,19 @@ module attributes { dlti.dl_spec = #dlti.dl_spec<
     %0 = llvm.ptrtoaddr %arg0 : !llvm.ptr to i64
   }
 }
+
+// -----
+
+func.func @nvvm_read_sreg_tid_x_wrong_type() {
+  // expected-error at +1 {{'nvvm.read.ptx.sreg.tid.x' op result #0 must be 32-bit signless integer, but got 'i64'}}
+  %0 = nvvm.read.ptx.sreg.tid.x : i64
+  return
+}
+
+// -----
+
+func.func @nvvm_read_sreg_clock64_wrong_type() {
+  // expected-error at +1 {{'nvvm.read.ptx.sreg.clock64' op result #0 must be 64-bit signless integer, but got 'i32'}}
+  %0 = nvvm.read.ptx.sreg.clock64 : i32
+  return
+}

>From 8be23f8821072af9ca3262079f4d07c4b55f11da Mon Sep 17 00:00:00 2001
From: Bastian Hagedorn <bhagedorn at nvidia.com>
Date: Thu, 30 Apr 2026 09:34:26 +0000
Subject: [PATCH 2/2] [mlir][NVVM][python] Test inferred-result form for
 special-register ops

Construct a few special-register ops from Python with no arguments and
check the printed types.

Assisted-by: Claude Opus 4.7 (Anthropic)
---
 mlir/test/python/dialects/nvvm.py | 13 +++++++++++++
 1 file changed, 13 insertions(+)

diff --git a/mlir/test/python/dialects/nvvm.py b/mlir/test/python/dialects/nvvm.py
index 24abf617548b8..f5e057812642e 100644
--- a/mlir/test/python/dialects/nvvm.py
+++ b/mlir/test/python/dialects/nvvm.py
@@ -377,3 +377,16 @@ def reductions(mask, vi32, vf32):
 # CHECK:           %[[REDUX_35:.*]] = nvvm.redux.sync fmax %[[ARG2]], %[[ARG1]] : f32 -> f32
 # CHECK:           return
 # CHECK:         }
+
+
+# CHECK-LABEL: TEST: testSpecialRegisterInferredResults
+ at constructAndPrintInModule
+def testSpecialRegisterInferredResults():
+    # CHECK: %{{.*}} = nvvm.read.ptx.sreg.tid.x : i32
+    nvvm.ThreadIdXOp()
+    # CHECK: %{{.*}} = nvvm.read.ptx.sreg.clock : i32
+    nvvm.ClockOp()
+    # CHECK: %{{.*}} = nvvm.read.ptx.sreg.clock64 : i64
+    nvvm.Clock64Op()
+    # CHECK: %{{.*}} = nvvm.read.ptx.sreg.globaltimer : i64
+    nvvm.GlobalTimerOp()



More information about the Mlir-commits mailing list