[llvm] r262373 - [NVPTX] Use different, convergent MIs for convergent calls.

Justin Lebar via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 1 11:24:03 PST 2016


Author: jlebar
Date: Tue Mar  1 13:24:03 2016
New Revision: 262373

URL: http://llvm.org/viewvc/llvm-project?rev=262373&view=rev
Log:
[NVPTX] Use different, convergent MIs for convergent calls.

Summary:
Calls sometimes need to be convergent.  This is already handled at the
LLVM IR level, but it also needs to be handled at the MI level.

Ideally we'd propagate convergence from instructions, down through the
selection DAG, and into MIs.  But this is Hard, and would affect
optimizations in the SDNs -- right now only SDNs with two operands have
any flags at all.

Instead, here's a much simpler hack: Add new opcodes for NVPTX for
convergent calls, and generate these when lowering convergent LLVM
calls.

Reviewers: jholewinski

Subscribers: jholewinski, chandlerc, joker.eph, jhen, tra, llvm-commits

Differential Revision: http://reviews.llvm.org/D17423

Added:
    llvm/trunk/test/CodeGen/NVPTX/convergent-mir-call.ll
Modified:
    llvm/trunk/include/llvm/Target/TargetLowering.h
    llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
    llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp
    llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.h
    llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td

Modified: llvm/trunk/include/llvm/Target/TargetLowering.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Target/TargetLowering.h?rev=262373&r1=262372&r2=262373&view=diff
==============================================================================
--- llvm/trunk/include/llvm/Target/TargetLowering.h (original)
+++ llvm/trunk/include/llvm/Target/TargetLowering.h Tue Mar  1 13:24:03 2016
@@ -2348,6 +2348,7 @@ public:
     bool IsInReg           : 1;
     bool DoesNotReturn     : 1;
     bool IsReturnValueUsed : 1;
+    bool IsConvergent      : 1;
 
     // IsTailCall should be modified by implementations of
     // TargetLowering::LowerCall that perform tail call conversions.
@@ -2366,10 +2367,11 @@ public:
     SmallVector<ISD::InputArg, 32> Ins;
 
     CallLoweringInfo(SelectionDAG &DAG)
-      : RetTy(nullptr), RetSExt(false), RetZExt(false), IsVarArg(false),
-        IsInReg(false), DoesNotReturn(false), IsReturnValueUsed(true),
-        IsTailCall(false), NumFixedArgs(-1), CallConv(CallingConv::C),
-        DAG(DAG), CS(nullptr), IsPatchPoint(false) {}
+        : RetTy(nullptr), RetSExt(false), RetZExt(false), IsVarArg(false),
+          IsInReg(false), DoesNotReturn(false), IsReturnValueUsed(true),
+          IsConvergent(false), IsTailCall(false), NumFixedArgs(-1),
+          CallConv(CallingConv::C), DAG(DAG), CS(nullptr), IsPatchPoint(false) {
+    }
 
     CallLoweringInfo &setDebugLoc(SDLoc dl) {
       DL = dl;
@@ -2441,6 +2443,11 @@ public:
       return *this;
     }
 
+    CallLoweringInfo &setConvergent(bool Value = true) {
+      IsConvergent = Value;
+      return *this;
+    }
+
     CallLoweringInfo &setSExtResult(bool Value = true) {
       RetSExt = Value;
       return *this;

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp?rev=262373&r1=262372&r2=262373&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp Tue Mar  1 13:24:03 2016
@@ -5562,9 +5562,11 @@ void SelectionDAGBuilder::LowerCallTo(Im
     isTailCall = false;
 
   TargetLowering::CallLoweringInfo CLI(DAG);
-  CLI.setDebugLoc(getCurSDLoc()).setChain(getRoot())
-    .setCallee(RetTy, FTy, Callee, std::move(Args), CS)
-    .setTailCall(isTailCall);
+  CLI.setDebugLoc(getCurSDLoc())
+      .setChain(getRoot())
+      .setCallee(RetTy, FTy, Callee, std::move(Args), CS)
+      .setTailCall(isTailCall)
+      .setConvergent(CS.isConvergent());
   std::pair<SDValue, SDValue> Result = lowerInvokable(CLI, EHPadBB);
 
   if (Result.first.getNode()) {

Modified: llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp?rev=262373&r1=262372&r2=262373&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp Tue Mar  1 13:24:03 2016
@@ -314,8 +314,12 @@ const char *NVPTXTargetLowering::getTarg
     return "NVPTXISD::DeclareRetParam";
   case NVPTXISD::PrintCall:
     return "NVPTXISD::PrintCall";
+  case NVPTXISD::PrintConvergentCall:
+    return "NVPTXISD::PrintConvergentCall";
   case NVPTXISD::PrintCallUni:
     return "NVPTXISD::PrintCallUni";
+  case NVPTXISD::PrintConvergentCallUni:
+    return "NVPTXISD::PrintConvergentCallUni";
   case NVPTXISD::LoadParam:
     return "NVPTXISD::LoadParam";
   case NVPTXISD::LoadParamV2:
@@ -1439,8 +1443,12 @@ SDValue NVPTXTargetLowering::LowerCall(T
   SDValue PrintCallOps[] = {
     Chain, DAG.getConstant((Ins.size() == 0) ? 0 : 1, dl, MVT::i32), InFlag
   };
-  Chain = DAG.getNode(Func ? (NVPTXISD::PrintCallUni) : (NVPTXISD::PrintCall),
-                      dl, PrintCallVTs, PrintCallOps);
+  // We model convergent calls as separate opcodes.
+  unsigned Opcode = Func ? NVPTXISD::PrintCallUni : NVPTXISD::PrintCall;
+  if (CLI.IsConvergent)
+    Opcode = Opcode == NVPTXISD::PrintCallUni ? NVPTXISD::PrintConvergentCallUni
+                                              : NVPTXISD::PrintConvergentCall;
+  Chain = DAG.getNode(Opcode, dl, PrintCallVTs, PrintCallOps);
   InFlag = Chain.getValue(1);
 
   // Ops to print out the function name

Modified: llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.h?rev=262373&r1=262372&r2=262373&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.h (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.h Tue Mar  1 13:24:03 2016
@@ -34,7 +34,9 @@ enum NodeType : unsigned {
   DeclareRet,
   DeclareScalarRet,
   PrintCall,
+  PrintConvergentCall,
   PrintCallUni,
+  PrintConvergentCallUni,
   CallArgBegin,
   CallArg,
   LastCallArg,

Modified: llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td?rev=262373&r1=262372&r2=262373&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td Tue Mar  1 13:24:03 2016
@@ -1701,9 +1701,15 @@ def LoadParamV4 :
 def PrintCall :
   SDNode<"NVPTXISD::PrintCall", SDTPrintCallProfile,
          [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
+def PrintConvergentCall :
+  SDNode<"NVPTXISD::PrintConvergentCall", SDTPrintCallProfile,
+         [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
 def PrintCallUni :
   SDNode<"NVPTXISD::PrintCallUni", SDTPrintCallUniProfile,
          [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
+def PrintConvergentCallUni :
+  SDNode<"NVPTXISD::PrintConvergentCallUni", SDTPrintCallUniProfile,
+         [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
 def StoreParam :
   SDNode<"NVPTXISD::StoreParam", SDTStoreParamProfile,
          [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
@@ -1821,53 +1827,44 @@ class StoreRetvalV4Inst<NVPTXRegClass re
                 []>;
 
 let isCall=1 in {
- def PrintCallNoRetInst : NVPTXInst<(outs), (ins),
-   "call ", [(PrintCall (i32 0))]>;
- def PrintCallRetInst1 : NVPTXInst<(outs), (ins),
-   "call (retval0), ", [(PrintCall (i32 1))]>;
- def PrintCallRetInst2 : NVPTXInst<(outs), (ins),
-   "call (retval0, retval1), ", [(PrintCall (i32 2))]>;
- def PrintCallRetInst3 : NVPTXInst<(outs), (ins),
-   "call (retval0, retval1, retval2), ", [(PrintCall (i32 3))]>;
- def PrintCallRetInst4 : NVPTXInst<(outs), (ins),
-   "call (retval0, retval1, retval2, retval3), ", [(PrintCall (i32 4))]>;
- def PrintCallRetInst5 : NVPTXInst<(outs), (ins),
-   "call (retval0, retval1, retval2, retval3, retval4), ",
-   [(PrintCall (i32 5))]>;
- def PrintCallRetInst6 : NVPTXInst<(outs), (ins),
-   "call (retval0, retval1, retval2, retval3, retval4, retval5), ",
-   [(PrintCall (i32 6))]>;
- def PrintCallRetInst7 : NVPTXInst<(outs), (ins),
-   "call (retval0, retval1, retval2, retval3, retval4, retval5, retval6), ",
-   [(PrintCall (i32 7))]>;
- def PrintCallRetInst8 : NVPTXInst<(outs), (ins),
-   "call (retval0, retval1, retval2, retval3, retval4, retval5, retval6, "
-         "retval7), ",
-   [(PrintCall (i32 8))]>;
-
- def PrintCallUniNoRetInst : NVPTXInst<(outs), (ins),
-   "call.uni ", [(PrintCallUni (i32 0))]>;
- def PrintCallUniRetInst1 : NVPTXInst<(outs), (ins),
-   "call.uni (retval0), ", [(PrintCallUni (i32 1))]>;
- def PrintCallUniRetInst2 : NVPTXInst<(outs), (ins),
-   "call.uni (retval0, retval1), ", [(PrintCallUni (i32 2))]>;
- def PrintCallUniRetInst3 : NVPTXInst<(outs), (ins),
-   "call.uni (retval0, retval1, retval2), ", [(PrintCallUni (i32 3))]>;
- def PrintCallUniRetInst4 : NVPTXInst<(outs), (ins),
-   "call.uni (retval0, retval1, retval2, retval3), ", [(PrintCallUni (i32 4))]>;
- def PrintCallUniRetInst5 : NVPTXInst<(outs), (ins),
-   "call.uni (retval0, retval1, retval2, retval3, retval4), ",
-   [(PrintCallUni (i32 5))]>;
- def PrintCallUniRetInst6 : NVPTXInst<(outs), (ins),
-   "call.uni (retval0, retval1, retval2, retval3, retval4, retval5), ",
-   [(PrintCallUni (i32 6))]>;
- def PrintCallUniRetInst7 : NVPTXInst<(outs), (ins),
-   "call.uni (retval0, retval1, retval2, retval3, retval4, retval5, retval6), ",
-   [(PrintCallUni (i32 7))]>;
- def PrintCallUniRetInst8 : NVPTXInst<(outs), (ins),
-   "call.uni (retval0, retval1, retval2, retval3, retval4, retval5, retval6, "
-             "retval7), ",
-   [(PrintCallUni (i32 8))]>;
+  multiclass CALL<string OpcStr, SDNode OpNode> {
+     def PrintCallNoRetInst : NVPTXInst<(outs), (ins),
+       !strconcat(OpcStr, " "), [(OpNode (i32 0))]>;
+     def PrintCallRetInst1 : NVPTXInst<(outs), (ins),
+       !strconcat(OpcStr, " (retval0), "), [(OpNode (i32 1))]>;
+     def PrintCallRetInst2 : NVPTXInst<(outs), (ins),
+       !strconcat(OpcStr, " (retval0, retval1), "), [(OpNode (i32 2))]>;
+     def PrintCallRetInst3 : NVPTXInst<(outs), (ins),
+       !strconcat(OpcStr, " (retval0, retval1, retval2), "), [(OpNode (i32 3))]>;
+     def PrintCallRetInst4 : NVPTXInst<(outs), (ins),
+       !strconcat(OpcStr, " (retval0, retval1, retval2, retval3), "),
+       [(OpNode (i32 4))]>;
+     def PrintCallRetInst5 : NVPTXInst<(outs), (ins),
+       !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4), "),
+       [(OpNode (i32 5))]>;
+     def PrintCallRetInst6 : NVPTXInst<(outs), (ins),
+       !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4, "
+                            "retval5), "),
+       [(OpNode (i32 6))]>;
+     def PrintCallRetInst7 : NVPTXInst<(outs), (ins),
+       !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4, "
+                            "retval5, retval6), "),
+       [(OpNode (i32 7))]>;
+     def PrintCallRetInst8 : NVPTXInst<(outs), (ins),
+       !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4, "
+                            "retval5, retval6, retval7), "),
+       [(OpNode (i32 8))]>;
+  }
+}
+
+defm Call : CALL<"call", PrintCall>;
+defm CallUni : CALL<"call.uni", PrintCallUni>;
+
+// Convergent call instructions.  These are identical to regular calls, except
+// they have the isConvergent bit set.
+let isConvergent=1 in {
+  defm ConvergentCall : CALL<"call", PrintConvergentCall>;
+  defm ConvergentCallUni : CALL<"call.uni", PrintConvergentCallUni>;
 }
 
 def LoadParamMemI64    : LoadParamMemInst<Int64Regs, ".b64">;

Added: llvm/trunk/test/CodeGen/NVPTX/convergent-mir-call.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/NVPTX/convergent-mir-call.ll?rev=262373&view=auto
==============================================================================
--- llvm/trunk/test/CodeGen/NVPTX/convergent-mir-call.ll (added)
+++ llvm/trunk/test/CodeGen/NVPTX/convergent-mir-call.ll Tue Mar  1 13:24:03 2016
@@ -0,0 +1,27 @@
+; RUN: llc -mtriple nvptx64-nvidia-cuda -stop-after machine-cp -o - < %s 2>&1 | FileCheck %s
+
+; Check that convergent calls are emitted using convergent MIR instructions,
+; while non-convergent calls are not.
+
+target triple = "nvptx64-nvidia-cuda"
+
+declare void @conv() convergent
+declare void @not_conv()
+
+define void @test(void ()* %f) {
+  ; CHECK: ConvergentCallUniPrintCall
+  ; CHECK-NEXT: @conv
+  call void @conv()
+
+  ; CHECK: CallUniPrintCall
+  ; CHECK-NEXT: @not_conv
+  call void @not_conv()
+
+  ; CHECK: ConvergentCallPrintCall
+  call void %f() convergent
+
+  ; CHECK: CallPrintCall
+  call void %f()
+
+  ret void
+}




More information about the llvm-commits mailing list