[llvm] 7972a6e - [DAGCombiner][NFC] Factor out ByteProvider
Jeffrey Byrnes via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 19 08:54:54 PDT 2023
Author: Jeffrey Byrnes
Date: 2023-06-19T08:54:34-07:00
New Revision: 7972a6e126bf3af19d9871e81c6ce1b4cb2fdea2
URL: https://github.com/llvm/llvm-project/commit/7972a6e126bf3af19d9871e81c6ce1b4cb2fdea2
DIFF: https://github.com/llvm/llvm-project/commit/7972a6e126bf3af19d9871e81c6ce1b4cb2fdea2.diff
LOG: [DAGCombiner][NFC] Factor out ByteProvider
Differential Revision: https://reviews.llvm.org/D143018
Change-Id: I3dc03787a3382c0c3fe6b869f869c2946f450874
Added:
llvm/include/llvm/CodeGen/ByteProvider.h
Modified:
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/CodeGen/ByteProvider.h b/llvm/include/llvm/CodeGen/ByteProvider.h
new file mode 100644
index 0000000000000..e0ba40b135336
--- /dev/null
+++ b/llvm/include/llvm/CodeGen/ByteProvider.h
@@ -0,0 +1,90 @@
+//===-- include/llvm/CodeGen/ByteProvider.h - Map bytes ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// \file
+// This file implements ByteProvider. The purpose of ByteProvider is to provide
+// a map between a target node's byte (byte position is DestOffset) and the
+// source (and byte position) that provides it (in Src and SrcOffset
+// respectively) See CodeGen/SelectionDAG/DAGCombiner.cpp MatchLoadCombine
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CODEGEN_BYTEPROVIDER_H
+#define LLVM_CODEGEN_BYTEPROVIDER_H
+
+#include <optional>
+#include <type_traits>
+
+namespace llvm {
+
+/// Represents known origin of an individual byte in combine pattern. The
+/// value of the byte is either constant zero, or comes from memory /
+/// some other productive instruction (e.g. arithmetic instructions).
+/// Bit manipulation instructions like shifts are not ByteProviders, rather
+/// are used to extract Bytes.
+template <typename ISelOp> class ByteProvider {
+private:
+ ByteProvider<ISelOp>(std::optional<ISelOp> Src, int64_t DestOffset,
+ int64_t SrcOffset)
+ : Src(Src), DestOffset(DestOffset), SrcOffset(SrcOffset) {}
+
+ // TODO -- use constraint in c++20
+ // Does this type correspond with an operation in selection DAG
+ template <typename T> class is_op {
+ private:
+ using yes = std::true_type;
+ using no = std::false_type;
+
+ // Only allow classes with member function getOpcode
+ template <typename U>
+ static auto test(int) -> decltype(std::declval<U>().getOpcode(), yes());
+
+ template <typename> static no test(...);
+
+ public:
+ using remove_pointer_t = typename std::remove_pointer<T>::type;
+ static constexpr bool value =
+ std::is_same<decltype(test<remove_pointer_t>(0)), yes>::value;
+ };
+
+public:
+ // For constant zero providers Src is set to nullopt. For actual providers
+ // Src represents the node which originally produced the relevant bits.
+ std::optional<ISelOp> Src = std::nullopt;
+ // DestOffset is the offset of the byte in the dest we are trying to map for.
+ int64_t DestOffset = 0;
+ // SrcOffset is the offset in the ultimate source node that maps to the
+ // DestOffset
+ int64_t SrcOffset = 0;
+
+ ByteProvider() = default;
+
+ static ByteProvider getSrc(std::optional<ISelOp> Val, int64_t ByteOffset,
+ int64_t VectorOffset) {
+ static_assert(is_op<ISelOp>().value,
+ "ByteProviders must contain an operation in selection DAG.");
+ return ByteProvider(Val, ByteOffset, VectorOffset);
+ }
+
+ static ByteProvider getConstantZero() {
+ return ByteProvider<ISelOp>(std::nullopt, 0, 0);
+ }
+ bool isConstantZero() const { return !Src; }
+
+ bool hasSrc() const { return Src.has_value(); }
+
+ bool hasSameSrc(const ByteProvider &Other) const { return Other.Src == Src; }
+
+ bool operator==(const ByteProvider &Other) const {
+ return Other.Src == Src && Other.DestOffset == DestOffset &&
+ Other.SrcOffset == SrcOffset;
+ }
+};
+} // end namespace llvm
+
+#endif // LLVM_CODEGEN_BYTEPROVIDER_H
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 454944419293c..56dc284adf709 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -32,6 +32,7 @@
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Analysis/VectorUtils.h"
+#include "llvm/CodeGen/ByteProvider.h"
#include "llvm/CodeGen/DAGCombine.h"
#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/CodeGen/MachineFunction.h"
@@ -8402,42 +8403,6 @@ SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
return SDValue();
}
-namespace {
-
-/// Represents known origin of an individual byte in load combine pattern. The
-/// value of the byte is either constant zero or comes from memory.
-struct ByteProvider {
- // For constant zero providers Load is set to nullptr. For memory providers
- // Load represents the node which loads the byte from memory.
- // ByteOffset is the offset of the byte in the value produced by the load.
- LoadSDNode *Load = nullptr;
- unsigned ByteOffset = 0;
- unsigned VectorOffset = 0;
-
- ByteProvider() = default;
-
- static ByteProvider getMemory(LoadSDNode *Load, unsigned ByteOffset,
- unsigned VectorOffset) {
- return ByteProvider(Load, ByteOffset, VectorOffset);
- }
-
- static ByteProvider getConstantZero() { return ByteProvider(nullptr, 0, 0); }
-
- bool isConstantZero() const { return !Load; }
- bool isMemory() const { return Load; }
-
- bool operator==(const ByteProvider &Other) const {
- return Other.Load == Load && Other.ByteOffset == ByteOffset &&
- Other.VectorOffset == VectorOffset;
- }
-
-private:
- ByteProvider(LoadSDNode *Load, unsigned ByteOffset, unsigned VectorOffset)
- : Load(Load), ByteOffset(ByteOffset), VectorOffset(VectorOffset) {}
-};
-
-} // end anonymous namespace
-
/// Recursively traverses the expression calculating the origin of the requested
/// byte of the given value. Returns std::nullopt if the provider can't be
/// calculated.
@@ -8479,7 +8444,9 @@ struct ByteProvider {
/// LOAD
///
/// *ExtractVectorElement
-static const std::optional<ByteProvider>
+using SDByteProvider = ByteProvider<SDNode *>;
+
+static const std::optional<SDByteProvider>
calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
std::optional<uint64_t> VectorIndex,
unsigned StartingIndex = 0) {
@@ -8538,7 +8505,7 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
// provide, then do not provide anything. Otherwise, subtract the index by
// the amount we shifted by.
return Index < ByteShift
- ? ByteProvider::getConstantZero()
+ ? SDByteProvider::getConstantZero()
: calculateByteProvider(Op->getOperand(0), Index - ByteShift,
Depth + 1, VectorIndex, Index);
}
@@ -8553,7 +8520,8 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
if (Index >= NarrowByteWidth)
return Op.getOpcode() == ISD::ZERO_EXTEND
- ? std::optional<ByteProvider>(ByteProvider::getConstantZero())
+ ? std::optional<SDByteProvider>(
+ SDByteProvider::getConstantZero())
: std::nullopt;
return calculateByteProvider(NarrowOp, Index, Depth + 1, VectorIndex,
StartingIndex);
@@ -8603,11 +8571,12 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
// question
if (Index >= NarrowByteWidth)
return L->getExtensionType() == ISD::ZEXTLOAD
- ? std::optional<ByteProvider>(ByteProvider::getConstantZero())
+ ? std::optional<SDByteProvider>(
+ SDByteProvider::getConstantZero())
: std::nullopt;
unsigned BPVectorIndex = VectorIndex.value_or(0U);
- return ByteProvider::getMemory(L, Index, BPVectorIndex);
+ return SDByteProvider::getSrc(L, Index, BPVectorIndex);
}
}
@@ -8901,23 +8870,24 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
unsigned ByteWidth = VT.getSizeInBits() / 8;
bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian();
- auto MemoryByteOffset = [&] (ByteProvider P) {
- assert(P.isMemory() && "Must be a memory byte provider");
- unsigned LoadBitWidth = P.Load->getMemoryVT().getScalarSizeInBits();
+ auto MemoryByteOffset = [&](SDByteProvider P) {
+ assert(P.hasSrc() && "Must be a memory byte provider");
+ auto *Load = cast<LoadSDNode>(P.Src.value());
+
+ unsigned LoadBitWidth = Load->getMemoryVT().getScalarSizeInBits();
assert(LoadBitWidth % 8 == 0 &&
"can only analyze providers for individual bytes not bit");
unsigned LoadByteWidth = LoadBitWidth / 8;
- return IsBigEndianTarget
- ? bigEndianByteAt(LoadByteWidth, P.ByteOffset)
- : littleEndianByteAt(LoadByteWidth, P.ByteOffset);
+ return IsBigEndianTarget ? bigEndianByteAt(LoadByteWidth, P.DestOffset)
+ : littleEndianByteAt(LoadByteWidth, P.DestOffset);
};
std::optional<BaseIndexOffset> Base;
SDValue Chain;
SmallPtrSet<LoadSDNode *, 8> Loads;
- std::optional<ByteProvider> FirstByteProvider;
+ std::optional<SDByteProvider> FirstByteProvider;
int64_t FirstOffset = INT64_MAX;
// Check if all the bytes of the OR we are looking at are loaded from the same
@@ -8938,9 +8908,8 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
return SDValue();
continue;
}
- assert(P->isMemory() && "provenance should either be memory or zero");
-
- LoadSDNode *L = P->Load;
+ assert(P->hasSrc() && "provenance should either be memory or zero");
+ auto *L = cast<LoadSDNode>(P->Src.value());
// All loads must share the same chain
SDValue LChain = L->getChain();
@@ -8964,7 +8933,7 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
unsigned LoadWidthInBit = L->getMemoryVT().getScalarSizeInBits();
if (LoadWidthInBit % 8 != 0)
return SDValue();
- unsigned ByteOffsetFromVector = P->VectorOffset * LoadWidthInBit / 8;
+ unsigned ByteOffsetFromVector = P->SrcOffset * LoadWidthInBit / 8;
Ptr.addToOffset(ByteOffsetFromVector);
}
@@ -9021,7 +8990,7 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
// So the combined value can be loaded from the first load address.
if (MemoryByteOffset(*FirstByteProvider) != 0)
return SDValue();
- LoadSDNode *FirstLoad = FirstByteProvider->Load;
+ auto *FirstLoad = cast<LoadSDNode>(FirstByteProvider->Src.value());
// The node we are looking at matches with the pattern, check if we can
// replace it with a single (possibly zero-extended) load and bswap + shift if
More information about the llvm-commits
mailing list