[llvm] [NVPTX] fold movs into loads and stores (PR #144581)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Jun 17 12:04:22 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-nvptx
Author: Princeton Ferro (Prince781)
<details>
<summary>Changes</summary>
Fold movs into loads and stores by increasing the number of return values or operands. For example:
```
L: v2f16,ch = Load [p]
e0 = extractelt L, 0
e1 = extractelt L, 1
consume(e0, e1)
```
...becomes...
```
L: f16,f16,ch = LoadV2 [p]
consume(L:0, L:1)
```
---
Patch is 327.00 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144581.diff
23 Files Affected:
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+241-34)
- (modified) llvm/test/CodeGen/NVPTX/bf16-instructions.ll (+246-271)
- (modified) llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll (+16-18)
- (modified) llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll (+177-211)
- (modified) llvm/test/CodeGen/NVPTX/f16x2-instructions.ll (+41-63)
- (modified) llvm/test/CodeGen/NVPTX/fexp2.ll (+54-57)
- (modified) llvm/test/CodeGen/NVPTX/flog2.ll (+32-34)
- (modified) llvm/test/CodeGen/NVPTX/fma-relu-contract.ll (+258-274)
- (modified) llvm/test/CodeGen/NVPTX/fma-relu-fma-intrinsic.ll (+196-207)
- (modified) llvm/test/CodeGen/NVPTX/fma-relu-instruction-flag.ll (+400-422)
- (modified) llvm/test/CodeGen/NVPTX/i16x2-instructions.ll (+98-82)
- (modified) llvm/test/CodeGen/NVPTX/i8x2-instructions.ll (+1-3)
- (modified) llvm/test/CodeGen/NVPTX/i8x4-instructions.ll (+24-24)
- (modified) llvm/test/CodeGen/NVPTX/ldg-invariant.ll (+6-7)
- (modified) llvm/test/CodeGen/NVPTX/load-store-vectors.ll (+32-40)
- (modified) llvm/test/CodeGen/NVPTX/load-with-non-coherent-cache.ll (+2-2)
- (modified) llvm/test/CodeGen/NVPTX/math-intrins.ll (+168-188)
- (modified) llvm/test/CodeGen/NVPTX/param-load-store.ll (+3-5)
- (modified) llvm/test/CodeGen/NVPTX/reduction-intrinsics.ll (+140-148)
- (modified) llvm/test/CodeGen/NVPTX/sext-setcc.ll (+2-5)
- (modified) llvm/test/CodeGen/NVPTX/shift-opt.ll (+7-8)
- (modified) llvm/test/CodeGen/NVPTX/unfold-masked-merge-vector-variablemask.ll (+37-37)
- (modified) llvm/test/CodeGen/NVPTX/vector-loads.ll (+10-18)
``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 492f4ab76fdbb..e736b2ca6a151 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -238,18 +238,11 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
return std::nullopt;
LLVM_FALLTHROUGH;
case MVT::v2i8:
- case MVT::v2i16:
case MVT::v2i32:
case MVT::v2i64:
- case MVT::v2f16:
- case MVT::v2bf16:
case MVT::v2f32:
case MVT::v2f64:
- case MVT::v4i8:
- case MVT::v4i16:
case MVT::v4i32:
- case MVT::v4f16:
- case MVT::v4bf16:
case MVT::v4f32:
// This is a "native" vector type
return std::pair(NumElts, EltVT);
@@ -262,6 +255,13 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
if (!CanLowerTo256Bit)
return std::nullopt;
LLVM_FALLTHROUGH;
+ case MVT::v2i16: // <1 x i16x2>
+ case MVT::v2f16: // <1 x f16x2>
+ case MVT::v2bf16: // <1 x bf16x2>
+ case MVT::v4i8: // <1 x i8x4>
+ case MVT::v4i16: // <2 x i16x2>
+ case MVT::v4f16: // <2 x f16x2>
+ case MVT::v4bf16: // <2 x bf16x2>
case MVT::v8i8: // <2 x i8x4>
case MVT::v8f16: // <4 x f16x2>
case MVT::v8bf16: // <4 x bf16x2>
@@ -845,7 +845,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// We have some custom DAG combine patterns for these nodes
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
- ISD::BUILD_VECTOR, ISD::ADDRSPACECAST});
+ ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD});
// setcc for f16x2 and bf16x2 needs special handling to prevent
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -3464,19 +3464,16 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
unsigned I = 0;
for (const unsigned NumElts : VectorInfo) {
const EVT EltVT = VTs[I];
- const EVT LoadVT = [&]() -> EVT {
- // i1 is loaded/stored as i8.
- if (EltVT == MVT::i1)
- return MVT::i8;
- // getLoad needs a vector type, but it can't handle
- // vectors which contain v2f16 or v2bf16 elements. So we must load
- // using i32 here and then bitcast back.
- if (EltVT.isVector())
- return MVT::getIntegerVT(EltVT.getFixedSizeInBits());
- return EltVT;
- }();
+ // i1 is loaded/stored as i8
+ const EVT LoadVT = EltVT == MVT::i1 ? MVT::i8 : EltVT;
+ // If the element is a packed type (ex. v2f16, v4i8, etc) holding
+ // multiple elements.
+ const unsigned PackingAmt =
+ LoadVT.isVector() ? LoadVT.getVectorNumElements() : 1;
+
+ const EVT VecVT = EVT::getVectorVT(
+ F->getContext(), LoadVT.getScalarType(), NumElts * PackingAmt);
- const EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
SDValue VecAddr = DAG.getObjectPtrOffset(
dl, ArgSymbol, TypeSize::getFixed(Offsets[I]));
@@ -3496,8 +3493,10 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
if (P.getNode())
P.getNode()->setIROrder(Arg.getArgNo() + 1);
for (const unsigned J : llvm::seq(NumElts)) {
- SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P,
- DAG.getIntPtrConstant(J, dl));
+ SDValue Elt =
+ DAG.getNode(LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR
+ : ISD::EXTRACT_VECTOR_ELT,
+ dl, LoadVT, P, DAG.getIntPtrConstant(J * PackingAmt, dl));
// Extend or truncate the element if necessary (e.g. an i8 is loaded
// into an i16 register)
@@ -3511,9 +3510,6 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
Elt);
} else if (ExpactedVT.bitsLT(Elt.getValueType())) {
Elt = DAG.getNode(ISD::TRUNCATE, dl, ExpactedVT, Elt);
- } else {
- // v2f16 was loaded as an i32. Now we must bitcast it back.
- Elt = DAG.getBitcast(EltVT, Elt);
}
InVals.push_back(Elt);
}
@@ -5047,26 +5043,229 @@ PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
return SDValue();
}
-static SDValue PerformStoreCombineHelper(SDNode *N, std::size_t Front,
- std::size_t Back) {
+/// Combine extractelts into a load by increasing the number of return values.
+static SDValue
+combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+ // Don't run this optimization before the legalizer
+ if (DCI.isBeforeLegalize())
+ return SDValue();
+
+ EVT ElemVT = N->getValueType(0);
+ if (!Isv2x16VT(ElemVT))
+ return SDValue();
+
+ // Check whether all outputs are either used by an extractelt or are
+ // glue/chain nodes
+ if (!all_of(N->uses(), [&](SDUse &U) {
+ return U.getValueType() != ElemVT ||
+ (U.getUser()->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
+ // also check that the extractelt is used if this is an
+ // ISD::LOAD, otherwise it may be optimized by something else
+ (N->getOpcode() != ISD::LOAD || !U.getUser()->use_empty()));
+ }))
+ return SDValue();
+
+ auto *LD = cast<MemSDNode>(N);
+ EVT MemVT = LD->getMemoryVT();
+ SDLoc DL(LD);
+
+ // the new opcode after we double the number of operands
+ NVPTXISD::NodeType Opcode;
+ SmallVector<SDValue> Operands(LD->ops());
+ switch (LD->getOpcode()) {
+ // Any packed type is legal, so the legalizer will not have lowered ISD::LOAD
+ // -> NVPTXISD::Load. We have to do it here.
+ case ISD::LOAD:
+ Opcode = NVPTXISD::LoadV2;
+ {
+ Operands.push_back(DCI.DAG.getIntPtrConstant(
+ cast<LoadSDNode>(LD)->getExtensionType(), DL));
+ Align Alignment = LD->getAlign();
+ const auto &TD = DCI.DAG.getDataLayout();
+ Align PrefAlign =
+ TD.getPrefTypeAlign(MemVT.getTypeForEVT(*DCI.DAG.getContext()));
+ if (Alignment < PrefAlign) {
+ // This load is not sufficiently aligned, so bail out and let this
+ // vector load be scalarized. Note that we may still be able to emit
+ // smaller vector loads. For example, if we are loading a <4 x float>
+ // with an alignment of 8, this check will fail but the legalizer will
+ // try again with 2 x <2 x float>, which will succeed with an alignment
+ // of 8.
+ return SDValue();
+ }
+ }
+ break;
+ case NVPTXISD::LoadParamV2:
+ Opcode = NVPTXISD::LoadParamV4;
+ break;
+ case NVPTXISD::LoadV2:
+ Opcode = NVPTXISD::LoadV4;
+ break;
+ case NVPTXISD::LoadV4:
+ // PTX doesn't support v8 for 16-bit values
+ case NVPTXISD::LoadV8:
+ // PTX doesn't support the next doubling of outputs
+ return SDValue();
+ }
+
+ SmallVector<EVT> NewVTs;
+ for (EVT VT : LD->values()) {
+ if (VT == ElemVT) {
+ const EVT ScalarVT = ElemVT.getVectorElementType();
+ NewVTs.insert(NewVTs.end(), {ScalarVT, ScalarVT});
+ } else
+ NewVTs.push_back(VT);
+ }
+
+ // Create the new load
+ SDValue NewLoad =
+ DCI.DAG.getMemIntrinsicNode(Opcode, DL, DCI.DAG.getVTList(NewVTs),
+ Operands, MemVT, LD->getMemOperand());
+
+ // Now we use a combination of BUILD_VECTORs and a MERGE_VALUES node to keep
+ // the outputs the same. These nodes will be optimized away in later
+ // DAGCombiner iterations.
+ SmallVector<SDValue> Results;
+ for (unsigned I = 0; I < NewLoad->getNumValues();) {
+ if (NewLoad->getValueType(I) == ElemVT.getVectorElementType()) {
+ Results.push_back(DCI.DAG.getBuildVector(
+ ElemVT, DL, {NewLoad.getValue(I), NewLoad.getValue(I + 1)}));
+ I += 2;
+ } else {
+ Results.push_back(NewLoad.getValue(I));
+ I += 1;
+ }
+ }
+
+ return DCI.DAG.getMergeValues(Results, DL);
+}
+
+/// Fold a packing mov into a store. This may help lower register pressure.
+///
+/// ex:
+/// v: v2f16 = build_vector a:f16, b:f16
+/// StoreRetval v
+///
+/// ...is turned into...
+///
+/// StoreRetvalV2 a:f16, b:f16
+static SDValue combinePackingMovIntoStore(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ unsigned Front, unsigned Back) {
+ // Don't run this optimization before the legalizer
+ if (DCI.isBeforeLegalize())
+ return SDValue();
+
+ // Get the type of the operands being stored.
+ EVT ElementVT = N->getOperand(Front).getValueType();
+
+ if (!Isv2x16VT(ElementVT))
+ return SDValue();
+
+ // We want to run this as late as possible since other optimizations may
+ // eliminate the BUILD_VECTORs.
+ if (!DCI.isAfterLegalizeDAG())
+ return SDValue();
+
+ auto *ST = cast<MemSDNode>(N);
+ EVT MemVT = ElementVT.getVectorElementType();
+
+ // The new opcode after we double the number of operands.
+ NVPTXISD::NodeType Opcode;
+ switch (N->getOpcode()) {
+ case NVPTXISD::StoreParam:
+ Opcode = NVPTXISD::StoreParamV2;
+ break;
+ case NVPTXISD::StoreParamV2:
+ Opcode = NVPTXISD::StoreParamV4;
+ break;
+ case NVPTXISD::StoreRetval:
+ Opcode = NVPTXISD::StoreRetvalV2;
+ break;
+ case NVPTXISD::StoreRetvalV2:
+ Opcode = NVPTXISD::StoreRetvalV4;
+ break;
+ case NVPTXISD::StoreV2:
+ MemVT = ST->getMemoryVT();
+ Opcode = NVPTXISD::StoreV4;
+ break;
+ case NVPTXISD::StoreV4:
+ // PTX doesn't support v8 for 16-bit values
+ case NVPTXISD::StoreParamV4:
+ case NVPTXISD::StoreRetvalV4:
+ case NVPTXISD::StoreV8:
+ // PTX doesn't support the next doubling of operands for these opcodes.
+ return SDValue();
+ default:
+ llvm_unreachable("Unhandled store opcode");
+ }
+
+ // Scan the operands and if they're all BUILD_VECTORs, we'll have gathered
+ // their elements.
+ SmallVector<SDValue, 4> Operands(N->ops().take_front(Front));
+ for (SDValue BV : N->ops().drop_front(Front).drop_back(Back)) {
+ if (BV.getOpcode() != ISD::BUILD_VECTOR)
+ return SDValue();
+
+ // If the operand has multiple uses, this optimization can increase register
+ // pressure.
+ if (!BV.hasOneUse())
+ return SDValue();
+
+ // DAGCombiner visits nodes bottom-up. Check the BUILD_VECTOR operands for
+ // any signs they may be folded by some other pattern or rule.
+ for (SDValue Op : BV->ops()) {
+ // Peek through bitcasts
+ if (Op.getOpcode() == ISD::BITCAST)
+ Op = Op.getOperand(0);
+
+ // This may be folded into a PRMT.
+ if (Op.getValueType() == MVT::i16 && Op.getOpcode() == ISD::TRUNCATE &&
+ Op->getOperand(0).getValueType() == MVT::i32)
+ return SDValue();
+
+ // This may be folded into cvt.bf16x2
+ if (Op.getOpcode() == ISD::FP_ROUND)
+ return SDValue();
+ }
+ Operands.insert(Operands.end(), {BV.getOperand(0), BV.getOperand(1)});
+ }
+ for (SDValue Op : N->ops().take_back(Back))
+ Operands.push_back(Op);
+
+ // Now we replace the store
+ return DCI.DAG.getMemIntrinsicNode(Opcode, SDLoc(N), N->getVTList(),
+ Operands, MemVT, ST->getMemOperand());
+}
+
+static SDValue PerformStoreCombineHelper(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ unsigned Front, unsigned Back) {
if (all_of(N->ops().drop_front(Front).drop_back(Back),
[](const SDUse &U) { return U.get()->isUndef(); }))
// Operand 0 is the previous value in the chain. Cannot return EntryToken
// as the previous value will become unused and eliminated later.
return N->getOperand(0);
- return SDValue();
+ return combinePackingMovIntoStore(N, DCI, Front, Back);
+}
+
+static SDValue PerformStoreCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ return combinePackingMovIntoStore(N, DCI, 1, 2);
}
-static SDValue PerformStoreParamCombine(SDNode *N) {
+static SDValue PerformStoreParamCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
// Operands from the 3rd to the 2nd last one are the values to be stored.
// {Chain, ArgID, Offset, Val, Glue}
- return PerformStoreCombineHelper(N, 3, 1);
+ return PerformStoreCombineHelper(N, DCI, 3, 1);
}
-static SDValue PerformStoreRetvalCombine(SDNode *N) {
+static SDValue PerformStoreRetvalCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
// Operands from the 2nd to the last one are the values to be stored
- return PerformStoreCombineHelper(N, 2, 0);
+ return PerformStoreCombineHelper(N, DCI, 2, 0);
}
/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
@@ -5697,14 +5896,22 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
return PerformREMCombine(N, DCI, OptLevel);
case ISD::SETCC:
return PerformSETCCCombine(N, DCI, STI.getSmVersion());
+ case ISD::LOAD:
+ case NVPTXISD::LoadParamV2:
+ case NVPTXISD::LoadV2:
+ case NVPTXISD::LoadV4:
+ return combineUnpackingMovIntoLoad(N, DCI);
case NVPTXISD::StoreRetval:
case NVPTXISD::StoreRetvalV2:
case NVPTXISD::StoreRetvalV4:
- return PerformStoreRetvalCombine(N);
+ return PerformStoreRetvalCombine(N, DCI);
case NVPTXISD::StoreParam:
case NVPTXISD::StoreParamV2:
case NVPTXISD::StoreParamV4:
- return PerformStoreParamCombine(N);
+ return PerformStoreParamCombine(N, DCI);
+ case NVPTXISD::StoreV2:
+ case NVPTXISD::StoreV4:
+ return PerformStoreCombine(N, DCI);
case ISD::EXTRACT_VECTOR_ELT:
return PerformEXTRACTCombine(N, DCI);
case ISD::VSELECT:
diff --git a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
index 32225ed04e2d9..95af9c64a73ac 100644
--- a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
@@ -146,37 +146,35 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM70: {
; SM70-NEXT: .reg .pred %p<3>;
; SM70-NEXT: .reg .b16 %rs<5>;
-; SM70-NEXT: .reg .b32 %r<24>;
+; SM70-NEXT: .reg .b32 %r<22>;
; SM70-EMPTY:
; SM70-NEXT: // %bb.0:
-; SM70-NEXT: ld.param.b32 %r1, [test_faddx2_param_0];
-; SM70-NEXT: ld.param.b32 %r2, [test_faddx2_param_1];
-; SM70-NEXT: mov.b32 {%rs1, %rs2}, %r2;
+; SM70-NEXT: ld.param.v2.b16 {%rs1, %rs2}, [test_faddx2_param_0];
+; SM70-NEXT: ld.param.v2.b16 {%rs3, %rs4}, [test_faddx2_param_1];
+; SM70-NEXT: cvt.u32.u16 %r1, %rs4;
+; SM70-NEXT: shl.b32 %r2, %r1, 16;
; SM70-NEXT: cvt.u32.u16 %r3, %rs2;
; SM70-NEXT: shl.b32 %r4, %r3, 16;
-; SM70-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM70-NEXT: cvt.u32.u16 %r5, %rs4;
-; SM70-NEXT: shl.b32 %r6, %r5, 16;
-; SM70-NEXT: add.rn.f32 %r7, %r6, %r4;
-; SM70-NEXT: bfe.u32 %r8, %r7, 16, 1;
-; SM70-NEXT: add.s32 %r9, %r8, %r7;
-; SM70-NEXT: add.s32 %r10, %r9, 32767;
-; SM70-NEXT: setp.nan.f32 %p1, %r7, %r7;
-; SM70-NEXT: or.b32 %r11, %r7, 4194304;
-; SM70-NEXT: selp.b32 %r12, %r11, %r10, %p1;
+; SM70-NEXT: add.rn.f32 %r5, %r4, %r2;
+; SM70-NEXT: bfe.u32 %r6, %r5, 16, 1;
+; SM70-NEXT: add.s32 %r7, %r6, %r5;
+; SM70-NEXT: add.s32 %r8, %r7, 32767;
+; SM70-NEXT: setp.nan.f32 %p1, %r5, %r5;
+; SM70-NEXT: or.b32 %r9, %r5, 4194304;
+; SM70-NEXT: selp.b32 %r10, %r9, %r8, %p1;
+; SM70-NEXT: cvt.u32.u16 %r11, %rs3;
+; SM70-NEXT: shl.b32 %r12, %r11, 16;
; SM70-NEXT: cvt.u32.u16 %r13, %rs1;
; SM70-NEXT: shl.b32 %r14, %r13, 16;
-; SM70-NEXT: cvt.u32.u16 %r15, %rs3;
-; SM70-NEXT: shl.b32 %r16, %r15, 16;
-; SM70-NEXT: add.rn.f32 %r17, %r16, %r14;
-; SM70-NEXT: bfe.u32 %r18, %r17, 16, 1;
-; SM70-NEXT: add.s32 %r19, %r18, %r17;
-; SM70-NEXT: add.s32 %r20, %r19, 32767;
-; SM70-NEXT: setp.nan.f32 %p2, %r17, %r17;
-; SM70-NEXT: or.b32 %r21, %r17, 4194304;
-; SM70-NEXT: selp.b32 %r22, %r21, %r20, %p2;
-; SM70-NEXT: prmt.b32 %r23, %r22, %r12, 0x7632U;
-; SM70-NEXT: st.param.b32 [func_retval0], %r23;
+; SM70-NEXT: add.rn.f32 %r15, %r14, %r12;
+; SM70-NEXT: bfe.u32 %r16, %r15, 16, 1;
+; SM70-NEXT: add.s32 %r17, %r16, %r15;
+; SM70-NEXT: add.s32 %r18, %r17, 32767;
+; SM70-NEXT: setp.nan.f32 %p2, %r15, %r15;
+; SM70-NEXT: or.b32 %r19, %r15, 4194304;
+; SM70-NEXT: selp.b32 %r20, %r19, %r18, %p2;
+; SM70-NEXT: prmt.b32 %r21, %r20, %r10, 0x7632U;
+; SM70-NEXT: st.param.b32 [func_retval0], %r21;
; SM70-NEXT: ret;
;
; SM80-LABEL: test_faddx2(
@@ -184,31 +182,29 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-NEXT: .reg .b32 %r<5>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
-; SM80-NEXT: ld.param.b32 %r1, [test_faddx2_param_1];
-; SM80-NEXT: ld.param.b32 %r2, [test_faddx2_param_0];
+; SM80-NEXT: ld.param.b32 %r1, [test_faddx2_param_0];
+; SM80-NEXT: ld.param.b32 %r2, [test_faddx2_param_1];
; SM80-NEXT: mov.b32 %r3, 1065369472;
-; SM80-NEXT: fma.rn.bf16x2 %r4, %r2, %r3, %r1;
+; SM80-NEXT: fma.rn.bf16x2 %r4, %r1, %r3, %r2;
; SM80-NEXT: st.param.b32 [func_retval0], %r4;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_faddx2(
; SM80-FTZ: {
; SM80-FTZ-NEXT: .reg .b16 %rs<5>;
-; SM80-FTZ-NEXT: .reg .b32 %r<10>;
+; SM80-FTZ-NEXT: .reg .b32 %r<8>;
; SM80-FTZ-EMPTY:
; SM80-FTZ-NEXT: // %bb.0:
-; SM80-FTZ-NEXT: ld.param.b32 %r1, [test_faddx2_param_0];
-; SM80-FTZ-NEXT: ld.param.b32 %r2, [test_faddx2_param_1];
-; SM80-FTZ-NEXT: mov.b32 {%rs1, %rs2}, %r2;
-; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r3, %rs1;
-; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r4, %rs3;
-; SM80-FTZ-NEXT: add.rn.ftz.f32 %r5, %r4, %r3;
-; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r6, %rs2;
-; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r7, %rs4;
-; SM80-FTZ-NEXT: add.rn.ftz.f32 %r8, %r7, %r6;
-; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r9, %r8, %r5;
-; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r9;
+; SM80-FTZ-NEXT: ld.param.v2.b16 {%rs1, %rs2}, [test_faddx2_param_0];
+; SM80-FTZ-NEXT: ld.param.v2.b16 {%rs3, %rs4}, [test_faddx2_param_1];
+; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r1, %rs3;
+; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r2, %rs1;
+; SM80-FTZ-NEXT: add.rn.ftz.f32 %r3, %r2, %r1;
+; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r4, %rs4;
+; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r5, %rs2;
+; SM80-FTZ-NEXT: add.rn.ftz.f32 %r6, %r5, %r4;
+; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r7, %r6, %r3;
+; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r7;
; SM80-FTZ-NEXT: ret;
;
; SM90-LABEL: test_faddx2(
@@ -216,9 +212,9 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM90-NEXT: .reg .b32 %r<4>;
; SM90-EMPTY:
; SM90-NEXT: // %bb.0:
-; SM90-NEXT: ld.param.b32 %r1, [test_faddx2_param_1];
-; SM90-NEXT: ld.param.b32 %r2, [test_faddx2_param_0];
-; SM90-NEXT: add.rn.bf16x2 %r3, %r2, %r1;
+; SM90-NEXT: ld.param.b32 %r1, [test_faddx2_param_0];
+; SM90-NEXT: ld.param.b32 %r2, [test_faddx2_param_1];
+; SM90-NEXT: add.rn.bf16x2 %r3, %r1, %r2;
; SM90-NEXT: st.param.b32 [func_retval0], %r3;
; SM90-NEXT: ret;
%r = fadd <2 x bfloat> %a, %b
@@ -230,37 +226,35 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM70: {
; SM70-NEXT: .reg .pred %p<3>;
; SM70-NEXT: .reg .b16 %rs<5>;
-; SM70-NEXT: .reg .b32 %r<24>;
+; SM70-NEXT: .reg .b32 %r<22>;
; SM70-EMPTY:
; SM70-NEXT: // %bb.0:
-; SM70-NEXT: ld.param.b32 %r1, [test_fsubx2_param_0];
-; SM70-NEXT: ld.param.b32 %r2, [test_fsubx2_param_1];
-; SM70-NEXT: mov.b32 {%rs1, %rs2}, %r2;
+; SM70-NEXT: ld.param.v2.b16 {%rs1, %rs2}, [test_fsubx2_param_0];
+; SM70-NEXT: ld.param.v2.b16 {%rs3, %rs4}, [test_fsubx2_param_1];
+; SM70-NEXT: cvt.u32.u16 %r1, %rs4;
+; SM70-NEXT: shl.b32 %r2, %r1, 16;
; SM70-NEXT: cvt.u32.u16 %r3, %rs2;
; SM70-NEXT: shl.b32 %r4, %r3, 16;
-; SM70-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM70-NEXT: cvt.u32.u16 %r5, %rs4;
-; SM70-NEXT: shl.b32 %r6, %r5, 16;
-; SM70-NEXT: sub.rn.f32 %r7, %r6, %r4;
-; SM70-NEXT: bfe.u32 %r8, %r7, 16, 1;
-; SM70-NEXT: a...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/144581
More information about the llvm-commits
mailing list