[llvm] [AArch64] Avoid streaming mode hazards from FP constant stores (PR #114207)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Oct 30 04:12:03 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: Benjamin Maxwell (MacDue)
<details>
<summary>Changes</summary>
Currently, several places will replace stores of floating-point constants with equivalent integer stores (and constant generation). This is generally beneficial. However, in streaming functions on some SME devices, this could result in a streaming mode memory hazard.
This patch adds a TLI hook `.canUseIntLoadStoreForFloatValues()` and implements it for AArch64 to prevent these transformations in streaming functions if a hazard is possible.
---
Full diff: https://github.com/llvm/llvm-project/pull/114207.diff
6 Files Affected:
- (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+4)
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+3)
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+40-35)
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+17-2)
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+2)
- (added) llvm/test/CodeGen/AArch64/sve-streaming-mode-fp-constant-stores.ll (+171)
``````````diff
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 8e0cdc6f1a5e77..c8f4436111ba78 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -489,6 +489,10 @@ class TargetLoweringBase {
return true;
}
+ /// Returns true if a floating-point load or store can be replaced with an
+ /// equivalent integer load or store without negatively affecting performance.
+ virtual bool canUseIntLoadStoreForFloatValues() const { return true; }
+
/// Return true if it is profitable to convert a select of FP constants into
/// a constant pool load whose address depends on the select condition. The
/// parameter may be used to differentiate a select with FP compare from
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index b800204d917503..ec3a7bc5eae0ea 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -21587,6 +21587,9 @@ SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) {
// processor operation but an i64 (which is not legal) requires two. So the
// transform should not be done in this case.
+ if (!TLI.canUseIntLoadStoreForFloatValues())
+ return SDValue();
+
SDValue Tmp;
switch (CFP->getSimpleValueType(0).SimpleTy) {
default:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 6ba12cfb8c5148..60aa5f1bd2fe01 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -430,45 +430,50 @@ SDValue SelectionDAGLegalize::OptimizeFloatStore(StoreSDNode* ST) {
if (Value.getOpcode() == ISD::TargetConstantFP)
return SDValue();
- if (ConstantFPSDNode *CFP = dyn_cast<ConstantFPSDNode>(Value)) {
- if (CFP->getValueType(0) == MVT::f32 &&
- TLI.isTypeLegal(MVT::i32)) {
- SDValue Con = DAG.getConstant(CFP->getValueAPF().
- bitcastToAPInt().zextOrTrunc(32),
- SDLoc(CFP), MVT::i32);
- return DAG.getStore(Chain, dl, Con, Ptr, ST->getPointerInfo(),
- ST->getOriginalAlign(), MMOFlags, AAInfo);
- }
+ ConstantFPSDNode *CFP = dyn_cast<ConstantFPSDNode>(Value);
+ if (!CFP)
+ return SDValue();
- if (CFP->getValueType(0) == MVT::f64 &&
- !TLI.isFPImmLegal(CFP->getValueAPF(), MVT::f64)) {
- // If this target supports 64-bit registers, do a single 64-bit store.
- if (TLI.isTypeLegal(MVT::i64)) {
- SDValue Con = DAG.getConstant(CFP->getValueAPF().bitcastToAPInt().
- zextOrTrunc(64), SDLoc(CFP), MVT::i64);
- return DAG.getStore(Chain, dl, Con, Ptr, ST->getPointerInfo(),
- ST->getOriginalAlign(), MMOFlags, AAInfo);
- }
+ if (!TLI.canUseIntLoadStoreForFloatValues())
+ return SDValue();
- if (TLI.isTypeLegal(MVT::i32) && !ST->isVolatile()) {
- // Otherwise, if the target supports 32-bit registers, use 2 32-bit
- // stores. If the target supports neither 32- nor 64-bits, this
- // xform is certainly not worth it.
- const APInt &IntVal = CFP->getValueAPF().bitcastToAPInt();
- SDValue Lo = DAG.getConstant(IntVal.trunc(32), dl, MVT::i32);
- SDValue Hi = DAG.getConstant(IntVal.lshr(32).trunc(32), dl, MVT::i32);
- if (DAG.getDataLayout().isBigEndian())
- std::swap(Lo, Hi);
-
- Lo = DAG.getStore(Chain, dl, Lo, Ptr, ST->getPointerInfo(),
- ST->getOriginalAlign(), MMOFlags, AAInfo);
- Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::getFixed(4), dl);
- Hi = DAG.getStore(Chain, dl, Hi, Ptr,
- ST->getPointerInfo().getWithOffset(4),
+ EVT FloatTy = CFP->getValueType(0);
+ if (FloatTy == MVT::f32 && TLI.isTypeLegal(MVT::i32)) {
+ SDValue Con =
+ DAG.getConstant(CFP->getValueAPF().bitcastToAPInt().zextOrTrunc(32),
+ SDLoc(CFP), MVT::i32);
+ return DAG.getStore(Chain, dl, Con, Ptr, ST->getPointerInfo(),
+ ST->getOriginalAlign(), MMOFlags, AAInfo);
+ }
+
+ if (FloatTy == MVT::f64 && !TLI.isFPImmLegal(CFP->getValueAPF(), MVT::f64)) {
+ // If this target supports 64-bit registers, do a single 64-bit store.
+ if (TLI.isTypeLegal(MVT::i64)) {
+ SDValue Con =
+ DAG.getConstant(CFP->getValueAPF().bitcastToAPInt().zextOrTrunc(64),
+ SDLoc(CFP), MVT::i64);
+ return DAG.getStore(Chain, dl, Con, Ptr, ST->getPointerInfo(),
ST->getOriginalAlign(), MMOFlags, AAInfo);
+ }
- return DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Lo, Hi);
- }
+ if (TLI.isTypeLegal(MVT::i32) && !ST->isVolatile()) {
+ // Otherwise, if the target supports 32-bit registers, use 2 32-bit
+ // stores. If the target supports neither 32- nor 64-bits, this xform is
+ // certainly not worth it.
+ const APInt &IntVal = CFP->getValueAPF().bitcastToAPInt();
+ SDValue Lo = DAG.getConstant(IntVal.trunc(32), dl, MVT::i32);
+ SDValue Hi = DAG.getConstant(IntVal.lshr(32).trunc(32), dl, MVT::i32);
+ if (DAG.getDataLayout().isBigEndian())
+ std::swap(Lo, Hi);
+
+ Lo = DAG.getStore(Chain, dl, Lo, Ptr, ST->getPointerInfo(),
+ ST->getOriginalAlign(), MMOFlags, AAInfo);
+ Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::getFixed(4), dl);
+ Hi = DAG.getStore(Chain, dl, Hi, Ptr,
+ ST->getPointerInfo().getWithOffset(4),
+ ST->getOriginalAlign(), MMOFlags, AAInfo);
+
+ return DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Lo, Hi);
}
}
return SDValue();
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 31a720ed7b5c77..bdabfafd43976f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -22514,7 +22514,8 @@ static SDValue performSTNT1Combine(SDNode *N, SelectionDAG &DAG) {
/// movi v0.2d, #0
/// str q0, [x0]
///
-static SDValue replaceZeroVectorStore(SelectionDAG &DAG, StoreSDNode &St) {
+static SDValue replaceZeroVectorStore(SelectionDAG &DAG, StoreSDNode &St,
+ AArch64Subtarget const &Subtarget) {
SDValue StVal = St.getValue();
EVT VT = StVal.getValueType();
@@ -22522,6 +22523,13 @@ static SDValue replaceZeroVectorStore(SelectionDAG &DAG, StoreSDNode &St) {
if (VT.isScalableVector())
return SDValue();
+ // Do not replace the FP store when it could result in a streaming memory
+ // hazard.
+ if (VT.getVectorElementType().isFloatingPoint() &&
+ Subtarget.getStreamingHazardSize() > 0 &&
+ (Subtarget.isStreaming() || Subtarget.isStreamingCompatible()))
+ return SDValue();
+
// It is beneficial to scalarize a zero splat store for 2 or 3 i64 elements or
// 2, 3 or 4 i32 elements.
int NumVecElts = VT.getVectorNumElements();
@@ -22651,7 +22659,7 @@ static SDValue splitStores(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
// If we get a splat of zeros, convert this vector store to a store of
// scalars. They will be merged into store pairs of xzr thereby removing one
// instruction and one register.
- if (SDValue ReplacedZeroSplat = replaceZeroVectorStore(DAG, *S))
+ if (SDValue ReplacedZeroSplat = replaceZeroVectorStore(DAG, *S, *Subtarget))
return ReplacedZeroSplat;
// FIXME: The logic for deciding if an unaligned store should be split should
@@ -27706,6 +27714,13 @@ bool AArch64TargetLowering::canMergeStoresTo(unsigned AddressSpace, EVT MemVT,
return !NoFloat || MemVT.getSizeInBits() <= 64;
}
+bool AArch64TargetLowering::canUseIntLoadStoreForFloatValues() const {
+ // Avoid replacing FP loads/stores with integer ones when it could result in a
+ // streaming memory hazard.
+ return !Subtarget->getStreamingHazardSize() ||
+ (!Subtarget->isStreaming() && !Subtarget->isStreamingCompatible());
+}
+
bool AArch64TargetLowering::preferIncOfAddToSubOfNot(EVT VT) const {
// We want inc-of-add for scalars and sub-of-not for vectors.
return VT.isScalarInteger();
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index d696355bb062a8..b77bb29144713b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -840,6 +840,8 @@ class AArch64TargetLowering : public TargetLowering {
bool canMergeStoresTo(unsigned AddressSpace, EVT MemVT,
const MachineFunction &MF) const override;
+ bool canUseIntLoadStoreForFloatValues() const override;
+
bool isCheapToSpeculateCttz(Type *) const override {
return true;
}
diff --git a/llvm/test/CodeGen/AArch64/sve-streaming-mode-fp-constant-stores.ll b/llvm/test/CodeGen/AArch64/sve-streaming-mode-fp-constant-stores.ll
new file mode 100644
index 00000000000000..c060bc0a901ea4
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-streaming-mode-fp-constant-stores.ll
@@ -0,0 +1,171 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -aarch64-streaming-hazard-size=64 -force-streaming-compatible -mattr=+sve < %s | FileCheck %s
+; RUN: llc -aarch64-streaming-hazard-size=64 -force-streaming -mattr=+sme < %s | FileCheck %s
+; RUN: llc -force-streaming -mattr=+sme < %s | FileCheck %s --check-prefix=NOHAZARD
+
+target triple = "aarch64-unknown-linux-gnu"
+
+; This test checks that in streaming[-compatible] functions if there could be
+; a hazard between GPR and FPR memory operations, then integer stores are not
+; used for floating-point constants.
+
+define void @"store_f64_0.0"(ptr %num) {
+; CHECK-LABEL: store_f64_0.0:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: fmov d0, xzr
+; CHECK-NEXT: str d0, [x0]
+; CHECK-NEXT: ret
+;
+; NOHAZARD-LABEL: store_f64_0.0:
+; NOHAZARD: // %bb.0: // %entry
+; NOHAZARD-NEXT: str xzr, [x0]
+; NOHAZARD-NEXT: ret
+entry:
+ store double 0.000000e+00, ptr %num, align 8
+ ret void
+}
+
+define void @"store_f64_1.0"(ptr %num) {
+; CHECK-LABEL: store_f64_1.0:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: fmov d0, #1.00000000
+; CHECK-NEXT: str d0, [x0]
+; CHECK-NEXT: ret
+;
+; NOHAZARD-LABEL: store_f64_1.0:
+; NOHAZARD: // %bb.0: // %entry
+; NOHAZARD-NEXT: mov x8, #4607182418800017408 // =0x3ff0000000000000
+; NOHAZARD-NEXT: str x8, [x0]
+; NOHAZARD-NEXT: ret
+entry:
+ store double 1.000000e+00, ptr %num, align 8
+ ret void
+}
+
+define void @"store_f64_1.23456789"(ptr %num) {
+; CHECK-LABEL: store_f64_1.23456789:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: adrp x8, .LCPI2_0
+; CHECK-NEXT: ldr d0, [x8, :lo12:.LCPI2_0]
+; CHECK-NEXT: str d0, [x0]
+; CHECK-NEXT: ret
+;
+; NOHAZARD-LABEL: store_f64_1.23456789:
+; NOHAZARD: // %bb.0: // %entry
+; NOHAZARD-NEXT: mov x8, #56859 // =0xde1b
+; NOHAZARD-NEXT: movk x8, #17027, lsl #16
+; NOHAZARD-NEXT: movk x8, #49354, lsl #32
+; NOHAZARD-NEXT: movk x8, #16371, lsl #48
+; NOHAZARD-NEXT: str x8, [x0]
+; NOHAZARD-NEXT: ret
+entry:
+ store double 0x3FF3C0CA4283DE1B, ptr %num, align 8
+ ret void
+}
+
+define void @"store_f32_0.0"(ptr %num) {
+; CHECK-LABEL: store_f32_0.0:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: fmov s0, wzr
+; CHECK-NEXT: str s0, [x0]
+; CHECK-NEXT: ret
+;
+; NOHAZARD-LABEL: store_f32_0.0:
+; NOHAZARD: // %bb.0: // %entry
+; NOHAZARD-NEXT: str wzr, [x0]
+; NOHAZARD-NEXT: ret
+entry:
+ store float 0.000000e+00, ptr %num, align 4
+ ret void
+}
+
+define void @"store_f32_1.0"(ptr %num) {
+; CHECK-LABEL: store_f32_1.0:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: fmov s0, #1.00000000
+; CHECK-NEXT: str s0, [x0]
+; CHECK-NEXT: ret
+;
+; NOHAZARD-LABEL: store_f32_1.0:
+; NOHAZARD: // %bb.0: // %entry
+; NOHAZARD-NEXT: mov w8, #1065353216 // =0x3f800000
+; NOHAZARD-NEXT: str w8, [x0]
+; NOHAZARD-NEXT: ret
+entry:
+ store float 1.000000e+00, ptr %num, align 4
+ ret void
+}
+
+define void @"store_f32_1.23456789"(ptr %num) {
+; CHECK-LABEL: store_f32_1.23456789:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov w8, #1618 // =0x652
+; CHECK-NEXT: movk w8, #16286, lsl #16
+; CHECK-NEXT: fmov s0, w8
+; CHECK-NEXT: str s0, [x0]
+; CHECK-NEXT: ret
+;
+; NOHAZARD-LABEL: store_f32_1.23456789:
+; NOHAZARD: // %bb.0: // %entry
+; NOHAZARD-NEXT: mov w8, #1618 // =0x652
+; NOHAZARD-NEXT: movk w8, #16286, lsl #16
+; NOHAZARD-NEXT: str w8, [x0]
+; NOHAZARD-NEXT: ret
+entry:
+ store float 0x3FF3C0CA40000000, ptr %num, align 4
+ ret void
+}
+
+define void @"store_v4f32_0.0"(ptr %num) {
+; CHECK-LABEL: store_v4f32_0.0:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z0.s, #0 // =0x0
+; CHECK-NEXT: str q0, [x0]
+; CHECK-NEXT: ret
+;
+; NOHAZARD-LABEL: store_v4f32_0.0:
+; NOHAZARD: // %bb.0: // %entry
+; NOHAZARD-NEXT: stp xzr, xzr, [x0]
+; NOHAZARD-NEXT: ret
+entry:
+ store <4 x float> zeroinitializer, ptr %num, align 16
+ ret void
+}
+
+define void @"store_v4f32_1.0"(ptr %num) {
+; CHECK-LABEL: store_v4f32_1.0:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: fmov z0.s, #1.00000000
+; CHECK-NEXT: str q0, [x0]
+; CHECK-NEXT: ret
+;
+; NOHAZARD-LABEL: store_v4f32_1.0:
+; NOHAZARD: // %bb.0: // %entry
+; NOHAZARD-NEXT: fmov z0.s, #1.00000000
+; NOHAZARD-NEXT: str q0, [x0]
+; NOHAZARD-NEXT: ret
+entry:
+ store <4 x float> <float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00>, ptr %num, align 16
+ ret void
+}
+
+define void @"store_v4f32_1.23456789"(ptr %num) {
+; CHECK-LABEL: store_v4f32_1.23456789:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov w8, #1618 // =0x652
+; CHECK-NEXT: movk w8, #16286, lsl #16
+; CHECK-NEXT: mov z0.s, w8
+; CHECK-NEXT: str q0, [x0]
+; CHECK-NEXT: ret
+;
+; NOHAZARD-LABEL: store_v4f32_1.23456789:
+; NOHAZARD: // %bb.0: // %entry
+; NOHAZARD-NEXT: mov w8, #1618 // =0x652
+; NOHAZARD-NEXT: movk w8, #16286, lsl #16
+; NOHAZARD-NEXT: mov z0.s, w8
+; NOHAZARD-NEXT: str q0, [x0]
+; NOHAZARD-NEXT: ret
+entry:
+ store <4 x float> <float 0x3FF3C0CA40000000, float 0x3FF3C0CA40000000, float 0x3FF3C0CA40000000, float 0x3FF3C0CA40000000>, ptr %num, align 16
+ ret void
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/114207
More information about the llvm-commits
mailing list