[llvm] [NVPTX] Mark callseq insts as reading and writing memory (PR #151376)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 30 11:46:10 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

<details>
<summary>Changes</summary>

In order to prevent the st.param and ld.param instructions which store parameters and load return values from being sunk or hoisted out of a call sequence, mark the callseq start and end nodes as reading and writing memory. 

Fixes #<!-- -->151329

---
Full diff: https://github.com/llvm/llvm-project/pull/151376.diff


2 Files Affected:

- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+38-32) 
- (added) llvm/test/CodeGen/NVPTX/ld-param-sink.ll (+47) 


``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 86d6f7c3fc3a3..ee4bccaba9419 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1830,6 +1830,18 @@ def : Pat<(declare_array_param externalsym:$a, imm:$align, imm:$size),
 def : Pat<(declare_scalar_param externalsym:$a, imm:$size),
           (DECLARE_PARAM_scalar (to_texternsym $a), imm:$size)>;
 
+// Call prototype wrapper, this is a dummy instruction that just prints it's
+// operand which is string defining the prototype.
+def SDTCallPrototype : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
+def CallPrototype :
+  SDNode<"NVPTXISD::CallPrototype", SDTCallPrototype,
+         [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
+def ProtoIdent : Operand<i32> { let PrintMethod = "printProtoIdent"; }
+def CALL_PROTOTYPE :
+  NVPTXInst<(outs), (ins ProtoIdent:$ident),
+            "$ident", [(CallPrototype (i32 texternalsym:$ident))]>;
+
+
 foreach t = [I32RT, I64RT] in {
   defvar inst_name = "MOV" # t.Size # "_PARAM";
   def inst_name : BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src), "mov.b" # t.Size>;
@@ -1849,6 +1861,32 @@ defm ProxyRegB16 : ProxyRegInst<"b16",  B16>;
 defm ProxyRegB32 : ProxyRegInst<"b32",  B32>;
 defm ProxyRegB64 : ProxyRegInst<"b64",  B64>;
 
+
+// Callseq start and end
+
+// Note: these nodes are marked as SDNPMayStore and SDNPMayLoad because
+// they define the scope in which the declared params may be used. Therefore
+// we add these flags to ensure ld.param and st.param are not sunk or hoisted
+// out of that scope.
+
+def callseq_start : SDNode<"ISD::CALLSEQ_START",
+                           SDCallSeqStart<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>,
+                           [SDNPHasChain, SDNPOutGlue,
+                            SDNPSideEffect, SDNPMayStore, SDNPMayLoad]>;
+def callseq_end   : SDNode<"ISD::CALLSEQ_END",
+                           SDCallSeqEnd<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>,
+                           [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
+                            SDNPSideEffect, SDNPMayStore, SDNPMayLoad]>;
+
+def Callseq_Start :
+  NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2),
+            "\\{ // callseq $amt1, $amt2",
+            [(callseq_start timm:$amt1, timm:$amt2)]>;
+def Callseq_End :
+  NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2),
+            "\\} // callseq $amt1",
+            [(callseq_end timm:$amt1, timm:$amt2)]>;
+
 //
 // Load / Store Handling
 //
@@ -2392,26 +2430,6 @@ def : Pat<(brcond i32:$a, bb:$target),
 def : Pat<(brcond (i1 (setne i1:$a, -1)), bb:$target),
           (CBranchOther $a, bb:$target)>;
 
-// Call
-def SDT_NVPTXCallSeqStart : SDCallSeqStart<[SDTCisVT<0, i32>,
-                                            SDTCisVT<1, i32>]>;
-def SDT_NVPTXCallSeqEnd   : SDCallSeqEnd<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>;
-
-def callseq_start : SDNode<"ISD::CALLSEQ_START", SDT_NVPTXCallSeqStart,
-                           [SDNPHasChain, SDNPOutGlue, SDNPSideEffect]>;
-def callseq_end   : SDNode<"ISD::CALLSEQ_END", SDT_NVPTXCallSeqEnd,
-                           [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
-                            SDNPSideEffect]>;
-
-def Callseq_Start :
-  NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2),
-            "\\{ // callseq $amt1, $amt2",
-            [(callseq_start timm:$amt1, timm:$amt2)]>;
-def Callseq_End :
-  NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2),
-            "\\} // callseq $amt1",
-            [(callseq_end timm:$amt1, timm:$amt2)]>;
-
 // trap instruction
 def trapinst : BasicNVPTXInst<(outs), (ins), "trap", [(trap)]>, Requires<[noPTXASUnreachableBug]>;
 // Emit an `exit` as well to convey to ptxas that `trap` exits the CFG.
@@ -2420,18 +2438,6 @@ def trapexitinst : NVPTXInst<(outs), (ins), "trap; exit;", [(trap)]>, Requires<[
 // brkpt instruction
 def debugtrapinst : BasicNVPTXInst<(outs), (ins), "brkpt", [(debugtrap)]>;
 
-// Call prototype wrapper
-def SDTCallPrototype : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
-def CallPrototype :
-  SDNode<"NVPTXISD::CallPrototype", SDTCallPrototype,
-         [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
-def ProtoIdent : Operand<i32> {
-  let PrintMethod = "printProtoIdent";
-}
-def CALL_PROTOTYPE :
-  NVPTXInst<(outs), (ins ProtoIdent:$ident),
-            "$ident", [(CallPrototype (i32 texternalsym:$ident))]>;
-
 def SDTDynAllocaOp :
   SDTypeProfile<1, 2, [SDTCisSameAs<0, 1>, SDTCisInt<1>, SDTCisVT<2, i32>]>;
 
diff --git a/llvm/test/CodeGen/NVPTX/ld-param-sink.ll b/llvm/test/CodeGen/NVPTX/ld-param-sink.ll
new file mode 100644
index 0000000000000..03523a3be50c2
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/ld-param-sink.ll
@@ -0,0 +1,47 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -verify-machineinstrs | FileCheck %s
+; RUN: %if ptxas %{ llc < %s | %ptxas-verify %}
+
+target triple = "nvptx64-nvidia-cuda"
+
+declare ptr @bar(i64)
+declare i64 @baz()
+
+define ptr @foo(i1 %cond) {
+; CHECK-LABEL: foo(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<2>;
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.b8 %rs1, [foo_param_0];
+; CHECK-NEXT:    and.b16 %rs2, %rs1, 1;
+; CHECK-NEXT:    setp.ne.b16 %p1, %rs2, 0;
+; CHECK-NEXT:    { // callseq 0, 0
+; CHECK-NEXT:    .param .b64 retval0;
+; CHECK-NEXT:    call.uni (retval0), baz, ();
+; CHECK-NEXT:    ld.param.b64 %rd2, [retval0];
+; CHECK-NEXT:    } // callseq 0
+; CHECK-NEXT:    @%p1 bra $L__BB0_2;
+; CHECK-NEXT:  // %bb.1: // %bb
+; CHECK-NEXT:    { // callseq 1, 0
+; CHECK-NEXT:    .param .b64 param0;
+; CHECK-NEXT:    .param .b64 retval0;
+; CHECK-NEXT:    st.param.b64 [param0], %rd2;
+; CHECK-NEXT:    call.uni (retval0), bar, (param0);
+; CHECK-NEXT:    } // callseq 1
+; CHECK-NEXT:  $L__BB0_2: // %common.ret
+; CHECK-NEXT:    st.param.b64 [func_retval0], 0;
+; CHECK-NEXT:    ret;
+entry:
+  %call = call i64 @baz()
+  br i1 %cond, label %common.ret, label %bb
+
+bb:
+  %tmp = call ptr @bar(i64 %call)
+  br label %common.ret
+
+common.ret:
+  ret ptr null
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/151376


More information about the llvm-commits mailing list