[llvm-branch-commits] [llvm] ee7e6c4 - common chains

Chen Zheng via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Aug 17 22:15:15 PDT 2021


Author: Chen Zheng
Date: 2021-08-18T03:20:39Z
New Revision: ee7e6c4e05af743c1ba2db57abd33fb828d49025

URL: https://github.com/llvm/llvm-project/commit/ee7e6c4e05af743c1ba2db57abd33fb828d49025
DIFF: https://github.com/llvm/llvm-project/commit/ee7e6c4e05af743c1ba2db57abd33fb828d49025.diff

LOG: common chains

Added: 
    

Modified: 
    llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp b/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp
index 010f49c8d3ebc..c3d3f1504fd4d 100644
--- a/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp
+++ b/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp
@@ -140,6 +140,38 @@ namespace {
     SmallVector<BucketElement, 16> Elements;
   };
 
+  struct ChainBucketElement {
+    ChainBucketElement(const SCEV *O, Instruction *I) : Offset(O), Instr(I) {}
+    ChainBucketElement(Instruction *I) : Offset(nullptr), Instr(I) {}
+
+    const SCEV *Offset;
+    Instruction *Instr;
+  };
+
+  struct ChainBucket {
+    ChainBucket(const SCEV *B, Instruction *I) : BaseSCEV(B),
+                                            Elements(1, ChainBucketElement(I)) { ChainSize = 0; }
+
+    const SCEV *BaseSCEV;
+   // Value *Ptr;
+    SmallVector<ChainBucketElement, 16> Elements;
+    unsigned ChainSize;
+    SmallVector<ChainBucketElement, 16> ChainBases;
+    //SmallVector<Bucket, 16> RewriteBuckets;
+    void dump() {
+      LLVM_DEBUG(dbgs() << "Chain base scev is "; BaseSCEV->dump());
+      LLVM_DEBUG(dbgs() << "Chain element size is "<< Elements.size() << "\n");
+      for (auto E : Elements) {
+        if (!E.Offset)
+        LLVM_DEBUG(dbgs() << "base Element Instruction is "; E.Instr->dump());
+        else {
+          LLVM_DEBUG(dbgs() << "Element offset is "; E.Offset->dump());
+          LLVM_DEBUG(dbgs() << "Element instruction is "; E.Instr->dump());
+        }
+      }
+    }
+  };
+
   // "UpdateForm" is not a real PPC instruction form, it stands for dform
   // load/store with update like ldu/stdu, or Prefetch intrinsic.
   // For DS form instructions, their displacements must be multiple of 4.
@@ -192,6 +224,21 @@ namespace {
     Value *getPreparedIncNode(Loop *L, Instruction *MemI,
                               const SCEV *BasePtrIncSCEV);
 
+    /// Collect chain load/store candidates in Loop \p L.
+    SmallVector<ChainBucket, 16>  collectCandidatesForChain(Loop *L);
+
+    /// Add a candidate to candidates \p Buckets for chain.
+    void addOneCandidateForChain(Instruction *MemI, const SCEV *LSCEV, SmallVector<ChainBucket, 16> &Buckets);
+
+    /// Common chains to reuse offsets for a loop to reduce register pressure.
+    bool chainCommoning(Loop *L, SmallVector<ChainBucket, 16> &ChainBuckets);
+
+    bool prepareBasesForChains(ChainBucket &BucketChain);
+
+    bool rewriteLoadStoresForChains(Loop *L, ChainBucket &Bucket,
+                           SmallSet<BasicBlock *, 16> &BBChanged,
+                            DenseMap<const SCEV*, Value*> &ExpandedOffsets);
+
     /// Collect condition matched(\p isValidCandidate() returns true)
     /// candidates in Loop \p L.
     SmallVector<Bucket, 16> collectCandidates(
@@ -272,7 +319,7 @@ static std::string getInstrName(const Value *I, StringRef Suffix) {
     return "";
 }
 
-static Value *GetPointerOperand(Value *MemI) {
+static Value *getPtrOperand(Value *MemI) {
   if (LoadInst *LMemI = dyn_cast<LoadInst>(MemI)) {
     return LMemI->getPointerOperand();
   } else if (StoreInst *SMemI = dyn_cast<StoreInst>(MemI)) {
@@ -309,10 +356,448 @@ bool PPCLoopInstrFormPrep::runOnFunction(Function &F) {
   return MadeChange;
 }
 
+// check if the SCEV is only with one ptr operand in its start, so that we can
+// use that start as a chain separator.
+static bool isValidChainCandidate(const SCEV *LSCEV)
+{
+  const SCEVAddRecExpr *ARSCEV = cast<SCEVAddRecExpr>(LSCEV);
+  if (!ARSCEV)
+    return false;
+
+  if (!ARSCEV->isAffine())
+    return false;
+
+  const SCEV *Start = ARSCEV->getStart();
+  LLVM_DEBUG(dbgs() << "Start SCEV is "; Start->dump());
+  LLVM_DEBUG(dbgs() << "Start SCEV type is "; Start->getType()->dump());
+
+  LLVM_DEBUG(dbgs() << "start is unknown is " << isa<SCEVUnknown>(Start) << "\n");
+
+  // A single pointer.
+  if (isa<SCEVUnknown>(Start) && Start->getType()->isPointerTy())
+    return true;
+
+  const SCEVAddExpr *ASCEV = dyn_cast<SCEVAddExpr>(Start);
+
+  // Now we only handle SCEVAddExpr.
+  if (!ASCEV)
+    return false;
+
+  bool SawPointer = false;
+  LLVM_DEBUG(dbgs() << "operand number is " << ASCEV->getNumOperands() << "\n");
+  int i = 0;
+  for (const SCEV *Op : ASCEV->operands()) {
+    i++;
+    LLVM_DEBUG(dbgs() << "operand " << i << " is "; Op->dump());
+    LLVM_DEBUG(dbgs() << "operand " << i << " type is "; Op->getType()->dump());
+    if (Op->getType()->isPointerTy()) {
+      if (SawPointer)
+        return false;
+      SawPointer = true;
+    }
+    else if (!Op->getType()->isIntegerTy())
+      return false;
+  }
+
+  return SawPointer;
+}
+
+// Make sure the 
diff  between the base and new candidate is:
+// 1: an integer type.
+// 2: does not contain any pointer type.
+static bool isValidChainDiff(const SCEV *LSCEV)
+{
+  assert(LSCEV && "Invalid SCEV for Ptr value.");
+  LLVM_DEBUG(dbgs() << "enter chain 
diff  funciton\n");
+  LLVM_DEBUG(dbgs() << "
diff  is "; LSCEV->dump());
+
+  // Don't mess up previous dform prepare.
+  if (isa<SCEVConstant>(LSCEV))
+    return false;
+  LLVM_DEBUG(dbgs() << "get SCEV type is " << LSCEV->getSCEVType() << "\n");
+
+  // A single integer type offset.
+  if (isa<SCEVUnknown>(LSCEV) && LSCEV->getType()->isIntegerTy())
+    return true;
+
+  const SCEVNAryExpr *ASCEV = dyn_cast<SCEVNAryExpr>(LSCEV);
+  if (!ASCEV)
+    return false;
+
+  LLVM_DEBUG(dbgs() << "
diff  is "; LSCEV->dump());
+  LLVM_DEBUG(dbgs() << "
diff  type is "; LSCEV->getType()->dump());
+
+  LLVM_DEBUG(dbgs() << "operand number is " << ASCEV->getNumOperands() << "\n");
+  int i = 0;
+  for (const SCEV *Op : ASCEV->operands()) {
+    LLVM_DEBUG(dbgs() << "operand " << i << " is "; Op->dump());
+    LLVM_DEBUG(dbgs() << "operand " << i << " type is "; Op->getType()->dump());
+    if (!Op->getType()->isIntegerTy())
+      return false;
+  }
+
+  return true;
+}
+
+void PPCLoopInstrFormPrep::addOneCandidateForChain(Instruction *MemI, const SCEV *LSCEV,
+                                        SmallVector<ChainBucket, 16> &Buckets) {
+  assert((MemI && getPtrOperand(MemI)) &&
+         "Candidate should be a memory instruction.");
+  assert(LSCEV && "Invalid SCEV for Ptr value.");
+
+  if (!isValidChainCandidate(LSCEV)) {
+   LLVM_DEBUG(dbgs() << "invalid chain candidate\n");
+    return;
+  }
+
+   LLVM_DEBUG(dbgs() << "valid chain candidate\n");
+
+  bool FoundBucket = false;
+  for (auto &B : Buckets) {
+    if (cast<SCEVAddRecExpr>(B.BaseSCEV)->getStepRecurrence(*SE) != cast<SCEVAddRecExpr>(LSCEV)->getStepRecurrence(*SE))
+      continue;
+    const SCEV *Diff = SE->getMinusSCEV(LSCEV, B.BaseSCEV);
+    if (isValidChainDiff(Diff)) {
+      LLVM_DEBUG(dbgs() << "add a valid candidate for one chain\n");
+      B.Elements.push_back(ChainBucketElement(Diff, MemI));
+      FoundBucket = true;
+      break;
+    }
+  }
+
+  if (!FoundBucket) {
+   LLVM_DEBUG(dbgs() << "create new chain\n");
+    Buckets.push_back(ChainBucket(LSCEV, MemI));
+  }
+}
+
+SmallVector<ChainBucket, 16> PPCLoopInstrFormPrep::collectCandidatesForChain(Loop *L) {
+  SmallVector<ChainBucket, 16> Buckets;
+  for (const auto &BB : L->blocks())
+    for (auto &J : *BB) {
+      Value *PtrValue = getPtrOperand(&J);
+      if (!PtrValue)
+        continue;
+
+      unsigned PtrAddrSpace = PtrValue->getType()->getPointerAddressSpace();
+      if (PtrAddrSpace)
+        continue;
+
+      if (L->isLoopInvariant(PtrValue))
+        continue;
+
+      const SCEV *LSCEV = SE->getSCEVAtScope(PtrValue, L);
+      const SCEVAddRecExpr *LARSCEV = dyn_cast<SCEVAddRecExpr>(LSCEV);
+      LLVM_DEBUG(dbgs() << "Instructions is "; J.dump());
+      LLVM_DEBUG(dbgs() << "ptr value is "; PtrValue->dump());
+      LLVM_DEBUG(dbgs() << "SCEV is "; LSCEV->dump());
+      if (!LARSCEV || LARSCEV->getLoop() != L)
+        continue;
+
+      addOneCandidateForChain(&J, LSCEV, Buckets);
+    }
+  return Buckets;
+}
+
+bool PPCLoopInstrFormPrep::prepareBasesForChains(ChainBucket &CBucket) {
+  assert(CBucket.Elements.size() >= 4 && "Not a candidate chain!\n");
+
+  const SCEV *Offset = CBucket.Elements[1].Offset;
+  unsigned TotalCount = 1;
+  unsigned FirstGroupCount = 1;
+  unsigned EleNum = CBucket.Elements.size();
+  bool SawSeparater = false;
+  for (unsigned j = 2; j != EleNum; ++j) {
+    if (SE->getMinusSCEV(CBucket.Elements[j].Offset, CBucket.Elements[j - 1].Offset) == Offset) {
+      if (!SawSeparater)
+        FirstGroupCount++;
+      TotalCount++;
+    } else
+     SawSeparater = true;
+  }
+
+  LLVM_DEBUG(dbgs() << "total count is " << TotalCount << " firstgroup count is " << FirstGroupCount << " saw separater is " << SawSeparater << "\n");
+  // No reuseable offset, skip now.
+  if (TotalCount == 1)
+    return false;
+
+  // All Elements are increased by Offset.
+  // The number of new bases should both be sqrt(EleNum).
+  if (!SawSeparater) {
+    unsigned ChainNum = (unsigned) sqrt(EleNum);
+    CBucket.ChainSize = (unsigned) (EleNum / ChainNum);
+
+    // If this is not a perfect chain(all Elements can be put inside a group.), skip it now.
+    if (CBucket.ChainSize * ChainNum != EleNum)
+      return false;
+
+    for (unsigned i = 0; i < ChainNum; i++)
+      CBucket.ChainBases.push_back(CBucket.Elements[i * CBucket.ChainSize]);
+    return true;
+  }
+
+  unsigned ChainNum = TotalCount / FirstGroupCount;
+  CBucket.ChainSize = EleNum / ChainNum;
+
+  // Perfect chain check.
+  if (CBucket.ChainSize * ChainNum != EleNum) {
+    LLVM_DEBUG(dbgs() << "return false 1\n");
+    return false;
+  }
+
+  for (unsigned i = 1; i < CBucket.ChainSize; i++)
+    for (unsigned j = 1; j < ChainNum; j++)
+      if (CBucket.Elements[i].Offset != SE->getMinusSCEV(CBucket.Elements[i + j * CBucket.ChainSize].Offset,   CBucket.Elements[j * CBucket.ChainSize].Offset)) {
+    LLVM_DEBUG(dbgs() << "return false 2\n");
+        return false;
+      }
+
+  for (unsigned i = 0; i < ChainNum; i++)
+    CBucket.ChainBases.push_back(CBucket.Elements[i * CBucket.ChainSize]);
+
+  return true;
+}
+
+bool PPCLoopInstrFormPrep::chainCommoning(Loop* L, SmallVector<ChainBucket, 16> &ChainBuckets) {
+  bool MadeChange = false;
+
+  if (ChainBuckets.empty())
+    return MadeChange;
+
+  SmallSet<BasicBlock *, 16> BBChanged;
+  DenseMap<const SCEV*, Value*> ExpandedOffsets;
+
+  for (auto &Bucket : ChainBuckets) {
+    // The minimal size for profitable chain commoning:
+    // A1 = base + offset1
+    // A2 = base + offset2 (offset2 - offset1 = X)
+    // A3 = base + offset3
+    // A4 = base + offset4 (offset4 - offset3 = X)
+    // ======>
+    // base1 = base + offset1
+    // base2 = base + offset3
+    // A1 = base1
+    // A2 = base1 + X
+    // A3 = base2
+    // A4 = base2 + X
+    //
+    // There is benefit because of reuse of offest 'X'.
+    if (Bucket.Elements.size() < 4)
+      continue;
+
+    LLVM_DEBUG(dbgs() << "start to prepare bases\n");
+    if (prepareBasesForChains(Bucket)) {
+      LLVM_DEBUG(dbgs() << "is a valid bucket chain\n");
+      LLVM_DEBUG(dbgs() << "group size is " << Bucket.ChainSize << "\n");
+      MadeChange |= rewriteLoadStoresForChains(L, Bucket, BBChanged, ExpandedOffsets); 
+    }
+  }
+
+  if (MadeChange)
+    for (auto &BB : L->blocks())
+      if (BBChanged.count(BB))
+        DeleteDeadPHIs(BB);
+  return MadeChange;
+}
+
+bool PPCLoopInstrFormPrep::rewriteLoadStoresForChains(Loop *L, ChainBucket &Bucket,
+                           SmallSet<BasicBlock *, 16> &BBChanged, DenseMap<const SCEV*, Value*> &ExpandedOffsets) {
+  bool MadeChange = false;
+
+  assert(Bucket.Elements.size() == Bucket.ChainBases.size() * Bucket.ChainSize && "invalid bucket for chain commoning!\n");
+
+  // Make sure each offset is able to expand.
+  for (unsigned Idx = 1; Idx < Bucket.ChainSize; ++Idx)
+    if (!isSafeToExpand(Bucket.Elements[Idx].Offset, *SE))
+      return MadeChange;
+
+  // Make sure each base is able to expand.
+  for (unsigned Idx = 0; Idx < Bucket.ChainBases.size(); ++Idx) {
+    const SCEV *BaseSCEV = Idx ? SE->getAddExpr(Bucket.BaseSCEV, Bucket.ChainBases[Idx].Offset) : Bucket.BaseSCEV;
+    const SCEVAddRecExpr *BasePtrSCEV = cast<SCEVAddRecExpr>(BaseSCEV);
+    if (!isSafeToExpand(BasePtrSCEV->getStart(), *SE))
+      return MadeChange;
+  }
+
+  for (unsigned ChainIdx = 0; ChainIdx < Bucket.ChainBases.size(); ++ChainIdx) {
+    unsigned BaseElemIdx = Bucket.ChainSize * ChainIdx;
+    const SCEV *BaseSCEV = ChainIdx ? SE->getAddExpr(Bucket.BaseSCEV, Bucket.Elements[BaseElemIdx].Offset) : Bucket.BaseSCEV;
+    const SCEVAddRecExpr *BasePtrSCEV = cast<SCEVAddRecExpr>(BaseSCEV);
+    assert(BasePtrSCEV->isAffine() && "Invalid SCEV type for the base ptr for a candidate chain!\n");
+
+  assert(BasePtrSCEV->getLoop() == L && "AddRec for the wrong loop?");
+
+  // The first elements is always the base for the first chain.
+  Instruction *MemI = Bucket.Elements[BaseElemIdx].Instr;
+  Value *BasePtr = getPtrOperand(MemI);
+  assert(BasePtr && "No pointer operand");
+
+  Type *I8Ty = Type::getInt8Ty(MemI->getParent()->getContext());
+  Type *I64Ty = Type::getInt64Ty(MemI->getParent()->getContext());
+  Type *I8PtrTy = Type::getInt8PtrTy(MemI->getParent()->getContext(),
+    BasePtr->getType()->getPointerAddressSpace());
+
+  const SCEV *BasePtrIncSCEV = BasePtrSCEV->getStepRecurrence(*SE);
+  const SCEV *BasePtrStartSCEV = BasePtrSCEV->getStart();
+  assert((SE->isLoopInvariant(BasePtrStartSCEV, L) && SE->isLoopInvariant(BasePtrIncSCEV, L)) && "Invalid base SCEV!\n");
+
+  bool IsConstantInc = false;
+  Value *IncNode = getPreparedIncNode(L, MemI, BasePtrIncSCEV);
+
+  const SCEVConstant *BasePtrIncConstantSCEV =
+      dyn_cast<SCEVConstant>(BasePtrIncSCEV);
+  if (BasePtrIncConstantSCEV)
+    IsConstantInc = true;
+
+  // No valid representation for the increment.
+  if (!IncNode) {
+    LLVM_DEBUG(dbgs() << "Loop Increasement can not be represented!\n");
+    return MadeChange;
+  }
+
+/*
+  if (alreadyPrepared(L, MemI, BasePtrStartSCEV, BasePtrIncSCEV, Form)) {
+    LLVM_DEBUG(dbgs() << "Instruction form is already prepared!\n");
+    return MadeChange;
+  }
+*/
+
+  BasicBlock *Header = L->getHeader();
+  unsigned HeaderLoopPredCount = pred_size(Header);
+  BasicBlock *LoopPredecessor = L->getLoopPredecessor();
+
+  PHINode *NewPHI =
+      PHINode::Create(I8PtrTy, HeaderLoopPredCount,
+                      getInstrName(MemI, PHINodeNameSuffix),
+                      Header->getFirstNonPHI());
+
+  SCEVExpander SCEVE(*SE, Header->getModule()->getDataLayout(), "pistart");
+  Value *BasePtrStart = SCEVE.expandCodeFor(BasePtrStartSCEV, I8PtrTy,
+                                            LoopPredecessor->getTerminator());
+
+  // Note that LoopPredecessor might occur in the predecessor list multiple
+  // times, and we need to add it the right number of times.
+  for (auto PI : predecessors(Header)) {
+    if (PI != LoopPredecessor)
+      continue;
+
+    NewPHI->addIncoming(BasePtrStart, LoopPredecessor);
+  }
+
+  Instruction *PtrInc = nullptr;
+  Instruction *NewBasePtr = nullptr;
+
+    // Note that LoopPredecessor might occur in the predecessor list multiple
+    // times, and we need to make sure no more incoming value for them in PHI.
+    for (auto PI : predecessors(Header)) {
+      if (PI == LoopPredecessor)
+        continue;
+
+      // For the latch predecessor, we need to insert a GEP just before the
+      // terminator to increase the address.
+      BasicBlock *BB = PI;
+      Instruction *InsPoint = BB->getTerminator();
+      PtrInc = GetElementPtrInst::Create(
+          I8Ty, NewPHI, IncNode, getInstrName(MemI, GEPNodeIncNameSuffix),
+          InsPoint);
+      cast<GetElementPtrInst>(PtrInc)->setIsInBounds(IsPtrInBounds(BasePtr));
+
+      NewPHI->addIncoming(PtrInc, PI);
+    }
+    PtrInc = NewPHI;
+    if (NewPHI->getType() != BasePtr->getType())
+      NewBasePtr =
+          new BitCastInst(NewPHI, BasePtr->getType(),
+                          getInstrName(NewPHI, CastNodeNameSuffix),
+                          &*Header->getFirstInsertionPt());
+    else
+      NewBasePtr = NewPHI;
+
+  // Expand the offset now, before we clear the SCEV expander.
+  if (ChainIdx == 0) {
+    for (unsigned Idx = 1; Idx < Bucket.ChainSize; ++Idx) {
+      if (ExpandedOffsets.find(Bucket.Elements[Idx].Offset) == ExpandedOffsets.end()) {
+        Value *Offset = SCEVE.expandCodeFor(Bucket.Elements[Idx].Offset, I64Ty, LoopPredecessor->getTerminator());
+        ExpandedOffsets[Bucket.Elements[Idx].Offset] = Offset;
+      }
+    }
+  }
+
+
+  LLVM_DEBUG(dbgs() << "debug mark\n");
+
+  // Clear the rewriter cache, because values that are in the rewriter's cache
+  // can be deleted below, causing the AssertingVH in the cache to trigger.
+  SCEVE.clear();
+
+  if (Instruction *IDel = dyn_cast<Instruction>(BasePtr))
+    BBChanged.insert(IDel->getParent());
+  BasePtr->replaceAllUsesWith(NewBasePtr);
+  RecursivelyDeleteTriviallyDeadInstructions(BasePtr);
+
+  // Keep track of the replacement pointer values we've inserted so that we
+  // don't generate more pointer values than necessary.
+  SmallPtrSet<Value *, 16> NewPtrs;
+  NewPtrs.insert(NewBasePtr);
+
+  for (unsigned Idx = BaseElemIdx + 1; Idx < BaseElemIdx + Bucket.ChainSize; ++Idx) {
+    ChainBucketElement &I = Bucket.Elements[Idx];
+    Value *Ptr = getPtrOperand(I.Instr);
+    assert(Ptr && "No pointer operand");
+    if (NewPtrs.count(Ptr))
+      continue;
+
+    Instruction *RealNewPtr;
+    if (!I.Offset) {
+      RealNewPtr = NewBasePtr;
+    } else {
+      Instruction *PtrIP = dyn_cast<Instruction>(Ptr);
+      if (PtrIP && isa<Instruction>(NewBasePtr) &&
+          cast<Instruction>(NewBasePtr)->getParent() == PtrIP->getParent())
+        PtrIP = nullptr;
+      else if (PtrIP && isa<PHINode>(PtrIP))
+        PtrIP = &*PtrIP->getParent()->getFirstInsertionPt();
+      else if (!PtrIP)
+        PtrIP = I.Instr;
+
+      const SCEV* Offset = BaseElemIdx ? SE->getMinusSCEV(Bucket.Elements[Idx].Offset, Bucket.Elements[BaseElemIdx].Offset) : Bucket.Elements[Idx].Offset;
+      assert(ExpandedOffsets.find(Offset) != ExpandedOffsets.end() && "Offset should be expanded before!\n");
+
+      GetElementPtrInst *NewPtr = GetElementPtrInst::Create(
+          I8Ty, PtrInc, ExpandedOffsets[Offset],
+          getInstrName(I.Instr, GEPNodeOffNameSuffix), PtrIP);
+      if (!PtrIP)
+        NewPtr->insertAfter(cast<Instruction>(PtrInc));
+      NewPtr->setIsInBounds(IsPtrInBounds(Ptr));
+      RealNewPtr = NewPtr;
+    }
+
+    if (Instruction *IDel = dyn_cast<Instruction>(Ptr))
+      BBChanged.insert(IDel->getParent());
+
+    Instruction *ReplNewPtr;
+    if (Ptr->getType() != RealNewPtr->getType()) {
+      ReplNewPtr = new BitCastInst(RealNewPtr, Ptr->getType(),
+        getInstrName(Ptr, CastNodeNameSuffix));
+      ReplNewPtr->insertAfter(RealNewPtr);
+    } else
+      ReplNewPtr = RealNewPtr;
+
+    Ptr->replaceAllUsesWith(ReplNewPtr);
+    RecursivelyDeleteTriviallyDeadInstructions(Ptr);
+
+    NewPtrs.insert(RealNewPtr);
+  }
+  }
+  MadeChange = true;
+  return MadeChange; 
+}
+
+
 void PPCLoopInstrFormPrep::addOneCandidate(Instruction *MemI, const SCEV *LSCEV,
                                         SmallVector<Bucket, 16> &Buckets,
                                         unsigned MaxCandidateNum) {
-  assert((MemI && GetPointerOperand(MemI)) &&
+  assert((MemI && getPtrOperand(MemI)) &&
          "Candidate should be a memory instruction.");
   assert(LSCEV && "Invalid SCEV for Ptr value.");
   bool FoundBucket = false;
@@ -340,6 +825,35 @@ SmallVector<Bucket, 16> PPCLoopInstrFormPrep::collectCandidates(
   SmallVector<Bucket, 16> Buckets;
   for (const auto &BB : L->blocks())
     for (auto &J : *BB) {
+      Value *PtrValue = getPtrOperand(&J);
+      if (!PtrValue)
+        continue;
+
+      unsigned PtrAddrSpace = PtrValue->getType()->getPointerAddressSpace();
+      if (PtrAddrSpace)
+        continue;
+
+      if (L->isLoopInvariant(PtrValue))
+        continue;
+
+      const SCEV *LSCEV = SE->getSCEVAtScope(PtrValue, L);
+      const SCEVAddRecExpr *LARSCEV = dyn_cast<SCEVAddRecExpr>(LSCEV);
+      if (!LARSCEV || LARSCEV->getLoop() != L)
+        continue;
+
+      Type *PointerElementType;
+      if (LoadInst *LMemI = dyn_cast<LoadInst>(&J))
+        PointerElementType = LMemI->getType();
+      else if (StoreInst *SMemI = dyn_cast<StoreInst>(&J))
+        PointerElementType = SMemI->getValueOperand()->getType();
+      else {
+        assert(isa<IntrinsicInst>(&J) && "Invalid point operand!\n");
+        PointerElementType = Type::getInt8Ty(J.getContext());
+      }
+
+      if (isValidCandidate(&J, PtrValue, PointerElementType))
+        addOneCandidate(&J, LSCEV, Buckets, MaxCandidateNum);
+/*
       Value *PtrValue;
       Type *PointerElementType;
 
@@ -368,11 +882,15 @@ SmallVector<Bucket, 16> PPCLoopInstrFormPrep::collectCandidates(
 
       const SCEV *LSCEV = SE->getSCEVAtScope(PtrValue, L);
       const SCEVAddRecExpr *LARSCEV = dyn_cast<SCEVAddRecExpr>(LSCEV);
+      LLVM_DEBUG(dbgs() << "Instructions is "; J.dump());
+      LLVM_DEBUG(dbgs() << "ptr value is "; PtrValue->dump());
+      LLVM_DEBUG(dbgs() << "SCEV is "; LSCEV->dump());
       if (!LARSCEV || LARSCEV->getLoop() != L)
         continue;
 
       if (isValidCandidate(&J, PtrValue, PointerElementType))
         addOneCandidate(&J, LSCEV, Buckets, MaxCandidateNum);
+*/
     }
   return Buckets;
 }
@@ -508,7 +1026,7 @@ bool PPCLoopInstrFormPrep::rewriteLoadStores(Loop *L, Bucket &BucketChain,
   // The instruction corresponding to the Bucket's BaseSCEV must be the first
   // in the vector of elements.
   Instruction *MemI = BucketChain.Elements.begin()->Instr;
-  Value *BasePtr = GetPointerOperand(MemI);
+  Value *BasePtr = getPtrOperand(MemI);
   assert(BasePtr && "No pointer operand");
 
   Type *I8Ty = Type::getInt8Ty(MemI->getParent()->getContext());
@@ -645,7 +1163,7 @@ bool PPCLoopInstrFormPrep::rewriteLoadStores(Loop *L, Bucket &BucketChain,
 
   for (auto I = std::next(BucketChain.Elements.begin()),
        IE = BucketChain.Elements.end(); I != IE; ++I) {
-    Value *Ptr = GetPointerOperand(I->Instr);
+    Value *Ptr = getPtrOperand(I->Instr);
     assert(Ptr && "No pointer operand");
     if (NewPtrs.count(Ptr))
       continue;
@@ -985,7 +1503,8 @@ bool PPCLoopInstrFormPrep::runOnLoop(Loop *L) {
     return ST && ST->hasP9Vector() && (PointerElementType->isVectorTy());
   };
 
-  // intrinsic for update form.
+  // Collect buckets of comparable addresses used by loads and stores for update
+  // form.
   SmallVector<Bucket, 16> UpdateFormBuckets =
       collectCandidates(L, isUpdateFormCandidate, MaxVarsUpdateForm);
 
@@ -1011,5 +1530,21 @@ bool PPCLoopInstrFormPrep::runOnLoop(Loop *L) {
   if (!DQFormBuckets.empty())
     MadeChange |= dispFormPrep(L, DQFormBuckets, DQForm);
 
+  LLVM_DEBUG(dbgs() << "start to collect chain candidates\n");
+  // Collect buckets of comparable addresses used by loads and stores for chain
+  // commoning. With chain commoning, we reuses offsets between the offsets, so
+  // the register pressure will be reduced.
+  SmallVector<ChainBucket, 16> ChainBuckets =
+      collectCandidatesForChain(L);
+
+  LLVM_DEBUG(dbgs() << "chain info is\n");
+  LLVM_DEBUG(dbgs() << "chain size is " << ChainBuckets.size() << "\n");
+  for (auto Chain : ChainBuckets)
+    Chain.dump();
+
+  // Prepare for chain commoning.
+  if (!ChainBuckets.empty())
+    MadeChange |= chainCommoning(L, ChainBuckets);
+
   return MadeChange;
 }


        


More information about the llvm-branch-commits mailing list