[llvm] c262ba2 - [Scalarizer] Avoid pointer element type accesses

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 3 01:29:06 PST 2022


Author: Nikita Popov
Date: 2022-03-03T10:28:58+01:00
New Revision: c262ba2aab78ec9857342f583101f2fe4edee372

URL: https://github.com/llvm/llvm-project/commit/c262ba2aab78ec9857342f583101f2fe4edee372
DIFF: https://github.com/llvm/llvm-project/commit/c262ba2aab78ec9857342f583101f2fe4edee372.diff

LOG: [Scalarizer] Avoid pointer element type accesses

Pass through the load/store type to the Scatterer instead.

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/Scalarizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index f150eca4cec56..4a77ce826ca3a 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -94,7 +94,7 @@ class Scatterer {
   // Scatter V into Size components.  If new instructions are needed,
   // insert them before BBI in BB.  If Cache is nonnull, use it to cache
   // the results.
-  Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,
+  Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v, Type *PtrElemTy,
             ValueVector *cachePtr = nullptr);
 
   // Return component I, creating a new Value for it if necessary.
@@ -107,8 +107,8 @@ class Scatterer {
   BasicBlock *BB;
   BasicBlock::iterator BBI;
   Value *V;
+  Type *PtrElemTy;
   ValueVector *CachePtr;
-  PointerType *PtrTy;
   ValueVector Tmp;
   unsigned Size;
 };
@@ -214,7 +214,7 @@ class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
   bool visitCallInst(CallInst &ICI);
 
 private:
-  Scatterer scatter(Instruction *Point, Value *V);
+  Scatterer scatter(Instruction *Point, Value *V, Type *PtrElemTy = nullptr);
   void gather(Instruction *Op, const ValueVector &CV);
   bool canTransferMetadata(unsigned Kind);
   void transferMetadataAndIRFlags(Instruction *Op, const ValueVector &CV);
@@ -263,12 +263,14 @@ INITIALIZE_PASS_END(ScalarizerLegacyPass, "scalarizer",
                     "Scalarize vector operations", false, false)
 
 Scatterer::Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,
-                     ValueVector *cachePtr)
-  : BB(bb), BBI(bbi), V(v), CachePtr(cachePtr) {
+                     Type *PtrElemTy, ValueVector *cachePtr)
+    : BB(bb), BBI(bbi), V(v), PtrElemTy(PtrElemTy), CachePtr(cachePtr) {
   Type *Ty = V->getType();
-  PtrTy = dyn_cast<PointerType>(Ty);
-  if (PtrTy)
-    Ty = PtrTy->getPointerElementType();
+  if (Ty->isPointerTy()) {
+    assert(cast<PointerType>(Ty)->isOpaqueOrPointeeTypeMatches(PtrElemTy) &&
+           "Pointer element type mismatch");
+    Ty = PtrElemTy;
+  }
   Size = cast<FixedVectorType>(Ty)->getNumElements();
   if (!CachePtr)
     Tmp.resize(Size, nullptr);
@@ -285,15 +287,15 @@ Value *Scatterer::operator[](unsigned I) {
   if (CV[I])
     return CV[I];
   IRBuilder<> Builder(BB, BBI);
-  if (PtrTy) {
-    Type *ElTy =
-        cast<VectorType>(PtrTy->getPointerElementType())->getElementType();
+  if (PtrElemTy) {
+    Type *VectorElemTy = cast<VectorType>(PtrElemTy)->getElementType();
     if (!CV[0]) {
-      Type *NewPtrTy = PointerType::get(ElTy, PtrTy->getAddressSpace());
+      Type *NewPtrTy = PointerType::get(
+          VectorElemTy, V->getType()->getPointerAddressSpace());
       CV[0] = Builder.CreateBitCast(V, NewPtrTy, V->getName() + ".i0");
     }
     if (I != 0)
-      CV[I] = Builder.CreateConstGEP1_32(ElTy, CV[0], I,
+      CV[I] = Builder.CreateConstGEP1_32(VectorElemTy, CV[0], I,
                                          V->getName() + ".i" + Twine(I));
   } else {
     // Search through a chain of InsertElementInsts looking for element I.
@@ -360,13 +362,14 @@ bool ScalarizerVisitor::visit(Function &F) {
 
 // Return a scattered form of V that can be accessed by Point.  V must be a
 // vector or a pointer to a vector.
-Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V) {
+Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V,
+                                     Type *PtrElemTy) {
   if (Argument *VArg = dyn_cast<Argument>(V)) {
     // Put the scattered form of arguments in the entry block,
     // so that it can be used everywhere.
     Function *F = VArg->getParent();
     BasicBlock *BB = &F->getEntryBlock();
-    return Scatterer(BB, BB->begin(), V, &Scattered[V]);
+    return Scatterer(BB, BB->begin(), V, PtrElemTy, &Scattered[V]);
   }
   if (Instruction *VOp = dyn_cast<Instruction>(V)) {
     // When scalarizing PHI nodes we might try to examine/rewrite InsertElement
@@ -377,17 +380,17 @@ Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V) {
     // need to analyse them further.
     if (!DT->isReachableFromEntry(VOp->getParent()))
       return Scatterer(Point->getParent(), Point->getIterator(),
-                       UndefValue::get(V->getType()));
+                       UndefValue::get(V->getType()), PtrElemTy);
     // Put the scattered form of an instruction directly after the
     // instruction, skipping over PHI nodes and debug intrinsics.
     BasicBlock *BB = VOp->getParent();
     return Scatterer(
         BB, skipPastPhiNodesAndDbg(std::next(BasicBlock::iterator(VOp))), V,
-        &Scattered[V]);
+        PtrElemTy, &Scattered[V]);
   }
   // In the fallback case, just put the scattered before Point and
   // keep the result local to Point.
-  return Scatterer(Point->getParent(), Point->getIterator(), V);
+  return Scatterer(Point->getParent(), Point->getIterator(), V, PtrElemTy);
 }
 
 // Replace Op with the gathered form of the components in CV.  Defer the
@@ -889,7 +892,7 @@ bool ScalarizerVisitor::visitLoadInst(LoadInst &LI) {
 
   unsigned NumElems = cast<FixedVectorType>(Layout->VecTy)->getNumElements();
   IRBuilder<> Builder(&LI);
-  Scatterer Ptr = scatter(&LI, LI.getPointerOperand());
+  Scatterer Ptr = scatter(&LI, LI.getPointerOperand(), LI.getType());
   ValueVector Res;
   Res.resize(NumElems);
 
@@ -915,7 +918,7 @@ bool ScalarizerVisitor::visitStoreInst(StoreInst &SI) {
 
   unsigned NumElems = cast<FixedVectorType>(Layout->VecTy)->getNumElements();
   IRBuilder<> Builder(&SI);
-  Scatterer VPtr = scatter(&SI, SI.getPointerOperand());
+  Scatterer VPtr = scatter(&SI, SI.getPointerOperand(), FullValue->getType());
   Scatterer VVal = scatter(&SI, FullValue);
 
   ValueVector Stores;


        


More information about the llvm-commits mailing list