[Mlir-commits] [mlir] 5cd56c9 - [MLIR][NVVM] Remove Pure trait from clock, clock64, globaltimer Ops (#147608)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jul 11 03:05:51 PDT 2025
Author: Pradeep Kumar
Date: 2025-07-11T15:35:46+05:30
New Revision: 5cd56c9216b2c524423a5d913fc8458d28e25d1a
URL: https://github.com/llvm/llvm-project/commit/5cd56c9216b2c524423a5d913fc8458d28e25d1a
DIFF: https://github.com/llvm/llvm-project/commit/5cd56c9216b2c524423a5d913fc8458d28e25d1a.diff
LOG: [MLIR][NVVM] Remove Pure trait from clock, clock64, globaltimer Ops (#147608)
This commit removes Pure trait from clock, clock64 and globaltimer Ops by creating NVVM_NCSpecialRegisterOp class to represent Ops which return non-constant values. This prevents CSE pass from optimizing awayredundant uses of them
Added:
mlir/test/Dialect/LLVMIR/cse-nvvm.mlir
Modified:
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 6895e946b8a45..45a8904375e2b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -153,14 +153,20 @@ class NVVM_IntrOp<string mnem, list<Trait> traits = [],
// NVVM special register op definitions
//===----------------------------------------------------------------------===//
-class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
+class NVVM_PureSpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
NVVM_IntrOp<mnemonic, !listconcat(traits, [Pure]), 1> {
let arguments = (ins);
let assemblyFormat = "attr-dict `:` type($res)";
}
-class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
- NVVM_SpecialRegisterOp<mnemonic,
+class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
+ NVVM_IntrOp<mnemonic, traits, 1> {
+ let arguments = (ins);
+ let assemblyFormat = "attr-dict `:` type($res)";
+}
+
+class NVVM_PureSpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
+ NVVM_PureSpecialRegisterOp<mnemonic,
!listconcat(traits,
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>])> {
let arguments = (ins OptionalAttr<LLVM_ConstantRangeAttr>:$range);
@@ -189,63 +195,63 @@ class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []>
//===----------------------------------------------------------------------===//
// Lane, Warp, SM, Grid index and range
-def NVVM_LaneIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.laneid">;
-def NVVM_WarpSizeOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpsize">;
-def NVVM_WarpIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpid">;
-def NVVM_WarpDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid">;
-def NVVM_SmIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.smid">;
-def NVVM_SmDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nsmid">;
-def NVVM_GridIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.gridid">;
+def NVVM_LaneIdOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.laneid">;
+def NVVM_WarpSizeOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.warpsize">;
+def NVVM_WarpIdOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.warpid">;
+def NVVM_WarpDimOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid">;
+def NVVM_SmIdOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.smid">;
+def NVVM_SmDimOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nsmid">;
+def NVVM_GridIdOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.gridid">;
//===----------------------------------------------------------------------===//
// Lane Mask Comparison Ops
-def NVVM_LaneMaskEqOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.eq">;
-def NVVM_LaneMaskLeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.le">;
-def NVVM_LaneMaskLtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.lt">;
-def NVVM_LaneMaskGeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.ge">;
-def NVVM_LaneMaskGtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.gt">;
+def NVVM_LaneMaskEqOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.eq">;
+def NVVM_LaneMaskLeOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.le">;
+def NVVM_LaneMaskLtOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.lt">;
+def NVVM_LaneMaskGeOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.ge">;
+def NVVM_LaneMaskGtOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.gt">;
//===----------------------------------------------------------------------===//
// Thread index and range
-def NVVM_ThreadIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.x">;
-def NVVM_ThreadIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.y">;
-def NVVM_ThreadIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.z">;
-def NVVM_BlockDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.x">;
-def NVVM_BlockDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.y">;
-def NVVM_BlockDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.z">;
+def NVVM_ThreadIdXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.tid.x">;
+def NVVM_ThreadIdYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.tid.y">;
+def NVVM_ThreadIdZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.tid.z">;
+def NVVM_BlockDimXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.ntid.x">;
+def NVVM_BlockDimYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.ntid.y">;
+def NVVM_BlockDimZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.ntid.z">;
//===----------------------------------------------------------------------===//
// Block index and range
-def NVVM_BlockIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.x">;
-def NVVM_BlockIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.y">;
-def NVVM_BlockIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.z">;
-def NVVM_GridDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.x">;
-def NVVM_GridDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.y">;
-def NVVM_GridDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.z">;
+def NVVM_BlockIdXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.x">;
+def NVVM_BlockIdYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.y">;
+def NVVM_BlockIdZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.z">;
+def NVVM_GridDimXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.x">;
+def NVVM_GridDimYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.y">;
+def NVVM_GridDimZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.z">;
//===----------------------------------------------------------------------===//
// CTA Cluster index and range
-def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x", [NVVMRequiresSM<90>]>;
-def NVVM_ClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.y">;
-def NVVM_ClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.z">;
-def NVVM_ClusterDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.x">;
-def NVVM_ClusterDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.y">;
-def NVVM_ClusterDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.z">;
+def NVVM_ClusterIdXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x", [NVVMRequiresSM<90>]>;
+def NVVM_ClusterIdYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.y">;
+def NVVM_ClusterIdZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.z">;
+def NVVM_ClusterDimXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.x">;
+def NVVM_ClusterDimYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.y">;
+def NVVM_ClusterDimZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.z">;
//===----------------------------------------------------------------------===//
// CTA index and range within Cluster
-def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x", [NVVMRequiresSM<90>]>;
-def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y", [NVVMRequiresSM<90>]>;
-def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z", [NVVMRequiresSM<90>]>;
-def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x", [NVVMRequiresSM<90>]>;
-def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.y", [NVVMRequiresSM<90>]>;
-def NVVM_ClusterDimBlocksZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.z">;
+def NVVM_BlockInClusterIdXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x", [NVVMRequiresSM<90>]>;
+def NVVM_BlockInClusterIdYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y", [NVVMRequiresSM<90>]>;
+def NVVM_BlockInClusterIdZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z", [NVVMRequiresSM<90>]>;
+def NVVM_ClusterDimBlocksXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x", [NVVMRequiresSM<90>]>;
+def NVVM_ClusterDimBlocksYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.y", [NVVMRequiresSM<90>]>;
+def NVVM_ClusterDimBlocksZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.z">;
//===----------------------------------------------------------------------===//
// CTA index and across Cluster dimensions
-def NVVM_ClusterId : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank", [NVVMRequiresSM<90>]>;
-def NVVM_ClusterDim : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctarank">;
+def NVVM_ClusterId : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank", [NVVMRequiresSM<90>]>;
+def NVVM_ClusterDim : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctarank">;
//===----------------------------------------------------------------------===//
// Clock registers
@@ -256,7 +262,7 @@ def NVVM_GlobalTimerOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.globaltimer">;
//===----------------------------------------------------------------------===//
// envreg registers
foreach index = !range(0, 32) in {
- def NVVM_EnvReg # index # Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.envreg" # index>;
+ def NVVM_EnvReg # index # Op : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.envreg" # index>;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/cse-nvvm.mlir b/mlir/test/Dialect/LLVMIR/cse-nvvm.mlir
new file mode 100644
index 0000000000000..8d24c3846f178
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/cse-nvvm.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s -cse -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: @nvvm_special_regs_clock
+llvm.func @nvvm_special_regs_clock() -> !llvm.struct<(i32, i32)> {
+ %0 = llvm.mlir.zero: !llvm.struct<(i32, i32)>
+ // CHECK: {{.*}} = nvvm.read.ptx.sreg.clock
+ %1 = nvvm.read.ptx.sreg.clock : i32
+ // CHECK: {{.*}} = nvvm.read.ptx.sreg.clock
+ %2 = nvvm.read.ptx.sreg.clock : i32
+ %4 = llvm.insertvalue %1, %0[0]: !llvm.struct<(i32, i32)>
+ %5 = llvm.insertvalue %2, %4[1]: !llvm.struct<(i32, i32)>
+ llvm.return %5: !llvm.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @nvvm_special_regs_clock64
+llvm.func @nvvm_special_regs_clock64() -> !llvm.struct<(i64, i64)> {
+ %0 = llvm.mlir.zero: !llvm.struct<(i64, i64)>
+ // CHECK: {{.*}} = nvvm.read.ptx.sreg.clock64
+ %1 = nvvm.read.ptx.sreg.clock64 : i64
+ // CHECK: {{.*}} = nvvm.read.ptx.sreg.clock64
+ %2 = nvvm.read.ptx.sreg.clock64 : i64
+ %4 = llvm.insertvalue %1, %0[0]: !llvm.struct<(i64, i64)>
+ %5 = llvm.insertvalue %2, %4[1]: !llvm.struct<(i64, i64)>
+ llvm.return %5: !llvm.struct<(i64, i64)>
+}
+
+// CHECK-LABEL: @nvvm_special_regs_globaltimer
+llvm.func @nvvm_special_regs_globaltimer() -> !llvm.struct<(i64, i64)> {
+ %0 = llvm.mlir.zero: !llvm.struct<(i64, i64)>
+ // CHECK: {{.*}} = nvvm.read.ptx.sreg.globaltimer
+ %1 = nvvm.read.ptx.sreg.globaltimer : i64
+ // CHECK: {{.*}} = nvvm.read.ptx.sreg.globaltimer
+ %2 = nvvm.read.ptx.sreg.globaltimer : i64
+ %4 = llvm.insertvalue %1, %0[0]: !llvm.struct<(i64, i64)>
+ %5 = llvm.insertvalue %2, %4[1]: !llvm.struct<(i64, i64)>
+ llvm.return %5: !llvm.struct<(i64, i64)>
+}
More information about the Mlir-commits
mailing list