[llvm] 2d3c260 - [AArch64] break non-temporal loads over 256 into 256-loads and a smaller load

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 28 07:21:13 PDT 2022


Author: Florian Hahn
Date: 2022-09-28T15:20:26+01:00
New Revision: 2d3c260362a29404909dadd65e904e872b52e09c

URL: https://github.com/llvm/llvm-project/commit/2d3c260362a29404909dadd65e904e872b52e09c
DIFF: https://github.com/llvm/llvm-project/commit/2d3c260362a29404909dadd65e904e872b52e09c.diff

LOG: [AArch64] break non-temporal loads over 256 into 256-loads and a smaller load

Currently over 256 non-temporal loads are broken inefficently. For example, `v17i32` gets broken into 2 128-bit loads. It is better if we can use
256-bit loads instead.

Reviewed By: fhahn

Differential Revision: https://reviews.llvm.org/D133421

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/nontemporal-load.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 28751afd70bd..f7cbc61a9ab2 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -899,8 +899,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
                        ISD::VECTOR_SPLICE, ISD::SIGN_EXTEND_INREG,
                        ISD::CONCAT_VECTORS, ISD::EXTRACT_SUBVECTOR,
                        ISD::INSERT_SUBVECTOR, ISD::STORE, ISD::BUILD_VECTOR});
-  if (Subtarget->supportsAddressTopByteIgnored())
-    setTargetDAGCombine(ISD::LOAD);
+  setTargetDAGCombine(ISD::LOAD);
 
   setTargetDAGCombine(ISD::MSTORE);
 
@@ -18080,6 +18079,87 @@ static SDValue foldTruncStoreOfExt(SelectionDAG &DAG, SDNode *N) {
   return SDValue();
 }
 
+// Perform TBI simplification if supported by the target and try to break up nontemporal loads larger than 256-bits loads for odd types so LDNPQ 256-bit load instructions can be selected.
+static SDValue performLOADCombine(SDNode *N,
+                                  TargetLowering::DAGCombinerInfo &DCI,
+                                  SelectionDAG &DAG,
+                                  const AArch64Subtarget *Subtarget) {
+  if (Subtarget->supportsAddressTopByteIgnored())
+    performTBISimplification(N->getOperand(1), DCI, DAG);
+
+  LoadSDNode *LD = cast<LoadSDNode>(N);
+  EVT MemVT = LD->getMemoryVT();
+  if (LD->isVolatile() || !LD->isNonTemporal() || !Subtarget->isLittleEndian())
+    return SDValue(N, 0);
+
+  if (MemVT.isScalableVector() || MemVT.getSizeInBits() <= 256 ||
+      MemVT.getSizeInBits() % 256 == 0 ||
+      256 % MemVT.getScalarSizeInBits() != 0)
+    return SDValue(N, 0);
+
+  SDLoc DL(LD);
+  SDValue Chain = LD->getChain();
+  SDValue BasePtr = LD->getBasePtr();
+  SDNodeFlags Flags = LD->getFlags();
+  SmallVector<SDValue, 4> LoadOps;
+  SmallVector<SDValue, 4> LoadOpsChain;
+  // Replace any non temporal load over 256-bit with a series of 256 bit loads
+  // and a scalar/vector load less than 256. This way we can utilize 256-bit
+  // loads and reduce the amount of load instructions generated.
+  MVT NewVT =
+      MVT::getVectorVT(MemVT.getVectorElementType().getSimpleVT(),
+                       256 / MemVT.getVectorElementType().getSizeInBits());
+  unsigned Num256Loads = MemVT.getSizeInBits() / 256;
+  // Create all 256-bit loads starting from offset 0 and up to Num256Loads-1*32.
+  for (unsigned I = 0; I < Num256Loads; I++) {
+    unsigned PtrOffset = I * 32;
+    SDValue NewPtr = DAG.getMemBasePlusOffset(
+        BasePtr, TypeSize::Fixed(PtrOffset), DL, Flags);
+    Align NewAlign = commonAlignment(LD->getAlign(), PtrOffset);
+    SDValue NewLoad = DAG.getLoad(
+        NewVT, DL, Chain, NewPtr, LD->getPointerInfo().getWithOffset(PtrOffset),
+        NewAlign, LD->getMemOperand()->getFlags(), LD->getAAInfo());
+    LoadOps.push_back(NewLoad);
+    LoadOpsChain.push_back(SDValue(cast<SDNode>(NewLoad), 1));
+  }
+
+  // Process remaining bits of the load operation.
+  // This is done by creating an UNDEF vector to match the size of the
+  // 256-bit loads and inserting the remaining load to it. We extract the
+  // original load type at the end using EXTRACT_SUBVECTOR instruction.
+  unsigned BitsRemaining = MemVT.getSizeInBits() % 256;
+  unsigned PtrOffset = (MemVT.getSizeInBits() - BitsRemaining) / 8;
+  MVT RemainingVT = MVT::getVectorVT(
+      MemVT.getVectorElementType().getSimpleVT(),
+      BitsRemaining / MemVT.getVectorElementType().getSizeInBits());
+  SDValue NewPtr =
+      DAG.getMemBasePlusOffset(BasePtr, TypeSize::Fixed(PtrOffset), DL, Flags);
+  Align NewAlign = commonAlignment(LD->getAlign(), PtrOffset);
+  SDValue RemainingLoad =
+      DAG.getLoad(RemainingVT, DL, Chain, NewPtr,
+                  LD->getPointerInfo().getWithOffset(PtrOffset), NewAlign,
+                  LD->getMemOperand()->getFlags(), LD->getAAInfo());
+  SDValue UndefVector = DAG.getUNDEF(NewVT);
+  SDValue InsertIdx = DAG.getVectorIdxConstant(0, DL);
+  SDValue ExtendedReminingLoad =
+      DAG.getNode(ISD::INSERT_SUBVECTOR, DL, NewVT,
+                  {UndefVector, RemainingLoad, InsertIdx});
+  LoadOps.push_back(ExtendedReminingLoad);
+  LoadOpsChain.push_back(SDValue(cast<SDNode>(RemainingLoad), 1));
+  EVT ConcatVT =
+      EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(),
+                       LoadOps.size() * NewVT.getVectorNumElements());
+  SDValue ConcatVectors =
+      DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatVT, LoadOps);
+  // Extract the original vector type size.
+  SDValue ExtractSubVector =
+      DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MemVT,
+                  {ConcatVectors, DAG.getVectorIdxConstant(0, DL)});
+  SDValue TokenFactor =
+      DAG.getNode(ISD::TokenFactor, DL, MVT::Other, LoadOpsChain);
+  return DAG.getMergeValues({ExtractSubVector, TokenFactor}, DL);
+}
+
 static SDValue performSTORECombine(SDNode *N,
                                    TargetLowering::DAGCombinerInfo &DCI,
                                    SelectionDAG &DAG,
@@ -20129,9 +20209,7 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::SETCC:
     return performSETCCCombine(N, DCI, DAG);
   case ISD::LOAD:
-    if (performTBISimplification(N->getOperand(1), DCI, DAG))
-      return SDValue(N, 0);
-    break;
+    return performLOADCombine(N, DCI, DAG, Subtarget);
   case ISD::STORE:
     return performSTORECombine(N, DCI, DAG, Subtarget);
   case ISD::MSTORE:

diff  --git a/llvm/test/CodeGen/AArch64/nontemporal-load.ll b/llvm/test/CodeGen/AArch64/nontemporal-load.ll
index 47556321aea4..288ba22e7928 100644
--- a/llvm/test/CodeGen/AArch64/nontemporal-load.ll
+++ b/llvm/test/CodeGen/AArch64/nontemporal-load.ll
@@ -320,12 +320,12 @@ define <16 x float> @test_ldnp_v16f32(<16 x float>* %A) {
 define <17 x float> @test_ldnp_v17f32(<17 x float>* %A) {
 ; CHECK-LABEL: test_ldnp_v17f32:
 ; CHECK:       ; %bb.0:
-; CHECK-NEXT:    ldp q1, q2, [x0, #32]
-; CHECK-NEXT:    ldp q3, q4, [x0]
-; CHECK-NEXT:    ldr s0, [x0, #64]
-; CHECK-NEXT:    stp q3, q4, [x8]
-; CHECK-NEXT:    stp q1, q2, [x8, #32]
-; CHECK-NEXT:    str s0, [x8, #64]
+; CHECK-NEXT:    ldnp q0, q1, [x0, #32]
+; CHECK-NEXT:    ldnp q2, q3, [x0]
+; CHECK-NEXT:    ldr s4, [x0, #64]
+; CHECK-NEXT:    stp q0, q1, [x8, #32]
+; CHECK-NEXT:    stp q2, q3, [x8]
+; CHECK-NEXT:    str s4, [x8, #64]
 ; CHECK-NEXT:    ret
 ;
 ; CHECK-BE-LABEL: test_ldnp_v17f32:
@@ -354,24 +354,24 @@ define <17 x float> @test_ldnp_v17f32(<17 x float>* %A) {
 define <33 x double> @test_ldnp_v33f64(<33 x double>* %A) {
 ; CHECK-LABEL: test_ldnp_v33f64:
 ; CHECK:       ; %bb.0:
-; CHECK-NEXT:    ldp q0, q1, [x0]
-; CHECK-NEXT:    ldp q2, q3, [x0, #32]
-; CHECK-NEXT:    ldp q4, q5, [x0, #64]
-; CHECK-NEXT:    ldp q6, q7, [x0, #96]
-; CHECK-NEXT:    ldp q16, q17, [x0, #128]
-; CHECK-NEXT:    ldp q18, q19, [x0, #160]
-; CHECK-NEXT:    ldp q21, q22, [x0, #224]
-; CHECK-NEXT:    ldp q23, q24, [x0, #192]
-; CHECK-NEXT:    ldr d20, [x0, #256]
+; CHECK-NEXT:    ldnp q0, q1, [x0]
+; CHECK-NEXT:    ldnp q2, q3, [x0, #32]
+; CHECK-NEXT:    ldnp q4, q5, [x0, #64]
+; CHECK-NEXT:    ldnp q6, q7, [x0, #96]
+; CHECK-NEXT:    ldnp q16, q17, [x0, #128]
+; CHECK-NEXT:    ldnp q18, q19, [x0, #224]
+; CHECK-NEXT:    ldnp q20, q21, [x0, #192]
+; CHECK-NEXT:    ldnp q22, q23, [x0, #160]
+; CHECK-NEXT:    ldr d24, [x0, #256]
 ; CHECK-NEXT:    stp q0, q1, [x8]
 ; CHECK-NEXT:    stp q2, q3, [x8, #32]
 ; CHECK-NEXT:    stp q4, q5, [x8, #64]
-; CHECK-NEXT:    str d20, [x8, #256]
 ; CHECK-NEXT:    stp q6, q7, [x8, #96]
 ; CHECK-NEXT:    stp q16, q17, [x8, #128]
-; CHECK-NEXT:    stp q18, q19, [x8, #160]
-; CHECK-NEXT:    stp q23, q24, [x8, #192]
-; CHECK-NEXT:    stp q21, q22, [x8, #224]
+; CHECK-NEXT:    stp q22, q23, [x8, #160]
+; CHECK-NEXT:    stp q20, q21, [x8, #192]
+; CHECK-NEXT:    stp q18, q19, [x8, #224]
+; CHECK-NEXT:    str d24, [x8, #256]
 ; CHECK-NEXT:    ret
 ;
 ; CHECK-BE-LABEL: test_ldnp_v33f64:
@@ -448,10 +448,11 @@ define <33 x double> @test_ldnp_v33f64(<33 x double>* %A) {
 define <33 x i8> @test_ldnp_v33i8(<33 x i8>* %A) {
 ; CHECK-LABEL: test_ldnp_v33i8:
 ; CHECK:       ; %bb.0:
-; CHECK-NEXT:    ldp q1, q0, [x0]
-; CHECK-NEXT:    ldrb w9, [x0, #32]
-; CHECK-NEXT:    stp q1, q0, [x8]
-; CHECK-NEXT:    strb w9, [x8, #32]
+; CHECK-NEXT:    ldnp q0, q1, [x0]
+; CHECK-NEXT:    add x9, x8, #32
+; CHECK-NEXT:    ldr b2, [x0, #32]
+; CHECK-NEXT:    stp q0, q1, [x8]
+; CHECK-NEXT:    st1.b { v2 }[0], [x9]
 ; CHECK-NEXT:    ret
 ;
 ; CHECK-BE-LABEL: test_ldnp_v33i8:
@@ -556,15 +557,14 @@ define <4 x i63> @test_ldnp_v4i63(<4 x i63>* %A) {
 define <5 x double> @test_ldnp_v5f64(<5 x double>* %A) {
 ; CHECK-LABEL: test_ldnp_v5f64:
 ; CHECK:       ; %bb.0:
-; CHECK-NEXT:    ldp q0, q2, [x0]
+; CHECK-NEXT:    ldnp q0, q2, [x0]
+; CHECK-NEXT:    ldr d4, [x0, #32]
 ; CHECK-NEXT:    ext.16b v1, v0, v0, #8
 ; CHECK-NEXT:    ; kill: def $d0 killed $d0 killed $q0
 ; CHECK-NEXT:    ; kill: def $d1 killed $d1 killed $q1
 ; CHECK-NEXT:    ext.16b v3, v2, v2, #8
-; CHECK-NEXT:    ldr d4, [x0, #32]
 ; CHECK-NEXT:    ; kill: def $d2 killed $d2 killed $q2
 ; CHECK-NEXT:    ; kill: def $d3 killed $d3 killed $q3
-; CHECK-NEXT:    ; kill: def $d4 killed $d4 killed $q4
 ; CHECK-NEXT:    ret
 ;
 ; CHECK-BE-LABEL: test_ldnp_v5f64:


        


More information about the llvm-commits mailing list