[llvm] d916856 - [AArch64] Allow strict opcodes in faddp patterns
John Brawn via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 17 05:12:34 PST 2022
Author: John Brawn
Date: 2022-02-17T13:11:55Z
New Revision: d916856bee1165aa78ca342cdd43523c33333736
URL: https://github.com/llvm/llvm-project/commit/d916856bee1165aa78ca342cdd43523c33333736
DIFF: https://github.com/llvm/llvm-project/commit/d916856bee1165aa78ca342cdd43523c33333736.diff
LOG: [AArch64] Allow strict opcodes in faddp patterns
This also requires adjustment to code in AArch64ISelLowering so that
vector_extract is distributed over strict_fadd.
Differential Revision: https://reviews.llvm.org/D118489
Added:
Modified:
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64InstrInfo.td
llvm/test/CodeGen/AArch64/faddp.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d06fd2b27341..6e763202ce91 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -14228,6 +14228,7 @@ static SDValue performANDCombine(SDNode *N,
static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
switch (Opcode) {
+ case ISD::STRICT_FADD:
case ISD::FADD:
return (FullFP16 && VT == MVT::f16) || VT == MVT::f32 || VT == MVT::f64;
case ISD::ADD:
@@ -14244,6 +14245,7 @@ static SDValue performExtractVectorEltCombine(SDNode *N, SelectionDAG &DAG) {
EVT VT = N->getValueType(0);
const bool FullFP16 =
static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasFullFP16();
+ bool IsStrict = N0->isStrictFPOpcode();
// Rewrite for pairwise fadd pattern
// (f32 (extract_vector_elt
@@ -14252,11 +14254,14 @@ static SDValue performExtractVectorEltCombine(SDNode *N, SelectionDAG &DAG) {
// ->
// (f32 (fadd (extract_vector_elt (vXf32 Other) 0)
// (extract_vector_elt (vXf32 Other) 1))
+ // For strict_fadd we need to make sure the old strict_fadd can be deleted, so
+ // we can only do this when it's used only by the extract_vector_elt.
if (ConstantN1 && ConstantN1->getZExtValue() == 0 &&
- hasPairwiseAdd(N0->getOpcode(), VT, FullFP16)) {
+ hasPairwiseAdd(N0->getOpcode(), VT, FullFP16) &&
+ (!IsStrict || N0.hasOneUse())) {
SDLoc DL(N0);
- SDValue N00 = N0->getOperand(0);
- SDValue N01 = N0->getOperand(1);
+ SDValue N00 = N0->getOperand(IsStrict ? 1 : 0);
+ SDValue N01 = N0->getOperand(IsStrict ? 2 : 1);
ShuffleVectorSDNode *Shuffle = dyn_cast<ShuffleVectorSDNode>(N01);
SDValue Other = N00;
@@ -14269,11 +14274,23 @@ static SDValue performExtractVectorEltCombine(SDNode *N, SelectionDAG &DAG) {
if (Shuffle && Shuffle->getMaskElt(0) == 1 &&
Other == Shuffle->getOperand(0)) {
- return DAG.getNode(N0->getOpcode(), DL, VT,
- DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Other,
- DAG.getConstant(0, DL, MVT::i64)),
- DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Other,
- DAG.getConstant(1, DL, MVT::i64)));
+ SDValue Extract1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Other,
+ DAG.getConstant(0, DL, MVT::i64));
+ SDValue Extract2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Other,
+ DAG.getConstant(1, DL, MVT::i64));
+ if (!IsStrict)
+ return DAG.getNode(N0->getOpcode(), DL, VT, Extract1, Extract2);
+
+ // For strict_fadd we need uses of the final extract_vector to be replaced
+ // with the strict_fadd, but we also need uses of the chain output of the
+ // original strict_fadd to use the chain output of the new strict_fadd as
+ // otherwise it may not be deleted.
+ SDValue Ret = DAG.getNode(N0->getOpcode(), DL,
+ {VT, MVT::Other},
+ {N0->getOperand(0), Extract1, Extract2});
+ DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Ret);
+ DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Ret.getValue(1));
+ return SDValue(N, 0);
}
}
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 53a06c2b9e8e..664f670d741c 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -8100,17 +8100,17 @@ defm : InsertSubvectorUndef<i64>;
def : Pat<(i64 (add (vector_extract (v2i64 FPR128:$Rn), (i64 0)),
(vector_extract (v2i64 FPR128:$Rn), (i64 1)))),
(i64 (ADDPv2i64p (v2i64 FPR128:$Rn)))>;
-def : Pat<(f64 (fadd (vector_extract (v2f64 FPR128:$Rn), (i64 0)),
- (vector_extract (v2f64 FPR128:$Rn), (i64 1)))),
+def : Pat<(f64 (any_fadd (vector_extract (v2f64 FPR128:$Rn), (i64 0)),
+ (vector_extract (v2f64 FPR128:$Rn), (i64 1)))),
(f64 (FADDPv2i64p (v2f64 FPR128:$Rn)))>;
// vector_extract on 64-bit vectors gets promoted to a 128 bit vector,
// so we match on v4f32 here, not v2f32. This will also catch adding
// the low two lanes of a true v4f32 vector.
-def : Pat<(fadd (vector_extract (v4f32 FPR128:$Rn), (i64 0)),
- (vector_extract (v4f32 FPR128:$Rn), (i64 1))),
+def : Pat<(any_fadd (vector_extract (v4f32 FPR128:$Rn), (i64 0)),
+ (vector_extract (v4f32 FPR128:$Rn), (i64 1))),
(f32 (FADDPv2i32p (EXTRACT_SUBREG FPR128:$Rn, dsub)))>;
-def : Pat<(fadd (vector_extract (v8f16 FPR128:$Rn), (i64 0)),
- (vector_extract (v8f16 FPR128:$Rn), (i64 1))),
+def : Pat<(any_fadd (vector_extract (v8f16 FPR128:$Rn), (i64 0)),
+ (vector_extract (v8f16 FPR128:$Rn), (i64 1))),
(f16 (FADDPv2i16p (EXTRACT_SUBREG FPR128:$Rn, dsub)))>;
// Scalar 64-bit shifts in FPR64 registers.
diff --git a/llvm/test/CodeGen/AArch64/faddp.ll b/llvm/test/CodeGen/AArch64/faddp.ll
index 06e976136c37..1476f7bcda5e 100644
--- a/llvm/test/CodeGen/AArch64/faddp.ll
+++ b/llvm/test/CodeGen/AArch64/faddp.ll
@@ -100,3 +100,83 @@ entry:
%1 = extractelement <2 x i64> %0, i32 0
ret i64 %1
}
+
+define float @faddp_2xfloat_strict(<2 x float> %a) #0 {
+; CHECK-LABEL: faddp_2xfloat_strict:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NEXT: faddp s0, v0.2s
+; CHECK-NEXT: ret
+entry:
+ %shift = shufflevector <2 x float> %a, <2 x float> undef, <2 x i32> <i32 1, i32 undef>
+ %0 = call <2 x float> @llvm.experimental.constrained.fadd.v2f32(<2 x float> %a, <2 x float> %shift, metadata !"round.tonearest", metadata !"fpexcept.strict") #0
+ %1 = extractelement <2 x float> %0, i32 0
+ ret float %1
+}
+
+define float @faddp_4xfloat_strict(<4 x float> %a) #0 {
+; CHECK-LABEL: faddp_4xfloat_strict:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: faddp s0, v0.2s
+; CHECK-NEXT: ret
+entry:
+ %shift = shufflevector <4 x float> %a, <4 x float> undef, <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
+ %0 = call <4 x float> @llvm.experimental.constrained.fadd.v4f32(<4 x float> %a, <4 x float> %shift, metadata !"round.tonearest", metadata !"fpexcept.strict") #0
+ %1 = extractelement <4 x float> %0, i32 0
+ ret float %1
+}
+
+define float @faddp_4xfloat_commute_strict(<4 x float> %a) #0 {
+; CHECK-LABEL: faddp_4xfloat_commute_strict:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: faddp s0, v0.2s
+; CHECK-NEXT: ret
+entry:
+ %shift = shufflevector <4 x float> %a, <4 x float> undef, <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
+ %0 = call <4 x float> @llvm.experimental.constrained.fadd.v4f32(<4 x float> %shift, <4 x float> %a, metadata !"round.tonearest", metadata !"fpexcept.strict") #0
+ %1 = extractelement <4 x float> %0, i32 0
+ ret float %1
+}
+
+define float @faddp_2xfloat_commute_strict(<2 x float> %a) #0 {
+; CHECK-LABEL: faddp_2xfloat_commute_strict:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NEXT: faddp s0, v0.2s
+; CHECK-NEXT: ret
+entry:
+ %shift = shufflevector <2 x float> %a, <2 x float> undef, <2 x i32> <i32 1, i32 undef>
+ %0 = call <2 x float> @llvm.experimental.constrained.fadd.v2f32(<2 x float> %shift, <2 x float> %a, metadata !"round.tonearest", metadata !"fpexcept.strict") #0
+ %1 = extractelement <2 x float> %0, i32 0
+ ret float %1
+}
+
+define double @faddp_2xdouble_strict(<2 x double> %a) #0 {
+; CHECK-LABEL: faddp_2xdouble_strict:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: faddp d0, v0.2d
+; CHECK-NEXT: ret
+entry:
+ %shift = shufflevector <2 x double> %a, <2 x double> undef, <2 x i32> <i32 1, i32 undef>
+ %0 = call <2 x double> @llvm.experimental.constrained.fadd.v2f64(<2 x double> %a, <2 x double> %shift, metadata !"round.tonearest", metadata !"fpexcept.strict") #0
+ %1 = extractelement <2 x double> %0, i32 0
+ ret double %1
+}
+
+define double @faddp_2xdouble_commute_strict(<2 x double> %a) #0 {
+; CHECK-LABEL: faddp_2xdouble_commute_strict:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: faddp d0, v0.2d
+; CHECK-NEXT: ret
+entry:
+ %shift = shufflevector <2 x double> %a, <2 x double> undef, <2 x i32> <i32 1, i32 undef>
+ %0 = call <2 x double> @llvm.experimental.constrained.fadd.v2f64(<2 x double> %shift, <2 x double> %a, metadata !"round.tonearest", metadata !"fpexcept.strict") #0
+ %1 = extractelement <2 x double> %0, i32 0
+ ret double %1
+}
+
+attributes #0 = { strictfp }
+
+declare <2 x float> @llvm.experimental.constrained.fadd.v2f32(<2 x float>, <2 x float>, metadata, metadata)
+declare <4 x float> @llvm.experimental.constrained.fadd.v4f32(<4 x float>, <4 x float>, metadata, metadata)
+declare <2 x double> @llvm.experimental.constrained.fadd.v2f64(<2 x double>, <2 x double>, metadata, metadata)
More information about the llvm-commits
mailing list