[llvm] [AArch64] Fold sext-in-reg for predicate -> fixed-length conversions. (PR #176883)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Jan 20 01:57:53 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: Sander de Smalen (sdesmalen-arm)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/176883.diff
3 Files Affected:
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+32-7)
- (modified) llvm/test/CodeGen/AArch64/alias_mask.ll (+35-59)
- (added) llvm/test/CodeGen/AArch64/fold-sext-in-reg-predicate-fixed-length.ll (+19)
``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c81f457898d7c..0a4216cbdda0d 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -24690,6 +24690,19 @@ static SDValue performGLD1Combine(SDNode *N, SelectionDAG &DAG) {
return SDValue();
}
+// Returns true when V is the following pattern:
+// v16i8 extract_subvector(
+// nxv16i8 sign_extend (nxv16i1 v))
+static bool isPredicateToFixedLengthVectorConversion(SDValue V) {
+ if (V.getOpcode() != ISD::EXTRACT_SUBVECTOR)
+ return false;
+
+ SDValue Src = V.getOperand(0);
+ return V.getValueType() == MVT::v16i8 && Src.getValueType() == MVT::nxv16i8 &&
+ Src.getOpcode() == ISD::SIGN_EXTEND &&
+ Src.getOperand(0).getValueType() == MVT::nxv16i1;
+}
+
/// Optimize a vector shift instruction and its operand if shifted out
/// bits are not used.
static SDValue performVectorShiftCombine(SDNode *N,
@@ -24707,9 +24720,12 @@ static SDValue performVectorShiftCombine(SDNode *N,
// Remove sign_extend_inreg (ashr(shl(x)) based on the number of sign bits.
if (N->getOpcode() == AArch64ISD::VASHR &&
Op.getOpcode() == AArch64ISD::VSHL &&
- N->getOperand(1) == Op.getOperand(1))
- if (DCI.DAG.ComputeNumSignBits(Op.getOperand(0)) > ShiftImm)
- return Op.getOperand(0);
+ N->getOperand(1) == Op.getOperand(1)) {
+ SDValue ShiftSrc = Op.getOperand(0);
+ if (isPredicateToFixedLengthVectorConversion(ShiftSrc) ||
+ DCI.DAG.ComputeNumSignBits(ShiftSrc) > ShiftImm)
+ return ShiftSrc;
+ }
// If the shift is exact, the shifted out bits matter.
if (N->getFlags().hasExact())
@@ -26951,9 +26967,13 @@ performSetccMergeZeroCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
}
// setcc_merge_zero(
- // pred, insert_subvector(undef, signext_inreg(vNi1), 0), != splat(0))
+ // pred, insert_subvector(undef, signext_inreg(x), 0), != splat(0))
+ // => setcc_merge_zero(
+ // pred, insert_subvector(undef, shl(x), 0), != splat(0))
+ // or:
// => setcc_merge_zero(
- // pred, insert_subvector(undef, shl(vNi1), 0), != splat(0))
+ // pred, insert_subvector(undef, x, 0), != splat(0))
+ // iff it can be proven that x is already sign-extended.
if (Cond == ISD::SETNE && isZerosVector(RHS.getNode()) &&
LHS->getOpcode() == ISD::INSERT_SUBVECTOR && LHS.hasOneUse()) {
SDValue L0 = LHS->getOperand(0);
@@ -26962,9 +26982,14 @@ performSetccMergeZeroCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
if (L0.isUndef() && isNullConstant(L2) && isSignExtInReg(L1)) {
SDLoc DL(N);
- SDValue Shl = L1.getOperand(0);
+ SDValue ExtVal = L1.getOperand(0);
+ unsigned NumShiftBits = ExtVal.getConstantOperandVal(1);
+ SDValue ShlSrc = ExtVal.getOperand(0);
+ if (isPredicateToFixedLengthVectorConversion(ShlSrc) ||
+ DCI.DAG.ComputeNumSignBits(ShlSrc) > NumShiftBits)
+ ExtVal = ShlSrc;
SDValue NewLHS = DAG.getNode(ISD::INSERT_SUBVECTOR, DL,
- LHS.getValueType(), L0, Shl, L2);
+ LHS.getValueType(), L0, ExtVal, L2);
return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, DL, N->getValueType(0),
Pred, NewLHS, RHS, N->getOperand(3));
}
diff --git a/llvm/test/CodeGen/AArch64/alias_mask.ll b/llvm/test/CodeGen/AArch64/alias_mask.ll
index 7a57fc3be84ac..42833aa19a7fd 100644
--- a/llvm/test/CodeGen/AArch64/alias_mask.ll
+++ b/llvm/test/CodeGen/AArch64/alias_mask.ll
@@ -110,10 +110,6 @@ define <32 x i1> @whilewr_8_split(ptr %a, ptr %b) {
; CHECK-NEXT: mov z0.b, p0/z, #-1 // =0xffffffffffffffff
; CHECK-NEXT: ldr q2, [x9, :lo12:.LCPI8_0]
; CHECK-NEXT: mov z1.b, p1/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT: shl v0.16b, v0.16b, #7
-; CHECK-NEXT: shl v1.16b, v1.16b, #7
-; CHECK-NEXT: cmlt v0.16b, v0.16b, #0
-; CHECK-NEXT: cmlt v1.16b, v1.16b, #0
; CHECK-NEXT: and v0.16b, v0.16b, v2.16b
; CHECK-NEXT: and v1.16b, v1.16b, v2.16b
; CHECK-NEXT: ext v2.16b, v0.16b, v0.16b, #8
@@ -135,48 +131,40 @@ define <64 x i1> @whilewr_8_split2(ptr %a, ptr %b) {
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sub x9, x1, x0
; CHECK-NEXT: mov w10, #48 // =0x30
-; CHECK-NEXT: mov w11, #32 // =0x20
+; CHECK-NEXT: mov w11, #16 // =0x10
; CHECK-NEXT: cmp x9, #1
+; CHECK-NEXT: mov w12, #32 // =0x20
; CHECK-NEXT: csinv x9, x9, xzr, ge
; CHECK-NEXT: whilewr p0.b, x0, x1
; CHECK-NEXT: whilelo p1.b, x10, x9
-; CHECK-NEXT: mov w10, #16 // =0x10
+; CHECK-NEXT: adrp x10, .LCPI9_0
; CHECK-NEXT: mov z0.b, p0/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT: whilelo p0.b, x11, x9
-; CHECK-NEXT: mov z1.b, p1/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT: whilelo p1.b, x10, x9
-; CHECK-NEXT: adrp x9, .LCPI9_0
-; CHECK-NEXT: mov z2.b, p0/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT: ldr q4, [x9, :lo12:.LCPI9_0]
-; CHECK-NEXT: mov z3.b, p1/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT: shl v0.16b, v0.16b, #7
-; CHECK-NEXT: shl v1.16b, v1.16b, #7
-; CHECK-NEXT: shl v2.16b, v2.16b, #7
-; CHECK-NEXT: shl v3.16b, v3.16b, #7
-; CHECK-NEXT: cmlt v0.16b, v0.16b, #0
-; CHECK-NEXT: cmlt v1.16b, v1.16b, #0
-; CHECK-NEXT: cmlt v2.16b, v2.16b, #0
-; CHECK-NEXT: cmlt v3.16b, v3.16b, #0
-; CHECK-NEXT: and v0.16b, v0.16b, v4.16b
-; CHECK-NEXT: and v1.16b, v1.16b, v4.16b
-; CHECK-NEXT: and v2.16b, v2.16b, v4.16b
-; CHECK-NEXT: and v3.16b, v3.16b, v4.16b
+; CHECK-NEXT: whilelo p0.b, x12, x9
+; CHECK-NEXT: ldr q1, [x10, :lo12:.LCPI9_0]
+; CHECK-NEXT: mov z2.b, p1/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT: whilelo p1.b, x11, x9
+; CHECK-NEXT: mov z3.b, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT: mov z4.b, p1/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT: and v0.16b, v0.16b, v1.16b
+; CHECK-NEXT: and v2.16b, v2.16b, v1.16b
+; CHECK-NEXT: and v3.16b, v3.16b, v1.16b
+; CHECK-NEXT: and v1.16b, v4.16b, v1.16b
; CHECK-NEXT: ext v4.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT: ext v5.16b, v1.16b, v1.16b, #8
-; CHECK-NEXT: ext v6.16b, v2.16b, v2.16b, #8
-; CHECK-NEXT: ext v7.16b, v3.16b, v3.16b, #8
+; CHECK-NEXT: ext v5.16b, v2.16b, v2.16b, #8
+; CHECK-NEXT: ext v6.16b, v3.16b, v3.16b, #8
+; CHECK-NEXT: ext v7.16b, v1.16b, v1.16b, #8
; CHECK-NEXT: zip1 v0.16b, v0.16b, v4.16b
-; CHECK-NEXT: zip1 v1.16b, v1.16b, v5.16b
-; CHECK-NEXT: zip1 v2.16b, v2.16b, v6.16b
-; CHECK-NEXT: zip1 v3.16b, v3.16b, v7.16b
+; CHECK-NEXT: zip1 v2.16b, v2.16b, v5.16b
+; CHECK-NEXT: zip1 v3.16b, v3.16b, v6.16b
+; CHECK-NEXT: zip1 v1.16b, v1.16b, v7.16b
; CHECK-NEXT: addv h0, v0.8h
-; CHECK-NEXT: addv h1, v1.8h
; CHECK-NEXT: addv h2, v2.8h
; CHECK-NEXT: addv h3, v3.8h
+; CHECK-NEXT: addv h1, v1.8h
; CHECK-NEXT: str h0, [x8]
-; CHECK-NEXT: str h1, [x8, #6]
-; CHECK-NEXT: str h2, [x8, #4]
-; CHECK-NEXT: str h3, [x8, #2]
+; CHECK-NEXT: str h2, [x8, #6]
+; CHECK-NEXT: str h3, [x8, #4]
+; CHECK-NEXT: str h1, [x8, #2]
; CHECK-NEXT: ret
entry:
%0 = call <64 x i1> @llvm.loop.dependence.war.mask.v64i1(ptr %a, ptr %b, i64 1)
@@ -213,14 +201,10 @@ define <32 x i1> @whilewr_16_expand2(ptr %a, ptr %b) {
; CHECK-NEXT: whilelo p1.b, xzr, x9
; CHECK-NEXT: adrp x9, .LCPI11_0
; CHECK-NEXT: mov z0.b, p0/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT: ldr q2, [x9, :lo12:.LCPI11_0]
-; CHECK-NEXT: mov z1.b, p1/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT: shl v0.16b, v0.16b, #7
-; CHECK-NEXT: shl v1.16b, v1.16b, #7
-; CHECK-NEXT: cmlt v0.16b, v0.16b, #0
-; CHECK-NEXT: cmlt v1.16b, v1.16b, #0
-; CHECK-NEXT: and v0.16b, v0.16b, v2.16b
-; CHECK-NEXT: and v1.16b, v1.16b, v2.16b
+; CHECK-NEXT: ldr q1, [x9, :lo12:.LCPI11_0]
+; CHECK-NEXT: mov z2.b, p1/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT: and v0.16b, v0.16b, v1.16b
+; CHECK-NEXT: and v1.16b, v2.16b, v1.16b
; CHECK-NEXT: ext v2.16b, v0.16b, v0.16b, #8
; CHECK-NEXT: ext v3.16b, v1.16b, v1.16b, #8
; CHECK-NEXT: zip1 v0.16b, v0.16b, v2.16b
@@ -285,14 +269,10 @@ define <32 x i1> @whilewr_32_expand3(ptr %a, ptr %b) {
; CHECK-NEXT: whilelo p1.b, xzr, x9
; CHECK-NEXT: adrp x9, .LCPI14_0
; CHECK-NEXT: mov z0.b, p0/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT: ldr q2, [x9, :lo12:.LCPI14_0]
-; CHECK-NEXT: mov z1.b, p1/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT: shl v0.16b, v0.16b, #7
-; CHECK-NEXT: shl v1.16b, v1.16b, #7
-; CHECK-NEXT: cmlt v0.16b, v0.16b, #0
-; CHECK-NEXT: cmlt v1.16b, v1.16b, #0
-; CHECK-NEXT: and v0.16b, v0.16b, v2.16b
-; CHECK-NEXT: and v1.16b, v1.16b, v2.16b
+; CHECK-NEXT: ldr q1, [x9, :lo12:.LCPI14_0]
+; CHECK-NEXT: mov z2.b, p1/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT: and v0.16b, v0.16b, v1.16b
+; CHECK-NEXT: and v1.16b, v2.16b, v1.16b
; CHECK-NEXT: ext v2.16b, v0.16b, v0.16b, #8
; CHECK-NEXT: ext v3.16b, v1.16b, v1.16b, #8
; CHECK-NEXT: zip1 v0.16b, v0.16b, v2.16b
@@ -375,14 +355,10 @@ define <32 x i1> @whilewr_64_expand4(ptr %a, ptr %b) {
; CHECK-NEXT: whilelo p1.b, xzr, x9
; CHECK-NEXT: adrp x9, .LCPI18_0
; CHECK-NEXT: mov z0.b, p0/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT: ldr q2, [x9, :lo12:.LCPI18_0]
-; CHECK-NEXT: mov z1.b, p1/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT: shl v0.16b, v0.16b, #7
-; CHECK-NEXT: shl v1.16b, v1.16b, #7
-; CHECK-NEXT: cmlt v0.16b, v0.16b, #0
-; CHECK-NEXT: cmlt v1.16b, v1.16b, #0
-; CHECK-NEXT: and v0.16b, v0.16b, v2.16b
-; CHECK-NEXT: and v1.16b, v1.16b, v2.16b
+; CHECK-NEXT: ldr q1, [x9, :lo12:.LCPI18_0]
+; CHECK-NEXT: mov z2.b, p1/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT: and v0.16b, v0.16b, v1.16b
+; CHECK-NEXT: and v1.16b, v2.16b, v1.16b
; CHECK-NEXT: ext v2.16b, v0.16b, v0.16b, #8
; CHECK-NEXT: ext v3.16b, v1.16b, v1.16b, #8
; CHECK-NEXT: zip1 v0.16b, v0.16b, v2.16b
diff --git a/llvm/test/CodeGen/AArch64/fold-sext-in-reg-predicate-fixed-length.ll b/llvm/test/CodeGen/AArch64/fold-sext-in-reg-predicate-fixed-length.ll
new file mode 100644
index 0000000000000..783cd6f693ed8
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/fold-sext-in-reg-predicate-fixed-length.ll
@@ -0,0 +1,19 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc -mattr=+sve < %s | FileCheck %s
+
+target triple = "aarch64"
+
+define <16 x i8> @active_lane_mask_mload(ptr %p, i64 %n) {
+; CHECK-LABEL: active_lane_mask_mload:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.b, vl16
+; CHECK-NEXT: whilelo p1.b, xzr, x1
+; CHECK-NEXT: and p0.b, p1/z, p1.b, p0.b
+; CHECK-NEXT: ld1b { z0.b }, p0/z, [x0]
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %mask = call <16 x i1> @llvm.get.active.lane.mask.v16i1(i64 0, i64 %n)
+ %data = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr %p, i32 1, <16 x i1> %mask, <16 x i8> zeroinitializer)
+ ret <16 x i8> %data
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/176883
More information about the llvm-commits
mailing list