[llvm] InferAddressSpaces: Factor replacement loop into function (PR #104430)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 15 04:17:44 PDT 2024


https://github.com/arsenm created https://github.com/llvm/llvm-project/pull/104430

None

>From 091890fa3d255c05574e5c5eec84c92cde205264 Mon Sep 17 00:00:00 2001
From: Matt Arsenault <Matthew.Arsenault at amd.com>
Date: Thu, 15 Aug 2024 14:51:07 +0400
Subject: [PATCH] InferAddressSpaces: Factor replacement loop into function

---
 .../Transforms/Scalar/InferAddressSpaces.cpp  | 213 ++++++++++--------
 1 file changed, 113 insertions(+), 100 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
index 0c8aee8a494c03..e223d4b606d339 100644
--- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -184,6 +184,7 @@ class InferAddressSpaces : public FunctionPass {
 
 class InferAddressSpacesImpl {
   AssumptionCache &AC;
+  Function *F = nullptr;
   const DominatorTree *DT = nullptr;
   const TargetTransformInfo *TTI = nullptr;
   const DataLayout *DL = nullptr;
@@ -212,14 +213,17 @@ class InferAddressSpacesImpl {
       const PredicatedAddrSpaceMapTy &PredicatedAS,
       SmallVectorImpl<const Use *> *PoisonUsesToFix) const;
 
+  void performPointerReplacement(
+      Value *V, Value *NewV, Use &U, ValueToValueMapTy &ValueWithNewAddrSpace,
+      SmallVectorImpl<Instruction *> &DeadInstructions) const;
+
   // Changes the flat address expressions in function F to point to specific
   // address spaces if InferredAddrSpace says so. Postorder is the postorder of
   // all flat expressions in the use-def graph of function F.
-  bool
-  rewriteWithNewAddressSpaces(ArrayRef<WeakTrackingVH> Postorder,
-                              const ValueToAddrSpaceMapTy &InferredAddrSpace,
-                              const PredicatedAddrSpaceMapTy &PredicatedAS,
-                              Function *F) const;
+  bool rewriteWithNewAddressSpaces(
+      ArrayRef<WeakTrackingVH> Postorder,
+      const ValueToAddrSpaceMapTy &InferredAddrSpace,
+      const PredicatedAddrSpaceMapTy &PredicatedAS) const;
 
   void appendsFlatAddressExpressionToPostorderStack(
       Value *V, PostorderStackTy &PostorderStack,
@@ -842,8 +846,9 @@ unsigned InferAddressSpacesImpl::joinAddressSpaces(unsigned AS1,
   return (AS1 == AS2) ? AS1 : FlatAddrSpace;
 }
 
-bool InferAddressSpacesImpl::run(Function &F) {
-  DL = &F.getDataLayout();
+bool InferAddressSpacesImpl::run(Function &F_) {
+  F = &F_;
+  DL = &F->getDataLayout();
 
   if (AssumeDefaultIsFlatAddressSpace)
     FlatAddrSpace = 0;
@@ -855,7 +860,7 @@ bool InferAddressSpacesImpl::run(Function &F) {
   }
 
   // Collects all flat address expressions in postorder.
-  std::vector<WeakTrackingVH> Postorder = collectFlatAddressExpressions(F);
+  std::vector<WeakTrackingVH> Postorder = collectFlatAddressExpressions(*F);
 
   // Runs a data-flow analysis to refine the address spaces of every expression
   // in Postorder.
@@ -865,8 +870,8 @@ bool InferAddressSpacesImpl::run(Function &F) {
 
   // Changes the address spaces of the flat address expressions who are inferred
   // to point to a specific address space.
-  return rewriteWithNewAddressSpaces(Postorder, InferredAddrSpace, PredicatedAS,
-                                     &F);
+  return rewriteWithNewAddressSpaces(Postorder, InferredAddrSpace,
+                                     PredicatedAS);
 }
 
 // Constants need to be tracked through RAUW to handle cases with nested
@@ -1164,10 +1169,105 @@ static Value::use_iterator skipToNextUser(Value::use_iterator I,
   return I;
 }
 
+void InferAddressSpacesImpl::performPointerReplacement(
+    Value *V, Value *NewV, Use &U, ValueToValueMapTy &ValueWithNewAddrSpace,
+    SmallVectorImpl<Instruction *> &DeadInstructions) const {
+
+  User *CurUser = U.getUser();
+
+  unsigned AddrSpace = V->getType()->getPointerAddressSpace();
+  if (replaceIfSimplePointerUse(*TTI, CurUser, AddrSpace, V, NewV))
+    return;
+
+  // Skip if the current user is the new value itself.
+  if (CurUser == NewV)
+    return;
+
+  if (auto *CurUserI = dyn_cast<Instruction>(CurUser);
+      CurUserI && CurUserI->getFunction() != F)
+    return;
+
+  // Handle more complex cases like intrinsic that need to be remangled.
+  if (auto *MI = dyn_cast<MemIntrinsic>(CurUser)) {
+    if (!MI->isVolatile() && handleMemIntrinsicPtrUse(MI, V, NewV))
+      return;
+  }
+
+  if (auto *II = dyn_cast<IntrinsicInst>(CurUser)) {
+    if (rewriteIntrinsicOperands(II, V, NewV))
+      return;
+  }
+
+  if (isa<Instruction>(CurUser)) {
+    if (ICmpInst *Cmp = dyn_cast<ICmpInst>(CurUser)) {
+      // If we can infer that both pointers are in the same addrspace,
+      // transform e.g.
+      //   %cmp = icmp eq float* %p, %q
+      // into
+      //   %cmp = icmp eq float addrspace(3)* %new_p, %new_q
+
+      unsigned NewAS = NewV->getType()->getPointerAddressSpace();
+      int SrcIdx = U.getOperandNo();
+      int OtherIdx = (SrcIdx == 0) ? 1 : 0;
+      Value *OtherSrc = Cmp->getOperand(OtherIdx);
+
+      if (Value *OtherNewV = ValueWithNewAddrSpace.lookup(OtherSrc)) {
+        if (OtherNewV->getType()->getPointerAddressSpace() == NewAS) {
+          Cmp->setOperand(OtherIdx, OtherNewV);
+          Cmp->setOperand(SrcIdx, NewV);
+          return;
+        }
+      }
+
+      // Even if the type mismatches, we can cast the constant.
+      if (auto *KOtherSrc = dyn_cast<Constant>(OtherSrc)) {
+        if (isSafeToCastConstAddrSpace(KOtherSrc, NewAS)) {
+          Cmp->setOperand(SrcIdx, NewV);
+          Cmp->setOperand(OtherIdx, ConstantExpr::getAddrSpaceCast(
+                                        KOtherSrc, NewV->getType()));
+          return;
+        }
+      }
+    }
+
+    if (AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(CurUser)) {
+      unsigned NewAS = NewV->getType()->getPointerAddressSpace();
+      if (ASC->getDestAddressSpace() == NewAS) {
+        ASC->replaceAllUsesWith(NewV);
+        DeadInstructions.push_back(ASC);
+        return;
+      }
+    }
+
+    // Otherwise, replaces the use with flat(NewV).
+    if (Instruction *VInst = dyn_cast<Instruction>(V)) {
+      // Don't create a copy of the original addrspacecast.
+      if (U == V && isa<AddrSpaceCastInst>(V))
+        return;
+
+      // Insert the addrspacecast after NewV.
+      BasicBlock::iterator InsertPos;
+      if (Instruction *NewVInst = dyn_cast<Instruction>(NewV))
+        InsertPos = std::next(NewVInst->getIterator());
+      else
+        InsertPos = std::next(VInst->getIterator());
+
+      while (isa<PHINode>(InsertPos))
+        ++InsertPos;
+      // This instruction may contain multiple uses of V, update them all.
+      CurUser->replaceUsesOfWith(
+          V, new AddrSpaceCastInst(NewV, V->getType(), "", InsertPos));
+    } else {
+      CurUser->replaceUsesOfWith(V, ConstantExpr::getAddrSpaceCast(
+                                        cast<Constant>(NewV), V->getType()));
+    }
+  }
+}
+
 bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces(
     ArrayRef<WeakTrackingVH> Postorder,
     const ValueToAddrSpaceMapTy &InferredAddrSpace,
-    const PredicatedAddrSpaceMapTy &PredicatedAS, Function *F) const {
+    const PredicatedAddrSpaceMapTy &PredicatedAS) const {
   // For each address expression to be modified, creates a clone of it with its
   // pointer operands converted to the new address space. Since the pointer
   // operands are converted, the clone is naturally in the new address space by
@@ -1258,100 +1358,13 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces(
     Value::use_iterator I, E, Next;
     for (I = V->use_begin(), E = V->use_end(); I != E;) {
       Use &U = *I;
-      User *CurUser = U.getUser();
 
       // Some users may see the same pointer operand in multiple operands. Skip
       // to the next instruction.
       I = skipToNextUser(I, E);
 
-      unsigned AddrSpace = V->getType()->getPointerAddressSpace();
-      if (replaceIfSimplePointerUse(*TTI, CurUser, AddrSpace, V, NewV))
-        continue;
-
-      // Skip if the current user is the new value itself.
-      if (CurUser == NewV)
-        continue;
-
-      if (auto *CurUserI = dyn_cast<Instruction>(CurUser);
-          CurUserI && CurUserI->getFunction() != F)
-        continue;
-
-      // Handle more complex cases like intrinsic that need to be remangled.
-      if (auto *MI = dyn_cast<MemIntrinsic>(CurUser)) {
-        if (!MI->isVolatile() && handleMemIntrinsicPtrUse(MI, V, NewV))
-          continue;
-      }
-
-      if (auto *II = dyn_cast<IntrinsicInst>(CurUser)) {
-        if (rewriteIntrinsicOperands(II, V, NewV))
-          continue;
-      }
-
-      if (isa<Instruction>(CurUser)) {
-        if (ICmpInst *Cmp = dyn_cast<ICmpInst>(CurUser)) {
-          // If we can infer that both pointers are in the same addrspace,
-          // transform e.g.
-          //   %cmp = icmp eq float* %p, %q
-          // into
-          //   %cmp = icmp eq float addrspace(3)* %new_p, %new_q
-
-          unsigned NewAS = NewV->getType()->getPointerAddressSpace();
-          int SrcIdx = U.getOperandNo();
-          int OtherIdx = (SrcIdx == 0) ? 1 : 0;
-          Value *OtherSrc = Cmp->getOperand(OtherIdx);
-
-          if (Value *OtherNewV = ValueWithNewAddrSpace.lookup(OtherSrc)) {
-            if (OtherNewV->getType()->getPointerAddressSpace() == NewAS) {
-              Cmp->setOperand(OtherIdx, OtherNewV);
-              Cmp->setOperand(SrcIdx, NewV);
-              continue;
-            }
-          }
-
-          // Even if the type mismatches, we can cast the constant.
-          if (auto *KOtherSrc = dyn_cast<Constant>(OtherSrc)) {
-            if (isSafeToCastConstAddrSpace(KOtherSrc, NewAS)) {
-              Cmp->setOperand(SrcIdx, NewV);
-              Cmp->setOperand(OtherIdx, ConstantExpr::getAddrSpaceCast(
-                                            KOtherSrc, NewV->getType()));
-              continue;
-            }
-          }
-        }
-
-        if (AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(CurUser)) {
-          unsigned NewAS = NewV->getType()->getPointerAddressSpace();
-          if (ASC->getDestAddressSpace() == NewAS) {
-            ASC->replaceAllUsesWith(NewV);
-            DeadInstructions.push_back(ASC);
-            continue;
-          }
-        }
-
-        // Otherwise, replaces the use with flat(NewV).
-        if (Instruction *VInst = dyn_cast<Instruction>(V)) {
-          // Don't create a copy of the original addrspacecast.
-          if (U == V && isa<AddrSpaceCastInst>(V))
-            continue;
-
-          // Insert the addrspacecast after NewV.
-          BasicBlock::iterator InsertPos;
-          if (Instruction *NewVInst = dyn_cast<Instruction>(NewV))
-            InsertPos = std::next(NewVInst->getIterator());
-          else
-            InsertPos = std::next(VInst->getIterator());
-
-          while (isa<PHINode>(InsertPos))
-            ++InsertPos;
-          // This instruction may contain multiple uses of V, update them all.
-          CurUser->replaceUsesOfWith(
-              V, new AddrSpaceCastInst(NewV, V->getType(), "", InsertPos));
-        } else {
-          CurUser->replaceUsesOfWith(
-              V, ConstantExpr::getAddrSpaceCast(cast<Constant>(NewV),
-                                                V->getType()));
-        }
-      }
+      performPointerReplacement(V, NewV, U, ValueWithNewAddrSpace,
+                                DeadInstructions);
     }
 
     if (V->use_empty()) {



More information about the llvm-commits mailing list