[llvm] [AMDGPULowerBufferFatPointers] Expand const exprs using fat pointers (PR #95558)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 14 08:35:50 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-ir
Author: Nikita Popov (nikic)
<details>
<summary>Changes</summary>
Expand all constant expressions that use fat pointers upfront, so that the rewriting logic only has to deal with instructions and not the constant expression variants as well.
My primary motivation is to remove the creation of illegal constant expressions (mul and shl) from this pass, but this also cuts down quite a bit on the amount of duplicate logic.
---
Full diff: https://github.com/llvm/llvm-project/pull/95558.diff
4 Files Affected:
- (modified) llvm/include/llvm/IR/ReplaceConstant.h (+5-1)
- (modified) llvm/lib/IR/ReplaceConstant.cpp (+12-5)
- (modified) llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp (+45-157)
- (modified) llvm/test/CodeGen/AMDGPU/lower-buffer-fat-pointers-constants.ll (+8-5)
``````````diff
diff --git a/llvm/include/llvm/IR/ReplaceConstant.h b/llvm/include/llvm/IR/ReplaceConstant.h
index e1af352f295ae..8c497eb4022c3 100644
--- a/llvm/include/llvm/IR/ReplaceConstant.h
+++ b/llvm/include/llvm/IR/ReplaceConstant.h
@@ -30,9 +30,13 @@ class Function;
/// RemoveDeadConstants by default will remove all dead constants as
/// the final step of the function after replacement, when passed
/// false it will skip this final step.
+///
+/// If \p IncludeSelf is enabled, also convert the passed constants themselves
+/// to instructions, rather than only their users.
bool convertUsersOfConstantsToInstructions(ArrayRef<Constant *> Consts,
Function *RestrictToFunc = nullptr,
- bool RemoveDeadConstants = true);
+ bool RemoveDeadConstants = true,
+ bool IncludeSelf = false);
} // end namespace llvm
diff --git a/llvm/lib/IR/ReplaceConstant.cpp b/llvm/lib/IR/ReplaceConstant.cpp
index 67b6fe6fda3b9..c55a80acf0b52 100644
--- a/llvm/lib/IR/ReplaceConstant.cpp
+++ b/llvm/lib/IR/ReplaceConstant.cpp
@@ -51,13 +51,20 @@ static SmallVector<Instruction *, 4> expandUser(BasicBlock::iterator InsertPt,
bool convertUsersOfConstantsToInstructions(ArrayRef<Constant *> Consts,
Function *RestrictToFunc,
- bool RemoveDeadConstants) {
+ bool RemoveDeadConstants,
+ bool IncludeSelf) {
// Find all expandable direct users of Consts.
SmallVector<Constant *> Stack;
- for (Constant *C : Consts)
- for (User *U : C->users())
- if (isExpandableUser(U))
- Stack.push_back(cast<Constant>(U));
+ for (Constant *C : Consts) {
+ if (IncludeSelf) {
+ assert(isExpandableUser(C) && "One of the constants is not expandable");
+ Stack.push_back(C);
+ } else {
+ for (User *U : C->users())
+ if (isExpandableUser(U))
+ Stack.push_back(cast<Constant>(U));
+ }
+ }
// Include transitive users.
SetVector<Constant *> ExpandableUsers;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp b/llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp
index dfe0583767313..a8f6ad09fe28c 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp
@@ -215,6 +215,7 @@
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ReplaceConstant.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/AtomicOrdering.h"
@@ -579,18 +580,14 @@ bool StoreFatPtrsAsIntsVisitor::visitStoreInst(StoreInst &SI) {
/// buffer fat pointer constant.
static std::pair<Constant *, Constant *>
splitLoweredFatBufferConst(Constant *C) {
- if (auto *AZ = dyn_cast<ConstantAggregateZero>(C))
- return std::make_pair(AZ->getStructElement(0), AZ->getStructElement(1));
- if (auto *SC = dyn_cast<ConstantStruct>(C))
- return std::make_pair(SC->getOperand(0), SC->getOperand(1));
- llvm_unreachable("Conversion should've created a {p8, i32} struct");
+ assert(isSplitFatPtr(C->getType()) && "Not a split fat buffer pointer");
+ return std::make_pair(C->getAggregateElement(0u), C->getAggregateElement(1u));
}
namespace {
/// Handle the remapping of ptr addrspace(7) constants.
class FatPtrConstMaterializer final : public ValueMaterializer {
BufferFatPtrToStructTypeMap *TypeMap;
- BufferFatPtrToIntTypeMap *IntTypeMap;
// An internal mapper that is used to recurse into the arguments of constants.
// While the documentation for `ValueMapper` specifies not to use it
// recursively, examination of the logic in mapValue() shows that it can
@@ -600,16 +597,12 @@ class FatPtrConstMaterializer final : public ValueMaterializer {
Constant *materializeBufferFatPtrConst(Constant *C);
- const DataLayout &DL;
-
public:
// UnderlyingMap is the value map this materializer will be filling.
FatPtrConstMaterializer(BufferFatPtrToStructTypeMap *TypeMap,
- ValueToValueMapTy &UnderlyingMap,
- BufferFatPtrToIntTypeMap *IntTypeMap,
- const DataLayout &DL)
- : TypeMap(TypeMap), IntTypeMap(IntTypeMap),
- InternalMapper(UnderlyingMap, RF_None, TypeMap, this), DL(DL) {}
+ ValueToValueMapTy &UnderlyingMap)
+ : TypeMap(TypeMap),
+ InternalMapper(UnderlyingMap, RF_None, TypeMap, this) {}
virtual ~FatPtrConstMaterializer() = default;
Value *materialize(Value *V) override;
@@ -632,10 +625,6 @@ Constant *FatPtrConstMaterializer::materializeBufferFatPtrConst(Constant *C) {
UndefValue::get(NewTy->getElementType(1))});
}
- if (isa<GlobalValue>(C))
- report_fatal_error("Global values containing ptr addrspace(7) (buffer "
- "fat pointer) values are not supported");
-
if (auto *VC = dyn_cast<ConstantVector>(C)) {
if (Constant *S = VC->getSplatValue()) {
Constant *NewS = InternalMapper.mapConstant(*S);
@@ -661,120 +650,14 @@ Constant *FatPtrConstMaterializer::materializeBufferFatPtrConst(Constant *C) {
return ConstantStruct::get(NewTy, {RsrcVec, OffVec});
}
- // Constant expressions. This code mirrors how we fix up the equivalent
- // instructions later.
- auto *CE = dyn_cast<ConstantExpr>(C);
- if (!CE)
- return nullptr;
- if (auto *GEPO = dyn_cast<GEPOperator>(C)) {
- Constant *RemappedPtr =
- InternalMapper.mapConstant(*cast<Constant>(GEPO->getPointerOperand()));
- auto [Rsrc, Off] = splitLoweredFatBufferConst(RemappedPtr);
- Type *OffTy = Off->getType();
- bool InBounds = GEPO->isInBounds();
-
- MapVector<Value *, APInt> VariableOffs;
- APInt NewConstOffVal = APInt::getZero(BufferOffsetWidth);
- if (!GEPO->collectOffset(DL, BufferOffsetWidth, VariableOffs,
- NewConstOffVal))
- report_fatal_error(
- "Scalable vector or unsized struct in fat pointer GEP");
- Constant *OffAccum = nullptr;
- for (auto [Arg, Multiple] : VariableOffs) {
- Constant *NewArg = InternalMapper.mapConstant(*cast<Constant>(Arg));
- NewArg = ConstantFoldIntegerCast(NewArg, OffTy, /*IsSigned=*/true, DL);
- if (!Multiple.isOne()) {
- if (Multiple.isPowerOf2()) {
- NewArg = ConstantExpr::getShl(
- NewArg, CE->getIntegerValue(OffTy, APInt(BufferOffsetWidth,
- Multiple.logBase2())));
- } else {
- NewArg = ConstantExpr::getMul(NewArg,
- CE->getIntegerValue(OffTy, Multiple));
- }
- }
- if (OffAccum) {
- OffAccum = ConstantExpr::getAdd(OffAccum, NewArg);
- } else {
- OffAccum = NewArg;
- }
- }
- Constant *NewConstOff = CE->getIntegerValue(OffTy, NewConstOffVal);
- if (OffAccum)
- OffAccum = ConstantExpr::getAdd(OffAccum, NewConstOff);
- else
- OffAccum = NewConstOff;
- bool HasNonNegativeOff = false;
- if (auto *CI = dyn_cast<ConstantInt>(OffAccum)) {
- HasNonNegativeOff = !CI->isNegative();
- }
- Constant *NewOff = ConstantExpr::getAdd(
- Off, OffAccum, /*hasNUW=*/InBounds && HasNonNegativeOff,
- /*hasNSW=*/false);
- return ConstantStruct::get(NewTy, {Rsrc, NewOff});
- }
-
- if (auto *PI = dyn_cast<PtrToIntOperator>(CE)) {
- Constant *Parts =
- InternalMapper.mapConstant(*cast<Constant>(PI->getPointerOperand()));
- auto [Rsrc, Off] = splitLoweredFatBufferConst(Parts);
- // Here, we take advantage of the fact that ptrtoint has a built-in
- // zero-extension behavior.
- unsigned FatPtrWidth =
- DL.getPointerSizeInBits(AMDGPUAS::BUFFER_FAT_POINTER);
- Constant *RsrcInt = CE->getPtrToInt(Rsrc, SrcTy);
- unsigned Width = SrcTy->getScalarSizeInBits();
- Constant *Shift =
- CE->getIntegerValue(SrcTy, APInt(Width, BufferOffsetWidth));
- Constant *OffCast =
- ConstantFoldIntegerCast(Off, SrcTy, /*IsSigned=*/false, DL);
- Constant *RsrcHi = ConstantExpr::getShl(
- RsrcInt, Shift, Width >= FatPtrWidth, Width > FatPtrWidth);
- // This should be an or, but those got recently removed.
- Constant *Result = ConstantExpr::getAdd(RsrcHi, OffCast, true, true);
- return Result;
- }
+ if (isa<GlobalValue>(C))
+ report_fatal_error("Global values containing ptr addrspace(7) (buffer "
+ "fat pointer) values are not supported");
- if (CE->getOpcode() == Instruction::IntToPtr) {
- auto *Arg = cast<Constant>(CE->getOperand(0));
- unsigned FatPtrWidth =
- DL.getPointerSizeInBits(AMDGPUAS::BUFFER_FAT_POINTER);
- unsigned RsrcPtrWidth = DL.getPointerSizeInBits(AMDGPUAS::BUFFER_RESOURCE);
- auto *WantedTy = Arg->getType()->getWithNewBitWidth(FatPtrWidth);
- Arg = ConstantFoldIntegerCast(Arg, WantedTy, /*IsSigned=*/false, DL);
-
- Constant *Shift =
- CE->getIntegerValue(WantedTy, APInt(FatPtrWidth, BufferOffsetWidth));
- Type *RsrcIntType = WantedTy->getWithNewBitWidth(RsrcPtrWidth);
- Type *RsrcTy = NewTy->getElementType(0);
- Type *OffTy = WantedTy->getWithNewBitWidth(BufferOffsetWidth);
- Constant *RsrcInt = CE->getTrunc(
- ConstantFoldBinaryOpOperands(Instruction::LShr, Arg, Shift, DL),
- RsrcIntType);
- Constant *Rsrc = CE->getIntToPtr(RsrcInt, RsrcTy);
- Constant *Off = ConstantFoldIntegerCast(Arg, OffTy, /*isSigned=*/false, DL);
-
- return ConstantStruct::get(NewTy, {Rsrc, Off});
- }
+ if (isa<ConstantExpr>(C))
+ report_fatal_error("Constant exprs containing ptr addrspace(7) (buffer "
+ "fat pointer) values should have been expanded earlier");
- if (auto *AC = dyn_cast<AddrSpaceCastOperator>(CE)) {
- unsigned SrcAS = AC->getSrcAddressSpace();
- unsigned DstAS = AC->getDestAddressSpace();
- auto *Arg = cast<Constant>(AC->getPointerOperand());
- auto *NewArg = InternalMapper.mapConstant(*Arg);
- if (!NewArg)
- return nullptr;
- if (SrcAS == AMDGPUAS::BUFFER_FAT_POINTER &&
- DstAS == AMDGPUAS::BUFFER_FAT_POINTER)
- return NewArg;
- if (SrcAS == AMDGPUAS::BUFFER_RESOURCE &&
- DstAS == AMDGPUAS::BUFFER_FAT_POINTER) {
- auto *NullOff = CE->getNullValue(NewTy->getElementType(1));
- return ConstantStruct::get(NewTy, {NewArg, NullOff});
- }
- report_fatal_error(
- "Unsupported address space cast for a buffer fat pointer");
- }
return nullptr;
}
@@ -782,26 +665,6 @@ Value *FatPtrConstMaterializer::materialize(Value *V) {
Constant *C = dyn_cast<Constant>(V);
if (!C)
return nullptr;
- if (auto *GEPO = dyn_cast<GEPOperator>(C)) {
- // As a special case, adjust GEP constants that have a ptr addrspace(7) in
- // their source types here, since the earlier local changes didn't handle
- // htis.
- Type *SrcTy = GEPO->getSourceElementType();
- Type *NewSrcTy = IntTypeMap->remapType(SrcTy);
- if (SrcTy != NewSrcTy) {
- SmallVector<Constant *> Ops;
- Ops.reserve(GEPO->getNumOperands());
- for (const Use &U : GEPO->operands())
- Ops.push_back(cast<Constant>(U.get()));
- auto *NewGEP = ConstantExpr::getGetElementPtr(
- NewSrcTy, Ops[0], ArrayRef<Constant *>(Ops).slice(1),
- GEPO->getNoWrapFlags(), GEPO->getInRange());
- LLVM_DEBUG(dbgs() << "p7-getting GEP: " << *GEPO << " becomes " << *NewGEP
- << "\n");
- Value *FurtherMap = materialize(NewGEP);
- return FurtherMap ? FurtherMap : NewGEP;
- }
- }
// Structs and other types that happen to contain fat pointers get remapped
// by the mapValue() logic.
if (!isBufferFatPtrConst(C))
@@ -1782,14 +1645,9 @@ class AMDGPULowerBufferFatPointers : public ModulePass {
static bool containsBufferFatPointers(const Function &F,
BufferFatPtrToStructTypeMap *TypeMap) {
bool HasFatPointers = false;
- for (const BasicBlock &BB : F) {
- for (const Instruction &I : BB) {
+ for (const BasicBlock &BB : F)
+ for (const Instruction &I : BB)
HasFatPointers |= (I.getType() != TypeMap->remapType(I.getType()));
- for (const Use &U : I.operands())
- if (auto *C = dyn_cast<Constant>(U.get()))
- HasFatPointers |= isBufferFatPtrConst(C);
- }
- }
return HasFatPointers;
}
@@ -1888,6 +1746,36 @@ bool AMDGPULowerBufferFatPointers::run(Module &M, const TargetMachine &TM) {
"buffer resource pointers (address space 8) instead.");
}
+ {
+ // Collect all constant exprs and aggregates referenced by any function.
+ SmallVector<Constant *, 8> Worklist;
+ for (Function &F : M.functions())
+ for (Instruction &I : instructions(F))
+ for (Value *Op : I.operands())
+ if (isa<ConstantExpr>(Op) || isa<ConstantAggregate>(Op))
+ Worklist.push_back(cast<Constant>(Op));
+
+ // Recursively look for any referenced buffer pointer constants.
+ SmallPtrSet<Constant *, 8> Visited;
+ SetVector<Constant *> BufferFatPtrConsts;
+ while (!Worklist.empty()) {
+ Constant *C = Worklist.pop_back_val();
+ if (!Visited.insert(C).second)
+ continue;
+ if (isBufferFatPtrOrVector(C->getType()))
+ BufferFatPtrConsts.insert(C);
+ for (Value *Op : C->operands())
+ if (isa<ConstantExpr>(Op) || isa<ConstantAggregate>(Op))
+ Worklist.push_back(cast<Constant>(Op));
+ }
+
+ // Expand all constant expressions using fat buffer pointers to
+ // instructions.
+ Changed |= convertUsersOfConstantsToInstructions(
+ BufferFatPtrConsts.getArrayRef(), /*RestrictToFunc=*/nullptr,
+ /*RemoveDeadConstants=*/false, /*IncludeSelf=*/true);
+ }
+
StoreFatPtrsAsIntsVisitor MemOpsRewrite(&IntTM, M.getContext());
for (Function &F : M.functions()) {
bool InterfaceChange = hasFatPointerInterface(F, &StructTM);
@@ -1903,7 +1791,7 @@ bool AMDGPULowerBufferFatPointers::run(Module &M, const TargetMachine &TM) {
SmallVector<Function *> Intrinsics;
// Keep one big map so as to memoize constants across functions.
ValueToValueMapTy CloneMap;
- FatPtrConstMaterializer Materializer(&StructTM, CloneMap, &IntTM, DL);
+ FatPtrConstMaterializer Materializer(&StructTM, CloneMap);
ValueMapper LowerInFuncs(CloneMap, RF_None, &StructTM, &Materializer);
for (auto [F, InterfaceChange] : NeedsRemap) {
diff --git a/llvm/test/CodeGen/AMDGPU/lower-buffer-fat-pointers-constants.ll b/llvm/test/CodeGen/AMDGPU/lower-buffer-fat-pointers-constants.ll
index 77ca227bcf663..e4424f317f235 100644
--- a/llvm/test/CodeGen/AMDGPU/lower-buffer-fat-pointers-constants.ll
+++ b/llvm/test/CodeGen/AMDGPU/lower-buffer-fat-pointers-constants.ll
@@ -143,7 +143,8 @@ define ptr addrspace(7) @gep_p7_from_p7() {
define i160 @ptrtoint() {
; CHECK-LABEL: define i160 @ptrtoint
; CHECK-SAME: () #[[ATTR0]] {
-; CHECK-NEXT: ret i160 add nuw nsw (i160 shl nuw (i160 ptrtoint (ptr addrspace(8) @buf to i160), i160 32), i160 12)
+; CHECK-NEXT: [[TMP1:%.*]] = or i160 shl nuw (i160 ptrtoint (ptr addrspace(8) @buf to i160), i160 32), 12
+; CHECK-NEXT: ret i160 [[TMP1]]
;
ret i160 ptrtoint(
ptr addrspace(7) getelementptr(
@@ -154,7 +155,8 @@ define i160 @ptrtoint() {
define i256 @ptrtoint_long() {
; CHECK-LABEL: define i256 @ptrtoint_long
; CHECK-SAME: () #[[ATTR0]] {
-; CHECK-NEXT: ret i256 add nuw nsw (i256 shl nuw nsw (i256 ptrtoint (ptr addrspace(8) @buf to i256), i256 32), i256 12)
+; CHECK-NEXT: [[TMP1:%.*]] = or i256 shl nuw nsw (i256 ptrtoint (ptr addrspace(8) @buf to i256), i256 32), 12
+; CHECK-NEXT: ret i256 [[TMP1]]
;
ret i256 ptrtoint(
ptr addrspace(7) getelementptr(
@@ -165,7 +167,8 @@ define i256 @ptrtoint_long() {
define i64 @ptrtoint_short() {
; CHECK-LABEL: define i64 @ptrtoint_short
; CHECK-SAME: () #[[ATTR0]] {
-; CHECK-NEXT: ret i64 add nuw nsw (i64 shl (i64 ptrtoint (ptr addrspace(8) @buf to i64), i64 32), i64 12)
+; CHECK-NEXT: [[TMP1:%.*]] = or i64 shl (i64 ptrtoint (ptr addrspace(8) @buf to i64), i64 32), 12
+; CHECK-NEXT: ret i64 [[TMP1]]
;
ret i64 ptrtoint(
ptr addrspace(7) getelementptr(
@@ -176,7 +179,7 @@ define i64 @ptrtoint_short() {
define i32 @ptrtoint_very_short() {
; CHECK-LABEL: define i32 @ptrtoint_very_short
; CHECK-SAME: () #[[ATTR0]] {
-; CHECK-NEXT: ret i32 add nuw nsw (i32 shl (i32 ptrtoint (ptr addrspace(8) @buf to i32), i32 32), i32 12)
+; CHECK-NEXT: ret i32 12
;
ret i32 ptrtoint(
ptr addrspace(7) getelementptr(
@@ -212,7 +215,7 @@ define <2 x ptr addrspace(7)> @inttoptr_vec() {
define i32 @fancy_zero() {
; CHECK-LABEL: define i32 @fancy_zero
; CHECK-SAME: () #[[ATTR0]] {
-; CHECK-NEXT: ret i32 shl (i32 ptrtoint (ptr addrspace(8) @buf to i32), i32 32)
+; CHECK-NEXT: ret i32 0
;
ret i32 ptrtoint (
ptr addrspace(7) addrspacecast (ptr addrspace(8) @buf to ptr addrspace(7))
``````````
</details>
https://github.com/llvm/llvm-project/pull/95558
More information about the llvm-commits
mailing list