[llvm] 2d3c260 - [AArch64] break non-temporal loads over 256 into 256-loads and a smaller load
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Wed Sep 28 07:21:13 PDT 2022
Author: Florian Hahn
Date: 2022-09-28T15:20:26+01:00
New Revision: 2d3c260362a29404909dadd65e904e872b52e09c
URL: https://github.com/llvm/llvm-project/commit/2d3c260362a29404909dadd65e904e872b52e09c
DIFF: https://github.com/llvm/llvm-project/commit/2d3c260362a29404909dadd65e904e872b52e09c.diff
LOG: [AArch64] break non-temporal loads over 256 into 256-loads and a smaller load
Currently over 256 non-temporal loads are broken inefficently. For example, `v17i32` gets broken into 2 128-bit loads. It is better if we can use
256-bit loads instead.
Reviewed By: fhahn
Differential Revision: https://reviews.llvm.org/D133421
Added:
Modified:
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/test/CodeGen/AArch64/nontemporal-load.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 28751afd70bd..f7cbc61a9ab2 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -899,8 +899,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
ISD::VECTOR_SPLICE, ISD::SIGN_EXTEND_INREG,
ISD::CONCAT_VECTORS, ISD::EXTRACT_SUBVECTOR,
ISD::INSERT_SUBVECTOR, ISD::STORE, ISD::BUILD_VECTOR});
- if (Subtarget->supportsAddressTopByteIgnored())
- setTargetDAGCombine(ISD::LOAD);
+ setTargetDAGCombine(ISD::LOAD);
setTargetDAGCombine(ISD::MSTORE);
@@ -18080,6 +18079,87 @@ static SDValue foldTruncStoreOfExt(SelectionDAG &DAG, SDNode *N) {
return SDValue();
}
+// Perform TBI simplification if supported by the target and try to break up nontemporal loads larger than 256-bits loads for odd types so LDNPQ 256-bit load instructions can be selected.
+static SDValue performLOADCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ SelectionDAG &DAG,
+ const AArch64Subtarget *Subtarget) {
+ if (Subtarget->supportsAddressTopByteIgnored())
+ performTBISimplification(N->getOperand(1), DCI, DAG);
+
+ LoadSDNode *LD = cast<LoadSDNode>(N);
+ EVT MemVT = LD->getMemoryVT();
+ if (LD->isVolatile() || !LD->isNonTemporal() || !Subtarget->isLittleEndian())
+ return SDValue(N, 0);
+
+ if (MemVT.isScalableVector() || MemVT.getSizeInBits() <= 256 ||
+ MemVT.getSizeInBits() % 256 == 0 ||
+ 256 % MemVT.getScalarSizeInBits() != 0)
+ return SDValue(N, 0);
+
+ SDLoc DL(LD);
+ SDValue Chain = LD->getChain();
+ SDValue BasePtr = LD->getBasePtr();
+ SDNodeFlags Flags = LD->getFlags();
+ SmallVector<SDValue, 4> LoadOps;
+ SmallVector<SDValue, 4> LoadOpsChain;
+ // Replace any non temporal load over 256-bit with a series of 256 bit loads
+ // and a scalar/vector load less than 256. This way we can utilize 256-bit
+ // loads and reduce the amount of load instructions generated.
+ MVT NewVT =
+ MVT::getVectorVT(MemVT.getVectorElementType().getSimpleVT(),
+ 256 / MemVT.getVectorElementType().getSizeInBits());
+ unsigned Num256Loads = MemVT.getSizeInBits() / 256;
+ // Create all 256-bit loads starting from offset 0 and up to Num256Loads-1*32.
+ for (unsigned I = 0; I < Num256Loads; I++) {
+ unsigned PtrOffset = I * 32;
+ SDValue NewPtr = DAG.getMemBasePlusOffset(
+ BasePtr, TypeSize::Fixed(PtrOffset), DL, Flags);
+ Align NewAlign = commonAlignment(LD->getAlign(), PtrOffset);
+ SDValue NewLoad = DAG.getLoad(
+ NewVT, DL, Chain, NewPtr, LD->getPointerInfo().getWithOffset(PtrOffset),
+ NewAlign, LD->getMemOperand()->getFlags(), LD->getAAInfo());
+ LoadOps.push_back(NewLoad);
+ LoadOpsChain.push_back(SDValue(cast<SDNode>(NewLoad), 1));
+ }
+
+ // Process remaining bits of the load operation.
+ // This is done by creating an UNDEF vector to match the size of the
+ // 256-bit loads and inserting the remaining load to it. We extract the
+ // original load type at the end using EXTRACT_SUBVECTOR instruction.
+ unsigned BitsRemaining = MemVT.getSizeInBits() % 256;
+ unsigned PtrOffset = (MemVT.getSizeInBits() - BitsRemaining) / 8;
+ MVT RemainingVT = MVT::getVectorVT(
+ MemVT.getVectorElementType().getSimpleVT(),
+ BitsRemaining / MemVT.getVectorElementType().getSizeInBits());
+ SDValue NewPtr =
+ DAG.getMemBasePlusOffset(BasePtr, TypeSize::Fixed(PtrOffset), DL, Flags);
+ Align NewAlign = commonAlignment(LD->getAlign(), PtrOffset);
+ SDValue RemainingLoad =
+ DAG.getLoad(RemainingVT, DL, Chain, NewPtr,
+ LD->getPointerInfo().getWithOffset(PtrOffset), NewAlign,
+ LD->getMemOperand()->getFlags(), LD->getAAInfo());
+ SDValue UndefVector = DAG.getUNDEF(NewVT);
+ SDValue InsertIdx = DAG.getVectorIdxConstant(0, DL);
+ SDValue ExtendedReminingLoad =
+ DAG.getNode(ISD::INSERT_SUBVECTOR, DL, NewVT,
+ {UndefVector, RemainingLoad, InsertIdx});
+ LoadOps.push_back(ExtendedReminingLoad);
+ LoadOpsChain.push_back(SDValue(cast<SDNode>(RemainingLoad), 1));
+ EVT ConcatVT =
+ EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(),
+ LoadOps.size() * NewVT.getVectorNumElements());
+ SDValue ConcatVectors =
+ DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatVT, LoadOps);
+ // Extract the original vector type size.
+ SDValue ExtractSubVector =
+ DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MemVT,
+ {ConcatVectors, DAG.getVectorIdxConstant(0, DL)});
+ SDValue TokenFactor =
+ DAG.getNode(ISD::TokenFactor, DL, MVT::Other, LoadOpsChain);
+ return DAG.getMergeValues({ExtractSubVector, TokenFactor}, DL);
+}
+
static SDValue performSTORECombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG,
@@ -20129,9 +20209,7 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
case ISD::SETCC:
return performSETCCCombine(N, DCI, DAG);
case ISD::LOAD:
- if (performTBISimplification(N->getOperand(1), DCI, DAG))
- return SDValue(N, 0);
- break;
+ return performLOADCombine(N, DCI, DAG, Subtarget);
case ISD::STORE:
return performSTORECombine(N, DCI, DAG, Subtarget);
case ISD::MSTORE:
diff --git a/llvm/test/CodeGen/AArch64/nontemporal-load.ll b/llvm/test/CodeGen/AArch64/nontemporal-load.ll
index 47556321aea4..288ba22e7928 100644
--- a/llvm/test/CodeGen/AArch64/nontemporal-load.ll
+++ b/llvm/test/CodeGen/AArch64/nontemporal-load.ll
@@ -320,12 +320,12 @@ define <16 x float> @test_ldnp_v16f32(<16 x float>* %A) {
define <17 x float> @test_ldnp_v17f32(<17 x float>* %A) {
; CHECK-LABEL: test_ldnp_v17f32:
; CHECK: ; %bb.0:
-; CHECK-NEXT: ldp q1, q2, [x0, #32]
-; CHECK-NEXT: ldp q3, q4, [x0]
-; CHECK-NEXT: ldr s0, [x0, #64]
-; CHECK-NEXT: stp q3, q4, [x8]
-; CHECK-NEXT: stp q1, q2, [x8, #32]
-; CHECK-NEXT: str s0, [x8, #64]
+; CHECK-NEXT: ldnp q0, q1, [x0, #32]
+; CHECK-NEXT: ldnp q2, q3, [x0]
+; CHECK-NEXT: ldr s4, [x0, #64]
+; CHECK-NEXT: stp q0, q1, [x8, #32]
+; CHECK-NEXT: stp q2, q3, [x8]
+; CHECK-NEXT: str s4, [x8, #64]
; CHECK-NEXT: ret
;
; CHECK-BE-LABEL: test_ldnp_v17f32:
@@ -354,24 +354,24 @@ define <17 x float> @test_ldnp_v17f32(<17 x float>* %A) {
define <33 x double> @test_ldnp_v33f64(<33 x double>* %A) {
; CHECK-LABEL: test_ldnp_v33f64:
; CHECK: ; %bb.0:
-; CHECK-NEXT: ldp q0, q1, [x0]
-; CHECK-NEXT: ldp q2, q3, [x0, #32]
-; CHECK-NEXT: ldp q4, q5, [x0, #64]
-; CHECK-NEXT: ldp q6, q7, [x0, #96]
-; CHECK-NEXT: ldp q16, q17, [x0, #128]
-; CHECK-NEXT: ldp q18, q19, [x0, #160]
-; CHECK-NEXT: ldp q21, q22, [x0, #224]
-; CHECK-NEXT: ldp q23, q24, [x0, #192]
-; CHECK-NEXT: ldr d20, [x0, #256]
+; CHECK-NEXT: ldnp q0, q1, [x0]
+; CHECK-NEXT: ldnp q2, q3, [x0, #32]
+; CHECK-NEXT: ldnp q4, q5, [x0, #64]
+; CHECK-NEXT: ldnp q6, q7, [x0, #96]
+; CHECK-NEXT: ldnp q16, q17, [x0, #128]
+; CHECK-NEXT: ldnp q18, q19, [x0, #224]
+; CHECK-NEXT: ldnp q20, q21, [x0, #192]
+; CHECK-NEXT: ldnp q22, q23, [x0, #160]
+; CHECK-NEXT: ldr d24, [x0, #256]
; CHECK-NEXT: stp q0, q1, [x8]
; CHECK-NEXT: stp q2, q3, [x8, #32]
; CHECK-NEXT: stp q4, q5, [x8, #64]
-; CHECK-NEXT: str d20, [x8, #256]
; CHECK-NEXT: stp q6, q7, [x8, #96]
; CHECK-NEXT: stp q16, q17, [x8, #128]
-; CHECK-NEXT: stp q18, q19, [x8, #160]
-; CHECK-NEXT: stp q23, q24, [x8, #192]
-; CHECK-NEXT: stp q21, q22, [x8, #224]
+; CHECK-NEXT: stp q22, q23, [x8, #160]
+; CHECK-NEXT: stp q20, q21, [x8, #192]
+; CHECK-NEXT: stp q18, q19, [x8, #224]
+; CHECK-NEXT: str d24, [x8, #256]
; CHECK-NEXT: ret
;
; CHECK-BE-LABEL: test_ldnp_v33f64:
@@ -448,10 +448,11 @@ define <33 x double> @test_ldnp_v33f64(<33 x double>* %A) {
define <33 x i8> @test_ldnp_v33i8(<33 x i8>* %A) {
; CHECK-LABEL: test_ldnp_v33i8:
; CHECK: ; %bb.0:
-; CHECK-NEXT: ldp q1, q0, [x0]
-; CHECK-NEXT: ldrb w9, [x0, #32]
-; CHECK-NEXT: stp q1, q0, [x8]
-; CHECK-NEXT: strb w9, [x8, #32]
+; CHECK-NEXT: ldnp q0, q1, [x0]
+; CHECK-NEXT: add x9, x8, #32
+; CHECK-NEXT: ldr b2, [x0, #32]
+; CHECK-NEXT: stp q0, q1, [x8]
+; CHECK-NEXT: st1.b { v2 }[0], [x9]
; CHECK-NEXT: ret
;
; CHECK-BE-LABEL: test_ldnp_v33i8:
@@ -556,15 +557,14 @@ define <4 x i63> @test_ldnp_v4i63(<4 x i63>* %A) {
define <5 x double> @test_ldnp_v5f64(<5 x double>* %A) {
; CHECK-LABEL: test_ldnp_v5f64:
; CHECK: ; %bb.0:
-; CHECK-NEXT: ldp q0, q2, [x0]
+; CHECK-NEXT: ldnp q0, q2, [x0]
+; CHECK-NEXT: ldr d4, [x0, #32]
; CHECK-NEXT: ext.16b v1, v0, v0, #8
; CHECK-NEXT: ; kill: def $d0 killed $d0 killed $q0
; CHECK-NEXT: ; kill: def $d1 killed $d1 killed $q1
; CHECK-NEXT: ext.16b v3, v2, v2, #8
-; CHECK-NEXT: ldr d4, [x0, #32]
; CHECK-NEXT: ; kill: def $d2 killed $d2 killed $q2
; CHECK-NEXT: ; kill: def $d3 killed $d3 killed $q3
-; CHECK-NEXT: ; kill: def $d4 killed $d4 killed $q4
; CHECK-NEXT: ret
;
; CHECK-BE-LABEL: test_ldnp_v5f64:
More information about the llvm-commits
mailing list