[llvm] [DAGCombine] Fix multi-use miscompile in load combine (PR #81492)

via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 12 07:51:12 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-x86

@llvm/pr-subscribers-llvm-selectiondag

Author: Nikita Popov (nikic)

<details>
<summary>Changes</summary>

The load combine replaces a number of original loads with one new loads and also replaces the output chains of the original loads with the output chain of the new load. This is only correct if the old loads actually get removed, otherwise they may get incorrectly reordered.

The code did enforce that all involved operations are one-use (which also guarantees that the loads will be removed), with one exceptions: For vector loads, multi-use was allowed to support multiple extract elements from one load.

This patch collects these extract elements, and then validates that the loads are only used inside them.

I think an alternative fix would be to replace the uses of the old output chains with TokenFactors that include both the old output chains and the new output chain. However, I think the proposed patch is preferable, as the profitability of the transform in the general multi-use case is unclear, as it may increase the overall number of loads.

Fixes https://github.com/llvm/llvm-project/issues/80911.

---
Full diff: https://github.com/llvm/llvm-project/pull/81492.diff


4 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+24-9) 
- (modified) llvm/test/CodeGen/AArch64/load-combine.ll (+5-3) 
- (modified) llvm/test/CodeGen/AMDGPU/combine-vload-extract.ll (+6-4) 
- (modified) llvm/test/CodeGen/X86/load-combine.ll (+16-9) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index d3cd9b1671e1b9..45114b85e25d8c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -8668,6 +8668,7 @@ using SDByteProvider = ByteProvider<SDNode *>;
 static std::optional<SDByteProvider>
 calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
                       std::optional<uint64_t> VectorIndex,
+                      SmallPtrSetImpl<SDNode *> &ExtractElements,
                       unsigned StartingIndex = 0) {
 
   // Typical i64 by i8 pattern requires recursion up to 8 calls depth
@@ -8694,12 +8695,12 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
 
   switch (Op.getOpcode()) {
   case ISD::OR: {
-    auto LHS =
-        calculateByteProvider(Op->getOperand(0), Index, Depth + 1, VectorIndex);
+    auto LHS = calculateByteProvider(Op->getOperand(0), Index, Depth + 1,
+                                     VectorIndex, ExtractElements);
     if (!LHS)
       return std::nullopt;
-    auto RHS =
-        calculateByteProvider(Op->getOperand(1), Index, Depth + 1, VectorIndex);
+    auto RHS = calculateByteProvider(Op->getOperand(1), Index, Depth + 1,
+                                     VectorIndex, ExtractElements);
     if (!RHS)
       return std::nullopt;
 
@@ -8726,7 +8727,8 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
     return Index < ByteShift
                ? SDByteProvider::getConstantZero()
                : calculateByteProvider(Op->getOperand(0), Index - ByteShift,
-                                       Depth + 1, VectorIndex, Index);
+                                       Depth + 1, VectorIndex, ExtractElements,
+                                       Index);
   }
   case ISD::ANY_EXTEND:
   case ISD::SIGN_EXTEND:
@@ -8743,11 +8745,12 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
                        SDByteProvider::getConstantZero())
                  : std::nullopt;
     return calculateByteProvider(NarrowOp, Index, Depth + 1, VectorIndex,
-                                 StartingIndex);
+                                 ExtractElements, StartingIndex);
   }
   case ISD::BSWAP:
     return calculateByteProvider(Op->getOperand(0), ByteWidth - Index - 1,
-                                 Depth + 1, VectorIndex, StartingIndex);
+                                 Depth + 1, VectorIndex, ExtractElements,
+                                 StartingIndex);
   case ISD::EXTRACT_VECTOR_ELT: {
     auto OffsetOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
     if (!OffsetOp)
@@ -8772,8 +8775,9 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
     if ((*VectorIndex + 1) * NarrowByteWidth <= StartingIndex)
       return std::nullopt;
 
+    ExtractElements.insert(Op.getNode());
     return calculateByteProvider(Op->getOperand(0), Index, Depth + 1,
-                                 VectorIndex, StartingIndex);
+                                 VectorIndex, ExtractElements, StartingIndex);
   }
   case ISD::LOAD: {
     auto L = cast<LoadSDNode>(Op.getNode());
@@ -9110,6 +9114,7 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
   SDValue Chain;
 
   SmallPtrSet<LoadSDNode *, 8> Loads;
+  SmallPtrSet<SDNode *, 8> ExtractElements;
   std::optional<SDByteProvider> FirstByteProvider;
   int64_t FirstOffset = INT64_MAX;
 
@@ -9119,7 +9124,9 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
   unsigned ZeroExtendedBytes = 0;
   for (int i = ByteWidth - 1; i >= 0; --i) {
     auto P =
-        calculateByteProvider(SDValue(N, 0), i, 0, /*VectorIndex*/ std::nullopt,
+        calculateByteProvider(SDValue(N, 0), i, 0,
+                              /*VectorIndex*/ std::nullopt, ExtractElements,
+
                               /*StartingIndex*/ i);
     if (!P)
       return SDValue();
@@ -9245,6 +9252,14 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
   if (!Allowed || !Fast)
     return SDValue();
 
+  // calculatebyteProvider() allows multi-use for vector loads. Ensure that
+  // all uses are in vector element extracts that are part of the pattern.
+  for (LoadSDNode *L : Loads)
+    if (L->getMemoryVT().isVector())
+      for (auto It = L->use_begin(); It != L->use_end(); ++It)
+        if (It.getUse().getResNo() == 0 && !ExtractElements.contains(*It))
+          return SDValue();
+
   SDValue NewLoad =
       DAG.getExtLoad(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, SDLoc(N), VT,
                      Chain, FirstLoad->getBasePtr(),
diff --git a/llvm/test/CodeGen/AArch64/load-combine.ll b/llvm/test/CodeGen/AArch64/load-combine.ll
index 57f61e5303ecf9..b30ee45aa4d1a0 100644
--- a/llvm/test/CodeGen/AArch64/load-combine.ll
+++ b/llvm/test/CodeGen/AArch64/load-combine.ll
@@ -606,10 +606,12 @@ define void @short_vector_to_i32_unused_high_i8(ptr %in, ptr %out, ptr %p) {
 ; CHECK-LABEL: short_vector_to_i32_unused_high_i8:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    ldr s0, [x0]
-; CHECK-NEXT:    ldrh w9, [x0]
 ; CHECK-NEXT:    ushll v0.8h, v0.8b, #0
-; CHECK-NEXT:    umov w8, v0.h[2]
-; CHECK-NEXT:    orr w8, w9, w8, lsl #16
+; CHECK-NEXT:    umov w8, v0.h[1]
+; CHECK-NEXT:    umov w9, v0.h[0]
+; CHECK-NEXT:    umov w10, v0.h[2]
+; CHECK-NEXT:    bfi w9, w8, #8, #8
+; CHECK-NEXT:    orr w8, w9, w10, lsl #16
 ; CHECK-NEXT:    str w8, [x1]
 ; CHECK-NEXT:    ret
   %ld = load <4 x i8>, ptr %in, align 4
diff --git a/llvm/test/CodeGen/AMDGPU/combine-vload-extract.ll b/llvm/test/CodeGen/AMDGPU/combine-vload-extract.ll
index c27e44609c527f..96921082801821 100644
--- a/llvm/test/CodeGen/AMDGPU/combine-vload-extract.ll
+++ b/llvm/test/CodeGen/AMDGPU/combine-vload-extract.ll
@@ -205,12 +205,14 @@ define i64 @load_3xi16_combine(ptr addrspace(1) %p) #0 {
 ; GCN-LABEL: load_3xi16_combine:
 ; GCN:       ; %bb.0:
 ; GCN-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; GCN-NEXT:    global_load_dword v2, v[0:1], off
-; GCN-NEXT:    global_load_ushort v3, v[0:1], off offset:4
+; GCN-NEXT:    global_load_dword v3, v[0:1], off
+; GCN-NEXT:    global_load_ushort v2, v[0:1], off offset:4
+; GCN-NEXT:    s_mov_b32 s4, 0xffff
 ; GCN-NEXT:    s_waitcnt vmcnt(1)
-; GCN-NEXT:    v_mov_b32_e32 v0, v2
+; GCN-NEXT:    v_and_b32_e32 v0, 0xffff0000, v3
+; GCN-NEXT:    v_and_or_b32 v0, v3, s4, v0
 ; GCN-NEXT:    s_waitcnt vmcnt(0)
-; GCN-NEXT:    v_mov_b32_e32 v1, v3
+; GCN-NEXT:    v_mov_b32_e32 v1, v2
 ; GCN-NEXT:    s_setpc_b64 s[30:31]
   %gep.p = getelementptr i16, ptr addrspace(1) %p, i32 1
   %gep.2p = getelementptr i16, ptr addrspace(1) %p, i32 2
diff --git a/llvm/test/CodeGen/X86/load-combine.ll b/llvm/test/CodeGen/X86/load-combine.ll
index 7e4e11fcc75c20..530e17a0b0f099 100644
--- a/llvm/test/CodeGen/X86/load-combine.ll
+++ b/llvm/test/CodeGen/X86/load-combine.ll
@@ -1283,26 +1283,33 @@ define i32 @zext_load_i32_by_i8_bswap_shl_16(ptr %arg) {
   ret i32 %tmp8
 }
 
-; FIXME: This is a miscompile.
 define i32 @pr80911_vector_load_multiuse(ptr %ptr, ptr %clobber) nounwind {
 ; CHECK-LABEL: pr80911_vector_load_multiuse:
 ; CHECK:       # %bb.0:
+; CHECK-NEXT:    pushl %edi
 ; CHECK-NEXT:    pushl %esi
-; CHECK-NEXT:    movl {{[0-9]+}}(%esp), %ecx
 ; CHECK-NEXT:    movl {{[0-9]+}}(%esp), %edx
-; CHECK-NEXT:    movl (%edx), %esi
-; CHECK-NEXT:    movzwl (%edx), %eax
-; CHECK-NEXT:    movl $0, (%ecx)
-; CHECK-NEXT:    movl %esi, (%edx)
+; CHECK-NEXT:    movl {{[0-9]+}}(%esp), %esi
+; CHECK-NEXT:    movzbl (%esi), %ecx
+; CHECK-NEXT:    movzbl 1(%esi), %eax
+; CHECK-NEXT:    movzwl 2(%esi), %edi
+; CHECK-NEXT:    movl $0, (%edx)
+; CHECK-NEXT:    movw %di, 2(%esi)
+; CHECK-NEXT:    movb %al, 1(%esi)
+; CHECK-NEXT:    movb %cl, (%esi)
+; CHECK-NEXT:    shll $8, %eax
+; CHECK-NEXT:    orl %ecx, %eax
 ; CHECK-NEXT:    popl %esi
+; CHECK-NEXT:    popl %edi
 ; CHECK-NEXT:    retl
 ;
 ; CHECK64-LABEL: pr80911_vector_load_multiuse:
 ; CHECK64:       # %bb.0:
-; CHECK64-NEXT:    movzwl (%rdi), %eax
+; CHECK64-NEXT:    movaps (%rdi), %xmm0
 ; CHECK64-NEXT:    movl $0, (%rsi)
-; CHECK64-NEXT:    movl (%rdi), %ecx
-; CHECK64-NEXT:    movl %ecx, (%rdi)
+; CHECK64-NEXT:    movss %xmm0, (%rdi)
+; CHECK64-NEXT:    movaps %xmm0, -{{[0-9]+}}(%rsp)
+; CHECK64-NEXT:    movzwl -{{[0-9]+}}(%rsp), %eax
 ; CHECK64-NEXT:    retq
   %load = load <4 x i8>, ptr %ptr, align 16
   store i32 0, ptr %clobber

``````````

</details>


https://github.com/llvm/llvm-project/pull/81492


More information about the llvm-commits mailing list