[llvm] [LLVM][CodeGen][SVE] Improve lowering of fixed length masked mem ops. (PR #134402)

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 4 08:52:04 PDT 2025


https://github.com/paulwalker-arm created https://github.com/llvm/llvm-project/pull/134402

Converting fixed length masks, as used by MLOAD, to scalable vectors is done by comparing the mask to zero. When the mask is the result of a compare we can instead promote the operands and regenerate the original compare. At worst this reduces the dependecy chain and in most cases removes the need for multiple compares.

>From ff59fd2d48d6367a62252bf65ac8bf82fb0a13e9 Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Fri, 21 Mar 2025 18:45:06 +0000
Subject: [PATCH] [LLVM][CodeGen][SVE] Improve lowering of fixed length masked
 mem ops.

Converting fixed length masks, as used by MLOAD, to scalable vectors
is done by comparing the mask to zero. When the mask is the result
of a compare we can instead promote the operands and regenerate the
original compare. At worst this reduces the dependecy chain and in
most cases removes the need for multiple compares.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 35 ++++++++++++++++---
 .../AArch64/sve-fixed-length-masked-gather.ll |  5 ++-
 .../AArch64/sve-fixed-length-masked-loads.ll  | 30 +++++++---------
 .../sve-fixed-length-masked-scatter.ll        |  5 ++-
 4 files changed, 46 insertions(+), 29 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a1ba3922996a1..57a950cfc702a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -20190,6 +20190,12 @@ performInsertSubvectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
   EVT VecVT = Vec.getValueType();
   EVT SubVT = SubVec.getValueType();
 
+  // Promote fixed length vector zeros.
+  if (VecVT.isScalableVector() && SubVT.isFixedLengthVector() &&
+      Vec.isUndef() && isZerosVector(SubVec.getNode()))
+    return VecVT.isInteger() ? DAG.getConstant(0, DL, VecVT)
+                             : DAG.getConstantFP(0, DL, VecVT);
+
   // Only do this for legal fixed vector types.
   if (!VecVT.isFixedLengthVector() ||
       !DAG.getTargetLoweringInfo().isTypeLegal(VecVT) ||
@@ -28697,17 +28703,36 @@ static SDValue convertFixedMaskToScalableVector(SDValue Mask,
   SDLoc DL(Mask);
   EVT InVT = Mask.getValueType();
   EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
-
-  auto Pg = getPredicateForFixedLengthVector(DAG, DL, InVT);
+  SDValue Pg = getPredicateForFixedLengthVector(DAG, DL, InVT);
 
   if (ISD::isBuildVectorAllOnes(Mask.getNode()))
     return Pg;
 
-  auto Op1 = convertToScalableVector(DAG, ContainerVT, Mask);
-  auto Op2 = DAG.getConstant(0, DL, ContainerVT);
+  bool InvertCond = false;
+  if (isBitwiseNot(Mask)) {
+    InvertCond = true;
+    Mask = Mask.getOperand(0);
+  }
+
+  SDValue Op1, Op2;
+  ISD::CondCode CC;
+
+  // When Mask is the result of a SETCC, it's better to regenerate the compare.
+  if (Mask.getOpcode() == ISD::SETCC) {
+    Op1 = convertToScalableVector(DAG, ContainerVT, Mask.getOperand(0));
+    Op2 = convertToScalableVector(DAG, ContainerVT, Mask.getOperand(1));
+    CC = cast<CondCodeSDNode>(Mask.getOperand(2))->get();
+  } else {
+    Op1 = convertToScalableVector(DAG, ContainerVT, Mask);
+    Op2 = DAG.getConstant(0, DL, ContainerVT);
+    CC = ISD::SETNE;
+  }
+
+  if (InvertCond)
+    CC = getSetCCInverse(CC, Op1.getValueType());
 
   return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, DL, Pg.getValueType(),
-                     {Pg, Op1, Op2, DAG.getCondCode(ISD::SETNE)});
+                     {Pg, Op1, Op2, DAG.getCondCode(CC)});
 }
 
 // Convert all fixed length vector loads larger than NEON to masked_loads.
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll
index a50d0dc37eaf6..093e6cd9328c8 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll
@@ -460,10 +460,9 @@ define void @masked_gather_v1i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
 define void @masked_gather_v2i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
 ; CHECK-LABEL: masked_gather_v2i64:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ldr q0, [x0]
 ; CHECK-NEXT:    ptrue p0.d, vl2
-; CHECK-NEXT:    cmeq v0.2d, v0.2d, #0
-; CHECK-NEXT:    cmpne p0.d, p0/z, z0.d, #0
+; CHECK-NEXT:    ldr q0, [x0]
+; CHECK-NEXT:    cmpeq p0.d, p0/z, z0.d, #0
 ; CHECK-NEXT:    ldr q0, [x1]
 ; CHECK-NEXT:    ld1d { z0.d }, p0/z, [z0.d]
 ; CHECK-NEXT:    str q0, [x0]
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-loads.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-loads.ll
index 6513b01d00922..34dc0bb5ef2d2 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-loads.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-loads.ll
@@ -401,11 +401,10 @@ define void @masked_load_sext_v32i8i16(ptr %ap, ptr %bp, ptr %c) #0 {
 define void @masked_load_sext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
 ; VBITS_GE_256-LABEL: masked_load_sext_v16i8i32:
 ; VBITS_GE_256:       // %bb.0:
-; VBITS_GE_256-NEXT:    ldr q0, [x1]
 ; VBITS_GE_256-NEXT:    ptrue p0.b, vl16
+; VBITS_GE_256-NEXT:    ldr q0, [x1]
 ; VBITS_GE_256-NEXT:    mov x8, #8 // =0x8
-; VBITS_GE_256-NEXT:    cmeq v0.16b, v0.16b, #0
-; VBITS_GE_256-NEXT:    cmpne p0.b, p0/z, z0.b, #0
+; VBITS_GE_256-NEXT:    cmpeq p0.b, p0/z, z0.b, #0
 ; VBITS_GE_256-NEXT:    ld1b { z0.b }, p0/z, [x0]
 ; VBITS_GE_256-NEXT:    ptrue p0.s, vl8
 ; VBITS_GE_256-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
@@ -436,11 +435,10 @@ define void @masked_load_sext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
 define void @masked_load_sext_v8i8i64(ptr %ap, ptr %bp, ptr %c) #0 {
 ; VBITS_GE_256-LABEL: masked_load_sext_v8i8i64:
 ; VBITS_GE_256:       // %bb.0:
-; VBITS_GE_256-NEXT:    ldr d0, [x1]
 ; VBITS_GE_256-NEXT:    ptrue p0.b, vl8
+; VBITS_GE_256-NEXT:    ldr d0, [x1]
 ; VBITS_GE_256-NEXT:    mov x8, #4 // =0x4
-; VBITS_GE_256-NEXT:    cmeq v0.8b, v0.8b, #0
-; VBITS_GE_256-NEXT:    cmpne p0.b, p0/z, z0.b, #0
+; VBITS_GE_256-NEXT:    cmpeq p0.b, p0/z, z0.b, #0
 ; VBITS_GE_256-NEXT:    ld1b { z0.b }, p0/z, [x0]
 ; VBITS_GE_256-NEXT:    ptrue p0.d, vl4
 ; VBITS_GE_256-NEXT:    sshll v0.8h, v0.8b, #0
@@ -504,11 +502,10 @@ define void @masked_load_sext_v16i16i32(ptr %ap, ptr %bp, ptr %c) #0 {
 define void @masked_load_sext_v8i16i64(ptr %ap, ptr %bp, ptr %c) #0 {
 ; VBITS_GE_256-LABEL: masked_load_sext_v8i16i64:
 ; VBITS_GE_256:       // %bb.0:
-; VBITS_GE_256-NEXT:    ldr q0, [x1]
 ; VBITS_GE_256-NEXT:    ptrue p0.h, vl8
+; VBITS_GE_256-NEXT:    ldr q0, [x1]
 ; VBITS_GE_256-NEXT:    mov x8, #4 // =0x4
-; VBITS_GE_256-NEXT:    cmeq v0.8h, v0.8h, #0
-; VBITS_GE_256-NEXT:    cmpne p0.h, p0/z, z0.h, #0
+; VBITS_GE_256-NEXT:    cmpeq p0.h, p0/z, z0.h, #0
 ; VBITS_GE_256-NEXT:    ld1h { z0.h }, p0/z, [x0]
 ; VBITS_GE_256-NEXT:    ptrue p0.d, vl4
 ; VBITS_GE_256-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
@@ -603,11 +600,10 @@ define void @masked_load_zext_v32i8i16(ptr %ap, ptr %bp, ptr %c) #0 {
 define void @masked_load_zext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
 ; VBITS_GE_256-LABEL: masked_load_zext_v16i8i32:
 ; VBITS_GE_256:       // %bb.0:
-; VBITS_GE_256-NEXT:    ldr q0, [x1]
 ; VBITS_GE_256-NEXT:    ptrue p0.b, vl16
+; VBITS_GE_256-NEXT:    ldr q0, [x1]
 ; VBITS_GE_256-NEXT:    mov x8, #8 // =0x8
-; VBITS_GE_256-NEXT:    cmeq v0.16b, v0.16b, #0
-; VBITS_GE_256-NEXT:    cmpne p0.b, p0/z, z0.b, #0
+; VBITS_GE_256-NEXT:    cmpeq p0.b, p0/z, z0.b, #0
 ; VBITS_GE_256-NEXT:    ld1b { z0.b }, p0/z, [x0]
 ; VBITS_GE_256-NEXT:    ptrue p0.s, vl8
 ; VBITS_GE_256-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
@@ -638,11 +634,10 @@ define void @masked_load_zext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
 define void @masked_load_zext_v8i8i64(ptr %ap, ptr %bp, ptr %c) #0 {
 ; VBITS_GE_256-LABEL: masked_load_zext_v8i8i64:
 ; VBITS_GE_256:       // %bb.0:
-; VBITS_GE_256-NEXT:    ldr d0, [x1]
 ; VBITS_GE_256-NEXT:    ptrue p0.b, vl8
+; VBITS_GE_256-NEXT:    ldr d0, [x1]
 ; VBITS_GE_256-NEXT:    mov x8, #4 // =0x4
-; VBITS_GE_256-NEXT:    cmeq v0.8b, v0.8b, #0
-; VBITS_GE_256-NEXT:    cmpne p0.b, p0/z, z0.b, #0
+; VBITS_GE_256-NEXT:    cmpeq p0.b, p0/z, z0.b, #0
 ; VBITS_GE_256-NEXT:    ld1b { z0.b }, p0/z, [x0]
 ; VBITS_GE_256-NEXT:    ptrue p0.d, vl4
 ; VBITS_GE_256-NEXT:    ushll v0.8h, v0.8b, #0
@@ -706,11 +701,10 @@ define void @masked_load_zext_v16i16i32(ptr %ap, ptr %bp, ptr %c) #0 {
 define void @masked_load_zext_v8i16i64(ptr %ap, ptr %bp, ptr %c) #0 {
 ; VBITS_GE_256-LABEL: masked_load_zext_v8i16i64:
 ; VBITS_GE_256:       // %bb.0:
-; VBITS_GE_256-NEXT:    ldr q0, [x1]
 ; VBITS_GE_256-NEXT:    ptrue p0.h, vl8
+; VBITS_GE_256-NEXT:    ldr q0, [x1]
 ; VBITS_GE_256-NEXT:    mov x8, #4 // =0x4
-; VBITS_GE_256-NEXT:    cmeq v0.8h, v0.8h, #0
-; VBITS_GE_256-NEXT:    cmpne p0.h, p0/z, z0.h, #0
+; VBITS_GE_256-NEXT:    cmpeq p0.h, p0/z, z0.h, #0
 ; VBITS_GE_256-NEXT:    ld1h { z0.h }, p0/z, [x0]
 ; VBITS_GE_256-NEXT:    ptrue p0.d, vl4
 ; VBITS_GE_256-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll
index a42fce70f4f15..ed03f9b322432 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll
@@ -433,11 +433,10 @@ define void @masked_scatter_v1i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
 define void @masked_scatter_v2i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
 ; CHECK-LABEL: masked_scatter_v2i64:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ldr q0, [x0]
 ; CHECK-NEXT:    ptrue p0.d, vl2
-; CHECK-NEXT:    cmeq v1.2d, v0.2d, #0
-; CHECK-NEXT:    cmpne p0.d, p0/z, z1.d, #0
+; CHECK-NEXT:    ldr q0, [x0]
 ; CHECK-NEXT:    ldr q1, [x1]
+; CHECK-NEXT:    cmpeq p0.d, p0/z, z0.d, #0
 ; CHECK-NEXT:    st1d { z0.d }, p0, [z1.d]
 ; CHECK-NEXT:    ret
   %vals = load <2 x i64>, ptr %a



More information about the llvm-commits mailing list