[llvm] [LLVM][CodeGen][SVE] Improve lowering of fixed length masked mem ops. (PR #134402)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Apr 4 08:52:37 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: Paul Walker (paulwalker-arm)
<details>
<summary>Changes</summary>
Converting fixed length masks, as used by MLOAD, to scalable vectors is done by comparing the mask to zero. When the mask is the result of a compare we can instead promote the operands and regenerate the original compare. At worst this reduces the dependecy chain and in most cases removes the need for multiple compares.
---
Full diff: https://github.com/llvm/llvm-project/pull/134402.diff
4 Files Affected:
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+30-5)
- (modified) llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll (+2-3)
- (modified) llvm/test/CodeGen/AArch64/sve-fixed-length-masked-loads.ll (+12-18)
- (modified) llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll (+2-3)
``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a1ba3922996a1..57a950cfc702a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -20190,6 +20190,12 @@ performInsertSubvectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
EVT VecVT = Vec.getValueType();
EVT SubVT = SubVec.getValueType();
+ // Promote fixed length vector zeros.
+ if (VecVT.isScalableVector() && SubVT.isFixedLengthVector() &&
+ Vec.isUndef() && isZerosVector(SubVec.getNode()))
+ return VecVT.isInteger() ? DAG.getConstant(0, DL, VecVT)
+ : DAG.getConstantFP(0, DL, VecVT);
+
// Only do this for legal fixed vector types.
if (!VecVT.isFixedLengthVector() ||
!DAG.getTargetLoweringInfo().isTypeLegal(VecVT) ||
@@ -28697,17 +28703,36 @@ static SDValue convertFixedMaskToScalableVector(SDValue Mask,
SDLoc DL(Mask);
EVT InVT = Mask.getValueType();
EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
-
- auto Pg = getPredicateForFixedLengthVector(DAG, DL, InVT);
+ SDValue Pg = getPredicateForFixedLengthVector(DAG, DL, InVT);
if (ISD::isBuildVectorAllOnes(Mask.getNode()))
return Pg;
- auto Op1 = convertToScalableVector(DAG, ContainerVT, Mask);
- auto Op2 = DAG.getConstant(0, DL, ContainerVT);
+ bool InvertCond = false;
+ if (isBitwiseNot(Mask)) {
+ InvertCond = true;
+ Mask = Mask.getOperand(0);
+ }
+
+ SDValue Op1, Op2;
+ ISD::CondCode CC;
+
+ // When Mask is the result of a SETCC, it's better to regenerate the compare.
+ if (Mask.getOpcode() == ISD::SETCC) {
+ Op1 = convertToScalableVector(DAG, ContainerVT, Mask.getOperand(0));
+ Op2 = convertToScalableVector(DAG, ContainerVT, Mask.getOperand(1));
+ CC = cast<CondCodeSDNode>(Mask.getOperand(2))->get();
+ } else {
+ Op1 = convertToScalableVector(DAG, ContainerVT, Mask);
+ Op2 = DAG.getConstant(0, DL, ContainerVT);
+ CC = ISD::SETNE;
+ }
+
+ if (InvertCond)
+ CC = getSetCCInverse(CC, Op1.getValueType());
return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, DL, Pg.getValueType(),
- {Pg, Op1, Op2, DAG.getCondCode(ISD::SETNE)});
+ {Pg, Op1, Op2, DAG.getCondCode(CC)});
}
// Convert all fixed length vector loads larger than NEON to masked_loads.
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll
index a50d0dc37eaf6..093e6cd9328c8 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll
@@ -460,10 +460,9 @@ define void @masked_gather_v1i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
define void @masked_gather_v2i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
; CHECK-LABEL: masked_gather_v2i64:
; CHECK: // %bb.0:
-; CHECK-NEXT: ldr q0, [x0]
; CHECK-NEXT: ptrue p0.d, vl2
-; CHECK-NEXT: cmeq v0.2d, v0.2d, #0
-; CHECK-NEXT: cmpne p0.d, p0/z, z0.d, #0
+; CHECK-NEXT: ldr q0, [x0]
+; CHECK-NEXT: cmpeq p0.d, p0/z, z0.d, #0
; CHECK-NEXT: ldr q0, [x1]
; CHECK-NEXT: ld1d { z0.d }, p0/z, [z0.d]
; CHECK-NEXT: str q0, [x0]
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-loads.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-loads.ll
index 6513b01d00922..34dc0bb5ef2d2 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-loads.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-loads.ll
@@ -401,11 +401,10 @@ define void @masked_load_sext_v32i8i16(ptr %ap, ptr %bp, ptr %c) #0 {
define void @masked_load_sext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
; VBITS_GE_256-LABEL: masked_load_sext_v16i8i32:
; VBITS_GE_256: // %bb.0:
-; VBITS_GE_256-NEXT: ldr q0, [x1]
; VBITS_GE_256-NEXT: ptrue p0.b, vl16
+; VBITS_GE_256-NEXT: ldr q0, [x1]
; VBITS_GE_256-NEXT: mov x8, #8 // =0x8
-; VBITS_GE_256-NEXT: cmeq v0.16b, v0.16b, #0
-; VBITS_GE_256-NEXT: cmpne p0.b, p0/z, z0.b, #0
+; VBITS_GE_256-NEXT: cmpeq p0.b, p0/z, z0.b, #0
; VBITS_GE_256-NEXT: ld1b { z0.b }, p0/z, [x0]
; VBITS_GE_256-NEXT: ptrue p0.s, vl8
; VBITS_GE_256-NEXT: ext v1.16b, v0.16b, v0.16b, #8
@@ -436,11 +435,10 @@ define void @masked_load_sext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
define void @masked_load_sext_v8i8i64(ptr %ap, ptr %bp, ptr %c) #0 {
; VBITS_GE_256-LABEL: masked_load_sext_v8i8i64:
; VBITS_GE_256: // %bb.0:
-; VBITS_GE_256-NEXT: ldr d0, [x1]
; VBITS_GE_256-NEXT: ptrue p0.b, vl8
+; VBITS_GE_256-NEXT: ldr d0, [x1]
; VBITS_GE_256-NEXT: mov x8, #4 // =0x4
-; VBITS_GE_256-NEXT: cmeq v0.8b, v0.8b, #0
-; VBITS_GE_256-NEXT: cmpne p0.b, p0/z, z0.b, #0
+; VBITS_GE_256-NEXT: cmpeq p0.b, p0/z, z0.b, #0
; VBITS_GE_256-NEXT: ld1b { z0.b }, p0/z, [x0]
; VBITS_GE_256-NEXT: ptrue p0.d, vl4
; VBITS_GE_256-NEXT: sshll v0.8h, v0.8b, #0
@@ -504,11 +502,10 @@ define void @masked_load_sext_v16i16i32(ptr %ap, ptr %bp, ptr %c) #0 {
define void @masked_load_sext_v8i16i64(ptr %ap, ptr %bp, ptr %c) #0 {
; VBITS_GE_256-LABEL: masked_load_sext_v8i16i64:
; VBITS_GE_256: // %bb.0:
-; VBITS_GE_256-NEXT: ldr q0, [x1]
; VBITS_GE_256-NEXT: ptrue p0.h, vl8
+; VBITS_GE_256-NEXT: ldr q0, [x1]
; VBITS_GE_256-NEXT: mov x8, #4 // =0x4
-; VBITS_GE_256-NEXT: cmeq v0.8h, v0.8h, #0
-; VBITS_GE_256-NEXT: cmpne p0.h, p0/z, z0.h, #0
+; VBITS_GE_256-NEXT: cmpeq p0.h, p0/z, z0.h, #0
; VBITS_GE_256-NEXT: ld1h { z0.h }, p0/z, [x0]
; VBITS_GE_256-NEXT: ptrue p0.d, vl4
; VBITS_GE_256-NEXT: ext v1.16b, v0.16b, v0.16b, #8
@@ -603,11 +600,10 @@ define void @masked_load_zext_v32i8i16(ptr %ap, ptr %bp, ptr %c) #0 {
define void @masked_load_zext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
; VBITS_GE_256-LABEL: masked_load_zext_v16i8i32:
; VBITS_GE_256: // %bb.0:
-; VBITS_GE_256-NEXT: ldr q0, [x1]
; VBITS_GE_256-NEXT: ptrue p0.b, vl16
+; VBITS_GE_256-NEXT: ldr q0, [x1]
; VBITS_GE_256-NEXT: mov x8, #8 // =0x8
-; VBITS_GE_256-NEXT: cmeq v0.16b, v0.16b, #0
-; VBITS_GE_256-NEXT: cmpne p0.b, p0/z, z0.b, #0
+; VBITS_GE_256-NEXT: cmpeq p0.b, p0/z, z0.b, #0
; VBITS_GE_256-NEXT: ld1b { z0.b }, p0/z, [x0]
; VBITS_GE_256-NEXT: ptrue p0.s, vl8
; VBITS_GE_256-NEXT: ext v1.16b, v0.16b, v0.16b, #8
@@ -638,11 +634,10 @@ define void @masked_load_zext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
define void @masked_load_zext_v8i8i64(ptr %ap, ptr %bp, ptr %c) #0 {
; VBITS_GE_256-LABEL: masked_load_zext_v8i8i64:
; VBITS_GE_256: // %bb.0:
-; VBITS_GE_256-NEXT: ldr d0, [x1]
; VBITS_GE_256-NEXT: ptrue p0.b, vl8
+; VBITS_GE_256-NEXT: ldr d0, [x1]
; VBITS_GE_256-NEXT: mov x8, #4 // =0x4
-; VBITS_GE_256-NEXT: cmeq v0.8b, v0.8b, #0
-; VBITS_GE_256-NEXT: cmpne p0.b, p0/z, z0.b, #0
+; VBITS_GE_256-NEXT: cmpeq p0.b, p0/z, z0.b, #0
; VBITS_GE_256-NEXT: ld1b { z0.b }, p0/z, [x0]
; VBITS_GE_256-NEXT: ptrue p0.d, vl4
; VBITS_GE_256-NEXT: ushll v0.8h, v0.8b, #0
@@ -706,11 +701,10 @@ define void @masked_load_zext_v16i16i32(ptr %ap, ptr %bp, ptr %c) #0 {
define void @masked_load_zext_v8i16i64(ptr %ap, ptr %bp, ptr %c) #0 {
; VBITS_GE_256-LABEL: masked_load_zext_v8i16i64:
; VBITS_GE_256: // %bb.0:
-; VBITS_GE_256-NEXT: ldr q0, [x1]
; VBITS_GE_256-NEXT: ptrue p0.h, vl8
+; VBITS_GE_256-NEXT: ldr q0, [x1]
; VBITS_GE_256-NEXT: mov x8, #4 // =0x4
-; VBITS_GE_256-NEXT: cmeq v0.8h, v0.8h, #0
-; VBITS_GE_256-NEXT: cmpne p0.h, p0/z, z0.h, #0
+; VBITS_GE_256-NEXT: cmpeq p0.h, p0/z, z0.h, #0
; VBITS_GE_256-NEXT: ld1h { z0.h }, p0/z, [x0]
; VBITS_GE_256-NEXT: ptrue p0.d, vl4
; VBITS_GE_256-NEXT: ext v1.16b, v0.16b, v0.16b, #8
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll
index a42fce70f4f15..ed03f9b322432 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll
@@ -433,11 +433,10 @@ define void @masked_scatter_v1i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
define void @masked_scatter_v2i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
; CHECK-LABEL: masked_scatter_v2i64:
; CHECK: // %bb.0:
-; CHECK-NEXT: ldr q0, [x0]
; CHECK-NEXT: ptrue p0.d, vl2
-; CHECK-NEXT: cmeq v1.2d, v0.2d, #0
-; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, #0
+; CHECK-NEXT: ldr q0, [x0]
; CHECK-NEXT: ldr q1, [x1]
+; CHECK-NEXT: cmpeq p0.d, p0/z, z0.d, #0
; CHECK-NEXT: st1d { z0.d }, p0, [z1.d]
; CHECK-NEXT: ret
%vals = load <2 x i64>, ptr %a
``````````
</details>
https://github.com/llvm/llvm-project/pull/134402
More information about the llvm-commits
mailing list