[llvm] [DAGCombiner] Remove a hasOneUse check in visitAND (PR #115142)

via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 6 01:59:05 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: David Sherwood (david-arm)

<details>
<summary>Changes</summary>

For some reason there was a hasOneUse check on the splat for the
second operand and it's not obvious to me why. The check blocks
optimisations for lowering of nodes like AVGFLOORU and AVGCEILU.

In a follow-on patch I also plan to improve the generated code
for AVGCEILU further by teaching computeKnownBits about
zero-extending masked loads.

---
Full diff: https://github.com/llvm/llvm-project/pull/115142.diff


2 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+1-2) 
- (modified) llvm/test/CodeGen/AArch64/avg.ll (+47-1) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 7eef09e55101d0..f718cbf65480ab 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -7095,8 +7095,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
     // fold (and (masked_load) (splat_vec (x, ...))) to zext_masked_load
     auto *MLoad = dyn_cast<MaskedLoadSDNode>(N0);
     ConstantSDNode *Splat = isConstOrConstSplat(N1, true, true);
-    if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && Splat &&
-        N1.hasOneUse()) {
+    if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && Splat) {
       EVT LoadVT = MLoad->getMemoryVT();
       EVT ExtVT = VT;
       if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT)) {
diff --git a/llvm/test/CodeGen/AArch64/avg.ll b/llvm/test/CodeGen/AArch64/avg.ll
index ea07b10c22c2e7..aac797aafcf2eb 100644
--- a/llvm/test/CodeGen/AArch64/avg.ll
+++ b/llvm/test/CodeGen/AArch64/avg.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
-; RUN: llc -mtriple=aarch64 < %s | FileCheck %s
+; RUN: llc -mtriple=aarch64 -mattr=+sve < %s | FileCheck %s
 
 define <16 x i16> @zext_avgflooru(<16 x i8> %a0, <16 x i8> %a1) {
 ; CHECK-LABEL: zext_avgflooru:
@@ -17,6 +17,28 @@ define <16 x i16> @zext_avgflooru(<16 x i8> %a0, <16 x i8> %a1) {
   ret <16 x i16> %avg
 }
 
+define void @zext_mload_avgflooru(ptr %p1, ptr %p2, <vscale x 8 x i1> %mask) {
+; CHECK-LABEL: zext_mload_avgflooru:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ld1b { z0.h }, p0/z, [x0]
+; CHECK-NEXT:    ld1b { z1.h }, p0/z, [x1]
+; CHECK-NEXT:    eor z2.d, z0.d, z1.d
+; CHECK-NEXT:    and z0.d, z0.d, z1.d
+; CHECK-NEXT:    lsr z1.h, z2.h, #1
+; CHECK-NEXT:    add z0.h, z0.h, z1.h
+; CHECK-NEXT:    st1h { z0.h }, p0, [x0]
+; CHECK-NEXT:    ret
+  %ld1 = call <vscale x 8 x i8> @llvm.masked.load(ptr %p1, i32 16, <vscale x 8 x i1> %mask, <vscale x 8 x i8> zeroinitializer)
+  %ld2 = call <vscale x 8 x i8> @llvm.masked.load(ptr %p2, i32 16, <vscale x 8 x i1> %mask, <vscale x 8 x i8> zeroinitializer)
+  %and = and <vscale x 8 x i8> %ld1, %ld2
+  %xor = xor <vscale x 8 x i8> %ld1, %ld2
+  %shift = lshr <vscale x 8 x i8> %xor, splat(i8 1)
+  %avg = add <vscale x 8 x i8> %and, %shift
+  %avgext = zext <vscale x 8 x i8> %avg to <vscale x 8 x i16>
+  call void @llvm.masked.store.nxv8i16(<vscale x 8 x i16> %avgext, ptr %p1, i32 16, <vscale x 8 x i1> %mask)
+  ret void
+}
+
 define <16 x i16> @zext_avgflooru_mismatch(<16 x i8> %a0, <16 x i4> %a1) {
 ; CHECK-LABEL: zext_avgflooru_mismatch:
 ; CHECK:       // %bb.0:
@@ -51,6 +73,30 @@ define <16 x i16> @zext_avgceilu(<16 x i8> %a0, <16 x i8> %a1) {
   ret <16 x i16> %avg
 }
 
+define void @zext_mload_avgceilu(ptr %p1, ptr %p2, <vscale x 8 x i1> %mask) {
+; CHECK-LABEL: zext_mload_avgceilu:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ld1b { z0.h }, p0/z, [x0]
+; CHECK-NEXT:    ld1b { z1.h }, p0/z, [x1]
+; CHECK-NEXT:    eor z2.d, z0.d, z1.d
+; CHECK-NEXT:    orr z0.d, z0.d, z1.d
+; CHECK-NEXT:    lsr z1.h, z2.h, #1
+; CHECK-NEXT:    sub z0.h, z0.h, z1.h
+; CHECK-NEXT:    st1b { z0.h }, p0, [x0]
+; CHECK-NEXT:    ret
+  %ld1 = call <vscale x 8 x i8> @llvm.masked.load(ptr %p1, i32 16, <vscale x 8 x i1> %mask, <vscale x 8 x i8> zeroinitializer)
+  %ld2 = call <vscale x 8 x i8> @llvm.masked.load(ptr %p2, i32 16, <vscale x 8 x i1> %mask, <vscale x 8 x i8> zeroinitializer)
+  %zext1 = zext <vscale x 8 x i8> %ld1 to <vscale x 8 x i16>
+  %zext2 = zext <vscale x 8 x i8> %ld2 to <vscale x 8 x i16>
+  %add1 = add nuw nsw <vscale x 8 x i16> %zext1, splat(i16 1)
+  %add2 = add nuw nsw <vscale x 8 x i16> %add1, %zext2
+  %shift = lshr <vscale x 8 x i16> %add2, splat(i16 1)
+  %trunc = trunc <vscale x 8 x i16> %shift to <vscale x 8 x i8>
+  call void @llvm.masked.store.nxv8i8(<vscale x 8 x i8> %trunc, ptr %p1, i32 16, <vscale x 8 x i1> %mask)
+  ret void
+}
+
+
 define <16 x i16> @zext_avgceilu_mismatch(<16 x i4> %a0, <16 x i8> %a1) {
 ; CHECK-LABEL: zext_avgceilu_mismatch:
 ; CHECK:       // %bb.0:

``````````

</details>


https://github.com/llvm/llvm-project/pull/115142


More information about the llvm-commits mailing list