[llvm] [AArch64] Fold sext-in-reg for predicate -> fixed-length conversions. (PR #176883)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 20 01:57:53 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Sander de Smalen (sdesmalen-arm)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/176883.diff


3 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+32-7) 
- (modified) llvm/test/CodeGen/AArch64/alias_mask.ll (+35-59) 
- (added) llvm/test/CodeGen/AArch64/fold-sext-in-reg-predicate-fixed-length.ll (+19) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c81f457898d7c..0a4216cbdda0d 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -24690,6 +24690,19 @@ static SDValue performGLD1Combine(SDNode *N, SelectionDAG &DAG) {
   return SDValue();
 }
 
+// Returns true when V is the following pattern:
+//    v16i8 extract_subvector(
+//      nxv16i8 sign_extend (nxv16i1 v))
+static bool isPredicateToFixedLengthVectorConversion(SDValue V) {
+  if (V.getOpcode() != ISD::EXTRACT_SUBVECTOR)
+    return false;
+
+  SDValue Src = V.getOperand(0);
+  return V.getValueType() == MVT::v16i8 && Src.getValueType() == MVT::nxv16i8 &&
+         Src.getOpcode() == ISD::SIGN_EXTEND &&
+         Src.getOperand(0).getValueType() == MVT::nxv16i1;
+}
+
 /// Optimize a vector shift instruction and its operand if shifted out
 /// bits are not used.
 static SDValue performVectorShiftCombine(SDNode *N,
@@ -24707,9 +24720,12 @@ static SDValue performVectorShiftCombine(SDNode *N,
   // Remove sign_extend_inreg (ashr(shl(x)) based on the number of sign bits.
   if (N->getOpcode() == AArch64ISD::VASHR &&
       Op.getOpcode() == AArch64ISD::VSHL &&
-      N->getOperand(1) == Op.getOperand(1))
-    if (DCI.DAG.ComputeNumSignBits(Op.getOperand(0)) > ShiftImm)
-      return Op.getOperand(0);
+      N->getOperand(1) == Op.getOperand(1)) {
+    SDValue ShiftSrc = Op.getOperand(0);
+    if (isPredicateToFixedLengthVectorConversion(ShiftSrc) ||
+        DCI.DAG.ComputeNumSignBits(ShiftSrc) > ShiftImm)
+      return ShiftSrc;
+  }
 
   // If the shift is exact, the shifted out bits matter.
   if (N->getFlags().hasExact())
@@ -26951,9 +26967,13 @@ performSetccMergeZeroCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
   }
 
   //    setcc_merge_zero(
-  //       pred, insert_subvector(undef, signext_inreg(vNi1), 0), != splat(0))
+  //       pred, insert_subvector(undef, signext_inreg(x), 0), != splat(0))
+  // => setcc_merge_zero(
+  //       pred, insert_subvector(undef, shl(x), 0), != splat(0))
+  // or:
   // => setcc_merge_zero(
-  //       pred, insert_subvector(undef, shl(vNi1), 0), != splat(0))
+  //       pred, insert_subvector(undef, x, 0), != splat(0))
+  // iff it can be proven that x is already sign-extended.
   if (Cond == ISD::SETNE && isZerosVector(RHS.getNode()) &&
       LHS->getOpcode() == ISD::INSERT_SUBVECTOR && LHS.hasOneUse()) {
     SDValue L0 = LHS->getOperand(0);
@@ -26962,9 +26982,14 @@ performSetccMergeZeroCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
 
     if (L0.isUndef() && isNullConstant(L2) && isSignExtInReg(L1)) {
       SDLoc DL(N);
-      SDValue Shl = L1.getOperand(0);
+      SDValue ExtVal = L1.getOperand(0);
+      unsigned NumShiftBits = ExtVal.getConstantOperandVal(1);
+      SDValue ShlSrc = ExtVal.getOperand(0);
+      if (isPredicateToFixedLengthVectorConversion(ShlSrc) ||
+          DCI.DAG.ComputeNumSignBits(ShlSrc) > NumShiftBits)
+        ExtVal = ShlSrc;
       SDValue NewLHS = DAG.getNode(ISD::INSERT_SUBVECTOR, DL,
-                                   LHS.getValueType(), L0, Shl, L2);
+                                   LHS.getValueType(), L0, ExtVal, L2);
       return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, DL, N->getValueType(0),
                          Pred, NewLHS, RHS, N->getOperand(3));
     }
diff --git a/llvm/test/CodeGen/AArch64/alias_mask.ll b/llvm/test/CodeGen/AArch64/alias_mask.ll
index 7a57fc3be84ac..42833aa19a7fd 100644
--- a/llvm/test/CodeGen/AArch64/alias_mask.ll
+++ b/llvm/test/CodeGen/AArch64/alias_mask.ll
@@ -110,10 +110,6 @@ define <32 x i1> @whilewr_8_split(ptr %a, ptr %b) {
 ; CHECK-NEXT:    mov z0.b, p0/z, #-1 // =0xffffffffffffffff
 ; CHECK-NEXT:    ldr q2, [x9, :lo12:.LCPI8_0]
 ; CHECK-NEXT:    mov z1.b, p1/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT:    shl v0.16b, v0.16b, #7
-; CHECK-NEXT:    shl v1.16b, v1.16b, #7
-; CHECK-NEXT:    cmlt v0.16b, v0.16b, #0
-; CHECK-NEXT:    cmlt v1.16b, v1.16b, #0
 ; CHECK-NEXT:    and v0.16b, v0.16b, v2.16b
 ; CHECK-NEXT:    and v1.16b, v1.16b, v2.16b
 ; CHECK-NEXT:    ext v2.16b, v0.16b, v0.16b, #8
@@ -135,48 +131,40 @@ define <64 x i1> @whilewr_8_split2(ptr %a, ptr %b) {
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    sub x9, x1, x0
 ; CHECK-NEXT:    mov w10, #48 // =0x30
-; CHECK-NEXT:    mov w11, #32 // =0x20
+; CHECK-NEXT:    mov w11, #16 // =0x10
 ; CHECK-NEXT:    cmp x9, #1
+; CHECK-NEXT:    mov w12, #32 // =0x20
 ; CHECK-NEXT:    csinv x9, x9, xzr, ge
 ; CHECK-NEXT:    whilewr p0.b, x0, x1
 ; CHECK-NEXT:    whilelo p1.b, x10, x9
-; CHECK-NEXT:    mov w10, #16 // =0x10
+; CHECK-NEXT:    adrp x10, .LCPI9_0
 ; CHECK-NEXT:    mov z0.b, p0/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT:    whilelo p0.b, x11, x9
-; CHECK-NEXT:    mov z1.b, p1/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT:    whilelo p1.b, x10, x9
-; CHECK-NEXT:    adrp x9, .LCPI9_0
-; CHECK-NEXT:    mov z2.b, p0/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT:    ldr q4, [x9, :lo12:.LCPI9_0]
-; CHECK-NEXT:    mov z3.b, p1/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT:    shl v0.16b, v0.16b, #7
-; CHECK-NEXT:    shl v1.16b, v1.16b, #7
-; CHECK-NEXT:    shl v2.16b, v2.16b, #7
-; CHECK-NEXT:    shl v3.16b, v3.16b, #7
-; CHECK-NEXT:    cmlt v0.16b, v0.16b, #0
-; CHECK-NEXT:    cmlt v1.16b, v1.16b, #0
-; CHECK-NEXT:    cmlt v2.16b, v2.16b, #0
-; CHECK-NEXT:    cmlt v3.16b, v3.16b, #0
-; CHECK-NEXT:    and v0.16b, v0.16b, v4.16b
-; CHECK-NEXT:    and v1.16b, v1.16b, v4.16b
-; CHECK-NEXT:    and v2.16b, v2.16b, v4.16b
-; CHECK-NEXT:    and v3.16b, v3.16b, v4.16b
+; CHECK-NEXT:    whilelo p0.b, x12, x9
+; CHECK-NEXT:    ldr q1, [x10, :lo12:.LCPI9_0]
+; CHECK-NEXT:    mov z2.b, p1/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    whilelo p1.b, x11, x9
+; CHECK-NEXT:    mov z3.b, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    mov z4.b, p1/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    and v0.16b, v0.16b, v1.16b
+; CHECK-NEXT:    and v2.16b, v2.16b, v1.16b
+; CHECK-NEXT:    and v3.16b, v3.16b, v1.16b
+; CHECK-NEXT:    and v1.16b, v4.16b, v1.16b
 ; CHECK-NEXT:    ext v4.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT:    ext v5.16b, v1.16b, v1.16b, #8
-; CHECK-NEXT:    ext v6.16b, v2.16b, v2.16b, #8
-; CHECK-NEXT:    ext v7.16b, v3.16b, v3.16b, #8
+; CHECK-NEXT:    ext v5.16b, v2.16b, v2.16b, #8
+; CHECK-NEXT:    ext v6.16b, v3.16b, v3.16b, #8
+; CHECK-NEXT:    ext v7.16b, v1.16b, v1.16b, #8
 ; CHECK-NEXT:    zip1 v0.16b, v0.16b, v4.16b
-; CHECK-NEXT:    zip1 v1.16b, v1.16b, v5.16b
-; CHECK-NEXT:    zip1 v2.16b, v2.16b, v6.16b
-; CHECK-NEXT:    zip1 v3.16b, v3.16b, v7.16b
+; CHECK-NEXT:    zip1 v2.16b, v2.16b, v5.16b
+; CHECK-NEXT:    zip1 v3.16b, v3.16b, v6.16b
+; CHECK-NEXT:    zip1 v1.16b, v1.16b, v7.16b
 ; CHECK-NEXT:    addv h0, v0.8h
-; CHECK-NEXT:    addv h1, v1.8h
 ; CHECK-NEXT:    addv h2, v2.8h
 ; CHECK-NEXT:    addv h3, v3.8h
+; CHECK-NEXT:    addv h1, v1.8h
 ; CHECK-NEXT:    str h0, [x8]
-; CHECK-NEXT:    str h1, [x8, #6]
-; CHECK-NEXT:    str h2, [x8, #4]
-; CHECK-NEXT:    str h3, [x8, #2]
+; CHECK-NEXT:    str h2, [x8, #6]
+; CHECK-NEXT:    str h3, [x8, #4]
+; CHECK-NEXT:    str h1, [x8, #2]
 ; CHECK-NEXT:    ret
 entry:
   %0 = call <64 x i1> @llvm.loop.dependence.war.mask.v64i1(ptr %a, ptr %b, i64 1)
@@ -213,14 +201,10 @@ define <32 x i1> @whilewr_16_expand2(ptr %a, ptr %b) {
 ; CHECK-NEXT:    whilelo p1.b, xzr, x9
 ; CHECK-NEXT:    adrp x9, .LCPI11_0
 ; CHECK-NEXT:    mov z0.b, p0/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT:    ldr q2, [x9, :lo12:.LCPI11_0]
-; CHECK-NEXT:    mov z1.b, p1/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT:    shl v0.16b, v0.16b, #7
-; CHECK-NEXT:    shl v1.16b, v1.16b, #7
-; CHECK-NEXT:    cmlt v0.16b, v0.16b, #0
-; CHECK-NEXT:    cmlt v1.16b, v1.16b, #0
-; CHECK-NEXT:    and v0.16b, v0.16b, v2.16b
-; CHECK-NEXT:    and v1.16b, v1.16b, v2.16b
+; CHECK-NEXT:    ldr q1, [x9, :lo12:.LCPI11_0]
+; CHECK-NEXT:    mov z2.b, p1/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    and v0.16b, v0.16b, v1.16b
+; CHECK-NEXT:    and v1.16b, v2.16b, v1.16b
 ; CHECK-NEXT:    ext v2.16b, v0.16b, v0.16b, #8
 ; CHECK-NEXT:    ext v3.16b, v1.16b, v1.16b, #8
 ; CHECK-NEXT:    zip1 v0.16b, v0.16b, v2.16b
@@ -285,14 +269,10 @@ define <32 x i1> @whilewr_32_expand3(ptr %a, ptr %b) {
 ; CHECK-NEXT:    whilelo p1.b, xzr, x9
 ; CHECK-NEXT:    adrp x9, .LCPI14_0
 ; CHECK-NEXT:    mov z0.b, p0/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT:    ldr q2, [x9, :lo12:.LCPI14_0]
-; CHECK-NEXT:    mov z1.b, p1/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT:    shl v0.16b, v0.16b, #7
-; CHECK-NEXT:    shl v1.16b, v1.16b, #7
-; CHECK-NEXT:    cmlt v0.16b, v0.16b, #0
-; CHECK-NEXT:    cmlt v1.16b, v1.16b, #0
-; CHECK-NEXT:    and v0.16b, v0.16b, v2.16b
-; CHECK-NEXT:    and v1.16b, v1.16b, v2.16b
+; CHECK-NEXT:    ldr q1, [x9, :lo12:.LCPI14_0]
+; CHECK-NEXT:    mov z2.b, p1/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    and v0.16b, v0.16b, v1.16b
+; CHECK-NEXT:    and v1.16b, v2.16b, v1.16b
 ; CHECK-NEXT:    ext v2.16b, v0.16b, v0.16b, #8
 ; CHECK-NEXT:    ext v3.16b, v1.16b, v1.16b, #8
 ; CHECK-NEXT:    zip1 v0.16b, v0.16b, v2.16b
@@ -375,14 +355,10 @@ define <32 x i1> @whilewr_64_expand4(ptr %a, ptr %b) {
 ; CHECK-NEXT:    whilelo p1.b, xzr, x9
 ; CHECK-NEXT:    adrp x9, .LCPI18_0
 ; CHECK-NEXT:    mov z0.b, p0/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT:    ldr q2, [x9, :lo12:.LCPI18_0]
-; CHECK-NEXT:    mov z1.b, p1/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT:    shl v0.16b, v0.16b, #7
-; CHECK-NEXT:    shl v1.16b, v1.16b, #7
-; CHECK-NEXT:    cmlt v0.16b, v0.16b, #0
-; CHECK-NEXT:    cmlt v1.16b, v1.16b, #0
-; CHECK-NEXT:    and v0.16b, v0.16b, v2.16b
-; CHECK-NEXT:    and v1.16b, v1.16b, v2.16b
+; CHECK-NEXT:    ldr q1, [x9, :lo12:.LCPI18_0]
+; CHECK-NEXT:    mov z2.b, p1/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    and v0.16b, v0.16b, v1.16b
+; CHECK-NEXT:    and v1.16b, v2.16b, v1.16b
 ; CHECK-NEXT:    ext v2.16b, v0.16b, v0.16b, #8
 ; CHECK-NEXT:    ext v3.16b, v1.16b, v1.16b, #8
 ; CHECK-NEXT:    zip1 v0.16b, v0.16b, v2.16b
diff --git a/llvm/test/CodeGen/AArch64/fold-sext-in-reg-predicate-fixed-length.ll b/llvm/test/CodeGen/AArch64/fold-sext-in-reg-predicate-fixed-length.ll
new file mode 100644
index 0000000000000..783cd6f693ed8
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/fold-sext-in-reg-predicate-fixed-length.ll
@@ -0,0 +1,19 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc -mattr=+sve < %s | FileCheck %s
+
+target triple = "aarch64"
+
+define <16 x i8> @active_lane_mask_mload(ptr %p, i64 %n) {
+; CHECK-LABEL: active_lane_mask_mload:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.b, vl16
+; CHECK-NEXT:    whilelo p1.b, xzr, x1
+; CHECK-NEXT:    and p0.b, p1/z, p1.b, p0.b
+; CHECK-NEXT:    ld1b { z0.b }, p0/z, [x0]
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+entry:
+  %mask = call <16 x i1> @llvm.get.active.lane.mask.v16i1(i64 0, i64 %n)
+  %data = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr %p, i32 1, <16 x i1> %mask, <16 x i8> zeroinitializer)
+  ret <16 x i8> %data
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/176883


More information about the llvm-commits mailing list