[llvm] [llvm][NVPTX] Don't reorder MIs that construct a PTX function call (PR #116522)

via llvm-commits llvm-commits at lists.llvm.org
Sat Nov 16 18:56:40 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: Youngsuk Kim (JOE1994)

<details>
<summary>Changes</summary>

With "-enable-misched", MachineScheduler can reorder MIs that must stick together (in initially set order) to generate legal PTX code for a function call.

When generating PTX code for the attached test (using LLVM before this revision), the following invalid PTX code is generated:

```
    { // callseq 0, 0
    .param .b64 param0;
    st.param.f64  [param0], %fd1;
    .param .b64 retval0;
    call.uni (retval0),
    cvt.u32.u64   %r20, %rd18;
    mad.lo.s32  %r21, %r7, %r20, 1;
    cvt.rn.f64.s32  %fd4, %r21;
    _FOO,
    (
    param0
    );
    ld.param.f64  %fd2, [retval0];
    add.s32   %r22, %r18, 1;
    cvt.rn.f64.s32  %fd5, %r22;
    } // callseq 0
```

---
Full diff: https://github.com/llvm/llvm-project/pull/116522.diff


3 Files Affected:

- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp (+20) 
- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.h (+3) 
- (added) llvm/test/CodeGen/NVPTX/misched_func_call.ll (+108) 


``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
index bec40874c89488..9c1ed0d5f5abd9 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
@@ -202,3 +202,23 @@ unsigned NVPTXInstrInfo::insertBranch(MachineBasicBlock &MBB,
   BuildMI(&MBB, DL, get(NVPTX::GOTO)).addMBB(FBB);
   return 2;
 }
+
+bool NVPTXInstrInfo::isSchedulingBoundary(const MachineInstr &MI,
+                                          const MachineBasicBlock *MBB,
+                                          const MachineFunction &MF) const {
+  // Prevent the scheduler from reordering & splitting up MachineInstrs
+  // which must stick together (in initially set order) to
+  // comprise a valid PTX function call sequence.
+  switch (MI.getOpcode()) {
+    case NVPTX::CallUniPrintCallRetInst1:
+    case NVPTX::CallArgBeginInst:
+    case NVPTX::CallArgI32imm:
+    case NVPTX::CallArgParam:
+    case NVPTX::LastCallArgI32imm:
+    case NVPTX::LastCallArgParam:
+    case NVPTX::CallArgEndInst1:
+      return true;
+  }
+
+  return TargetInstrInfo::isSchedulingBoundary(MI, MBB, MF);
+}
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.h b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.h
index f674a00bc351bf..a1d9f017120188 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.h
@@ -67,6 +67,9 @@ class NVPTXInstrInfo : public NVPTXGenInstrInfo {
                         MachineBasicBlock *FBB, ArrayRef<MachineOperand> Cond,
                         const DebugLoc &DL,
                         int *BytesAdded = nullptr) const override;
+  bool isSchedulingBoundary(const MachineInstr &MI,
+                            const MachineBasicBlock *MBB,
+                            const MachineFunction &MF) const override;
 };
 
 } // namespace llvm
diff --git a/llvm/test/CodeGen/NVPTX/misched_func_call.ll b/llvm/test/CodeGen/NVPTX/misched_func_call.ll
new file mode 100644
index 00000000000000..c54674a9c791f2
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/misched_func_call.ll
@@ -0,0 +1,108 @@
+; RUN: llc -O3 -march=nvptx64 -enable-misched %s -o - | FileCheck %s
+
+; ModuleID = 'The Accel Module'
+source_filename = "The Accel Module"
+target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"
+target triple = "nvptx64-nvidia-cuda"
+
+; Function Attrs: noinline
+define ptx_kernel void @"my_kernel"(i32 %"arg_0", i64 %"arg_1", i64 %"arg_2", i64 %"arg_3") {
+"Entry_BB":
+%r = tail call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
+%r6 = tail call i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
+%r7 = mul i32 %r, %r6
+%r9 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
+%r10 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+%r11 = mul i32 %r, %r9
+%r12 = add i32 %r10, %r11
+%"arg_3.tr" = trunc i64 %"arg_3" to i32
+%r16 = shl i32 %"arg_3.tr", 1
+%r19.not = icmp slt i32 %r12, %r16
+br i1 %r19.not, label %"BB1490", label %"EXIT_BB"
+
+"BB1490": ; preds = %"Entry_BB"
+%r23 = sext i32 %"arg_0" to i64
+%r24 = shl nsw i64 %r23, 3
+br label %"BB1692"
+
+"BB1692": ; preds = %"BB18", %"BB1490"
+%"$$i_l40_0_t23.0" = phi i32 [ %r12, %"BB1490" ], [ %r80, %"BB18" ]
+%r28 = sext i32 %"$$i_l40_0_t23.0" to i64
+%0 = or i64 %r28, %"arg_3"
+%1 = and i64 %0, -4294967296
+%2 = icmp eq i64 %1, 0
+br i1 %2, label %3, label %8
+
+3:                                                ; preds = %"BB1692"
+%4 = trunc i64 %"arg_3" to i32
+%5 = trunc i64 %r28 to i32
+%6 = udiv i32 %5, %4
+%7 = zext i32 %6 to i64
+br label %"BB18"
+
+8:                                                ; preds = %"BB1692"
+%9 = sdiv i64 %r28, %"arg_3"
+br label %"BB18"
+
+"BB18": ; preds = %8, %3
+%10 = phi i64 [ %7, %3 ], [ %9, %8 ]
+%r31 = trunc i64 %10 to i32
+%.neg = mul i64 %10, -4294967296
+%r35 = ashr exact i64 %.neg, 32
+%r38 = mul i64 %"arg_3", %r35
+%r39 = add i64 %r28, %r38
+%r42 = mul i64 %r24, %r39
+%r44 = mul i32 %r31, 10
+%r47 = inttoptr i64 %"arg_1" to ptr addrspace(1)
+%gep2 = getelementptr i8, ptr addrspace(1) %r47, i64 %r42
+%11 = sext i32 %r44 to i64
+%r53 = getelementptr double, ptr addrspace(1) %gep2, i64 %11
+%r54 = load double, ptr addrspace(1) %r53, align 8
+; CHECK:      call.uni (retval0),
+; CHECK-NEXT: _FOO,
+; CHECK-NEXT: (
+; CHECK-NEXT: param0
+; CHECK-NEXT: );
+%r55 = tail call double @_FOO(double %r54)
+%12 = trunc i64 %r39 to i32
+%r59 = mul i32 %"arg_0", %12
+%r60 = add i32 %r59, 1
+%r61 = sitofp i32 %r60 to double
+%r65 = add i32 %r31, 1
+%r66 = sitofp i32 %r65 to double
+%r67 = tail call double @llvm.fma.f64(double %r55, double %r55, double %r66)
+%r68 = fadd double %r67, %r61
+%r71 = inttoptr i64 %"arg_2" to ptr addrspace(1)
+%gep88 = getelementptr i8, ptr addrspace(1) %r71, i64 %r42
+%r77 = getelementptr double, ptr addrspace(1) %gep88, i64 %11
+store double %r68, ptr addrspace(1) %r77, align 8
+%r80 = add i32 %r7, %"$$i_l40_0_t23.0"
+%r85 = icmp slt i32 %r80, %r16
+br i1 %r85, label %"BB1692", label %"EXIT_BB"
+
+"EXIT_BB": ; preds = %"BB18", %"Entry_BB"
+ret void
+}
+
+; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.ntid.x() #1
+
+; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.nctaid.x() #1
+
+; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #1
+
+; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1
+
+declare double @_FOO(double)
+
+; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
+declare double @llvm.fma.f64(double, double, double) #1
+
+attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
+
+!llvm.module.flags = !{ !2}
+
+!2 = !{i32 4, !"nvvm-reflect-ftz", i32 0}

``````````

</details>


https://github.com/llvm/llvm-project/pull/116522


More information about the llvm-commits mailing list