[llvm] [LLVM][CodeGen][SVE] Improve lowering of fixed length masked mem ops. (PR #134402)

via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 4 08:52:37 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Paul Walker (paulwalker-arm)

<details>
<summary>Changes</summary>

Converting fixed length masks, as used by MLOAD, to scalable vectors is done by comparing the mask to zero. When the mask is the result of a compare we can instead promote the operands and regenerate the original compare. At worst this reduces the dependecy chain and in most cases removes the need for multiple compares.

---
Full diff: https://github.com/llvm/llvm-project/pull/134402.diff


4 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+30-5) 
- (modified) llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll (+2-3) 
- (modified) llvm/test/CodeGen/AArch64/sve-fixed-length-masked-loads.ll (+12-18) 
- (modified) llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll (+2-3) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a1ba3922996a1..57a950cfc702a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -20190,6 +20190,12 @@ performInsertSubvectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
   EVT VecVT = Vec.getValueType();
   EVT SubVT = SubVec.getValueType();
 
+  // Promote fixed length vector zeros.
+  if (VecVT.isScalableVector() && SubVT.isFixedLengthVector() &&
+      Vec.isUndef() && isZerosVector(SubVec.getNode()))
+    return VecVT.isInteger() ? DAG.getConstant(0, DL, VecVT)
+                             : DAG.getConstantFP(0, DL, VecVT);
+
   // Only do this for legal fixed vector types.
   if (!VecVT.isFixedLengthVector() ||
       !DAG.getTargetLoweringInfo().isTypeLegal(VecVT) ||
@@ -28697,17 +28703,36 @@ static SDValue convertFixedMaskToScalableVector(SDValue Mask,
   SDLoc DL(Mask);
   EVT InVT = Mask.getValueType();
   EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
-
-  auto Pg = getPredicateForFixedLengthVector(DAG, DL, InVT);
+  SDValue Pg = getPredicateForFixedLengthVector(DAG, DL, InVT);
 
   if (ISD::isBuildVectorAllOnes(Mask.getNode()))
     return Pg;
 
-  auto Op1 = convertToScalableVector(DAG, ContainerVT, Mask);
-  auto Op2 = DAG.getConstant(0, DL, ContainerVT);
+  bool InvertCond = false;
+  if (isBitwiseNot(Mask)) {
+    InvertCond = true;
+    Mask = Mask.getOperand(0);
+  }
+
+  SDValue Op1, Op2;
+  ISD::CondCode CC;
+
+  // When Mask is the result of a SETCC, it's better to regenerate the compare.
+  if (Mask.getOpcode() == ISD::SETCC) {
+    Op1 = convertToScalableVector(DAG, ContainerVT, Mask.getOperand(0));
+    Op2 = convertToScalableVector(DAG, ContainerVT, Mask.getOperand(1));
+    CC = cast<CondCodeSDNode>(Mask.getOperand(2))->get();
+  } else {
+    Op1 = convertToScalableVector(DAG, ContainerVT, Mask);
+    Op2 = DAG.getConstant(0, DL, ContainerVT);
+    CC = ISD::SETNE;
+  }
+
+  if (InvertCond)
+    CC = getSetCCInverse(CC, Op1.getValueType());
 
   return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, DL, Pg.getValueType(),
-                     {Pg, Op1, Op2, DAG.getCondCode(ISD::SETNE)});
+                     {Pg, Op1, Op2, DAG.getCondCode(CC)});
 }
 
 // Convert all fixed length vector loads larger than NEON to masked_loads.
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll
index a50d0dc37eaf6..093e6cd9328c8 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll
@@ -460,10 +460,9 @@ define void @masked_gather_v1i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
 define void @masked_gather_v2i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
 ; CHECK-LABEL: masked_gather_v2i64:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ldr q0, [x0]
 ; CHECK-NEXT:    ptrue p0.d, vl2
-; CHECK-NEXT:    cmeq v0.2d, v0.2d, #0
-; CHECK-NEXT:    cmpne p0.d, p0/z, z0.d, #0
+; CHECK-NEXT:    ldr q0, [x0]
+; CHECK-NEXT:    cmpeq p0.d, p0/z, z0.d, #0
 ; CHECK-NEXT:    ldr q0, [x1]
 ; CHECK-NEXT:    ld1d { z0.d }, p0/z, [z0.d]
 ; CHECK-NEXT:    str q0, [x0]
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-loads.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-loads.ll
index 6513b01d00922..34dc0bb5ef2d2 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-loads.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-loads.ll
@@ -401,11 +401,10 @@ define void @masked_load_sext_v32i8i16(ptr %ap, ptr %bp, ptr %c) #0 {
 define void @masked_load_sext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
 ; VBITS_GE_256-LABEL: masked_load_sext_v16i8i32:
 ; VBITS_GE_256:       // %bb.0:
-; VBITS_GE_256-NEXT:    ldr q0, [x1]
 ; VBITS_GE_256-NEXT:    ptrue p0.b, vl16
+; VBITS_GE_256-NEXT:    ldr q0, [x1]
 ; VBITS_GE_256-NEXT:    mov x8, #8 // =0x8
-; VBITS_GE_256-NEXT:    cmeq v0.16b, v0.16b, #0
-; VBITS_GE_256-NEXT:    cmpne p0.b, p0/z, z0.b, #0
+; VBITS_GE_256-NEXT:    cmpeq p0.b, p0/z, z0.b, #0
 ; VBITS_GE_256-NEXT:    ld1b { z0.b }, p0/z, [x0]
 ; VBITS_GE_256-NEXT:    ptrue p0.s, vl8
 ; VBITS_GE_256-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
@@ -436,11 +435,10 @@ define void @masked_load_sext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
 define void @masked_load_sext_v8i8i64(ptr %ap, ptr %bp, ptr %c) #0 {
 ; VBITS_GE_256-LABEL: masked_load_sext_v8i8i64:
 ; VBITS_GE_256:       // %bb.0:
-; VBITS_GE_256-NEXT:    ldr d0, [x1]
 ; VBITS_GE_256-NEXT:    ptrue p0.b, vl8
+; VBITS_GE_256-NEXT:    ldr d0, [x1]
 ; VBITS_GE_256-NEXT:    mov x8, #4 // =0x4
-; VBITS_GE_256-NEXT:    cmeq v0.8b, v0.8b, #0
-; VBITS_GE_256-NEXT:    cmpne p0.b, p0/z, z0.b, #0
+; VBITS_GE_256-NEXT:    cmpeq p0.b, p0/z, z0.b, #0
 ; VBITS_GE_256-NEXT:    ld1b { z0.b }, p0/z, [x0]
 ; VBITS_GE_256-NEXT:    ptrue p0.d, vl4
 ; VBITS_GE_256-NEXT:    sshll v0.8h, v0.8b, #0
@@ -504,11 +502,10 @@ define void @masked_load_sext_v16i16i32(ptr %ap, ptr %bp, ptr %c) #0 {
 define void @masked_load_sext_v8i16i64(ptr %ap, ptr %bp, ptr %c) #0 {
 ; VBITS_GE_256-LABEL: masked_load_sext_v8i16i64:
 ; VBITS_GE_256:       // %bb.0:
-; VBITS_GE_256-NEXT:    ldr q0, [x1]
 ; VBITS_GE_256-NEXT:    ptrue p0.h, vl8
+; VBITS_GE_256-NEXT:    ldr q0, [x1]
 ; VBITS_GE_256-NEXT:    mov x8, #4 // =0x4
-; VBITS_GE_256-NEXT:    cmeq v0.8h, v0.8h, #0
-; VBITS_GE_256-NEXT:    cmpne p0.h, p0/z, z0.h, #0
+; VBITS_GE_256-NEXT:    cmpeq p0.h, p0/z, z0.h, #0
 ; VBITS_GE_256-NEXT:    ld1h { z0.h }, p0/z, [x0]
 ; VBITS_GE_256-NEXT:    ptrue p0.d, vl4
 ; VBITS_GE_256-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
@@ -603,11 +600,10 @@ define void @masked_load_zext_v32i8i16(ptr %ap, ptr %bp, ptr %c) #0 {
 define void @masked_load_zext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
 ; VBITS_GE_256-LABEL: masked_load_zext_v16i8i32:
 ; VBITS_GE_256:       // %bb.0:
-; VBITS_GE_256-NEXT:    ldr q0, [x1]
 ; VBITS_GE_256-NEXT:    ptrue p0.b, vl16
+; VBITS_GE_256-NEXT:    ldr q0, [x1]
 ; VBITS_GE_256-NEXT:    mov x8, #8 // =0x8
-; VBITS_GE_256-NEXT:    cmeq v0.16b, v0.16b, #0
-; VBITS_GE_256-NEXT:    cmpne p0.b, p0/z, z0.b, #0
+; VBITS_GE_256-NEXT:    cmpeq p0.b, p0/z, z0.b, #0
 ; VBITS_GE_256-NEXT:    ld1b { z0.b }, p0/z, [x0]
 ; VBITS_GE_256-NEXT:    ptrue p0.s, vl8
 ; VBITS_GE_256-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
@@ -638,11 +634,10 @@ define void @masked_load_zext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
 define void @masked_load_zext_v8i8i64(ptr %ap, ptr %bp, ptr %c) #0 {
 ; VBITS_GE_256-LABEL: masked_load_zext_v8i8i64:
 ; VBITS_GE_256:       // %bb.0:
-; VBITS_GE_256-NEXT:    ldr d0, [x1]
 ; VBITS_GE_256-NEXT:    ptrue p0.b, vl8
+; VBITS_GE_256-NEXT:    ldr d0, [x1]
 ; VBITS_GE_256-NEXT:    mov x8, #4 // =0x4
-; VBITS_GE_256-NEXT:    cmeq v0.8b, v0.8b, #0
-; VBITS_GE_256-NEXT:    cmpne p0.b, p0/z, z0.b, #0
+; VBITS_GE_256-NEXT:    cmpeq p0.b, p0/z, z0.b, #0
 ; VBITS_GE_256-NEXT:    ld1b { z0.b }, p0/z, [x0]
 ; VBITS_GE_256-NEXT:    ptrue p0.d, vl4
 ; VBITS_GE_256-NEXT:    ushll v0.8h, v0.8b, #0
@@ -706,11 +701,10 @@ define void @masked_load_zext_v16i16i32(ptr %ap, ptr %bp, ptr %c) #0 {
 define void @masked_load_zext_v8i16i64(ptr %ap, ptr %bp, ptr %c) #0 {
 ; VBITS_GE_256-LABEL: masked_load_zext_v8i16i64:
 ; VBITS_GE_256:       // %bb.0:
-; VBITS_GE_256-NEXT:    ldr q0, [x1]
 ; VBITS_GE_256-NEXT:    ptrue p0.h, vl8
+; VBITS_GE_256-NEXT:    ldr q0, [x1]
 ; VBITS_GE_256-NEXT:    mov x8, #4 // =0x4
-; VBITS_GE_256-NEXT:    cmeq v0.8h, v0.8h, #0
-; VBITS_GE_256-NEXT:    cmpne p0.h, p0/z, z0.h, #0
+; VBITS_GE_256-NEXT:    cmpeq p0.h, p0/z, z0.h, #0
 ; VBITS_GE_256-NEXT:    ld1h { z0.h }, p0/z, [x0]
 ; VBITS_GE_256-NEXT:    ptrue p0.d, vl4
 ; VBITS_GE_256-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll
index a42fce70f4f15..ed03f9b322432 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll
@@ -433,11 +433,10 @@ define void @masked_scatter_v1i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
 define void @masked_scatter_v2i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
 ; CHECK-LABEL: masked_scatter_v2i64:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ldr q0, [x0]
 ; CHECK-NEXT:    ptrue p0.d, vl2
-; CHECK-NEXT:    cmeq v1.2d, v0.2d, #0
-; CHECK-NEXT:    cmpne p0.d, p0/z, z1.d, #0
+; CHECK-NEXT:    ldr q0, [x0]
 ; CHECK-NEXT:    ldr q1, [x1]
+; CHECK-NEXT:    cmpeq p0.d, p0/z, z0.d, #0
 ; CHECK-NEXT:    st1d { z0.d }, p0, [z1.d]
 ; CHECK-NEXT:    ret
   %vals = load <2 x i64>, ptr %a

``````````

</details>


https://github.com/llvm/llvm-project/pull/134402


More information about the llvm-commits mailing list