[llvm] [NVPTX] fold movs into loads and stores (PR #144581)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 17 12:04:22 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: Princeton Ferro (Prince781)

<details>
<summary>Changes</summary>

Fold movs into loads and stores by increasing the number of return values or operands. For example:

```
L: v2f16,ch = Load [p]
e0 = extractelt L, 0
e1 = extractelt L, 1
consume(e0, e1)
```

...becomes...

```
L: f16,f16,ch = LoadV2 [p]
consume(L:0, L:1)
```

---

Patch is 327.00 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144581.diff


23 Files Affected:

- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+241-34) 
- (modified) llvm/test/CodeGen/NVPTX/bf16-instructions.ll (+246-271) 
- (modified) llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll (+16-18) 
- (modified) llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll (+177-211) 
- (modified) llvm/test/CodeGen/NVPTX/f16x2-instructions.ll (+41-63) 
- (modified) llvm/test/CodeGen/NVPTX/fexp2.ll (+54-57) 
- (modified) llvm/test/CodeGen/NVPTX/flog2.ll (+32-34) 
- (modified) llvm/test/CodeGen/NVPTX/fma-relu-contract.ll (+258-274) 
- (modified) llvm/test/CodeGen/NVPTX/fma-relu-fma-intrinsic.ll (+196-207) 
- (modified) llvm/test/CodeGen/NVPTX/fma-relu-instruction-flag.ll (+400-422) 
- (modified) llvm/test/CodeGen/NVPTX/i16x2-instructions.ll (+98-82) 
- (modified) llvm/test/CodeGen/NVPTX/i8x2-instructions.ll (+1-3) 
- (modified) llvm/test/CodeGen/NVPTX/i8x4-instructions.ll (+24-24) 
- (modified) llvm/test/CodeGen/NVPTX/ldg-invariant.ll (+6-7) 
- (modified) llvm/test/CodeGen/NVPTX/load-store-vectors.ll (+32-40) 
- (modified) llvm/test/CodeGen/NVPTX/load-with-non-coherent-cache.ll (+2-2) 
- (modified) llvm/test/CodeGen/NVPTX/math-intrins.ll (+168-188) 
- (modified) llvm/test/CodeGen/NVPTX/param-load-store.ll (+3-5) 
- (modified) llvm/test/CodeGen/NVPTX/reduction-intrinsics.ll (+140-148) 
- (modified) llvm/test/CodeGen/NVPTX/sext-setcc.ll (+2-5) 
- (modified) llvm/test/CodeGen/NVPTX/shift-opt.ll (+7-8) 
- (modified) llvm/test/CodeGen/NVPTX/unfold-masked-merge-vector-variablemask.ll (+37-37) 
- (modified) llvm/test/CodeGen/NVPTX/vector-loads.ll (+10-18) 


``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 492f4ab76fdbb..e736b2ca6a151 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -238,18 +238,11 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
       return std::nullopt;
     LLVM_FALLTHROUGH;
   case MVT::v2i8:
-  case MVT::v2i16:
   case MVT::v2i32:
   case MVT::v2i64:
-  case MVT::v2f16:
-  case MVT::v2bf16:
   case MVT::v2f32:
   case MVT::v2f64:
-  case MVT::v4i8:
-  case MVT::v4i16:
   case MVT::v4i32:
-  case MVT::v4f16:
-  case MVT::v4bf16:
   case MVT::v4f32:
     // This is a "native" vector type
     return std::pair(NumElts, EltVT);
@@ -262,6 +255,13 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
     if (!CanLowerTo256Bit)
       return std::nullopt;
     LLVM_FALLTHROUGH;
+  case MVT::v2i16:  // <1 x i16x2>
+  case MVT::v2f16:  // <1 x f16x2>
+  case MVT::v2bf16: // <1 x bf16x2>
+  case MVT::v4i8:   // <1 x i8x4>
+  case MVT::v4i16:  // <2 x i16x2>
+  case MVT::v4f16:  // <2 x f16x2>
+  case MVT::v4bf16: // <2 x bf16x2>
   case MVT::v8i8:   // <2 x i8x4>
   case MVT::v8f16:  // <4 x f16x2>
   case MVT::v8bf16: // <4 x bf16x2>
@@ -845,7 +845,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   // We have some custom DAG combine patterns for these nodes
   setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
                        ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
-                       ISD::BUILD_VECTOR, ISD::ADDRSPACECAST});
+                       ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD});
 
   // setcc for f16x2 and bf16x2 needs special handling to prevent
   // legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -3464,19 +3464,16 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
       unsigned I = 0;
       for (const unsigned NumElts : VectorInfo) {
         const EVT EltVT = VTs[I];
-        const EVT LoadVT = [&]() -> EVT {
-          // i1 is loaded/stored as i8.
-          if (EltVT == MVT::i1)
-            return MVT::i8;
-          // getLoad needs a vector type, but it can't handle
-          // vectors which contain v2f16 or v2bf16 elements. So we must load
-          // using i32 here and then bitcast back.
-          if (EltVT.isVector())
-            return MVT::getIntegerVT(EltVT.getFixedSizeInBits());
-          return EltVT;
-        }();
+        // i1 is loaded/stored as i8
+        const EVT LoadVT = EltVT == MVT::i1 ? MVT::i8 : EltVT;
+        // If the element is a packed type (ex. v2f16, v4i8, etc) holding
+        // multiple elements.
+        const unsigned PackingAmt =
+            LoadVT.isVector() ? LoadVT.getVectorNumElements() : 1;
+
+        const EVT VecVT = EVT::getVectorVT(
+            F->getContext(), LoadVT.getScalarType(), NumElts * PackingAmt);
 
-        const EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
         SDValue VecAddr = DAG.getObjectPtrOffset(
             dl, ArgSymbol, TypeSize::getFixed(Offsets[I]));
 
@@ -3496,8 +3493,10 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
         if (P.getNode())
           P.getNode()->setIROrder(Arg.getArgNo() + 1);
         for (const unsigned J : llvm::seq(NumElts)) {
-          SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P,
-                                    DAG.getIntPtrConstant(J, dl));
+          SDValue Elt =
+              DAG.getNode(LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR
+                                            : ISD::EXTRACT_VECTOR_ELT,
+                          dl, LoadVT, P, DAG.getIntPtrConstant(J * PackingAmt, dl));
 
           // Extend or truncate the element if necessary (e.g. an i8 is loaded
           // into an i16 register)
@@ -3511,9 +3510,6 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
                               Elt);
           } else if (ExpactedVT.bitsLT(Elt.getValueType())) {
             Elt = DAG.getNode(ISD::TRUNCATE, dl, ExpactedVT, Elt);
-          } else {
-            // v2f16 was loaded as an i32. Now we must bitcast it back.
-            Elt = DAG.getBitcast(EltVT, Elt);
           }
           InVals.push_back(Elt);
         }
@@ -5047,26 +5043,229 @@ PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
   return SDValue();
 }
 
-static SDValue PerformStoreCombineHelper(SDNode *N, std::size_t Front,
-                                         std::size_t Back) {
+/// Combine extractelts into a load by increasing the number of return values.
+static SDValue
+combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+  // Don't run this optimization before the legalizer
+  if (DCI.isBeforeLegalize())
+    return SDValue();
+
+  EVT ElemVT = N->getValueType(0);
+  if (!Isv2x16VT(ElemVT))
+    return SDValue();
+
+  // Check whether all outputs are either used by an extractelt or are
+  // glue/chain nodes
+  if (!all_of(N->uses(), [&](SDUse &U) {
+        return U.getValueType() != ElemVT ||
+               (U.getUser()->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
+                // also check that the extractelt is used if this is an
+                // ISD::LOAD, otherwise it may be optimized by something else
+                (N->getOpcode() != ISD::LOAD || !U.getUser()->use_empty()));
+      }))
+    return SDValue();
+
+  auto *LD = cast<MemSDNode>(N);
+  EVT MemVT = LD->getMemoryVT();
+  SDLoc DL(LD);
+
+  // the new opcode after we double the number of operands
+  NVPTXISD::NodeType Opcode;
+  SmallVector<SDValue> Operands(LD->ops());
+  switch (LD->getOpcode()) {
+  // Any packed type is legal, so the legalizer will not have lowered ISD::LOAD
+  // -> NVPTXISD::Load. We have to do it here.
+  case ISD::LOAD:
+    Opcode = NVPTXISD::LoadV2;
+    {
+      Operands.push_back(DCI.DAG.getIntPtrConstant(
+          cast<LoadSDNode>(LD)->getExtensionType(), DL));
+      Align Alignment = LD->getAlign();
+      const auto &TD = DCI.DAG.getDataLayout();
+      Align PrefAlign =
+          TD.getPrefTypeAlign(MemVT.getTypeForEVT(*DCI.DAG.getContext()));
+      if (Alignment < PrefAlign) {
+        // This load is not sufficiently aligned, so bail out and let this
+        // vector load be scalarized.  Note that we may still be able to emit
+        // smaller vector loads.  For example, if we are loading a <4 x float>
+        // with an alignment of 8, this check will fail but the legalizer will
+        // try again with 2 x <2 x float>, which will succeed with an alignment
+        // of 8.
+        return SDValue();
+      }
+    }
+    break;
+  case NVPTXISD::LoadParamV2:
+    Opcode = NVPTXISD::LoadParamV4;
+    break;
+  case NVPTXISD::LoadV2:
+    Opcode = NVPTXISD::LoadV4;
+    break;
+  case NVPTXISD::LoadV4:
+    // PTX doesn't support v8 for 16-bit values
+  case NVPTXISD::LoadV8:
+    // PTX doesn't support the next doubling of outputs
+    return SDValue();
+  }
+
+  SmallVector<EVT> NewVTs;
+  for (EVT VT : LD->values()) {
+    if (VT == ElemVT) {
+      const EVT ScalarVT = ElemVT.getVectorElementType();
+      NewVTs.insert(NewVTs.end(), {ScalarVT, ScalarVT});
+    } else
+      NewVTs.push_back(VT);
+  }
+
+  // Create the new load
+  SDValue NewLoad =
+      DCI.DAG.getMemIntrinsicNode(Opcode, DL, DCI.DAG.getVTList(NewVTs),
+                                  Operands, MemVT, LD->getMemOperand());
+
+  // Now we use a combination of BUILD_VECTORs and a MERGE_VALUES node to keep
+  // the outputs the same. These nodes will be optimized away in later
+  // DAGCombiner iterations.
+  SmallVector<SDValue> Results;
+  for (unsigned I = 0; I < NewLoad->getNumValues();) {
+    if (NewLoad->getValueType(I) == ElemVT.getVectorElementType()) {
+      Results.push_back(DCI.DAG.getBuildVector(
+          ElemVT, DL, {NewLoad.getValue(I), NewLoad.getValue(I + 1)}));
+      I += 2;
+    } else {
+      Results.push_back(NewLoad.getValue(I));
+      I += 1;
+    }
+  }
+
+  return DCI.DAG.getMergeValues(Results, DL);
+}
+
+/// Fold a packing mov into a store. This may help lower register pressure.
+///
+/// ex:
+/// v: v2f16 = build_vector a:f16, b:f16
+/// StoreRetval v
+///
+/// ...is turned into...
+///
+/// StoreRetvalV2 a:f16, b:f16
+static SDValue combinePackingMovIntoStore(SDNode *N,
+                                          TargetLowering::DAGCombinerInfo &DCI,
+                                          unsigned Front, unsigned Back) {
+  // Don't run this optimization before the legalizer
+  if (DCI.isBeforeLegalize())
+    return SDValue();
+
+  // Get the type of the operands being stored.
+  EVT ElementVT = N->getOperand(Front).getValueType();
+
+  if (!Isv2x16VT(ElementVT))
+    return SDValue();
+
+  // We want to run this as late as possible since other optimizations may
+  // eliminate the BUILD_VECTORs.
+  if (!DCI.isAfterLegalizeDAG())
+    return SDValue();
+
+  auto *ST = cast<MemSDNode>(N);
+  EVT MemVT = ElementVT.getVectorElementType();
+
+  // The new opcode after we double the number of operands.
+  NVPTXISD::NodeType Opcode;
+  switch (N->getOpcode()) {
+  case NVPTXISD::StoreParam:
+    Opcode = NVPTXISD::StoreParamV2;
+    break;
+  case NVPTXISD::StoreParamV2:
+    Opcode = NVPTXISD::StoreParamV4;
+    break;
+  case NVPTXISD::StoreRetval:
+    Opcode = NVPTXISD::StoreRetvalV2;
+    break;
+  case NVPTXISD::StoreRetvalV2:
+    Opcode = NVPTXISD::StoreRetvalV4;
+    break;
+  case NVPTXISD::StoreV2:
+    MemVT = ST->getMemoryVT();
+    Opcode = NVPTXISD::StoreV4;
+    break;
+  case NVPTXISD::StoreV4:
+    // PTX doesn't support v8 for 16-bit values
+  case NVPTXISD::StoreParamV4:
+  case NVPTXISD::StoreRetvalV4:
+  case NVPTXISD::StoreV8:
+    // PTX doesn't support the next doubling of operands for these opcodes.
+    return SDValue();
+  default:
+    llvm_unreachable("Unhandled store opcode");
+  }
+
+  // Scan the operands and if they're all BUILD_VECTORs, we'll have gathered
+  // their elements.
+  SmallVector<SDValue, 4> Operands(N->ops().take_front(Front));
+  for (SDValue BV : N->ops().drop_front(Front).drop_back(Back)) {
+    if (BV.getOpcode() != ISD::BUILD_VECTOR)
+      return SDValue();
+
+    // If the operand has multiple uses, this optimization can increase register
+    // pressure.
+    if (!BV.hasOneUse())
+      return SDValue();
+
+    // DAGCombiner visits nodes bottom-up. Check the BUILD_VECTOR operands for
+    // any signs they may be folded by some other pattern or rule.
+    for (SDValue Op : BV->ops()) {
+      // Peek through bitcasts
+      if (Op.getOpcode() == ISD::BITCAST)
+        Op = Op.getOperand(0);
+
+      // This may be folded into a PRMT.
+      if (Op.getValueType() == MVT::i16 && Op.getOpcode() == ISD::TRUNCATE &&
+          Op->getOperand(0).getValueType() == MVT::i32)
+        return SDValue();
+
+      // This may be folded into cvt.bf16x2
+      if (Op.getOpcode() == ISD::FP_ROUND)
+        return SDValue();
+    }
+    Operands.insert(Operands.end(), {BV.getOperand(0), BV.getOperand(1)});
+  }
+  for (SDValue Op : N->ops().take_back(Back))
+    Operands.push_back(Op);
+
+  // Now we replace the store
+  return DCI.DAG.getMemIntrinsicNode(Opcode, SDLoc(N), N->getVTList(),
+          Operands, MemVT, ST->getMemOperand());
+}
+
+static SDValue PerformStoreCombineHelper(SDNode *N,
+                                         TargetLowering::DAGCombinerInfo &DCI,
+                                         unsigned Front, unsigned Back) {
   if (all_of(N->ops().drop_front(Front).drop_back(Back),
              [](const SDUse &U) { return U.get()->isUndef(); }))
     // Operand 0 is the previous value in the chain. Cannot return EntryToken
     // as the previous value will become unused and eliminated later.
     return N->getOperand(0);
 
-  return SDValue();
+  return combinePackingMovIntoStore(N, DCI, Front, Back);
+}
+
+static SDValue PerformStoreCombine(SDNode *N,
+                                   TargetLowering::DAGCombinerInfo &DCI) {
+  return combinePackingMovIntoStore(N, DCI, 1, 2);
 }
 
-static SDValue PerformStoreParamCombine(SDNode *N) {
+static SDValue PerformStoreParamCombine(SDNode *N,
+                                        TargetLowering::DAGCombinerInfo &DCI) {
   // Operands from the 3rd to the 2nd last one are the values to be stored.
   //   {Chain, ArgID, Offset, Val, Glue}
-  return PerformStoreCombineHelper(N, 3, 1);
+  return PerformStoreCombineHelper(N, DCI, 3, 1);
 }
 
-static SDValue PerformStoreRetvalCombine(SDNode *N) {
+static SDValue PerformStoreRetvalCombine(SDNode *N,
+                                         TargetLowering::DAGCombinerInfo &DCI) {
   // Operands from the 2nd to the last one are the values to be stored
-  return PerformStoreCombineHelper(N, 2, 0);
+  return PerformStoreCombineHelper(N, DCI, 2, 0);
 }
 
 /// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
@@ -5697,14 +5896,22 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
       return PerformREMCombine(N, DCI, OptLevel);
     case ISD::SETCC:
       return PerformSETCCCombine(N, DCI, STI.getSmVersion());
+    case ISD::LOAD:
+    case NVPTXISD::LoadParamV2:
+    case NVPTXISD::LoadV2:
+    case NVPTXISD::LoadV4:
+      return combineUnpackingMovIntoLoad(N, DCI);
     case NVPTXISD::StoreRetval:
     case NVPTXISD::StoreRetvalV2:
     case NVPTXISD::StoreRetvalV4:
-      return PerformStoreRetvalCombine(N);
+      return PerformStoreRetvalCombine(N, DCI);
     case NVPTXISD::StoreParam:
     case NVPTXISD::StoreParamV2:
     case NVPTXISD::StoreParamV4:
-      return PerformStoreParamCombine(N);
+      return PerformStoreParamCombine(N, DCI);
+    case NVPTXISD::StoreV2:
+    case NVPTXISD::StoreV4:
+      return PerformStoreCombine(N, DCI);
     case ISD::EXTRACT_VECTOR_ELT:
       return PerformEXTRACTCombine(N, DCI);
     case ISD::VSELECT:
diff --git a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
index 32225ed04e2d9..95af9c64a73ac 100644
--- a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
@@ -146,37 +146,35 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM70:       {
 ; SM70-NEXT:    .reg .pred %p<3>;
 ; SM70-NEXT:    .reg .b16 %rs<5>;
-; SM70-NEXT:    .reg .b32 %r<24>;
+; SM70-NEXT:    .reg .b32 %r<22>;
 ; SM70-EMPTY:
 ; SM70-NEXT:  // %bb.0:
-; SM70-NEXT:    ld.param.b32 %r1, [test_faddx2_param_0];
-; SM70-NEXT:    ld.param.b32 %r2, [test_faddx2_param_1];
-; SM70-NEXT:    mov.b32 {%rs1, %rs2}, %r2;
+; SM70-NEXT:    ld.param.v2.b16 {%rs1, %rs2}, [test_faddx2_param_0];
+; SM70-NEXT:    ld.param.v2.b16 {%rs3, %rs4}, [test_faddx2_param_1];
+; SM70-NEXT:    cvt.u32.u16 %r1, %rs4;
+; SM70-NEXT:    shl.b32 %r2, %r1, 16;
 ; SM70-NEXT:    cvt.u32.u16 %r3, %rs2;
 ; SM70-NEXT:    shl.b32 %r4, %r3, 16;
-; SM70-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
-; SM70-NEXT:    cvt.u32.u16 %r5, %rs4;
-; SM70-NEXT:    shl.b32 %r6, %r5, 16;
-; SM70-NEXT:    add.rn.f32 %r7, %r6, %r4;
-; SM70-NEXT:    bfe.u32 %r8, %r7, 16, 1;
-; SM70-NEXT:    add.s32 %r9, %r8, %r7;
-; SM70-NEXT:    add.s32 %r10, %r9, 32767;
-; SM70-NEXT:    setp.nan.f32 %p1, %r7, %r7;
-; SM70-NEXT:    or.b32 %r11, %r7, 4194304;
-; SM70-NEXT:    selp.b32 %r12, %r11, %r10, %p1;
+; SM70-NEXT:    add.rn.f32 %r5, %r4, %r2;
+; SM70-NEXT:    bfe.u32 %r6, %r5, 16, 1;
+; SM70-NEXT:    add.s32 %r7, %r6, %r5;
+; SM70-NEXT:    add.s32 %r8, %r7, 32767;
+; SM70-NEXT:    setp.nan.f32 %p1, %r5, %r5;
+; SM70-NEXT:    or.b32 %r9, %r5, 4194304;
+; SM70-NEXT:    selp.b32 %r10, %r9, %r8, %p1;
+; SM70-NEXT:    cvt.u32.u16 %r11, %rs3;
+; SM70-NEXT:    shl.b32 %r12, %r11, 16;
 ; SM70-NEXT:    cvt.u32.u16 %r13, %rs1;
 ; SM70-NEXT:    shl.b32 %r14, %r13, 16;
-; SM70-NEXT:    cvt.u32.u16 %r15, %rs3;
-; SM70-NEXT:    shl.b32 %r16, %r15, 16;
-; SM70-NEXT:    add.rn.f32 %r17, %r16, %r14;
-; SM70-NEXT:    bfe.u32 %r18, %r17, 16, 1;
-; SM70-NEXT:    add.s32 %r19, %r18, %r17;
-; SM70-NEXT:    add.s32 %r20, %r19, 32767;
-; SM70-NEXT:    setp.nan.f32 %p2, %r17, %r17;
-; SM70-NEXT:    or.b32 %r21, %r17, 4194304;
-; SM70-NEXT:    selp.b32 %r22, %r21, %r20, %p2;
-; SM70-NEXT:    prmt.b32 %r23, %r22, %r12, 0x7632U;
-; SM70-NEXT:    st.param.b32 [func_retval0], %r23;
+; SM70-NEXT:    add.rn.f32 %r15, %r14, %r12;
+; SM70-NEXT:    bfe.u32 %r16, %r15, 16, 1;
+; SM70-NEXT:    add.s32 %r17, %r16, %r15;
+; SM70-NEXT:    add.s32 %r18, %r17, 32767;
+; SM70-NEXT:    setp.nan.f32 %p2, %r15, %r15;
+; SM70-NEXT:    or.b32 %r19, %r15, 4194304;
+; SM70-NEXT:    selp.b32 %r20, %r19, %r18, %p2;
+; SM70-NEXT:    prmt.b32 %r21, %r20, %r10, 0x7632U;
+; SM70-NEXT:    st.param.b32 [func_retval0], %r21;
 ; SM70-NEXT:    ret;
 ;
 ; SM80-LABEL: test_faddx2(
@@ -184,31 +182,29 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM80-NEXT:    .reg .b32 %r<5>;
 ; SM80-EMPTY:
 ; SM80-NEXT:  // %bb.0:
-; SM80-NEXT:    ld.param.b32 %r1, [test_faddx2_param_1];
-; SM80-NEXT:    ld.param.b32 %r2, [test_faddx2_param_0];
+; SM80-NEXT:    ld.param.b32 %r1, [test_faddx2_param_0];
+; SM80-NEXT:    ld.param.b32 %r2, [test_faddx2_param_1];
 ; SM80-NEXT:    mov.b32 %r3, 1065369472;
-; SM80-NEXT:    fma.rn.bf16x2 %r4, %r2, %r3, %r1;
+; SM80-NEXT:    fma.rn.bf16x2 %r4, %r1, %r3, %r2;
 ; SM80-NEXT:    st.param.b32 [func_retval0], %r4;
 ; SM80-NEXT:    ret;
 ;
 ; SM80-FTZ-LABEL: test_faddx2(
 ; SM80-FTZ:       {
 ; SM80-FTZ-NEXT:    .reg .b16 %rs<5>;
-; SM80-FTZ-NEXT:    .reg .b32 %r<10>;
+; SM80-FTZ-NEXT:    .reg .b32 %r<8>;
 ; SM80-FTZ-EMPTY:
 ; SM80-FTZ-NEXT:  // %bb.0:
-; SM80-FTZ-NEXT:    ld.param.b32 %r1, [test_faddx2_param_0];
-; SM80-FTZ-NEXT:    ld.param.b32 %r2, [test_faddx2_param_1];
-; SM80-FTZ-NEXT:    mov.b32 {%rs1, %rs2}, %r2;
-; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %r3, %rs1;
-; SM80-FTZ-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
-; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %r4, %rs3;
-; SM80-FTZ-NEXT:    add.rn.ftz.f32 %r5, %r4, %r3;
-; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %r6, %rs2;
-; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %r7, %rs4;
-; SM80-FTZ-NEXT:    add.rn.ftz.f32 %r8, %r7, %r6;
-; SM80-FTZ-NEXT:    cvt.rn.bf16x2.f32 %r9, %r8, %r5;
-; SM80-FTZ-NEXT:    st.param.b32 [func_retval0], %r9;
+; SM80-FTZ-NEXT:    ld.param.v2.b16 {%rs1, %rs2}, [test_faddx2_param_0];
+; SM80-FTZ-NEXT:    ld.param.v2.b16 {%rs3, %rs4}, [test_faddx2_param_1];
+; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %r1, %rs3;
+; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %r2, %rs1;
+; SM80-FTZ-NEXT:    add.rn.ftz.f32 %r3, %r2, %r1;
+; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %r4, %rs4;
+; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %r5, %rs2;
+; SM80-FTZ-NEXT:    add.rn.ftz.f32 %r6, %r5, %r4;
+; SM80-FTZ-NEXT:    cvt.rn.bf16x2.f32 %r7, %r6, %r3;
+; SM80-FTZ-NEXT:    st.param.b32 [func_retval0], %r7;
 ; SM80-FTZ-NEXT:    ret;
 ;
 ; SM90-LABEL: test_faddx2(
@@ -216,9 +212,9 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM90-NEXT:    .reg .b32 %r<4>;
 ; SM90-EMPTY:
 ; SM90-NEXT:  // %bb.0:
-; SM90-NEXT:    ld.param.b32 %r1, [test_faddx2_param_1];
-; SM90-NEXT:    ld.param.b32 %r2, [test_faddx2_param_0];
-; SM90-NEXT:    add.rn.bf16x2 %r3, %r2, %r1;
+; SM90-NEXT:    ld.param.b32 %r1, [test_faddx2_param_0];
+; SM90-NEXT:    ld.param.b32 %r2, [test_faddx2_param_1];
+; SM90-NEXT:    add.rn.bf16x2 %r3, %r1, %r2;
 ; SM90-NEXT:    st.param.b32 [func_retval0], %r3;
 ; SM90-NEXT:    ret;
   %r = fadd <2 x bfloat> %a, %b
@@ -230,37 +226,35 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM70:       {
 ; SM70-NEXT:    .reg .pred %p<3>;
 ; SM70-NEXT:    .reg .b16 %rs<5>;
-; SM70-NEXT:    .reg .b32 %r<24>;
+; SM70-NEXT:    .reg .b32 %r<22>;
 ; SM70-EMPTY:
 ; SM70-NEXT:  // %bb.0:
-; SM70-NEXT:    ld.param.b32 %r1, [test_fsubx2_param_0];
-; SM70-NEXT:    ld.param.b32 %r2, [test_fsubx2_param_1];
-; SM70-NEXT:    mov.b32 {%rs1, %rs2}, %r2;
+; SM70-NEXT:    ld.param.v2.b16 {%rs1, %rs2}, [test_fsubx2_param_0];
+; SM70-NEXT:    ld.param.v2.b16 {%rs3, %rs4}, [test_fsubx2_param_1];
+; SM70-NEXT:    cvt.u32.u16 %r1, %rs4;
+; SM70-NEXT:    shl.b32 %r2, %r1, 16;
 ; SM70-NEXT:    cvt.u32.u16 %r3, %rs2;
 ; SM70-NEXT:    shl.b32 %r4, %r3, 16;
-; SM70-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
-; SM70-NEXT:    cvt.u32.u16 %r5, %rs4;
-; SM70-NEXT:    shl.b32 %r6, %r5, 16;
-; SM70-NEXT:    sub.rn.f32 %r7, %r6, %r4;
-; SM70-NEXT:    bfe.u32 %r8, %r7, 16, 1;
-; SM70-NEXT:    a...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/144581


More information about the llvm-commits mailing list