[llvm] 72fc08a - [InstCombine] Teach alloca replacement to handle `addrspacecast`

Michael Liao via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 11 08:55:31 PDT 2023


Author: Michael Liao
Date: 2023-04-11T11:47:37-04:00
New Revision: 72fc08a5412ec7ee7f0b904926db16cd86c1f876

URL: https://github.com/llvm/llvm-project/commit/72fc08a5412ec7ee7f0b904926db16cd86c1f876
DIFF: https://github.com/llvm/llvm-project/commit/72fc08a5412ec7ee7f0b904926db16cd86c1f876.diff

LOG: [InstCombine] Teach alloca replacement to handle `addrspacecast`

- As the address space cast may not be valid on a specific target,
  `addrspacecast` is not handled when an `alloca` is able to be replaced
  with the source of memcpy/memmove. This patch addresses that by
  querying a target hook on whether that address space cast is valid.
  For example, on most GPU targets, the cast from a global pointer to a
  generic pointer is valid.
- If that cast is allowedd (by querying `isValidAddrSpaceCast`), the
  replacement is enhanced to handle that `addrspacecast` as well.

Reviewed By: yaxunl

Differential Revision: https://reviews.llvm.org/D147025

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/TargetTransformInfo.h
    llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
    llvm/include/llvm/CodeGen/BasicTTIImpl.h
    llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
    llvm/lib/Analysis/TargetTransformInfo.cpp
    llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
    llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
    llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
    llvm/test/Transforms/InstCombine/AMDGPU/memcpy-from-constant.ll
    llvm/test/Transforms/InstCombine/ptr-replace-alloca.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 7ffeffee18162..9f074e537322e 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -421,6 +421,10 @@ class TargetTransformInfo {
   // even taking non-uniform arguments
   bool isAlwaysUniform(const Value *V) const;
 
+  /// Query the target whether the specified address space cast from FromAS to
+  /// ToAS is valid.
+  bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const;
+
   /// Returns the address space ID for a target's 'flat' address space. Note
   /// this is not necessarily the same as addrspace(0), which LLVM sometimes
   /// refers to as the generic address space. The flat address space is a
@@ -1685,6 +1689,7 @@ class TargetTransformInfo::Concept {
   virtual bool useGPUDivergenceAnalysis() = 0;
   virtual bool isSourceOfDivergence(const Value *V) = 0;
   virtual bool isAlwaysUniform(const Value *V) = 0;
+  virtual bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const = 0;
   virtual unsigned getFlatAddressSpace() = 0;
   virtual bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
                                           Intrinsic::ID IID) const = 0;
@@ -2062,6 +2067,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.isAlwaysUniform(V);
   }
 
+  bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const override {
+    return Impl.isValidAddrSpaceCast(FromAS, ToAS);
+  }
+
   unsigned getFlatAddressSpace() override { return Impl.getFlatAddressSpace(); }
 
   bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,

diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index fbd5202097014..721959fa77156 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -95,6 +95,10 @@ class TargetTransformInfoImplBase {
 
   bool isAlwaysUniform(const Value *V) const { return false; }
 
+  bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const {
+    return false;
+  }
+
   unsigned getFlatAddressSpace() const { return -1; }
 
   bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,

diff  --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index c6c741c831e68..88d70613caba0 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -284,6 +284,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
   bool isAlwaysUniform(const Value *V) { return false; }
 
+  bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const {
+    return false;
+  }
+
   unsigned getFlatAddressSpace() {
     // Return an invalid address space.
     return -1;

diff  --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
index a876385581e72..20be9dbd35bcf 100644
--- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
+++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
@@ -529,6 +529,8 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
   SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, APInt &UndefElts,
                              unsigned Depth = 0,
                              bool AllowMultipleUsers = false) = 0;
+
+  bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const;
 };
 
 } // namespace llvm

diff  --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index a29ccf0653b65..68674b9793bb2 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -274,6 +274,11 @@ bool llvm::TargetTransformInfo::isAlwaysUniform(const Value *V) const {
   return TTIImpl->isAlwaysUniform(V);
 }
 
+bool llvm::TargetTransformInfo::isValidAddrSpaceCast(unsigned FromAS,
+                                                     unsigned ToAS) const {
+  return TTIImpl->isValidAddrSpaceCast(FromAS, ToAS);
+}
+
 unsigned TargetTransformInfo::getFlatAddressSpace() const {
   return TTIImpl->getFlatAddressSpace();
 }

diff  --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
index a4041fcc3ccee..f965b54c2a880 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
@@ -168,6 +168,28 @@ class GCNTTIImpl final : public BasicTTIImplBase<GCNTTIImpl> {
   bool isSourceOfDivergence(const Value *V) const;
   bool isAlwaysUniform(const Value *V) const;
 
+  bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const {
+    if (ToAS == AMDGPUAS::FLAT_ADDRESS) {
+      switch (FromAS) {
+      case AMDGPUAS::GLOBAL_ADDRESS:
+      case AMDGPUAS::CONSTANT_ADDRESS:
+      case AMDGPUAS::CONSTANT_ADDRESS_32BIT:
+      case AMDGPUAS::LOCAL_ADDRESS:
+      case AMDGPUAS::PRIVATE_ADDRESS:
+        return true;
+      default:
+        break;
+      }
+      return false;
+    }
+    if ((FromAS == AMDGPUAS::CONSTANT_ADDRESS_32BIT &&
+         ToAS == AMDGPUAS::CONSTANT_ADDRESS) ||
+        (FromAS == AMDGPUAS::CONSTANT_ADDRESS &&
+         ToAS == AMDGPUAS::CONSTANT_ADDRESS_32BIT))
+      return true;
+    return false;
+  }
+
   unsigned getFlatAddressSpace() const {
     // Don't bother running InferAddressSpaces pass on graphics shaders which
     // don't use flat addressing.

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
index e9bb4089c4e9b..a069a0b498110 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
@@ -260,8 +260,8 @@ namespace {
 // instruction.
 class PointerReplacer {
 public:
-  PointerReplacer(InstCombinerImpl &IC, Instruction &Root)
-    : IC(IC), Root(Root) {}
+  PointerReplacer(InstCombinerImpl &IC, Instruction &Root, unsigned SrcAS)
+      : IC(IC), Root(Root), FromAS(SrcAS) {}
 
   bool collectUsers();
   void replacePointer(Value *V);
@@ -274,11 +274,21 @@ class PointerReplacer {
     return I == &Root || Worklist.contains(I);
   }
 
+  bool isEqualOrValidAddrSpaceCast(const Instruction *I,
+                                   unsigned FromAS) const {
+    const auto *ASC = dyn_cast<AddrSpaceCastInst>(I);
+    if (!ASC)
+      return false;
+    unsigned ToAS = ASC->getDestAddressSpace();
+    return (FromAS == ToAS) || IC.isValidAddrSpaceCast(FromAS, ToAS);
+  }
+
   SmallPtrSet<Instruction *, 32> ValuesToRevisit;
   SmallSetVector<Instruction *, 4> Worklist;
   MapVector<Value *, Value *> WorkMap;
   InstCombinerImpl &IC;
   Instruction &Root;
+  unsigned FromAS;
 };
 } // end anonymous namespace
 
@@ -342,6 +352,8 @@ bool PointerReplacer::collectUsersRecursive(Instruction &I) {
       if (MI->isVolatile())
         return false;
       Worklist.insert(Inst);
+    } else if (isEqualOrValidAddrSpaceCast(Inst, FromAS)) {
+      Worklist.insert(Inst);
     } else if (Inst->isLifetimeStartOrEnd()) {
       continue;
     } else {
@@ -427,6 +439,22 @@ void PointerReplacer::replace(Instruction *I) {
 
     IC.eraseInstFromFunction(*MemCpy);
     WorkMap[MemCpy] = NewI;
+  } else if (auto *ASC = dyn_cast<AddrSpaceCastInst>(I)) {
+    auto *V = getReplacement(ASC->getPointerOperand());
+    assert(V && "Operand not replaced");
+    assert(isEqualOrValidAddrSpaceCast(
+               ASC, V->getType()->getPointerAddressSpace()) &&
+           "Invalid address space cast!");
+    auto *NewV = V;
+    if (V->getType()->getPointerAddressSpace() !=
+        ASC->getType()->getPointerAddressSpace()) {
+      auto *NewI = new AddrSpaceCastInst(V, ASC->getType(), "");
+      NewI->takeName(ASC);
+      IC.InsertNewInstWith(NewI, *ASC);
+      NewV = NewI;
+    }
+    IC.replaceInstUsesWith(*ASC, NewV);
+    IC.eraseInstFromFunction(*ASC);
   } else {
     llvm_unreachable("should never reach here");
   }
@@ -519,7 +547,7 @@ Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) {
         return NewI;
       }
 
-      PointerReplacer PtrReplacer(*this, AI);
+      PointerReplacer PtrReplacer(*this, AI, SrcAddrSpace);
       if (PtrReplacer.collectUsers()) {
         for (Instruction *Delete : ToDelete)
           eraseInstFromFunction(*Delete);

diff  --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 236ebf4bebc8c..a1bdc54c009ae 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -196,6 +196,10 @@ std::optional<Value *> InstCombiner::targetSimplifyDemandedVectorEltsIntrinsic(
   return std::nullopt;
 }
 
+bool InstCombiner::isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const {
+  return TTI.isValidAddrSpaceCast(FromAS, ToAS);
+}
+
 Value *InstCombinerImpl::EmitGEPOffset(User *GEP) {
   return llvm::emitGEPOffset(&Builder, DL, GEP);
 }

diff  --git a/llvm/test/Transforms/InstCombine/AMDGPU/memcpy-from-constant.ll b/llvm/test/Transforms/InstCombine/AMDGPU/memcpy-from-constant.ll
index 44f343fff5485..2d2a93b28a5ba 100644
--- a/llvm/test/Transforms/InstCombine/AMDGPU/memcpy-from-constant.ll
+++ b/llvm/test/Transforms/InstCombine/AMDGPU/memcpy-from-constant.ll
@@ -137,11 +137,9 @@ define amdgpu_kernel void @memcpy_constant_intrinsic_ptr_to_alloca(ptr addrspace
 ; Alloca is written through a flat pointer
 define i8 @memcpy_constant_arg_ptr_to_alloca_addrspacecast_to_flat(ptr addrspace(4) noalias readonly align 4 dereferenceable(32) %arg, i32 %idx) {
 ; CHECK-LABEL: @memcpy_constant_arg_ptr_to_alloca_addrspacecast_to_flat(
-; CHECK-NEXT:    [[ALLOCA:%.*]] = alloca [32 x i8], align 4, addrspace(5)
-; CHECK-NEXT:    [[ALLOCA_CAST_ASC:%.*]] = addrspacecast ptr addrspace(5) [[ALLOCA]] to ptr
-; CHECK-NEXT:    call void @llvm.memcpy.p0.p4.i64(ptr noundef nonnull align 1 dereferenceable(31) [[ALLOCA_CAST_ASC]], ptr addrspace(4) noundef align 4 dereferenceable(31) [[ARG:%.*]], i64 31, i1 false)
-; CHECK-NEXT:    [[GEP:%.*]] = getelementptr inbounds [32 x i8], ptr addrspace(5) [[ALLOCA]], i32 0, i32 [[IDX:%.*]]
-; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr addrspace(5) [[GEP]], align 1
+; CHECK-NEXT:    [[TMP1:%.*]] = sext i32 [[IDX:%.*]] to i64
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr [32 x i8], ptr addrspace(4) [[ARG:%.*]], i64 0, i64 [[TMP1]]
+; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr addrspace(4) [[GEP]], align 1
 ; CHECK-NEXT:    ret i8 [[LOAD]]
 ;
   %alloca = alloca [32 x i8], align 4, addrspace(5)
@@ -155,9 +153,7 @@ define i8 @memcpy_constant_arg_ptr_to_alloca_addrspacecast_to_flat(ptr addrspace
 ; Alloca is only addressed through flat pointer.
 define i8 @memcpy_constant_arg_ptr_to_alloca_addrspacecast_to_flat2(ptr addrspace(4) noalias readonly align 4 dereferenceable(32) %arg, i32 %idx) {
 ; CHECK-LABEL: @memcpy_constant_arg_ptr_to_alloca_addrspacecast_to_flat2(
-; CHECK-NEXT:    [[ALLOCA:%.*]] = alloca [32 x i8], align 4, addrspace(5)
-; CHECK-NEXT:    [[ALLOCA_CAST_ASC:%.*]] = addrspacecast ptr addrspace(5) [[ALLOCA]] to ptr
-; CHECK-NEXT:    call void @llvm.memcpy.p0.p4.i64(ptr noundef nonnull align 1 dereferenceable(32) [[ALLOCA_CAST_ASC]], ptr addrspace(4) noundef align 4 dereferenceable(32) [[ARG:%.*]], i64 32, i1 false)
+; CHECK-NEXT:    [[ALLOCA_CAST_ASC:%.*]] = addrspacecast ptr addrspace(4) [[ARG:%.*]] to ptr
 ; CHECK-NEXT:    [[TMP1:%.*]] = sext i32 [[IDX:%.*]] to i64
 ; CHECK-NEXT:    [[GEP:%.*]] = getelementptr inbounds [32 x i8], ptr [[ALLOCA_CAST_ASC]], i64 0, i64 [[TMP1]]
 ; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr [[GEP]], align 1
@@ -202,9 +198,7 @@ bb:
 define amdgpu_kernel void @byref_infloop_addrspacecast(ptr %scratch, ptr addrspace(4) byref(%struct.ty) align 4 %arg) local_unnamed_addr #1 {
 ; CHECK-LABEL: @byref_infloop_addrspacecast(
 ; CHECK-NEXT:  bb:
-; CHECK-NEXT:    [[ALLOCA:%.*]] = alloca [4 x i32], align 4, addrspace(5)
-; CHECK-NEXT:    [[ADDRSPACECAST_ALLOCA:%.*]] = addrspacecast ptr addrspace(5) [[ALLOCA]] to ptr
-; CHECK-NEXT:    call void @llvm.memcpy.p0.p4.i64(ptr noundef nonnull align 4 dereferenceable(16) [[ADDRSPACECAST_ALLOCA]], ptr addrspace(4) noundef align 4 dereferenceable(16) [[ARG:%.*]], i64 16, i1 false)
+; CHECK-NEXT:    [[ADDRSPACECAST_ALLOCA:%.*]] = addrspacecast ptr addrspace(4) [[ARG:%.*]] to ptr
 ; CHECK-NEXT:    call void @llvm.memcpy.p0.p0.i64(ptr noundef nonnull align 4 dereferenceable(16) [[SCRATCH:%.*]], ptr noundef nonnull align 4 dereferenceable(16) [[ADDRSPACECAST_ALLOCA]], i64 16, i1 false)
 ; CHECK-NEXT:    ret void
 ;

diff  --git a/llvm/test/Transforms/InstCombine/ptr-replace-alloca.ll b/llvm/test/Transforms/InstCombine/ptr-replace-alloca.ll
index 8f25c71e47823..6a0e7098ab744 100644
--- a/llvm/test/Transforms/InstCombine/ptr-replace-alloca.ll
+++ b/llvm/test/Transforms/InstCombine/ptr-replace-alloca.ll
@@ -427,6 +427,40 @@ entry:
   ret i8 %load
 }
 
+declare i8 @readonly_callee(ptr readonly nocapture)
+
+define i8 @call_readonly_remove_alloca() {
+; CHECK-LABEL: @call_readonly_remove_alloca(
+; CHECK-NEXT:    [[V:%.*]] = call i8 @readonly_callee(ptr nonnull @g1)
+; CHECK-NEXT:    ret i8 [[V]]
+;
+  %alloca = alloca [32 x i8], addrspace(1)
+  call void @llvm.memcpy.p1.p0.i64(ptr addrspace(1) %alloca, ptr @g1, i64 32, i1 false)
+  %p = addrspacecast ptr addrspace(1) %alloca to ptr
+  %v = call i8 @readonly_callee(ptr %p)
+  ret i8 %v
+}
+
+define i8 @call_readonly_keep_alloca2() {
+; CHECK-LABEL: @call_readonly_keep_alloca2(
+; CHECK-NEXT:    [[ALLOCA:%.*]] = alloca [32 x i8], align 1, addrspace(1)
+; CHECK-NEXT:    call void @llvm.memcpy.p1.p0.i64(ptr addrspace(1) noundef align 1 dereferenceable(16) [[ALLOCA]], ptr noundef nonnull align 16 dereferenceable(16) @g1, i64 16, i1 false)
+; CHECK-NEXT:    [[A1:%.*]] = getelementptr inbounds [32 x i8], ptr addrspace(1) [[ALLOCA]], i64 0, i64 16
+; CHECK-NEXT:    call void @llvm.memcpy.p1.p1.i64(ptr addrspace(1) noundef align 1 dereferenceable(16) [[A1]], ptr addrspace(1) noundef align 16 dereferenceable(16) @g2, i64 16, i1 false)
+; CHECK-NEXT:    [[P:%.*]] = addrspacecast ptr addrspace(1) [[ALLOCA]] to ptr
+; CHECK-NEXT:    [[V:%.*]] = call i8 @readonly_callee(ptr [[P]])
+; CHECK-NEXT:    ret i8 [[V]]
+;
+  %alloca = alloca [32 x i8], addrspace(1)
+  call void @llvm.memcpy.p1.p0.i64(ptr addrspace(1) %alloca, ptr @g1, i64 16, i1 false)
+  %a1 = getelementptr inbounds [32 x i8], ptr addrspace(1) %alloca, i32 0, i32 16
+  call void @llvm.memcpy.p1.p1.i64(ptr addrspace(1) %a1, ptr addrspace(1) @g2, i64 16, i1 false)
+  %p = addrspacecast ptr addrspace(1) %alloca to ptr
+  %v = call i8 @readonly_callee(ptr %p)
+  ret i8 %v
+}
+
 declare void @llvm.memcpy.p1.p0.i64(ptr addrspace(1), ptr, i64, i1)
 declare void @llvm.memcpy.p0.p0.i64(ptr, ptr, i64, i1)
 declare void @llvm.memcpy.p0.p1.i64(ptr, ptr addrspace(1), i64, i1)
+declare void @llvm.memcpy.p1.p1.i64(ptr addrspace(1), ptr addrspace(1), i64, i1)


        


More information about the llvm-commits mailing list