[llvm] [NVPTX] support packed f32 instructions for sm_100+ (PR #126337)
via llvm-commits
llvm-commits at lists.llvm.org
Sat Feb 8 11:59:05 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-nvptx
Author: Princeton Ferro (Prince781)
<details>
<summary>Changes</summary>
This adds support for lowering `fadd`, `fsub`, `fmul`, and `fma` to sm_100+ packed-f32 instructions[^1] (e.g. `add.rn.f32x2 Int64Reg, Int64Reg`). Rather than legalizing `v2f32`, we handle these four instructions ad hoc, so that codegen remains the same unless these instructions are present. We also introduce some DAGCombiner rules to simplify bitwise packing/unpacking to use `mov`, and to reduce redundant `mov`s.
In this PR I didn't implement support for alternative rounding modes, as that was lower priority. If there's sufficient demand, I can add that to this PR. Otherwise we can leave that for later.
[^1]: Introduced in PTX 8.6: https://docs.nvidia.com/cuda/parallel-thread-execution/#changes-in-ptx-isa-version-8-6
---
Patch is 127.41 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/126337.diff
10 Files Affected:
- (modified) llvm/include/llvm/Target/TargetSelectionDAG.td (+7)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+19)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h (+1)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+275-9)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+4)
- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+14-2)
- (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+30)
- (modified) llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td (+3-1)
- (modified) llvm/lib/Target/NVPTX/NVPTXSubtarget.h (+3)
- (added) llvm/test/CodeGen/NVPTX/f32x2-instructions.ll (+2665)
``````````diff
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 42a5fbec95174e..394428594b9870 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -115,6 +115,9 @@ def SDTPtrAddOp : SDTypeProfile<1, 2, [ // ptradd
def SDTIntBinOp : SDTypeProfile<1, 2, [ // add, and, or, xor, udiv, etc.
SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisInt<0>
]>;
+def SDTIntTernaryOp : SDTypeProfile<1, 3, [ // fma32x2
+ SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>, SDTCisInt<0>
+]>;
def SDTIntShiftOp : SDTypeProfile<1, 2, [ // shl, sra, srl
SDTCisSameAs<0, 1>, SDTCisInt<0>, SDTCisInt<2>
]>;
@@ -818,6 +821,10 @@ def step_vector : SDNode<"ISD::STEP_VECTOR", SDTypeProfile<1, 1,
def scalar_to_vector : SDNode<"ISD::SCALAR_TO_VECTOR", SDTypeProfile<1, 1, []>,
[]>;
+def build_pair : SDNode<"ISD::BUILD_PAIR", SDTypeProfile<1, 2,
+ [SDTCisInt<0>, SDTCisInt<1>, SDTCisInt<2>]>, []>;
+
+
// vector_extract/vector_insert are deprecated. extractelt/insertelt
// are preferred.
def vector_extract : SDNode<"ISD::EXTRACT_VECTOR_ELT",
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index ec654e0f3f200f..3a39f6dab0c85f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -190,6 +190,12 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
SelectI128toV2I64(N);
return;
}
+ if (N->getOperand(1).getValueType() == MVT::i64 &&
+ N->getValueType(0) == MVT::f32 && N->getValueType(1) == MVT::f32) {
+ // {f32,f32} = mov i64
+ SelectI64ToV2F32(N);
+ return;
+ }
break;
}
case ISD::FADD:
@@ -2765,6 +2771,19 @@ void NVPTXDAGToDAGISel::SelectI128toV2I64(SDNode *N) {
ReplaceNode(N, Mov);
}
+void NVPTXDAGToDAGISel::SelectI64ToV2F32(SDNode *N) {
+ SDValue Ch = N->getOperand(0);
+ SDValue Src = N->getOperand(1);
+ assert(N->getValueType(0) == MVT::f32 && N->getValueType(1) == MVT::f32 &&
+ "expected {f32,f32} = CopyFromReg i64");
+ SDLoc DL(N);
+
+ SDNode *Mov = CurDAG->getMachineNode(NVPTX::I64toV2F32, DL,
+ {MVT::f32, MVT::f32, Ch.getValueType()},
+ {Src, Ch});
+ ReplaceNode(N, Mov);
+}
+
/// GetConvertOpcode - Returns the CVT_ instruction opcode that implements a
/// conversion from \p SrcTy to \p DestTy.
unsigned NVPTXDAGToDAGISel::GetConvertOpcode(MVT DestTy, MVT SrcTy,
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 8dc6bc86c68281..703a80f74e90c7 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -91,6 +91,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N);
void SelectV2I64toI128(SDNode *N);
void SelectI128toV2I64(SDNode *N);
+ void SelectI64ToV2F32(SDNode *N);
void SelectCpAsyncBulkG2S(SDNode *N);
void SelectCpAsyncBulkS2G(SDNode *N);
void SelectCpAsyncBulkPrefetchL2(SDNode *N);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 58ad92a8934a66..1e417f23fdb099 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -866,6 +866,24 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setBF16OperationAction(ISD::FNEG, MVT::v2bf16, Legal, Expand);
// (would be) Library functions.
+ if (STI.hasF32x2Instructions()) {
+ // Handle custom lowering for: v2f32 = OP v2f32, v2f32
+ for (const auto &Op : {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FMA})
+ setOperationAction(Op, MVT::v2f32, Custom);
+ // Handle custom lowering for: f32 = extract_vector_elt v2f32
+ setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2f32, Custom);
+ // Combine:
+ // i64 = or (i64 = zero_extend X, i64 = shl (i64 = any_extend Y, 32))
+ // -> i64 = build_pair (X, Y)
+ setTargetDAGCombine(ISD::OR);
+ // i32 = truncate (i64 = srl (i64 = build_pair (X, Y), 32))
+ // -> i32 Y
+ setTargetDAGCombine(ISD::TRUNCATE);
+ // i64 = build_pair ({i32, i32} = CopyFromReg (CopyToReg (i64 X)))
+ // -> i64 X
+ setTargetDAGCombine(ISD::BUILD_PAIR);
+ }
+
// These map to conversion instructions for scalar FP types.
for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
ISD::FROUNDEVEN, ISD::FTRUNC}) {
@@ -1066,6 +1084,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(NVPTXISD::STACKSAVE)
MAKE_CASE(NVPTXISD::SETP_F16X2)
MAKE_CASE(NVPTXISD::SETP_BF16X2)
+ MAKE_CASE(NVPTXISD::FADD_F32X2)
+ MAKE_CASE(NVPTXISD::FSUB_F32X2)
+ MAKE_CASE(NVPTXISD::FMUL_F32X2)
+ MAKE_CASE(NVPTXISD::FMA_F32X2)
MAKE_CASE(NVPTXISD::Dummy)
MAKE_CASE(NVPTXISD::MUL_WIDE_SIGNED)
MAKE_CASE(NVPTXISD::MUL_WIDE_UNSIGNED)
@@ -2207,6 +2229,30 @@ SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
return DAG.getAnyExtOrTrunc(BFE, DL, Op->getValueType(0));
}
+ if (VectorVT == MVT::v2f32) {
+ auto GetOperand = [&DAG, &DL](SDValue Op, SDValue Index) {
+ if (const auto *ConstIdx = dyn_cast<ConstantSDNode>(Index))
+ return Op.getOperand(ConstIdx->getZExtValue());
+ SDValue E0 = Op.getOperand(0);
+ SDValue E1 = Op.getOperand(1);
+ return DAG.getSelectCC(DL, Index, DAG.getIntPtrConstant(0, DL), E0, E1,
+ ISD::CondCode::SETEQ);
+ };
+ if (SDValue Pair = Vector.getOperand(0);
+ Vector.getOpcode() == ISD::BITCAST &&
+ Pair.getOpcode() == ISD::BUILD_PAIR) {
+ // peek through v2f32 = bitcast (i64 = build_pair (i32 A, i32 B))
+ // where A:i32, B:i32 = CopyFromReg (i64 = F32X2 Operation ...)
+ return DAG.getNode(ISD::BITCAST, DL, Op.getValueType(),
+ GetOperand(Pair, Index));
+ }
+ if (Vector.getOpcode() == ISD::BUILD_VECTOR)
+ return GetOperand(Vector, Index);
+
+ // Otherwise, let SelectionDAG expand the operand.
+ return SDValue();
+ }
+
// Constant index will be matched by tablegen.
if (isa<ConstantSDNode>(Index.getNode()))
return Op;
@@ -4573,26 +4619,109 @@ PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
return SDValue();
}
+// If {Lo, Hi} = <packed f32x2 val>, returns that value
+static SDValue peekThroughF32x2Copy(const SDValue &Lo, const SDValue &Hi) {
+ if (Lo.getValueType() != MVT::f32 || Lo.getOpcode() != ISD::CopyFromReg ||
+ Lo.getNode() != Hi.getNode() || Lo == Hi)
+ return SDValue();
+
+ SDNode *CopyF = Lo.getNode();
+ SDNode *CopyT = CopyF->getOperand(0).getNode();
+ if (CopyT->getOpcode() != ISD::CopyToReg)
+ return SDValue();
+
+ // check the two registers are the same
+ if (cast<RegisterSDNode>(CopyF->getOperand(1))->getReg() !=
+ cast<RegisterSDNode>(CopyT->getOperand(1))->getReg())
+ return SDValue();
+
+ SDValue OrigV = CopyT->getOperand(2);
+ if (OrigV.getValueType() != MVT::i64)
+ return SDValue();
+ return OrigV;
+}
+
+static SDValue
+PerformPackedF32StoreCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+ CodeGenOptLevel OptLevel) {
+ if (OptLevel == CodeGenOptLevel::None)
+ return SDValue();
+
+ // rewrite stores of packed f32 values
+ auto *MemN = cast<MemSDNode>(N);
+ if (MemN->getMemoryVT() == MVT::f32) {
+ std::optional<NVPTXISD::NodeType> NewOpcode;
+ switch (MemN->getOpcode()) {
+ case NVPTXISD::StoreRetvalV2:
+ NewOpcode = NVPTXISD::StoreRetval;
+ break;
+ case NVPTXISD::StoreRetvalV4:
+ NewOpcode = NVPTXISD::StoreRetvalV2;
+ break;
+ case NVPTXISD::StoreParamV2:
+ NewOpcode = NVPTXISD::StoreParam;
+ break;
+ case NVPTXISD::StoreParamV4:
+ NewOpcode = NVPTXISD::StoreParamV2;
+ break;
+ }
+
+ if (NewOpcode) {
+ SmallVector<SDValue> NewOps = {N->getOperand(0), N->getOperand(1)};
+ unsigned NumPacked = 0;
+
+ // gather all packed operands
+ for (unsigned I = 2, E = MemN->getNumOperands(); I < E; I += 2) {
+ if (SDValue Packed = peekThroughF32x2Copy(MemN->getOperand(I),
+ MemN->getOperand(I + 1))) {
+ NewOps.push_back(Packed);
+ ++NumPacked;
+ } else {
+ NumPacked = 0;
+ break;
+ }
+ }
+
+ if (NumPacked) {
+ return DCI.DAG.getMemIntrinsicNode(
+ *NewOpcode, SDLoc(N), N->getVTList(), NewOps, MVT::i64,
+ MemN->getPointerInfo(), MemN->getAlign(),
+ MachineMemOperand::MOStore);
+ }
+ }
+ }
+ return SDValue();
+}
+
static SDValue PerformStoreCombineHelper(SDNode *N, std::size_t Front,
- std::size_t Back) {
+ std::size_t Back,
+ TargetLowering::DAGCombinerInfo &DCI,
+ CodeGenOptLevel OptLevel) {
if (all_of(N->ops().drop_front(Front).drop_back(Back),
[](const SDUse &U) { return U.get()->isUndef(); }))
// Operand 0 is the previous value in the chain. Cannot return EntryToken
// as the previous value will become unused and eliminated later.
return N->getOperand(0);
+ if (SDValue V = PerformPackedF32StoreCombine(N, DCI, OptLevel))
+ return V;
+
return SDValue();
}
-static SDValue PerformStoreParamCombine(SDNode *N) {
+static SDValue PerformStoreParamCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ CodeGenOptLevel OptLevel) {
// Operands from the 3rd to the 2nd last one are the values to be stored.
// {Chain, ArgID, Offset, Val, Glue}
- return PerformStoreCombineHelper(N, 3, 1);
+ return PerformStoreCombineHelper(N, 3, 1, DCI, OptLevel);
}
-static SDValue PerformStoreRetvalCombine(SDNode *N) {
+static SDValue PerformStoreRetvalCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ CodeGenOptLevel OptLevel) {
// Operands from the 2nd to the last one are the values to be stored
- return PerformStoreCombineHelper(N, 2, 0);
+ return PerformStoreCombineHelper(N, 2, 0, DCI, OptLevel);
}
/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
@@ -5055,10 +5184,10 @@ static SDValue PerformEXTRACTCombine(SDNode *N,
IsPTXVectorType(VectorVT.getSimpleVT()))
return SDValue(); // Native vector loads already combine nicely w/
// extract_vector_elt.
- // Don't mess with singletons or v2*16, v4i8 and v8i8 types, we already
+ // Don't mess with singletons or v2*16, v4i8, v8i8, or v2f32 types, we already
// handle them OK.
if (VectorVT.getVectorNumElements() == 1 || Isv2x16VT(VectorVT) ||
- VectorVT == MVT::v4i8 || VectorVT == MVT::v8i8)
+ VectorVT == MVT::v4i8 || VectorVT == MVT::v8i8 || VectorVT == MVT::v2f32)
return SDValue();
// Don't mess with undef values as sra may be simplified to 0, not undef.
@@ -5188,6 +5317,78 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
return DAG.getNode(ISD::BITCAST, DL, VT, PRMT);
}
+static SDValue PerformORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+ CodeGenOptLevel OptLevel) {
+ if (OptLevel == CodeGenOptLevel::None)
+ return SDValue();
+
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+
+ // i64 = or (i64 = zero_extend A, i64 = shl (i64 = any_extend B, 32))
+ // -> i64 = build_pair (A, B)
+ if (N->getValueType(0) == MVT::i64 && Op0.getOpcode() == ISD::ZERO_EXTEND &&
+ Op1.getOpcode() == ISD::SHL) {
+ SDValue SHLOp0 = Op1.getOperand(0);
+ SDValue SHLOp1 = Op1.getOperand(1);
+ if (const auto *Const = dyn_cast<ConstantSDNode>(SHLOp1);
+ Const && Const->getZExtValue() == 32 &&
+ SHLOp0.getOpcode() == ISD::ANY_EXTEND) {
+ SDLoc DL(N);
+ return DCI.DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64,
+ {Op0.getOperand(0), SHLOp0.getOperand(0)});
+ }
+ }
+ return SDValue();
+}
+
+static SDValue PerformTRUNCATECombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ CodeGenOptLevel OptLevel) {
+ if (OptLevel == CodeGenOptLevel::None)
+ return SDValue();
+
+ SDValue Op = N->getOperand(0);
+ if (Op.getOpcode() == ISD::SRL) {
+ SDValue SrlOp = Op.getOperand(0);
+ SDValue SrlSh = Op.getOperand(1);
+ // i32 = truncate (i64 = srl (i64 build_pair (A, B), 32))
+ // -> i32 A
+ if (const auto *Const = dyn_cast<ConstantSDNode>(SrlSh);
+ Const && Const->getZExtValue() == 32) {
+ if (SrlOp.getOpcode() == ISD::BUILD_PAIR)
+ return SrlOp.getOperand(1);
+ }
+ }
+
+ return SDValue();
+}
+
+static SDValue PerformBUILD_PAIRCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ CodeGenOptLevel OptLevel) {
+ if (OptLevel == CodeGenOptLevel::None)
+ return SDValue();
+
+ EVT ToVT = N->getValueType(0);
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ // i64 = build_pair ({i32, i32} = CopyFromReg (CopyToReg (i64 X)))
+ // -> i64 X
+ if (ToVT == MVT::i64 && Op0.getOpcode() == ISD::CopyFromReg &&
+ Op1.getNode() == Op0.getNode() && Op0 != Op1) {
+ SDValue CFRChain = Op0.getOperand(0);
+ Register Reg = cast<RegisterSDNode>(Op0.getOperand(1))->getReg();
+ if (CFRChain.getOpcode() == ISD::CopyToReg &&
+ cast<RegisterSDNode>(CFRChain.getOperand(1))->getReg() == Reg) {
+ SDValue Value = CFRChain.getOperand(2);
+ return Value;
+ }
+ }
+
+ return SDValue();
+}
+
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5211,17 +5412,23 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
case NVPTXISD::StoreRetval:
case NVPTXISD::StoreRetvalV2:
case NVPTXISD::StoreRetvalV4:
- return PerformStoreRetvalCombine(N);
+ return PerformStoreRetvalCombine(N, DCI, OptLevel);
case NVPTXISD::StoreParam:
case NVPTXISD::StoreParamV2:
case NVPTXISD::StoreParamV4:
- return PerformStoreParamCombine(N);
+ return PerformStoreParamCombine(N, DCI, OptLevel);
case ISD::EXTRACT_VECTOR_ELT:
return PerformEXTRACTCombine(N, DCI);
case ISD::VSELECT:
return PerformVSELECTCombine(N, DCI);
case ISD::BUILD_VECTOR:
return PerformBUILD_VECTORCombine(N, DCI);
+ case ISD::OR:
+ return PerformORCombine(N, DCI, OptLevel);
+ case ISD::TRUNCATE:
+ return PerformTRUNCATECombine(N, DCI, OptLevel);
+ case ISD::BUILD_PAIR:
+ return PerformBUILD_PAIRCombine(N, DCI, OptLevel);
}
return SDValue();
}
@@ -5478,6 +5685,59 @@ static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG,
Results.push_back(NewValue.getValue(3));
}
+static void ReplaceF32x2Op(SDNode *N, SelectionDAG &DAG,
+ SmallVectorImpl<SDValue> &Results) {
+ SDLoc DL(N);
+ EVT OldResultTy = N->getValueType(0); // <2 x float>
+ assert(OldResultTy == MVT::v2f32 && "Unexpected result type for F32x2 op!");
+
+ SmallVector<SDValue> NewOps;
+
+ // whether we use FTZ (TODO)
+
+ // replace with NVPTX F32x2 op:
+ unsigned Opcode;
+ switch (N->getOpcode()) {
+ case ISD::FADD:
+ Opcode = NVPTXISD::FADD_F32X2;
+ break;
+ case ISD::FSUB:
+ Opcode = NVPTXISD::FSUB_F32X2;
+ break;
+ case ISD::FMUL:
+ Opcode = NVPTXISD::FMUL_F32X2;
+ break;
+ case ISD::FMA:
+ Opcode = NVPTXISD::FMA_F32X2;
+ break;
+ default:
+ llvm_unreachable("Unexpected opcode");
+ }
+
+ // bitcast operands: <2 x float> -> i64
+ for (const SDValue &Op : N->ops())
+ NewOps.push_back(DAG.getNode(ISD::BITCAST, DL, MVT::i64, Op));
+
+ SDValue Chain = DAG.getEntryNode();
+
+ // break packed result into two f32 registers for later instructions that may
+ // access element #0 or #1
+ SDValue NewValue = DAG.getNode(Opcode, DL, MVT::i64, NewOps);
+ MachineRegisterInfo &RegInfo = DAG.getMachineFunction().getRegInfo();
+ Register DestReg = RegInfo.createVirtualRegister(
+ DAG.getTargetLoweringInfo().getRegClassFor(MVT::i64));
+ SDValue RegCopy = DAG.getCopyToReg(Chain, DL, DestReg, NewValue);
+ SDValue Explode = DAG.getNode(ISD::CopyFromReg, DL,
+ {MVT::f32, MVT::f32, Chain.getValueType()},
+ {RegCopy, DAG.getRegister(DestReg, MVT::i64)});
+ // cast i64 result of new op back to <2 x float>
+ Results.push_back(DAG.getBitcast(
+ OldResultTy,
+ DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64,
+ {DAG.getBitcast(MVT::i32, Explode.getValue(0)),
+ DAG.getBitcast(MVT::i32, Explode.getValue(1))})));
+}
+
void NVPTXTargetLowering::ReplaceNodeResults(
SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
switch (N->getOpcode()) {
@@ -5495,6 +5755,12 @@ void NVPTXTargetLowering::ReplaceNodeResults(
case ISD::CopyFromReg:
ReplaceCopyFromReg_128(N, DAG, Results);
return;
+ case ISD::FADD:
+ case ISD::FSUB:
+ case ISD::FMUL:
+ case ISD::FMA:
+ ReplaceF32x2Op(N, DAG, Results);
+ return;
}
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 5adf69d621552f..8fd4ded42a238a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -55,6 +55,10 @@ enum NodeType : unsigned {
FSHR_CLAMP,
MUL_WIDE_SIGNED,
MUL_WIDE_UNSIGNED,
+ FADD_F32X2,
+ FMUL_F32X2,
+ FSUB_F32X2,
+ FMA_F32X2,
SETP_F16X2,
SETP_BF16X2,
BFE,
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 7d9697e40e6aba..b0eb9bbbb2456a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -165,6 +165,7 @@ def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
def hasTcgen05Instructions : Predicate<"Subtarget->hasTcgen05Instructions()">;
+def hasF32x2Instructions : Predicate<"Subtarget->hasF32x2Instructions()">;
def True : Predicate<"true">;
def False : Predicate<"false">;
@@ -2638,13 +2639,13 @@ class LastCallArgInstVT<NVPTXRegClass regclass, ValueType vt> :
NVPTXInst<(outs), (ins regclass:$a), "$a",
[(LastCallArg (i32 0), vt:$a)]>;
-def CallArgI64 : CallArgInst<Int64Regs>;
+def CallArgI64 : CallArgInstVT<Int64Regs, i64>;
def CallArgI32 : CallArgInstVT<Int32Regs, i32>;
def CallArgI16 : CallArgInstVT<Int16Regs, i16>;
def CallArgF64 : CallArgInst<Float64Regs>;
def CallArgF32 : CallArgInst<Float32Regs>;
-def LastCallArgI64 : LastCallArgInst<Int64Regs>;
+def LastCallArgI64 : LastCallArgInstVT<Int64Regs, i64>;
def LastCallArgI32 : LastCallArgInstVT<Int32Regs, i32>;
def LastCallArgI16 : LastCallArgInstVT<Int16Regs, i16>;
def LastCallArgF64 : LastCallArgInst<Float64Regs>;
@@ -3371,6 +3372,9 @@ let hasSideEffects = false in {
def V2F32toF64 : NVPTXInst<(outs Float64Regs:$d),
(ins Float32Regs:$s1, Float32Regs:$s2),
"mov.b64 \t$d, {{$s1, $s2}};", []>;
+ def V2F32toI64 : NVPTXInst<(outs Int64Regs:$d),
+ (ins Float32Regs:$s1, Float32Regs:$s2),
+ "mov.b64 \t$d, {{$s1, $s2}};", []>;
// unpack a larger int register to a set of smaller int registers
def I64toV4I16 : NVPTXInst<(outs Int16Regs:$d1, Int16Regs:$d2,
@@ -3383,6 +3387,9 @@ let hasSideEffects = false in {
def I64toV2I32 : NVPTXInst<(outs Int32Regs:$d1, Int32Regs:$d2),
(ins Int64Regs:$s),
"mov.b64 \t{{$d1, $d2}}, $...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/126337
More information about the llvm-commits
mailing list