[llvm] [AArch64] Spare N2I roundtrip when splatting float comparison (PR #141806)
via llvm-commits
llvm-commits at lists.llvm.org
Wed May 28 09:58:37 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: Guy David (guy-david)
<details>
<summary>Changes</summary>
Transform `select_cc t1, t2, -1, 0` for floats into a vector comparison which generates a mask, which is later on combined with potential vectorized DUPs.
For GlobalISel, it seems that an equivalent for `SELECT_CC` does not exist yet?
---
Full diff: https://github.com/llvm/llvm-project/pull/141806.diff
4 Files Affected:
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+89-44)
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+2-2)
- (modified) llvm/test/CodeGen/AArch64/arm64-neon-v1i1-setcc.ll (+2-3)
- (added) llvm/test/CodeGen/AArch64/build-vector-dup-simd.ll (+32)
``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a817ed5f0e917..da5117292e269 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -10906,9 +10906,48 @@ SDValue AArch64TargetLowering::LowerSETCCCARRY(SDValue Op,
Cmp.getValue(1));
}
+/// Emit vector comparison for floating-point values, producing a mask.
+static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
+ AArch64CC::CondCode CC, bool NoNans, EVT VT,
+ const SDLoc &dl, SelectionDAG &DAG) {
+ EVT SrcVT = LHS.getValueType();
+ assert(VT.getSizeInBits() == SrcVT.getSizeInBits() &&
+ "function only supposed to emit natural comparisons");
+
+ switch (CC) {
+ default:
+ return SDValue();
+ case AArch64CC::NE: {
+ SDValue Fcmeq = DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
+ return DAG.getNOT(dl, Fcmeq, VT);
+ }
+ case AArch64CC::EQ:
+ return DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
+ case AArch64CC::GE:
+ return DAG.getNode(AArch64ISD::FCMGE, dl, VT, LHS, RHS);
+ case AArch64CC::GT:
+ return DAG.getNode(AArch64ISD::FCMGT, dl, VT, LHS, RHS);
+ case AArch64CC::LE:
+ if (!NoNans)
+ return SDValue();
+ // If we ignore NaNs then we can use to the LS implementation.
+ [[fallthrough]];
+ case AArch64CC::LS:
+ return DAG.getNode(AArch64ISD::FCMGE, dl, VT, RHS, LHS);
+ case AArch64CC::LT:
+ if (!NoNans)
+ return SDValue();
+ // If we ignore NaNs then we can use to the MI implementation.
+ [[fallthrough]];
+ case AArch64CC::MI:
+ return DAG.getNode(AArch64ISD::FCMGT, dl, VT, RHS, LHS);
+ }
+}
+
SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
SDValue RHS, SDValue TVal,
- SDValue FVal, const SDLoc &dl,
+ SDValue FVal, bool HasNoNaNs,
+ const SDLoc &dl,
SelectionDAG &DAG) const {
// Handle f128 first, because it will result in a comparison of some RTLIB
// call result against zero.
@@ -11092,6 +11131,29 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
LHS.getValueType() == MVT::f64);
assert(LHS.getValueType() == RHS.getValueType());
EVT VT = TVal.getValueType();
+
+ // If the purpose of the comparison is to select between all ones
+ // or all zeros, use a vector comparison because the operands are already
+ // stored in SIMD registers.
+ auto *CTVal = dyn_cast<ConstantSDNode>(TVal);
+ auto *CFVal = dyn_cast<ConstantSDNode>(FVal);
+ if (Subtarget->isNeonAvailable() &&
+ (VT.getSizeInBits() == LHS.getValueType().getSizeInBits()) && CTVal &&
+ CFVal &&
+ ((CTVal->isAllOnes() && CFVal->isZero()) ||
+ ((CTVal->isZero()) && CFVal->isAllOnes()))) {
+ AArch64CC::CondCode CC1;
+ AArch64CC::CondCode CC2;
+ bool ShouldInvert = false;
+ changeVectorFPCCToAArch64CC(CC, CC1, CC2, ShouldInvert);
+ if (CTVal->isZero() ^ ShouldInvert)
+ std::swap(TVal, FVal);
+ bool NoNaNs = getTargetMachine().Options.NoNaNsFPMath || HasNoNaNs;
+ SDValue Res = EmitVectorComparison(LHS, RHS, CC1, NoNaNs, VT, dl, DAG);
+ if (Res)
+ return Res;
+ }
+
SDValue Cmp = emitComparison(LHS, RHS, CC, dl, DAG);
// Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't totally
@@ -11178,8 +11240,9 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(SDValue Op,
SDValue RHS = Op.getOperand(1);
SDValue TVal = Op.getOperand(2);
SDValue FVal = Op.getOperand(3);
+ bool HasNoNans = Op->getFlags().hasNoNaNs();
SDLoc DL(Op);
- return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
+ return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, HasNoNans, DL, DAG);
}
SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
@@ -11187,6 +11250,7 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
SDValue CCVal = Op->getOperand(0);
SDValue TVal = Op->getOperand(1);
SDValue FVal = Op->getOperand(2);
+ bool HasNoNans = Op->getFlags().hasNoNaNs();
SDLoc DL(Op);
EVT Ty = Op.getValueType();
@@ -11253,7 +11317,7 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
DAG.getUNDEF(MVT::f32), FVal);
}
- SDValue Res = LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
+ SDValue Res = LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, HasNoNans, DL, DAG);
if ((Ty == MVT::f16 || Ty == MVT::bf16) && !Subtarget->hasFullFP16()) {
return DAG.getTargetExtractSubreg(AArch64::hsub, DL, Ty, Res);
@@ -15506,47 +15570,6 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
llvm_unreachable("unexpected shift opcode");
}
-static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
- AArch64CC::CondCode CC, bool NoNans, EVT VT,
- const SDLoc &dl, SelectionDAG &DAG) {
- EVT SrcVT = LHS.getValueType();
- assert(VT.getSizeInBits() == SrcVT.getSizeInBits() &&
- "function only supposed to emit natural comparisons");
-
- if (SrcVT.getVectorElementType().isFloatingPoint()) {
- switch (CC) {
- default:
- return SDValue();
- case AArch64CC::NE: {
- SDValue Fcmeq = DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
- return DAG.getNOT(dl, Fcmeq, VT);
- }
- case AArch64CC::EQ:
- return DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
- case AArch64CC::GE:
- return DAG.getNode(AArch64ISD::FCMGE, dl, VT, LHS, RHS);
- case AArch64CC::GT:
- return DAG.getNode(AArch64ISD::FCMGT, dl, VT, LHS, RHS);
- case AArch64CC::LE:
- if (!NoNans)
- return SDValue();
- // If we ignore NaNs then we can use to the LS implementation.
- [[fallthrough]];
- case AArch64CC::LS:
- return DAG.getNode(AArch64ISD::FCMGE, dl, VT, RHS, LHS);
- case AArch64CC::LT:
- if (!NoNans)
- return SDValue();
- // If we ignore NaNs then we can use to the MI implementation.
- [[fallthrough]];
- case AArch64CC::MI:
- return DAG.getNode(AArch64ISD::FCMGT, dl, VT, RHS, LHS);
- }
- }
-
- return SDValue();
-}
-
SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
SelectionDAG &DAG) const {
if (Op.getValueType().isScalableVector())
@@ -25365,6 +25388,28 @@ static SDValue performDUPCombine(SDNode *N,
}
if (N->getOpcode() == AArch64ISD::DUP) {
+ // If the instruction is known to produce a scalar in SIMD registers, we can
+ // can duplicate it across the vector lanes using DUPLANE instead of moving
+ // it to a GPR first. For example, this allows us to handle:
+ // v4i32 = DUP (i32 (FCMGT (f32, f32)))
+ SDValue Op = N->getOperand(0);
+ // FIXME: Ideally, we should be able to handle all instructions that
+ // produce a scalar value in FPRs.
+ if (Op.getOpcode() == AArch64ISD::FCMEQ ||
+ Op.getOpcode() == AArch64ISD::FCMGE ||
+ Op.getOpcode() == AArch64ISD::FCMGT) {
+ EVT ElemVT = VT.getVectorElementType();
+ EVT ExpandedVT = VT;
+ // Insert into a 128-bit vector to match DUPLANE's pattern.
+ if (VT.getSizeInBits() != 128)
+ ExpandedVT = EVT::getVectorVT(*DCI.DAG.getContext(), ElemVT,
+ 128 / ElemVT.getSizeInBits());
+ SDValue Zero = DCI.DAG.getConstant(0, DL, MVT::i64);
+ SDValue Vec = DCI.DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ExpandedVT,
+ DCI.DAG.getUNDEF(ExpandedVT), Op, Zero);
+ return DCI.DAG.getNode(getDUPLANEOp(ElemVT), DL, VT, Vec, Zero);
+ }
+
if (DCI.isAfterLegalizeDAG()) {
// If scalar dup's operand is extract_vector_elt, try to combine them into
// duplane. For example,
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 1924d20f67f49..e2e2150133e82 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -645,8 +645,8 @@ class AArch64TargetLowering : public TargetLowering {
SDValue LowerSELECT(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerSELECT_CC(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerSELECT_CC(ISD::CondCode CC, SDValue LHS, SDValue RHS,
- SDValue TVal, SDValue FVal, const SDLoc &dl,
- SelectionDAG &DAG) const;
+ SDValue TVal, SDValue FVal, bool HasNoNans,
+ const SDLoc &dl, SelectionDAG &DAG) const;
SDValue LowerINIT_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerJumpTable(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/test/CodeGen/AArch64/arm64-neon-v1i1-setcc.ll b/llvm/test/CodeGen/AArch64/arm64-neon-v1i1-setcc.ll
index 6c70d19a977a5..05178c1dc291c 100644
--- a/llvm/test/CodeGen/AArch64/arm64-neon-v1i1-setcc.ll
+++ b/llvm/test/CodeGen/AArch64/arm64-neon-v1i1-setcc.ll
@@ -174,9 +174,8 @@ define <1 x i16> @test_select_f16_i16(half %i105, half %in, <1 x i16> %x, <1 x i
; CHECK-LABEL: test_select_f16_i16:
; CHECK: // %bb.0:
; CHECK-NEXT: fcvt s0, h0
-; CHECK-NEXT: fcmp s0, s0
-; CHECK-NEXT: csetm w8, vs
-; CHECK-NEXT: dup v0.4h, w8
+; CHECK-NEXT: fcmgt s0, s0, s0
+; CHECK-NEXT: dup v0.4h, v0.h[0]
; CHECK-NEXT: bsl v0.8b, v2.8b, v3.8b
; CHECK-NEXT: ret
%i179 = fcmp uno half %i105, zeroinitializer
diff --git a/llvm/test/CodeGen/AArch64/build-vector-dup-simd.ll b/llvm/test/CodeGen/AArch64/build-vector-dup-simd.ll
new file mode 100644
index 0000000000000..c52b8817ab6f8
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/build-vector-dup-simd.ll
@@ -0,0 +1,32 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=aarch64 | FileCheck %s
+
+define <4 x float> @dup32(float %a, float %b) {
+; CHECK-LABEL: dup32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: fcmgt s0, s0, s1
+; CHECK-NEXT: dup v0.4s, v0.s[0]
+; CHECK-NEXT: ret
+entry:
+ %0 = fcmp ogt float %a, %b
+ %vcmpd.i = sext i1 %0 to i32
+ %vecinit.i = insertelement <4 x i32> poison, i32 %vcmpd.i, i64 0
+ %1 = bitcast <4 x i32> %vecinit.i to <4 x float>
+ %2 = shufflevector <4 x float> %1, <4 x float> poison, <4 x i32> zeroinitializer
+ ret <4 x float> %2
+}
+
+define <2 x double> @dup64(double %a, double %b) {
+; CHECK-LABEL: dup64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: fcmgt d0, d0, d1
+; CHECK-NEXT: dup v0.2d, v0.d[0]
+; CHECK-NEXT: ret
+entry:
+ %0 = fcmp ogt double %a, %b
+ %vcmpd.i = sext i1 %0 to i64
+ %vecinit.i = insertelement <2 x i64> poison, i64 %vcmpd.i, i64 0
+ %1 = bitcast <2 x i64> %vecinit.i to <2 x double>
+ %2 = shufflevector <2 x double> %1, <2 x double> poison, <2 x i32> zeroinitializer
+ ret <2 x double> %2
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/141806
More information about the llvm-commits
mailing list