[llvm] 7972a6e - [DAGCombiner][NFC] Factor out ByteProvider

Jeffrey Byrnes via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 19 08:54:54 PDT 2023


Author: Jeffrey Byrnes
Date: 2023-06-19T08:54:34-07:00
New Revision: 7972a6e126bf3af19d9871e81c6ce1b4cb2fdea2

URL: https://github.com/llvm/llvm-project/commit/7972a6e126bf3af19d9871e81c6ce1b4cb2fdea2
DIFF: https://github.com/llvm/llvm-project/commit/7972a6e126bf3af19d9871e81c6ce1b4cb2fdea2.diff

LOG: [DAGCombiner][NFC] Factor out ByteProvider

Differential Revision: https://reviews.llvm.org/D143018

Change-Id: I3dc03787a3382c0c3fe6b869f869c2946f450874

Added: 
    llvm/include/llvm/CodeGen/ByteProvider.h

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/ByteProvider.h b/llvm/include/llvm/CodeGen/ByteProvider.h
new file mode 100644
index 0000000000000..e0ba40b135336
--- /dev/null
+++ b/llvm/include/llvm/CodeGen/ByteProvider.h
@@ -0,0 +1,90 @@
+//===-- include/llvm/CodeGen/ByteProvider.h - Map bytes ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// \file
+// This file implements ByteProvider. The purpose of ByteProvider is to provide
+// a map between a target node's byte (byte position is DestOffset) and the
+// source (and byte position) that provides it (in Src and SrcOffset
+// respectively) See CodeGen/SelectionDAG/DAGCombiner.cpp MatchLoadCombine
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CODEGEN_BYTEPROVIDER_H
+#define LLVM_CODEGEN_BYTEPROVIDER_H
+
+#include <optional>
+#include <type_traits>
+
+namespace llvm {
+
+/// Represents known origin of an individual byte in combine pattern. The
+/// value of the byte is either constant zero, or comes from memory /
+/// some other productive instruction (e.g. arithmetic instructions).
+/// Bit manipulation instructions like shifts are not ByteProviders, rather
+/// are used to extract Bytes.
+template <typename ISelOp> class ByteProvider {
+private:
+  ByteProvider<ISelOp>(std::optional<ISelOp> Src, int64_t DestOffset,
+                       int64_t SrcOffset)
+      : Src(Src), DestOffset(DestOffset), SrcOffset(SrcOffset) {}
+
+  // TODO -- use constraint in c++20
+  // Does this type correspond with an operation in selection DAG
+  template <typename T> class is_op {
+  private:
+    using yes = std::true_type;
+    using no = std::false_type;
+
+    // Only allow classes with member function getOpcode
+    template <typename U>
+    static auto test(int) -> decltype(std::declval<U>().getOpcode(), yes());
+
+    template <typename> static no test(...);
+
+  public:
+    using remove_pointer_t = typename std::remove_pointer<T>::type;
+    static constexpr bool value =
+        std::is_same<decltype(test<remove_pointer_t>(0)), yes>::value;
+  };
+
+public:
+  // For constant zero providers Src is set to nullopt. For actual providers
+  // Src represents the node which originally produced the relevant bits.
+  std::optional<ISelOp> Src = std::nullopt;
+  // DestOffset is the offset of the byte in the dest we are trying to map for.
+  int64_t DestOffset = 0;
+  // SrcOffset is the offset in the ultimate source node that maps to the
+  // DestOffset
+  int64_t SrcOffset = 0;
+
+  ByteProvider() = default;
+
+  static ByteProvider getSrc(std::optional<ISelOp> Val, int64_t ByteOffset,
+                             int64_t VectorOffset) {
+    static_assert(is_op<ISelOp>().value,
+                  "ByteProviders must contain an operation in selection DAG.");
+    return ByteProvider(Val, ByteOffset, VectorOffset);
+  }
+
+  static ByteProvider getConstantZero() {
+    return ByteProvider<ISelOp>(std::nullopt, 0, 0);
+  }
+  bool isConstantZero() const { return !Src; }
+
+  bool hasSrc() const { return Src.has_value(); }
+
+  bool hasSameSrc(const ByteProvider &Other) const { return Other.Src == Src; }
+
+  bool operator==(const ByteProvider &Other) const {
+    return Other.Src == Src && Other.DestOffset == DestOffset &&
+           Other.SrcOffset == SrcOffset;
+  }
+};
+} // end namespace llvm
+
+#endif // LLVM_CODEGEN_BYTEPROVIDER_H

diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 454944419293c..56dc284adf709 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -32,6 +32,7 @@
 #include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/Analysis/VectorUtils.h"
+#include "llvm/CodeGen/ByteProvider.h"
 #include "llvm/CodeGen/DAGCombine.h"
 #include "llvm/CodeGen/ISDOpcodes.h"
 #include "llvm/CodeGen/MachineFunction.h"
@@ -8402,42 +8403,6 @@ SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
   return SDValue();
 }
 
-namespace {
-
-/// Represents known origin of an individual byte in load combine pattern. The
-/// value of the byte is either constant zero or comes from memory.
-struct ByteProvider {
-  // For constant zero providers Load is set to nullptr. For memory providers
-  // Load represents the node which loads the byte from memory.
-  // ByteOffset is the offset of the byte in the value produced by the load.
-  LoadSDNode *Load = nullptr;
-  unsigned ByteOffset = 0;
-  unsigned VectorOffset = 0;
-
-  ByteProvider() = default;
-
-  static ByteProvider getMemory(LoadSDNode *Load, unsigned ByteOffset,
-                                unsigned VectorOffset) {
-    return ByteProvider(Load, ByteOffset, VectorOffset);
-  }
-
-  static ByteProvider getConstantZero() { return ByteProvider(nullptr, 0, 0); }
-
-  bool isConstantZero() const { return !Load; }
-  bool isMemory() const { return Load; }
-
-  bool operator==(const ByteProvider &Other) const {
-    return Other.Load == Load && Other.ByteOffset == ByteOffset &&
-           Other.VectorOffset == VectorOffset;
-  }
-
-private:
-  ByteProvider(LoadSDNode *Load, unsigned ByteOffset, unsigned VectorOffset)
-      : Load(Load), ByteOffset(ByteOffset), VectorOffset(VectorOffset) {}
-};
-
-} // end anonymous namespace
-
 /// Recursively traverses the expression calculating the origin of the requested
 /// byte of the given value. Returns std::nullopt if the provider can't be
 /// calculated.
@@ -8479,7 +8444,9 @@ struct ByteProvider {
 ///                 LOAD
 ///
 /// *ExtractVectorElement
-static const std::optional<ByteProvider>
+using SDByteProvider = ByteProvider<SDNode *>;
+
+static const std::optional<SDByteProvider>
 calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
                       std::optional<uint64_t> VectorIndex,
                       unsigned StartingIndex = 0) {
@@ -8538,7 +8505,7 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
     // provide, then do not provide anything. Otherwise, subtract the index by
     // the amount we shifted by.
     return Index < ByteShift
-               ? ByteProvider::getConstantZero()
+               ? SDByteProvider::getConstantZero()
                : calculateByteProvider(Op->getOperand(0), Index - ByteShift,
                                        Depth + 1, VectorIndex, Index);
   }
@@ -8553,7 +8520,8 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
 
     if (Index >= NarrowByteWidth)
       return Op.getOpcode() == ISD::ZERO_EXTEND
-                 ? std::optional<ByteProvider>(ByteProvider::getConstantZero())
+                 ? std::optional<SDByteProvider>(
+                       SDByteProvider::getConstantZero())
                  : std::nullopt;
     return calculateByteProvider(NarrowOp, Index, Depth + 1, VectorIndex,
                                  StartingIndex);
@@ -8603,11 +8571,12 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
     // question
     if (Index >= NarrowByteWidth)
       return L->getExtensionType() == ISD::ZEXTLOAD
-                 ? std::optional<ByteProvider>(ByteProvider::getConstantZero())
+                 ? std::optional<SDByteProvider>(
+                       SDByteProvider::getConstantZero())
                  : std::nullopt;
 
     unsigned BPVectorIndex = VectorIndex.value_or(0U);
-    return ByteProvider::getMemory(L, Index, BPVectorIndex);
+    return SDByteProvider::getSrc(L, Index, BPVectorIndex);
   }
   }
 
@@ -8901,23 +8870,24 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
   unsigned ByteWidth = VT.getSizeInBits() / 8;
 
   bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian();
-  auto MemoryByteOffset = [&] (ByteProvider P) {
-    assert(P.isMemory() && "Must be a memory byte provider");
-    unsigned LoadBitWidth = P.Load->getMemoryVT().getScalarSizeInBits();
+  auto MemoryByteOffset = [&](SDByteProvider P) {
+    assert(P.hasSrc() && "Must be a memory byte provider");
+    auto *Load = cast<LoadSDNode>(P.Src.value());
+
+    unsigned LoadBitWidth = Load->getMemoryVT().getScalarSizeInBits();
 
     assert(LoadBitWidth % 8 == 0 &&
            "can only analyze providers for individual bytes not bit");
     unsigned LoadByteWidth = LoadBitWidth / 8;
-    return IsBigEndianTarget
-            ? bigEndianByteAt(LoadByteWidth, P.ByteOffset)
-            : littleEndianByteAt(LoadByteWidth, P.ByteOffset);
+    return IsBigEndianTarget ? bigEndianByteAt(LoadByteWidth, P.DestOffset)
+                             : littleEndianByteAt(LoadByteWidth, P.DestOffset);
   };
 
   std::optional<BaseIndexOffset> Base;
   SDValue Chain;
 
   SmallPtrSet<LoadSDNode *, 8> Loads;
-  std::optional<ByteProvider> FirstByteProvider;
+  std::optional<SDByteProvider> FirstByteProvider;
   int64_t FirstOffset = INT64_MAX;
 
   // Check if all the bytes of the OR we are looking at are loaded from the same
@@ -8938,9 +8908,8 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
         return SDValue();
       continue;
     }
-    assert(P->isMemory() && "provenance should either be memory or zero");
-
-    LoadSDNode *L = P->Load;
+    assert(P->hasSrc() && "provenance should either be memory or zero");
+    auto *L = cast<LoadSDNode>(P->Src.value());
 
     // All loads must share the same chain
     SDValue LChain = L->getChain();
@@ -8964,7 +8933,7 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
       unsigned LoadWidthInBit = L->getMemoryVT().getScalarSizeInBits();
       if (LoadWidthInBit % 8 != 0)
         return SDValue();
-      unsigned ByteOffsetFromVector = P->VectorOffset * LoadWidthInBit / 8;
+      unsigned ByteOffsetFromVector = P->SrcOffset * LoadWidthInBit / 8;
       Ptr.addToOffset(ByteOffsetFromVector);
     }
 
@@ -9021,7 +8990,7 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
   // So the combined value can be loaded from the first load address.
   if (MemoryByteOffset(*FirstByteProvider) != 0)
     return SDValue();
-  LoadSDNode *FirstLoad = FirstByteProvider->Load;
+  auto *FirstLoad = cast<LoadSDNode>(FirstByteProvider->Src.value());
 
   // The node we are looking at matches with the pattern, check if we can
   // replace it with a single (possibly zero-extended) load and bswap + shift if


        


More information about the llvm-commits mailing list