[llvm] [DAGCombiner] Remove a hasOneUse check in visitAND (PR #115142)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Nov 6 01:59:05 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: David Sherwood (david-arm)
<details>
<summary>Changes</summary>
For some reason there was a hasOneUse check on the splat for the
second operand and it's not obvious to me why. The check blocks
optimisations for lowering of nodes like AVGFLOORU and AVGCEILU.
In a follow-on patch I also plan to improve the generated code
for AVGCEILU further by teaching computeKnownBits about
zero-extending masked loads.
---
Full diff: https://github.com/llvm/llvm-project/pull/115142.diff
2 Files Affected:
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+1-2)
- (modified) llvm/test/CodeGen/AArch64/avg.ll (+47-1)
``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 7eef09e55101d0..f718cbf65480ab 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -7095,8 +7095,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
// fold (and (masked_load) (splat_vec (x, ...))) to zext_masked_load
auto *MLoad = dyn_cast<MaskedLoadSDNode>(N0);
ConstantSDNode *Splat = isConstOrConstSplat(N1, true, true);
- if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && Splat &&
- N1.hasOneUse()) {
+ if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && Splat) {
EVT LoadVT = MLoad->getMemoryVT();
EVT ExtVT = VT;
if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT)) {
diff --git a/llvm/test/CodeGen/AArch64/avg.ll b/llvm/test/CodeGen/AArch64/avg.ll
index ea07b10c22c2e7..aac797aafcf2eb 100644
--- a/llvm/test/CodeGen/AArch64/avg.ll
+++ b/llvm/test/CodeGen/AArch64/avg.ll
@@ -1,5 +1,5 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
-; RUN: llc -mtriple=aarch64 < %s | FileCheck %s
+; RUN: llc -mtriple=aarch64 -mattr=+sve < %s | FileCheck %s
define <16 x i16> @zext_avgflooru(<16 x i8> %a0, <16 x i8> %a1) {
; CHECK-LABEL: zext_avgflooru:
@@ -17,6 +17,28 @@ define <16 x i16> @zext_avgflooru(<16 x i8> %a0, <16 x i8> %a1) {
ret <16 x i16> %avg
}
+define void @zext_mload_avgflooru(ptr %p1, ptr %p2, <vscale x 8 x i1> %mask) {
+; CHECK-LABEL: zext_mload_avgflooru:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ld1b { z0.h }, p0/z, [x0]
+; CHECK-NEXT: ld1b { z1.h }, p0/z, [x1]
+; CHECK-NEXT: eor z2.d, z0.d, z1.d
+; CHECK-NEXT: and z0.d, z0.d, z1.d
+; CHECK-NEXT: lsr z1.h, z2.h, #1
+; CHECK-NEXT: add z0.h, z0.h, z1.h
+; CHECK-NEXT: st1h { z0.h }, p0, [x0]
+; CHECK-NEXT: ret
+ %ld1 = call <vscale x 8 x i8> @llvm.masked.load(ptr %p1, i32 16, <vscale x 8 x i1> %mask, <vscale x 8 x i8> zeroinitializer)
+ %ld2 = call <vscale x 8 x i8> @llvm.masked.load(ptr %p2, i32 16, <vscale x 8 x i1> %mask, <vscale x 8 x i8> zeroinitializer)
+ %and = and <vscale x 8 x i8> %ld1, %ld2
+ %xor = xor <vscale x 8 x i8> %ld1, %ld2
+ %shift = lshr <vscale x 8 x i8> %xor, splat(i8 1)
+ %avg = add <vscale x 8 x i8> %and, %shift
+ %avgext = zext <vscale x 8 x i8> %avg to <vscale x 8 x i16>
+ call void @llvm.masked.store.nxv8i16(<vscale x 8 x i16> %avgext, ptr %p1, i32 16, <vscale x 8 x i1> %mask)
+ ret void
+}
+
define <16 x i16> @zext_avgflooru_mismatch(<16 x i8> %a0, <16 x i4> %a1) {
; CHECK-LABEL: zext_avgflooru_mismatch:
; CHECK: // %bb.0:
@@ -51,6 +73,30 @@ define <16 x i16> @zext_avgceilu(<16 x i8> %a0, <16 x i8> %a1) {
ret <16 x i16> %avg
}
+define void @zext_mload_avgceilu(ptr %p1, ptr %p2, <vscale x 8 x i1> %mask) {
+; CHECK-LABEL: zext_mload_avgceilu:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ld1b { z0.h }, p0/z, [x0]
+; CHECK-NEXT: ld1b { z1.h }, p0/z, [x1]
+; CHECK-NEXT: eor z2.d, z0.d, z1.d
+; CHECK-NEXT: orr z0.d, z0.d, z1.d
+; CHECK-NEXT: lsr z1.h, z2.h, #1
+; CHECK-NEXT: sub z0.h, z0.h, z1.h
+; CHECK-NEXT: st1b { z0.h }, p0, [x0]
+; CHECK-NEXT: ret
+ %ld1 = call <vscale x 8 x i8> @llvm.masked.load(ptr %p1, i32 16, <vscale x 8 x i1> %mask, <vscale x 8 x i8> zeroinitializer)
+ %ld2 = call <vscale x 8 x i8> @llvm.masked.load(ptr %p2, i32 16, <vscale x 8 x i1> %mask, <vscale x 8 x i8> zeroinitializer)
+ %zext1 = zext <vscale x 8 x i8> %ld1 to <vscale x 8 x i16>
+ %zext2 = zext <vscale x 8 x i8> %ld2 to <vscale x 8 x i16>
+ %add1 = add nuw nsw <vscale x 8 x i16> %zext1, splat(i16 1)
+ %add2 = add nuw nsw <vscale x 8 x i16> %add1, %zext2
+ %shift = lshr <vscale x 8 x i16> %add2, splat(i16 1)
+ %trunc = trunc <vscale x 8 x i16> %shift to <vscale x 8 x i8>
+ call void @llvm.masked.store.nxv8i8(<vscale x 8 x i8> %trunc, ptr %p1, i32 16, <vscale x 8 x i1> %mask)
+ ret void
+}
+
+
define <16 x i16> @zext_avgceilu_mismatch(<16 x i4> %a0, <16 x i8> %a1) {
; CHECK-LABEL: zext_avgceilu_mismatch:
; CHECK: // %bb.0:
``````````
</details>
https://github.com/llvm/llvm-project/pull/115142
More information about the llvm-commits
mailing list