[llvm] r289538 - [DAGCombiner] Match load by bytes idiom and fold it into a single load

Artur Pilipenko via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 16 04:32:43 PST 2016


Comments inlined…

On 16 Dec 2016, at 07:00, Chandler Carruth <chandlerc at gmail.com<mailto:chandlerc at gmail.com>> wrote:

On Tue, Dec 13, 2016 at 6:31 AM Artur Pilipenko via llvm-commits <llvm-commits at lists.llvm.org<mailto:llvm-commits at lists.llvm.org>> wrote:
Author: apilipenko
Date: Tue Dec 13 08:21:14 2016
New Revision: 289538

URL: http://llvm.org/viewvc/llvm-project?rev=289538&view=rev
Log:
[DAGCombiner] Match load by bytes idiom and fold it into a single load

Really cool transform, but unfortunately, this patch has a bad bug in it that is a bit tricky to trigger. Consider the test case:

target triple = "x86_64-unknown-linux-gnu"

define void @crash(i8* %src1, i8* %src2, i64* %dst) {
entry:
  %load1 = load i8, i8* %src1, align 1
  %conv46 = zext i8 %load1 to i32
  %shl47 = shl i32 %conv46, 56
  %or55 = or i32 %shl47, 0
  %load2 = load i8, i8* %src2, align 1
  %conv57 = zext i8 %load2 to i32
  %shl58 = shl i32 %conv57, 32
  %or59 = or i32 %or55, %shl58
  %or74 = or i32 %or59, 0
  %conv75 = sext i32 %or74 to i64
  store i64 %conv75, i64* %dst, align 8
  ret void
}

This will crash when it tries to back and end iterator past the beginning of a vector. You need to check for bad shift widths.


However, there are several other serious problems with the patch that need to be addressed so I'm going to revert it for now to unblock folks (we're hitting this in real code) and let you post an updated patch for review. Please get the updated patch re-reviewed, as I think this needs some more detailed review around data structures and algorithms used. See my (very rough) comments below...

==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp Tue Dec 13 08:21:14 2016
@@ -20,6 +20,7 @@
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/AliasAnalysis.h"
 #include "llvm/CodeGen/MachineFrameInfo.h"
@@ -375,6 +376,7 @@ namespace {
                               unsigned PosOpcode, unsigned NegOpcode,
                               const SDLoc &DL);
     SDNode *MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL);
+    SDValue MatchLoadCombine(SDNode *N);
     SDValue ReduceLoadWidth(SDNode *N);
     SDValue ReduceLoadOpStoreWidth(SDNode *N);
     SDValue splitMergedValStore(StoreSDNode *ST);
@@ -3969,6 +3971,9 @@ SDValue DAGCombiner::visitOR(SDNode *N)
   if (SDNode *Rot = MatchRotate(N0, N1, SDLoc(N)))
     return SDValue(Rot, 0);

+  if (SDValue Load = MatchLoadCombine(N))
+    return Load;
+
   // Simplify the operands using demanded-bits information.
   if (!VT.isVector() &&
       SimplifyDemandedBits(SDValue(N, 0)))
@@ -4340,6 +4345,277 @@ struct BaseIndexOffset {
 };
 } // namespace

+namespace {
+/// Represents the origin of an individual byte in load combine pattern. The
+/// value of the byte is either unknown, zero or comes from memory.
+struct ByteProvider {
+  enum ProviderTy {
+    Unknown,
+    ZeroConstant,
+    Memory
+  };
+
+  ProviderTy Kind;

Putting this before a pointer means you'll have almost all padding bits here.. =[

+  // Load and ByteOffset are set for Memory providers only.
+  // Load represents the node which loads the byte from memory.
+  // ByteOffset is the offset of the byte in the value produced by the load.
+  LoadSDNode *Load;
+  unsigned ByteOffset;

Most of the providers are not loads though? This data structure seems really inefficient for the common case.

Also, ByteOffset seems completely redundant. It is always set to the same value as the index into the vector below. Can you eliminate it or avoid it somehow?
ByteOffset is initially set to the same value as the index into vector, but then the vector can be shuffled by shifts. For example, the chain:
shl (load i32* %p), 16
produces providers:
{ zero, zero, load p[2], p[3] }

In the upcoming change I’m going to handle bswap as a part of combine pattern. It will also shuffle the offsets.

Given that ByteOffset is needed I don’t immediately see who to compress the structure.

Then, could you use a PointerIntPair so that this whole thing becomes the size of a pointer? That would seem much more efficient.

+/// Recursively traverses the expression collecting the origin of individual
+/// bytes of the given value. For all the values except the root of the
+/// expression verifies that it doesn't have uses outside of the expression.
+const Optional<SmallVector<ByteProvider, 4> >

Please use clang-format with the latest LLVM settings. We don't need the space before the closing '>' any more.

+collectByteProviders(SDValue Op, bool CheckNumberOfUses = false) {
+  if (CheckNumberOfUses && !Op.hasOneUse())
+    return None;
+
+  unsigned BitWidth = Op.getScalarValueSizeInBits();
+  if (BitWidth % 8 != 0)
+    return None;
+  unsigned ByteWidth = BitWidth / 8;

Do you want to bound the size you consider here? Especially as this is in the backend you should know the maximal size you can fuse together... But maybe if you make this more efficient you can handle the full width of any integer.
I bound the type of the root of the chain in DAGCombiner::MatchLoadCombine, but I allow intermediate parts of the chain to be any width.

+
+  switch (Op.getOpcode()) {
+  case ISD::OR: {
+    auto LHS = collectByteProviders(Op->getOperand(0),
+                                    /*CheckNumberOfUses=*/true);
+    auto RHS = collectByteProviders(Op->getOperand(1),
+                                    /*CheckNumberOfUses=*/true);
+    if (!LHS || !RHS)
+      return None;
+
+    auto OR = [](ByteProvider LHS, ByteProvider RHS) {
+      if (LHS == RHS)
+        return LHS;
+      if (LHS.Kind == ByteProvider::Unknown ||
+          RHS.Kind == ByteProvider::Unknown)
+        return ByteProvider::getUnknown();
+      if (LHS.Kind == ByteProvider::Memory && RHS.Kind == ByteProvider::Memory)
+        return ByteProvider::getUnknown();
+
+      if (LHS.Kind == ByteProvider::Memory)
+        return LHS;
+      else
+        return RHS;
+    };
+
+    SmallVector<ByteProvider, 4> Result(ByteWidth);
+    for (unsigned i = 0; i < LHS->size(); i++)
+      Result[i] = OR(LHS.getValue()[i], RHS.getValue()[i]);

This is really inefficient.

You build the entire tree for each layer even if you throw it away again.
Do you mean cases when some of the bytes turn to unknown? Once we get providers for one operand we might know that specific bytes will turn into unknown but we still need to collect providers for other bytes. Currently collectByteProviders does it for all bytes in the value. One of the option is to provide a mask for collectByteProviders to specify which bytes are actually needed.

And because you recurse infinitely upwards, it is quadratic for tall chains.

For combines that recurse upward to not be quadratic in nature, they must:
- Combine away the nodes they recurse upward through so we don't re-visit them, or
- Have a constant factor limit on the depth of recursion
collectByteProviders doesn’t follow nodes which have more than one use (the first check in this function). I don’t want to combine patterns individual parts of which are used somewhere else. The side effect of it is collectByteProviders never visits the same node more than once.

And to make matters worse when you have a load of more than one byte, you build the information for the load redundantly, once for each byte.
Can you explain? collectByteProviders collects providers for all the bytes in the value. To handle or I collect providers for all bytes for both sides of the operation and then analyze the result.

+
+    return Result;
+  }
+  case ISD::SHL: {
+    auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
+    if (!ShiftOp)
+      return None;
+
+    uint64_t BitShift = ShiftOp->getZExtValue();
+    if (BitShift % 8 != 0)
+      return None;
+    uint64_t ByteShift = BitShift / 8;
+
+    auto Original = collectByteProviders(Op->getOperand(0),
+                                         /*CheckNumberOfUses=*/true);
+    if (!Original)
+      return None;
+
+    SmallVector<ByteProvider, 4> Result;
+    Result.insert(Result.begin(), ByteShift, ByteProvider::getZero());
+    Result.insert(Result.end(), Original->begin(),
+                  std::prev(Original->end(), ByteShift));
+    assert(Result.size() == ByteWidth && "sanity");
+    return Result;

This is again really inefficient.

First, you build the full Original set of things, but then throw away the tail. And second you are copying things from small vector to small vector everywhere here.
The same idea with the mask of required bytes might work here.

This really feel like it should be linearly filling in the provider for each byte of the bytes that survived to be used, with no copying or other movement. That way you can only recurse through the byte providers you actually care about, etc. Does that make sense?
The tradeoff of this approach is that we need to scan the whole tree for each individual byte.

Or maybe some other algorithm. But fundamentally, I think this recursive walk will need to be significantly re-worked.


+  // Calculate byte providers for the OR we are looking at
+  auto Res = collectByteProviders(SDValue(N, 0));
+  if (!Res)
+    return SDValue();
+  auto &Bytes = Res.getValue();
+  unsigned ByteWidth = Bytes.size();
+  assert(VT.getSizeInBits() == ByteWidth * 8 && "sanity");
+
+  auto LittleEndianByteAt = [](unsigned BW, unsigned i) { return i; };
+  auto BigEndianByteAt = [](unsigned BW, unsigned i) { return BW - i - 1; };
+
+  Optional<BaseIndexOffset> Base;
+  SDValue Chain;
+
+  SmallSet<LoadSDNode *, 8> Loads;
+  LoadSDNode *FirstLoad = nullptr;
+
+  // Check if all the bytes of the OR we are looking at are loaded from the same
+  // base address. Collect bytes offsets from Base address in ByteOffsets.
+  SmallVector<int64_t, 4> ByteOffsets(ByteWidth);
+  for (unsigned i = 0; i < ByteWidth; i++) {
+    // All the bytes must be loaded from memory
+    if (Bytes[i].Kind != ByteProvider::Memory)
+      return SDValue();

This again seems especially wasteful. Above you carefully built a structure tracking which bytes were zeroed and which came from memory, but you don't need all of that. You just need the bytes from memory and you could early exit if they aren't from memory.

Alternatively, you could implement something that does the full load and mask off the things that need to be zero. But that would be a more complex cost model and tradeoff space so it might not be worth it.
I’m going to use zero byte information to handle partially available values. Now the transform only handles the patterns when all the bytes are known to be loaded from memory. But I’d like to support cases when we know that the higher bytes are zeros and lower bytes are read from memory.

Artur

==============================================================================
--- llvm/trunk/test/CodeGen/ARM/load-combine-big-endian.ll (added)
+++ llvm/trunk/test/CodeGen/ARM/load-combine-big-endian.ll Tue Dec 13 08:21:14 2016
@@ -0,0 +1,234 @@
+; RUN: llc < %s -mtriple=armeb-unknown | FileCheck %s
+; RUN: llc < %s -mtriple=arm64eb-unknown | FileCheck %s --check-prefix=CHECK64
+
+; i8* p; // p is 4 byte aligned
+; ((i32) p[0] << 24) | ((i32) p[1] << 16) | ((i32) p[2] << 8) | (i32) p[3]
+define i32 @load_i32_by_i8_big_endian(i32*) {
+; CHECK-LABEL: load_i32_by_i8_big_endian:
+; CHECK: ldr r0, [r0]
+; CHECK-NEXT: mov pc, lr
+
+; CHECK64-LABEL: load_i32_by_i8_big_endian:
+; CHECK64: ldr         w0, [x0]
+; CHECK64-NEXT: ret
+  %2 = bitcast i32* %0 to i8*
+  %3 = load i8, i8* %2, align 4

Please *always* use named values in your tests. otherwise making minor updates is incredibly painful as you have to renumber all subsequent values.

-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20161216/14832ed8/attachment.html>


More information about the llvm-commits mailing list