[llvm] 72fc08a - [InstCombine] Teach alloca replacement to handle `addrspacecast`
Michael Liao via llvm-commits
llvm-commits at lists.llvm.org
Tue Apr 11 08:55:31 PDT 2023
Author: Michael Liao
Date: 2023-04-11T11:47:37-04:00
New Revision: 72fc08a5412ec7ee7f0b904926db16cd86c1f876
URL: https://github.com/llvm/llvm-project/commit/72fc08a5412ec7ee7f0b904926db16cd86c1f876
DIFF: https://github.com/llvm/llvm-project/commit/72fc08a5412ec7ee7f0b904926db16cd86c1f876.diff
LOG: [InstCombine] Teach alloca replacement to handle `addrspacecast`
- As the address space cast may not be valid on a specific target,
`addrspacecast` is not handled when an `alloca` is able to be replaced
with the source of memcpy/memmove. This patch addresses that by
querying a target hook on whether that address space cast is valid.
For example, on most GPU targets, the cast from a global pointer to a
generic pointer is valid.
- If that cast is allowedd (by querying `isValidAddrSpaceCast`), the
replacement is enhanced to handle that `addrspacecast` as well.
Reviewed By: yaxunl
Differential Revision: https://reviews.llvm.org/D147025
Added:
Modified:
llvm/include/llvm/Analysis/TargetTransformInfo.h
llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
llvm/include/llvm/CodeGen/BasicTTIImpl.h
llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
llvm/lib/Analysis/TargetTransformInfo.cpp
llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
llvm/test/Transforms/InstCombine/AMDGPU/memcpy-from-constant.ll
llvm/test/Transforms/InstCombine/ptr-replace-alloca.ll
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 7ffeffee18162..9f074e537322e 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -421,6 +421,10 @@ class TargetTransformInfo {
// even taking non-uniform arguments
bool isAlwaysUniform(const Value *V) const;
+ /// Query the target whether the specified address space cast from FromAS to
+ /// ToAS is valid.
+ bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const;
+
/// Returns the address space ID for a target's 'flat' address space. Note
/// this is not necessarily the same as addrspace(0), which LLVM sometimes
/// refers to as the generic address space. The flat address space is a
@@ -1685,6 +1689,7 @@ class TargetTransformInfo::Concept {
virtual bool useGPUDivergenceAnalysis() = 0;
virtual bool isSourceOfDivergence(const Value *V) = 0;
virtual bool isAlwaysUniform(const Value *V) = 0;
+ virtual bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const = 0;
virtual unsigned getFlatAddressSpace() = 0;
virtual bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
Intrinsic::ID IID) const = 0;
@@ -2062,6 +2067,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
return Impl.isAlwaysUniform(V);
}
+ bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const override {
+ return Impl.isValidAddrSpaceCast(FromAS, ToAS);
+ }
+
unsigned getFlatAddressSpace() override { return Impl.getFlatAddressSpace(); }
bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index fbd5202097014..721959fa77156 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -95,6 +95,10 @@ class TargetTransformInfoImplBase {
bool isAlwaysUniform(const Value *V) const { return false; }
+ bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const {
+ return false;
+ }
+
unsigned getFlatAddressSpace() const { return -1; }
bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index c6c741c831e68..88d70613caba0 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -284,6 +284,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
bool isAlwaysUniform(const Value *V) { return false; }
+ bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const {
+ return false;
+ }
+
unsigned getFlatAddressSpace() {
// Return an invalid address space.
return -1;
diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
index a876385581e72..20be9dbd35bcf 100644
--- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
+++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
@@ -529,6 +529,8 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, APInt &UndefElts,
unsigned Depth = 0,
bool AllowMultipleUsers = false) = 0;
+
+ bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const;
};
} // namespace llvm
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index a29ccf0653b65..68674b9793bb2 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -274,6 +274,11 @@ bool llvm::TargetTransformInfo::isAlwaysUniform(const Value *V) const {
return TTIImpl->isAlwaysUniform(V);
}
+bool llvm::TargetTransformInfo::isValidAddrSpaceCast(unsigned FromAS,
+ unsigned ToAS) const {
+ return TTIImpl->isValidAddrSpaceCast(FromAS, ToAS);
+}
+
unsigned TargetTransformInfo::getFlatAddressSpace() const {
return TTIImpl->getFlatAddressSpace();
}
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
index a4041fcc3ccee..f965b54c2a880 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
@@ -168,6 +168,28 @@ class GCNTTIImpl final : public BasicTTIImplBase<GCNTTIImpl> {
bool isSourceOfDivergence(const Value *V) const;
bool isAlwaysUniform(const Value *V) const;
+ bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const {
+ if (ToAS == AMDGPUAS::FLAT_ADDRESS) {
+ switch (FromAS) {
+ case AMDGPUAS::GLOBAL_ADDRESS:
+ case AMDGPUAS::CONSTANT_ADDRESS:
+ case AMDGPUAS::CONSTANT_ADDRESS_32BIT:
+ case AMDGPUAS::LOCAL_ADDRESS:
+ case AMDGPUAS::PRIVATE_ADDRESS:
+ return true;
+ default:
+ break;
+ }
+ return false;
+ }
+ if ((FromAS == AMDGPUAS::CONSTANT_ADDRESS_32BIT &&
+ ToAS == AMDGPUAS::CONSTANT_ADDRESS) ||
+ (FromAS == AMDGPUAS::CONSTANT_ADDRESS &&
+ ToAS == AMDGPUAS::CONSTANT_ADDRESS_32BIT))
+ return true;
+ return false;
+ }
+
unsigned getFlatAddressSpace() const {
// Don't bother running InferAddressSpaces pass on graphics shaders which
// don't use flat addressing.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
index e9bb4089c4e9b..a069a0b498110 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
@@ -260,8 +260,8 @@ namespace {
// instruction.
class PointerReplacer {
public:
- PointerReplacer(InstCombinerImpl &IC, Instruction &Root)
- : IC(IC), Root(Root) {}
+ PointerReplacer(InstCombinerImpl &IC, Instruction &Root, unsigned SrcAS)
+ : IC(IC), Root(Root), FromAS(SrcAS) {}
bool collectUsers();
void replacePointer(Value *V);
@@ -274,11 +274,21 @@ class PointerReplacer {
return I == &Root || Worklist.contains(I);
}
+ bool isEqualOrValidAddrSpaceCast(const Instruction *I,
+ unsigned FromAS) const {
+ const auto *ASC = dyn_cast<AddrSpaceCastInst>(I);
+ if (!ASC)
+ return false;
+ unsigned ToAS = ASC->getDestAddressSpace();
+ return (FromAS == ToAS) || IC.isValidAddrSpaceCast(FromAS, ToAS);
+ }
+
SmallPtrSet<Instruction *, 32> ValuesToRevisit;
SmallSetVector<Instruction *, 4> Worklist;
MapVector<Value *, Value *> WorkMap;
InstCombinerImpl &IC;
Instruction &Root;
+ unsigned FromAS;
};
} // end anonymous namespace
@@ -342,6 +352,8 @@ bool PointerReplacer::collectUsersRecursive(Instruction &I) {
if (MI->isVolatile())
return false;
Worklist.insert(Inst);
+ } else if (isEqualOrValidAddrSpaceCast(Inst, FromAS)) {
+ Worklist.insert(Inst);
} else if (Inst->isLifetimeStartOrEnd()) {
continue;
} else {
@@ -427,6 +439,22 @@ void PointerReplacer::replace(Instruction *I) {
IC.eraseInstFromFunction(*MemCpy);
WorkMap[MemCpy] = NewI;
+ } else if (auto *ASC = dyn_cast<AddrSpaceCastInst>(I)) {
+ auto *V = getReplacement(ASC->getPointerOperand());
+ assert(V && "Operand not replaced");
+ assert(isEqualOrValidAddrSpaceCast(
+ ASC, V->getType()->getPointerAddressSpace()) &&
+ "Invalid address space cast!");
+ auto *NewV = V;
+ if (V->getType()->getPointerAddressSpace() !=
+ ASC->getType()->getPointerAddressSpace()) {
+ auto *NewI = new AddrSpaceCastInst(V, ASC->getType(), "");
+ NewI->takeName(ASC);
+ IC.InsertNewInstWith(NewI, *ASC);
+ NewV = NewI;
+ }
+ IC.replaceInstUsesWith(*ASC, NewV);
+ IC.eraseInstFromFunction(*ASC);
} else {
llvm_unreachable("should never reach here");
}
@@ -519,7 +547,7 @@ Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) {
return NewI;
}
- PointerReplacer PtrReplacer(*this, AI);
+ PointerReplacer PtrReplacer(*this, AI, SrcAddrSpace);
if (PtrReplacer.collectUsers()) {
for (Instruction *Delete : ToDelete)
eraseInstFromFunction(*Delete);
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 236ebf4bebc8c..a1bdc54c009ae 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -196,6 +196,10 @@ std::optional<Value *> InstCombiner::targetSimplifyDemandedVectorEltsIntrinsic(
return std::nullopt;
}
+bool InstCombiner::isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const {
+ return TTI.isValidAddrSpaceCast(FromAS, ToAS);
+}
+
Value *InstCombinerImpl::EmitGEPOffset(User *GEP) {
return llvm::emitGEPOffset(&Builder, DL, GEP);
}
diff --git a/llvm/test/Transforms/InstCombine/AMDGPU/memcpy-from-constant.ll b/llvm/test/Transforms/InstCombine/AMDGPU/memcpy-from-constant.ll
index 44f343fff5485..2d2a93b28a5ba 100644
--- a/llvm/test/Transforms/InstCombine/AMDGPU/memcpy-from-constant.ll
+++ b/llvm/test/Transforms/InstCombine/AMDGPU/memcpy-from-constant.ll
@@ -137,11 +137,9 @@ define amdgpu_kernel void @memcpy_constant_intrinsic_ptr_to_alloca(ptr addrspace
; Alloca is written through a flat pointer
define i8 @memcpy_constant_arg_ptr_to_alloca_addrspacecast_to_flat(ptr addrspace(4) noalias readonly align 4 dereferenceable(32) %arg, i32 %idx) {
; CHECK-LABEL: @memcpy_constant_arg_ptr_to_alloca_addrspacecast_to_flat(
-; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [32 x i8], align 4, addrspace(5)
-; CHECK-NEXT: [[ALLOCA_CAST_ASC:%.*]] = addrspacecast ptr addrspace(5) [[ALLOCA]] to ptr
-; CHECK-NEXT: call void @llvm.memcpy.p0.p4.i64(ptr noundef nonnull align 1 dereferenceable(31) [[ALLOCA_CAST_ASC]], ptr addrspace(4) noundef align 4 dereferenceable(31) [[ARG:%.*]], i64 31, i1 false)
-; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds [32 x i8], ptr addrspace(5) [[ALLOCA]], i32 0, i32 [[IDX:%.*]]
-; CHECK-NEXT: [[LOAD:%.*]] = load i8, ptr addrspace(5) [[GEP]], align 1
+; CHECK-NEXT: [[TMP1:%.*]] = sext i32 [[IDX:%.*]] to i64
+; CHECK-NEXT: [[GEP:%.*]] = getelementptr [32 x i8], ptr addrspace(4) [[ARG:%.*]], i64 0, i64 [[TMP1]]
+; CHECK-NEXT: [[LOAD:%.*]] = load i8, ptr addrspace(4) [[GEP]], align 1
; CHECK-NEXT: ret i8 [[LOAD]]
;
%alloca = alloca [32 x i8], align 4, addrspace(5)
@@ -155,9 +153,7 @@ define i8 @memcpy_constant_arg_ptr_to_alloca_addrspacecast_to_flat(ptr addrspace
; Alloca is only addressed through flat pointer.
define i8 @memcpy_constant_arg_ptr_to_alloca_addrspacecast_to_flat2(ptr addrspace(4) noalias readonly align 4 dereferenceable(32) %arg, i32 %idx) {
; CHECK-LABEL: @memcpy_constant_arg_ptr_to_alloca_addrspacecast_to_flat2(
-; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [32 x i8], align 4, addrspace(5)
-; CHECK-NEXT: [[ALLOCA_CAST_ASC:%.*]] = addrspacecast ptr addrspace(5) [[ALLOCA]] to ptr
-; CHECK-NEXT: call void @llvm.memcpy.p0.p4.i64(ptr noundef nonnull align 1 dereferenceable(32) [[ALLOCA_CAST_ASC]], ptr addrspace(4) noundef align 4 dereferenceable(32) [[ARG:%.*]], i64 32, i1 false)
+; CHECK-NEXT: [[ALLOCA_CAST_ASC:%.*]] = addrspacecast ptr addrspace(4) [[ARG:%.*]] to ptr
; CHECK-NEXT: [[TMP1:%.*]] = sext i32 [[IDX:%.*]] to i64
; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds [32 x i8], ptr [[ALLOCA_CAST_ASC]], i64 0, i64 [[TMP1]]
; CHECK-NEXT: [[LOAD:%.*]] = load i8, ptr [[GEP]], align 1
@@ -202,9 +198,7 @@ bb:
define amdgpu_kernel void @byref_infloop_addrspacecast(ptr %scratch, ptr addrspace(4) byref(%struct.ty) align 4 %arg) local_unnamed_addr #1 {
; CHECK-LABEL: @byref_infloop_addrspacecast(
; CHECK-NEXT: bb:
-; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [4 x i32], align 4, addrspace(5)
-; CHECK-NEXT: [[ADDRSPACECAST_ALLOCA:%.*]] = addrspacecast ptr addrspace(5) [[ALLOCA]] to ptr
-; CHECK-NEXT: call void @llvm.memcpy.p0.p4.i64(ptr noundef nonnull align 4 dereferenceable(16) [[ADDRSPACECAST_ALLOCA]], ptr addrspace(4) noundef align 4 dereferenceable(16) [[ARG:%.*]], i64 16, i1 false)
+; CHECK-NEXT: [[ADDRSPACECAST_ALLOCA:%.*]] = addrspacecast ptr addrspace(4) [[ARG:%.*]] to ptr
; CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr noundef nonnull align 4 dereferenceable(16) [[SCRATCH:%.*]], ptr noundef nonnull align 4 dereferenceable(16) [[ADDRSPACECAST_ALLOCA]], i64 16, i1 false)
; CHECK-NEXT: ret void
;
diff --git a/llvm/test/Transforms/InstCombine/ptr-replace-alloca.ll b/llvm/test/Transforms/InstCombine/ptr-replace-alloca.ll
index 8f25c71e47823..6a0e7098ab744 100644
--- a/llvm/test/Transforms/InstCombine/ptr-replace-alloca.ll
+++ b/llvm/test/Transforms/InstCombine/ptr-replace-alloca.ll
@@ -427,6 +427,40 @@ entry:
ret i8 %load
}
+declare i8 @readonly_callee(ptr readonly nocapture)
+
+define i8 @call_readonly_remove_alloca() {
+; CHECK-LABEL: @call_readonly_remove_alloca(
+; CHECK-NEXT: [[V:%.*]] = call i8 @readonly_callee(ptr nonnull @g1)
+; CHECK-NEXT: ret i8 [[V]]
+;
+ %alloca = alloca [32 x i8], addrspace(1)
+ call void @llvm.memcpy.p1.p0.i64(ptr addrspace(1) %alloca, ptr @g1, i64 32, i1 false)
+ %p = addrspacecast ptr addrspace(1) %alloca to ptr
+ %v = call i8 @readonly_callee(ptr %p)
+ ret i8 %v
+}
+
+define i8 @call_readonly_keep_alloca2() {
+; CHECK-LABEL: @call_readonly_keep_alloca2(
+; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [32 x i8], align 1, addrspace(1)
+; CHECK-NEXT: call void @llvm.memcpy.p1.p0.i64(ptr addrspace(1) noundef align 1 dereferenceable(16) [[ALLOCA]], ptr noundef nonnull align 16 dereferenceable(16) @g1, i64 16, i1 false)
+; CHECK-NEXT: [[A1:%.*]] = getelementptr inbounds [32 x i8], ptr addrspace(1) [[ALLOCA]], i64 0, i64 16
+; CHECK-NEXT: call void @llvm.memcpy.p1.p1.i64(ptr addrspace(1) noundef align 1 dereferenceable(16) [[A1]], ptr addrspace(1) noundef align 16 dereferenceable(16) @g2, i64 16, i1 false)
+; CHECK-NEXT: [[P:%.*]] = addrspacecast ptr addrspace(1) [[ALLOCA]] to ptr
+; CHECK-NEXT: [[V:%.*]] = call i8 @readonly_callee(ptr [[P]])
+; CHECK-NEXT: ret i8 [[V]]
+;
+ %alloca = alloca [32 x i8], addrspace(1)
+ call void @llvm.memcpy.p1.p0.i64(ptr addrspace(1) %alloca, ptr @g1, i64 16, i1 false)
+ %a1 = getelementptr inbounds [32 x i8], ptr addrspace(1) %alloca, i32 0, i32 16
+ call void @llvm.memcpy.p1.p1.i64(ptr addrspace(1) %a1, ptr addrspace(1) @g2, i64 16, i1 false)
+ %p = addrspacecast ptr addrspace(1) %alloca to ptr
+ %v = call i8 @readonly_callee(ptr %p)
+ ret i8 %v
+}
+
declare void @llvm.memcpy.p1.p0.i64(ptr addrspace(1), ptr, i64, i1)
declare void @llvm.memcpy.p0.p0.i64(ptr, ptr, i64, i1)
declare void @llvm.memcpy.p0.p1.i64(ptr, ptr addrspace(1), i64, i1)
+declare void @llvm.memcpy.p1.p1.i64(ptr addrspace(1), ptr addrspace(1), i64, i1)
More information about the llvm-commits
mailing list