[llvm] [NVPTX] support packed f32 instructions for sm_100+ (PR #126337)

via llvm-commits llvm-commits at lists.llvm.org
Sat Feb 8 11:59:05 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: Princeton Ferro (Prince781)

<details>
<summary>Changes</summary>

This adds support for lowering `fadd`, `fsub`, `fmul`, and `fma` to sm_100+ packed-f32 instructions[^1] (e.g. `add.rn.f32x2 Int64Reg, Int64Reg`). Rather than legalizing `v2f32`, we handle these four instructions ad hoc, so that codegen remains the same unless these instructions are present. We also introduce some DAGCombiner rules to simplify bitwise packing/unpacking to use `mov`, and to reduce redundant `mov`s.

In this PR I didn't implement support for alternative rounding modes, as that was lower priority. If there's sufficient demand, I can add that to this PR. Otherwise we can leave that for later.

[^1]: Introduced in PTX 8.6: https://docs.nvidia.com/cuda/parallel-thread-execution/#changes-in-ptx-isa-version-8-6

---

Patch is 127.41 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/126337.diff


10 Files Affected:

- (modified) llvm/include/llvm/Target/TargetSelectionDAG.td (+7) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+19) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h (+1) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+275-9) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+4) 
- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+14-2) 
- (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+30) 
- (modified) llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td (+3-1) 
- (modified) llvm/lib/Target/NVPTX/NVPTXSubtarget.h (+3) 
- (added) llvm/test/CodeGen/NVPTX/f32x2-instructions.ll (+2665) 


``````````diff
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 42a5fbec95174e..394428594b9870 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -115,6 +115,9 @@ def SDTPtrAddOp : SDTypeProfile<1, 2, [     // ptradd
 def SDTIntBinOp : SDTypeProfile<1, 2, [     // add, and, or, xor, udiv, etc.
   SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisInt<0>
 ]>;
+def SDTIntTernaryOp : SDTypeProfile<1, 3, [  // fma32x2
+  SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>, SDTCisInt<0>
+]>;
 def SDTIntShiftOp : SDTypeProfile<1, 2, [   // shl, sra, srl
   SDTCisSameAs<0, 1>, SDTCisInt<0>, SDTCisInt<2>
 ]>;
@@ -818,6 +821,10 @@ def step_vector : SDNode<"ISD::STEP_VECTOR", SDTypeProfile<1, 1,
 def scalar_to_vector : SDNode<"ISD::SCALAR_TO_VECTOR", SDTypeProfile<1, 1, []>,
                               []>;
 
+def build_pair : SDNode<"ISD::BUILD_PAIR", SDTypeProfile<1, 2,
+                        [SDTCisInt<0>, SDTCisInt<1>, SDTCisInt<2>]>, []>;
+
+
 // vector_extract/vector_insert are deprecated. extractelt/insertelt
 // are preferred.
 def vector_extract : SDNode<"ISD::EXTRACT_VECTOR_ELT",
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index ec654e0f3f200f..3a39f6dab0c85f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -190,6 +190,12 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
       SelectI128toV2I64(N);
       return;
     }
+    if (N->getOperand(1).getValueType() == MVT::i64 &&
+        N->getValueType(0) == MVT::f32 && N->getValueType(1) == MVT::f32) {
+      // {f32,f32} = mov i64
+      SelectI64ToV2F32(N);
+      return;
+    }
     break;
   }
   case ISD::FADD:
@@ -2765,6 +2771,19 @@ void NVPTXDAGToDAGISel::SelectI128toV2I64(SDNode *N) {
   ReplaceNode(N, Mov);
 }
 
+void NVPTXDAGToDAGISel::SelectI64ToV2F32(SDNode *N) {
+  SDValue Ch = N->getOperand(0);
+  SDValue Src = N->getOperand(1);
+  assert(N->getValueType(0) == MVT::f32 && N->getValueType(1) == MVT::f32 &&
+         "expected {f32,f32} = CopyFromReg i64");
+  SDLoc DL(N);
+
+  SDNode *Mov = CurDAG->getMachineNode(NVPTX::I64toV2F32, DL,
+                                       {MVT::f32, MVT::f32, Ch.getValueType()},
+                                       {Src, Ch});
+  ReplaceNode(N, Mov);
+}
+
 /// GetConvertOpcode - Returns the CVT_ instruction opcode that implements a
 /// conversion from \p SrcTy to \p DestTy.
 unsigned NVPTXDAGToDAGISel::GetConvertOpcode(MVT DestTy, MVT SrcTy,
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 8dc6bc86c68281..703a80f74e90c7 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -91,6 +91,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
   bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N);
   void SelectV2I64toI128(SDNode *N);
   void SelectI128toV2I64(SDNode *N);
+  void SelectI64ToV2F32(SDNode *N);
   void SelectCpAsyncBulkG2S(SDNode *N);
   void SelectCpAsyncBulkS2G(SDNode *N);
   void SelectCpAsyncBulkPrefetchL2(SDNode *N);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 58ad92a8934a66..1e417f23fdb099 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -866,6 +866,24 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setBF16OperationAction(ISD::FNEG, MVT::v2bf16, Legal, Expand);
   // (would be) Library functions.
 
+  if (STI.hasF32x2Instructions()) {
+    // Handle custom lowering for: v2f32 = OP v2f32, v2f32
+    for (const auto &Op : {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FMA})
+      setOperationAction(Op, MVT::v2f32, Custom);
+    // Handle custom lowering for: f32 = extract_vector_elt v2f32
+    setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2f32, Custom);
+    // Combine:
+    // i64 = or (i64 = zero_extend X, i64 = shl (i64 = any_extend Y, 32))
+    // -> i64 = build_pair (X, Y)
+    setTargetDAGCombine(ISD::OR);
+    // i32 = truncate (i64 = srl (i64 = build_pair (X, Y), 32))
+    // -> i32 Y
+    setTargetDAGCombine(ISD::TRUNCATE);
+    // i64 = build_pair ({i32, i32} = CopyFromReg (CopyToReg (i64 X)))
+    // -> i64 X
+    setTargetDAGCombine(ISD::BUILD_PAIR);
+  }
+
   // These map to conversion instructions for scalar FP types.
   for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
                          ISD::FROUNDEVEN, ISD::FTRUNC}) {
@@ -1066,6 +1084,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(NVPTXISD::STACKSAVE)
     MAKE_CASE(NVPTXISD::SETP_F16X2)
     MAKE_CASE(NVPTXISD::SETP_BF16X2)
+    MAKE_CASE(NVPTXISD::FADD_F32X2)
+    MAKE_CASE(NVPTXISD::FSUB_F32X2)
+    MAKE_CASE(NVPTXISD::FMUL_F32X2)
+    MAKE_CASE(NVPTXISD::FMA_F32X2)
     MAKE_CASE(NVPTXISD::Dummy)
     MAKE_CASE(NVPTXISD::MUL_WIDE_SIGNED)
     MAKE_CASE(NVPTXISD::MUL_WIDE_UNSIGNED)
@@ -2207,6 +2229,30 @@ SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
     return DAG.getAnyExtOrTrunc(BFE, DL, Op->getValueType(0));
   }
 
+  if (VectorVT == MVT::v2f32) {
+    auto GetOperand = [&DAG, &DL](SDValue Op, SDValue Index) {
+      if (const auto *ConstIdx = dyn_cast<ConstantSDNode>(Index))
+        return Op.getOperand(ConstIdx->getZExtValue());
+      SDValue E0 = Op.getOperand(0);
+      SDValue E1 = Op.getOperand(1);
+      return DAG.getSelectCC(DL, Index, DAG.getIntPtrConstant(0, DL), E0, E1,
+                             ISD::CondCode::SETEQ);
+    };
+    if (SDValue Pair = Vector.getOperand(0);
+        Vector.getOpcode() == ISD::BITCAST &&
+        Pair.getOpcode() == ISD::BUILD_PAIR) {
+      // peek through v2f32 = bitcast (i64 = build_pair (i32 A, i32 B))
+      // where A:i32, B:i32 = CopyFromReg (i64 = F32X2 Operation ...)
+      return DAG.getNode(ISD::BITCAST, DL, Op.getValueType(),
+                         GetOperand(Pair, Index));
+    }
+    if (Vector.getOpcode() == ISD::BUILD_VECTOR)
+      return GetOperand(Vector, Index);
+
+    // Otherwise, let SelectionDAG expand the operand.
+    return SDValue();
+  }
+
   // Constant index will be matched by tablegen.
   if (isa<ConstantSDNode>(Index.getNode()))
     return Op;
@@ -4573,26 +4619,109 @@ PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
   return SDValue();
 }
 
+// If {Lo, Hi} = <packed f32x2 val>, returns that value
+static SDValue peekThroughF32x2Copy(const SDValue &Lo, const SDValue &Hi) {
+  if (Lo.getValueType() != MVT::f32 || Lo.getOpcode() != ISD::CopyFromReg ||
+      Lo.getNode() != Hi.getNode() || Lo == Hi)
+    return SDValue();
+
+  SDNode *CopyF = Lo.getNode();
+  SDNode *CopyT = CopyF->getOperand(0).getNode();
+  if (CopyT->getOpcode() != ISD::CopyToReg)
+    return SDValue();
+
+  // check the two registers are the same
+  if (cast<RegisterSDNode>(CopyF->getOperand(1))->getReg() !=
+      cast<RegisterSDNode>(CopyT->getOperand(1))->getReg())
+    return SDValue();
+
+  SDValue OrigV = CopyT->getOperand(2);
+  if (OrigV.getValueType() != MVT::i64)
+    return SDValue();
+  return OrigV;
+}
+
+static SDValue
+PerformPackedF32StoreCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+                             CodeGenOptLevel OptLevel) {
+  if (OptLevel == CodeGenOptLevel::None)
+    return SDValue();
+
+  // rewrite stores of packed f32 values
+  auto *MemN = cast<MemSDNode>(N);
+  if (MemN->getMemoryVT() == MVT::f32) {
+    std::optional<NVPTXISD::NodeType> NewOpcode;
+    switch (MemN->getOpcode()) {
+    case NVPTXISD::StoreRetvalV2:
+      NewOpcode = NVPTXISD::StoreRetval;
+      break;
+    case NVPTXISD::StoreRetvalV4:
+      NewOpcode = NVPTXISD::StoreRetvalV2;
+      break;
+    case NVPTXISD::StoreParamV2:
+      NewOpcode = NVPTXISD::StoreParam;
+      break;
+    case NVPTXISD::StoreParamV4:
+      NewOpcode = NVPTXISD::StoreParamV2;
+      break;
+    }
+
+    if (NewOpcode) {
+      SmallVector<SDValue> NewOps = {N->getOperand(0), N->getOperand(1)};
+      unsigned NumPacked = 0;
+
+      // gather all packed operands
+      for (unsigned I = 2, E = MemN->getNumOperands(); I < E; I += 2) {
+        if (SDValue Packed = peekThroughF32x2Copy(MemN->getOperand(I),
+                                                  MemN->getOperand(I + 1))) {
+          NewOps.push_back(Packed);
+          ++NumPacked;
+        } else {
+          NumPacked = 0;
+          break;
+        }
+      }
+
+      if (NumPacked) {
+        return DCI.DAG.getMemIntrinsicNode(
+            *NewOpcode, SDLoc(N), N->getVTList(), NewOps, MVT::i64,
+            MemN->getPointerInfo(), MemN->getAlign(),
+            MachineMemOperand::MOStore);
+      }
+    }
+  }
+  return SDValue();
+}
+
 static SDValue PerformStoreCombineHelper(SDNode *N, std::size_t Front,
-                                         std::size_t Back) {
+                                         std::size_t Back,
+                                         TargetLowering::DAGCombinerInfo &DCI,
+                                         CodeGenOptLevel OptLevel) {
   if (all_of(N->ops().drop_front(Front).drop_back(Back),
              [](const SDUse &U) { return U.get()->isUndef(); }))
     // Operand 0 is the previous value in the chain. Cannot return EntryToken
     // as the previous value will become unused and eliminated later.
     return N->getOperand(0);
 
+  if (SDValue V = PerformPackedF32StoreCombine(N, DCI, OptLevel))
+    return V;
+
   return SDValue();
 }
 
-static SDValue PerformStoreParamCombine(SDNode *N) {
+static SDValue PerformStoreParamCombine(SDNode *N,
+                                        TargetLowering::DAGCombinerInfo &DCI,
+                                        CodeGenOptLevel OptLevel) {
   // Operands from the 3rd to the 2nd last one are the values to be stored.
   //   {Chain, ArgID, Offset, Val, Glue}
-  return PerformStoreCombineHelper(N, 3, 1);
+  return PerformStoreCombineHelper(N, 3, 1, DCI, OptLevel);
 }
 
-static SDValue PerformStoreRetvalCombine(SDNode *N) {
+static SDValue PerformStoreRetvalCombine(SDNode *N,
+                                         TargetLowering::DAGCombinerInfo &DCI,
+                                         CodeGenOptLevel OptLevel) {
   // Operands from the 2nd to the last one are the values to be stored
-  return PerformStoreCombineHelper(N, 2, 0);
+  return PerformStoreCombineHelper(N, 2, 0, DCI, OptLevel);
 }
 
 /// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
@@ -5055,10 +5184,10 @@ static SDValue PerformEXTRACTCombine(SDNode *N,
       IsPTXVectorType(VectorVT.getSimpleVT()))
     return SDValue(); // Native vector loads already combine nicely w/
                       // extract_vector_elt.
-  // Don't mess with singletons or v2*16, v4i8 and v8i8 types, we already
+  // Don't mess with singletons or v2*16, v4i8, v8i8, or v2f32 types, we already
   // handle them OK.
   if (VectorVT.getVectorNumElements() == 1 || Isv2x16VT(VectorVT) ||
-      VectorVT == MVT::v4i8 || VectorVT == MVT::v8i8)
+      VectorVT == MVT::v4i8 || VectorVT == MVT::v8i8 || VectorVT == MVT::v2f32)
     return SDValue();
 
   // Don't mess with undef values as sra may be simplified to 0, not undef.
@@ -5188,6 +5317,78 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
   return DAG.getNode(ISD::BITCAST, DL, VT, PRMT);
 }
 
+static SDValue PerformORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+                                CodeGenOptLevel OptLevel) {
+  if (OptLevel == CodeGenOptLevel::None)
+    return SDValue();
+
+  SDValue Op0 = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+
+  // i64 = or (i64 = zero_extend A, i64 = shl (i64 = any_extend B, 32))
+  // -> i64 = build_pair (A, B)
+  if (N->getValueType(0) == MVT::i64 && Op0.getOpcode() == ISD::ZERO_EXTEND &&
+      Op1.getOpcode() == ISD::SHL) {
+    SDValue SHLOp0 = Op1.getOperand(0);
+    SDValue SHLOp1 = Op1.getOperand(1);
+    if (const auto *Const = dyn_cast<ConstantSDNode>(SHLOp1);
+        Const && Const->getZExtValue() == 32 &&
+        SHLOp0.getOpcode() == ISD::ANY_EXTEND) {
+      SDLoc DL(N);
+      return DCI.DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64,
+                             {Op0.getOperand(0), SHLOp0.getOperand(0)});
+    }
+  }
+  return SDValue();
+}
+
+static SDValue PerformTRUNCATECombine(SDNode *N,
+                                      TargetLowering::DAGCombinerInfo &DCI,
+                                      CodeGenOptLevel OptLevel) {
+  if (OptLevel == CodeGenOptLevel::None)
+    return SDValue();
+
+  SDValue Op = N->getOperand(0);
+  if (Op.getOpcode() == ISD::SRL) {
+    SDValue SrlOp = Op.getOperand(0);
+    SDValue SrlSh = Op.getOperand(1);
+    // i32 = truncate (i64 = srl (i64 build_pair (A, B), 32))
+    // -> i32 A
+    if (const auto *Const = dyn_cast<ConstantSDNode>(SrlSh);
+        Const && Const->getZExtValue() == 32) {
+      if (SrlOp.getOpcode() == ISD::BUILD_PAIR)
+        return SrlOp.getOperand(1);
+    }
+  }
+
+  return SDValue();
+}
+
+static SDValue PerformBUILD_PAIRCombine(SDNode *N,
+                                        TargetLowering::DAGCombinerInfo &DCI,
+                                        CodeGenOptLevel OptLevel) {
+  if (OptLevel == CodeGenOptLevel::None)
+    return SDValue();
+
+  EVT ToVT = N->getValueType(0);
+  SDValue Op0 = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+  // i64 = build_pair ({i32, i32} = CopyFromReg (CopyToReg (i64 X)))
+  // -> i64 X
+  if (ToVT == MVT::i64 && Op0.getOpcode() == ISD::CopyFromReg &&
+      Op1.getNode() == Op0.getNode() && Op0 != Op1) {
+    SDValue CFRChain = Op0.getOperand(0);
+    Register Reg = cast<RegisterSDNode>(Op0.getOperand(1))->getReg();
+    if (CFRChain.getOpcode() == ISD::CopyToReg &&
+        cast<RegisterSDNode>(CFRChain.getOperand(1))->getReg() == Reg) {
+      SDValue Value = CFRChain.getOperand(2);
+      return Value;
+    }
+  }
+
+  return SDValue();
+}
+
 SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
                                                DAGCombinerInfo &DCI) const {
   CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5211,17 +5412,23 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
     case NVPTXISD::StoreRetval:
     case NVPTXISD::StoreRetvalV2:
     case NVPTXISD::StoreRetvalV4:
-      return PerformStoreRetvalCombine(N);
+      return PerformStoreRetvalCombine(N, DCI, OptLevel);
     case NVPTXISD::StoreParam:
     case NVPTXISD::StoreParamV2:
     case NVPTXISD::StoreParamV4:
-      return PerformStoreParamCombine(N);
+      return PerformStoreParamCombine(N, DCI, OptLevel);
     case ISD::EXTRACT_VECTOR_ELT:
       return PerformEXTRACTCombine(N, DCI);
     case ISD::VSELECT:
       return PerformVSELECTCombine(N, DCI);
     case ISD::BUILD_VECTOR:
       return PerformBUILD_VECTORCombine(N, DCI);
+    case ISD::OR:
+      return PerformORCombine(N, DCI, OptLevel);
+    case ISD::TRUNCATE:
+      return PerformTRUNCATECombine(N, DCI, OptLevel);
+    case ISD::BUILD_PAIR:
+      return PerformBUILD_PAIRCombine(N, DCI, OptLevel);
   }
   return SDValue();
 }
@@ -5478,6 +5685,59 @@ static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG,
   Results.push_back(NewValue.getValue(3));
 }
 
+static void ReplaceF32x2Op(SDNode *N, SelectionDAG &DAG,
+                           SmallVectorImpl<SDValue> &Results) {
+  SDLoc DL(N);
+  EVT OldResultTy = N->getValueType(0); // <2 x float>
+  assert(OldResultTy == MVT::v2f32 && "Unexpected result type for F32x2 op!");
+
+  SmallVector<SDValue> NewOps;
+
+  // whether we use FTZ (TODO)
+
+  // replace with NVPTX F32x2 op:
+  unsigned Opcode;
+  switch (N->getOpcode()) {
+  case ISD::FADD:
+    Opcode = NVPTXISD::FADD_F32X2;
+    break;
+  case ISD::FSUB:
+    Opcode = NVPTXISD::FSUB_F32X2;
+    break;
+  case ISD::FMUL:
+    Opcode = NVPTXISD::FMUL_F32X2;
+    break;
+  case ISD::FMA:
+    Opcode = NVPTXISD::FMA_F32X2;
+    break;
+  default:
+    llvm_unreachable("Unexpected opcode");
+  }
+
+  // bitcast operands: <2 x float> -> i64
+  for (const SDValue &Op : N->ops())
+    NewOps.push_back(DAG.getNode(ISD::BITCAST, DL, MVT::i64, Op));
+
+  SDValue Chain = DAG.getEntryNode();
+
+  // break packed result into two f32 registers for later instructions that may
+  // access element #0 or #1
+  SDValue NewValue = DAG.getNode(Opcode, DL, MVT::i64, NewOps);
+  MachineRegisterInfo &RegInfo = DAG.getMachineFunction().getRegInfo();
+  Register DestReg = RegInfo.createVirtualRegister(
+      DAG.getTargetLoweringInfo().getRegClassFor(MVT::i64));
+  SDValue RegCopy = DAG.getCopyToReg(Chain, DL, DestReg, NewValue);
+  SDValue Explode = DAG.getNode(ISD::CopyFromReg, DL,
+                                {MVT::f32, MVT::f32, Chain.getValueType()},
+                                {RegCopy, DAG.getRegister(DestReg, MVT::i64)});
+  // cast i64 result of new op back to <2 x float>
+  Results.push_back(DAG.getBitcast(
+      OldResultTy,
+      DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64,
+                  {DAG.getBitcast(MVT::i32, Explode.getValue(0)),
+                   DAG.getBitcast(MVT::i32, Explode.getValue(1))})));
+}
+
 void NVPTXTargetLowering::ReplaceNodeResults(
     SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
   switch (N->getOpcode()) {
@@ -5495,6 +5755,12 @@ void NVPTXTargetLowering::ReplaceNodeResults(
   case ISD::CopyFromReg:
     ReplaceCopyFromReg_128(N, DAG, Results);
     return;
+  case ISD::FADD:
+  case ISD::FSUB:
+  case ISD::FMUL:
+  case ISD::FMA:
+    ReplaceF32x2Op(N, DAG, Results);
+    return;
   }
 }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 5adf69d621552f..8fd4ded42a238a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -55,6 +55,10 @@ enum NodeType : unsigned {
   FSHR_CLAMP,
   MUL_WIDE_SIGNED,
   MUL_WIDE_UNSIGNED,
+  FADD_F32X2,
+  FMUL_F32X2,
+  FSUB_F32X2,
+  FMA_F32X2,
   SETP_F16X2,
   SETP_BF16X2,
   BFE,
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 7d9697e40e6aba..b0eb9bbbb2456a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -165,6 +165,7 @@ def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
 def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
 def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
 def hasTcgen05Instructions : Predicate<"Subtarget->hasTcgen05Instructions()">;
+def hasF32x2Instructions : Predicate<"Subtarget->hasF32x2Instructions()">;
 
 def True : Predicate<"true">;
 def False : Predicate<"false">;
@@ -2638,13 +2639,13 @@ class LastCallArgInstVT<NVPTXRegClass regclass, ValueType vt> :
   NVPTXInst<(outs), (ins regclass:$a), "$a",
             [(LastCallArg (i32 0), vt:$a)]>;
 
-def CallArgI64     : CallArgInst<Int64Regs>;
+def CallArgI64     : CallArgInstVT<Int64Regs, i64>;
 def CallArgI32     : CallArgInstVT<Int32Regs, i32>;
 def CallArgI16     : CallArgInstVT<Int16Regs, i16>;
 def CallArgF64     : CallArgInst<Float64Regs>;
 def CallArgF32     : CallArgInst<Float32Regs>;
 
-def LastCallArgI64 : LastCallArgInst<Int64Regs>;
+def LastCallArgI64 : LastCallArgInstVT<Int64Regs, i64>;
 def LastCallArgI32 : LastCallArgInstVT<Int32Regs, i32>;
 def LastCallArgI16 : LastCallArgInstVT<Int16Regs, i16>;
 def LastCallArgF64 : LastCallArgInst<Float64Regs>;
@@ -3371,6 +3372,9 @@ let hasSideEffects = false in {
   def V2F32toF64 : NVPTXInst<(outs Float64Regs:$d),
                              (ins Float32Regs:$s1, Float32Regs:$s2),
                              "mov.b64 \t$d, {{$s1, $s2}};", []>;
+  def V2F32toI64 : NVPTXInst<(outs Int64Regs:$d),
+                             (ins Float32Regs:$s1, Float32Regs:$s2),
+                             "mov.b64 \t$d, {{$s1, $s2}};", []>;
 
   // unpack a larger int register to a set of smaller int registers
   def I64toV4I16 : NVPTXInst<(outs Int16Regs:$d1, Int16Regs:$d2,
@@ -3383,6 +3387,9 @@ let hasSideEffects = false in {
   def I64toV2I32 : NVPTXInst<(outs Int32Regs:$d1, Int32Regs:$d2),
                              (ins Int64Regs:$s),
                              "mov.b64 \t{{$d1, $d2}}, $...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/126337


More information about the llvm-commits mailing list