[llvm] [NVPTX] Mark callseq insts as reading and writing memory (PR #151376)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 30 11:46:10 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-nvptx
Author: Alex MacLean (AlexMaclean)
<details>
<summary>Changes</summary>
In order to prevent the st.param and ld.param instructions which store parameters and load return values from being sunk or hoisted out of a call sequence, mark the callseq start and end nodes as reading and writing memory.
Fixes #<!-- -->151329
---
Full diff: https://github.com/llvm/llvm-project/pull/151376.diff
2 Files Affected:
- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+38-32)
- (added) llvm/test/CodeGen/NVPTX/ld-param-sink.ll (+47)
``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 86d6f7c3fc3a3..ee4bccaba9419 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1830,6 +1830,18 @@ def : Pat<(declare_array_param externalsym:$a, imm:$align, imm:$size),
def : Pat<(declare_scalar_param externalsym:$a, imm:$size),
(DECLARE_PARAM_scalar (to_texternsym $a), imm:$size)>;
+// Call prototype wrapper, this is a dummy instruction that just prints it's
+// operand which is string defining the prototype.
+def SDTCallPrototype : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
+def CallPrototype :
+ SDNode<"NVPTXISD::CallPrototype", SDTCallPrototype,
+ [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
+def ProtoIdent : Operand<i32> { let PrintMethod = "printProtoIdent"; }
+def CALL_PROTOTYPE :
+ NVPTXInst<(outs), (ins ProtoIdent:$ident),
+ "$ident", [(CallPrototype (i32 texternalsym:$ident))]>;
+
+
foreach t = [I32RT, I64RT] in {
defvar inst_name = "MOV" # t.Size # "_PARAM";
def inst_name : BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src), "mov.b" # t.Size>;
@@ -1849,6 +1861,32 @@ defm ProxyRegB16 : ProxyRegInst<"b16", B16>;
defm ProxyRegB32 : ProxyRegInst<"b32", B32>;
defm ProxyRegB64 : ProxyRegInst<"b64", B64>;
+
+// Callseq start and end
+
+// Note: these nodes are marked as SDNPMayStore and SDNPMayLoad because
+// they define the scope in which the declared params may be used. Therefore
+// we add these flags to ensure ld.param and st.param are not sunk or hoisted
+// out of that scope.
+
+def callseq_start : SDNode<"ISD::CALLSEQ_START",
+ SDCallSeqStart<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>,
+ [SDNPHasChain, SDNPOutGlue,
+ SDNPSideEffect, SDNPMayStore, SDNPMayLoad]>;
+def callseq_end : SDNode<"ISD::CALLSEQ_END",
+ SDCallSeqEnd<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>,
+ [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
+ SDNPSideEffect, SDNPMayStore, SDNPMayLoad]>;
+
+def Callseq_Start :
+ NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2),
+ "\\{ // callseq $amt1, $amt2",
+ [(callseq_start timm:$amt1, timm:$amt2)]>;
+def Callseq_End :
+ NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2),
+ "\\} // callseq $amt1",
+ [(callseq_end timm:$amt1, timm:$amt2)]>;
+
//
// Load / Store Handling
//
@@ -2392,26 +2430,6 @@ def : Pat<(brcond i32:$a, bb:$target),
def : Pat<(brcond (i1 (setne i1:$a, -1)), bb:$target),
(CBranchOther $a, bb:$target)>;
-// Call
-def SDT_NVPTXCallSeqStart : SDCallSeqStart<[SDTCisVT<0, i32>,
- SDTCisVT<1, i32>]>;
-def SDT_NVPTXCallSeqEnd : SDCallSeqEnd<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>;
-
-def callseq_start : SDNode<"ISD::CALLSEQ_START", SDT_NVPTXCallSeqStart,
- [SDNPHasChain, SDNPOutGlue, SDNPSideEffect]>;
-def callseq_end : SDNode<"ISD::CALLSEQ_END", SDT_NVPTXCallSeqEnd,
- [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
- SDNPSideEffect]>;
-
-def Callseq_Start :
- NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2),
- "\\{ // callseq $amt1, $amt2",
- [(callseq_start timm:$amt1, timm:$amt2)]>;
-def Callseq_End :
- NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2),
- "\\} // callseq $amt1",
- [(callseq_end timm:$amt1, timm:$amt2)]>;
-
// trap instruction
def trapinst : BasicNVPTXInst<(outs), (ins), "trap", [(trap)]>, Requires<[noPTXASUnreachableBug]>;
// Emit an `exit` as well to convey to ptxas that `trap` exits the CFG.
@@ -2420,18 +2438,6 @@ def trapexitinst : NVPTXInst<(outs), (ins), "trap; exit;", [(trap)]>, Requires<[
// brkpt instruction
def debugtrapinst : BasicNVPTXInst<(outs), (ins), "brkpt", [(debugtrap)]>;
-// Call prototype wrapper
-def SDTCallPrototype : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
-def CallPrototype :
- SDNode<"NVPTXISD::CallPrototype", SDTCallPrototype,
- [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
-def ProtoIdent : Operand<i32> {
- let PrintMethod = "printProtoIdent";
-}
-def CALL_PROTOTYPE :
- NVPTXInst<(outs), (ins ProtoIdent:$ident),
- "$ident", [(CallPrototype (i32 texternalsym:$ident))]>;
-
def SDTDynAllocaOp :
SDTypeProfile<1, 2, [SDTCisSameAs<0, 1>, SDTCisInt<1>, SDTCisVT<2, i32>]>;
diff --git a/llvm/test/CodeGen/NVPTX/ld-param-sink.ll b/llvm/test/CodeGen/NVPTX/ld-param-sink.ll
new file mode 100644
index 0000000000000..03523a3be50c2
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/ld-param-sink.ll
@@ -0,0 +1,47 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -verify-machineinstrs | FileCheck %s
+; RUN: %if ptxas %{ llc < %s | %ptxas-verify %}
+
+target triple = "nvptx64-nvidia-cuda"
+
+declare ptr @bar(i64)
+declare i64 @baz()
+
+define ptr @foo(i1 %cond) {
+; CHECK-LABEL: foo(
+; CHECK: {
+; CHECK-NEXT: .reg .pred %p<2>;
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b8 %rs1, [foo_param_0];
+; CHECK-NEXT: and.b16 %rs2, %rs1, 1;
+; CHECK-NEXT: setp.ne.b16 %p1, %rs2, 0;
+; CHECK-NEXT: { // callseq 0, 0
+; CHECK-NEXT: .param .b64 retval0;
+; CHECK-NEXT: call.uni (retval0), baz, ();
+; CHECK-NEXT: ld.param.b64 %rd2, [retval0];
+; CHECK-NEXT: } // callseq 0
+; CHECK-NEXT: @%p1 bra $L__BB0_2;
+; CHECK-NEXT: // %bb.1: // %bb
+; CHECK-NEXT: { // callseq 1, 0
+; CHECK-NEXT: .param .b64 param0;
+; CHECK-NEXT: .param .b64 retval0;
+; CHECK-NEXT: st.param.b64 [param0], %rd2;
+; CHECK-NEXT: call.uni (retval0), bar, (param0);
+; CHECK-NEXT: } // callseq 1
+; CHECK-NEXT: $L__BB0_2: // %common.ret
+; CHECK-NEXT: st.param.b64 [func_retval0], 0;
+; CHECK-NEXT: ret;
+entry:
+ %call = call i64 @baz()
+ br i1 %cond, label %common.ret, label %bb
+
+bb:
+ %tmp = call ptr @bar(i64 %call)
+ br label %common.ret
+
+common.ret:
+ ret ptr null
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/151376
More information about the llvm-commits
mailing list