[llvm] [DirectX] Simplify and correct the flattening of GEPs in DXILFlattenArrays (PR #146173)

Finn Plummer via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 14 12:07:32 PDT 2025


================
@@ -225,131 +212,145 @@ bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
   return true;
 }
 
-void DXILFlattenArraysVisitor::collectIndicesAndDimsFromGEP(
-    GetElementPtrInst &GEP, SmallVectorImpl<Value *> &Indices,
-    SmallVectorImpl<uint64_t> &Dims, bool &AllIndicesAreConstInt) {
-
-  Type *CurrentType = GEP.getSourceElementType();
+bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
+  // Do not visit GEPs more than once
+  if (GEPChainInfoMap.contains(cast<GEPOperator>(&GEP)))
+    return false;
 
-  // Note index 0 is the ptr index.
-  for (Value *Index : llvm::drop_begin(GEP.indices(), 1)) {
-    Indices.push_back(Index);
-    AllIndicesAreConstInt &= isa<ConstantInt>(Index);
+  Value *PtrOperand = GEP.getPointerOperand();
 
-    if (auto *ArrayTy = dyn_cast<ArrayType>(CurrentType)) {
-      Dims.push_back(ArrayTy->getNumElements());
-      CurrentType = ArrayTy->getElementType();
-    } else {
-      assert(false && "Expected array type in GEP chain");
-    }
+  // Replace a GEP ConstantExpr pointer operand with a GEP instruction so that
+  // it can be visited
+  if (auto *PtrOpGEPCE = dyn_cast<ConstantExpr>(PtrOperand);
+      PtrOpGEPCE && PtrOpGEPCE->getOpcode() == Instruction::GetElementPtr) {
+    GetElementPtrInst *OldGEPI =
+        cast<GetElementPtrInst>(PtrOpGEPCE->getAsInstruction());
+    OldGEPI->insertBefore(GEP.getIterator());
+
+    IRBuilder<> Builder(&GEP);
+    SmallVector<Value *> Indices(GEP.indices());
+    Value *NewGEP =
+        Builder.CreateGEP(GEP.getSourceElementType(), OldGEPI, Indices,
+                          GEP.getName(), GEP.getNoWrapFlags());
+    assert(isa<GetElementPtrInst>(NewGEP) &&
+           "Expected newly-created GEP to be an instruction");
+    GetElementPtrInst *NewGEPI = cast<GetElementPtrInst>(NewGEP);
+
+    GEP.replaceAllUsesWith(NewGEPI);
+    GEP.eraseFromParent();
+    visitGetElementPtrInst(*OldGEPI);
+    visitGetElementPtrInst(*NewGEPI);
+    return true;
   }
-}
 
-void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
-    GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType,
-    Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector<Value *> Indices,
-    SmallVector<uint64_t> Dims, bool AllIndicesAreConstInt) {
-  // Check if this GEP is already in the map to avoid circular references
-  if (GEPChainMap.count(&CurrGEP) > 0)
-    return;
-
-  // Collect indices and dimensions from the current GEP
-  collectIndicesAndDimsFromGEP(CurrGEP, Indices, Dims, AllIndicesAreConstInt);
-  bool IsMultiDimArr = isMultiDimensionalArray(CurrGEP.getSourceElementType());
-  if (!IsMultiDimArr) {
-    assert(GEPChainUseCount < FlattenedArrayType->getNumElements());
-    GEPChainMap.insert(
-        {&CurrGEP,
-         {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
-          std::move(Dims), AllIndicesAreConstInt}});
-    return;
-  }
-  bool GepUses = false;
-  for (auto *User : CurrGEP.users()) {
-    if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) {
-      recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand,
-                             ++GEPChainUseCount, Indices, Dims,
-                             AllIndicesAreConstInt);
-      GepUses = true;
+  // Construct GEPInfo for this GEP
+  GEPInfo Info;
+
+  // Obtain the variable and constant byte offsets computed by this GEP
+  const DataLayout &DL = GEP.getDataLayout();
+  unsigned BitWidth = DL.getIndexTypeSizeInBits(GEP.getType());
+  Info.ConstantOffset = {BitWidth, 0};
+  [[maybe_unused]] bool Success = GEP.collectOffset(
+      DL, BitWidth, Info.VariableOffsets, Info.ConstantOffset);
+  assert(Success && "Failed to collect offsets for GEP");
+
+  // If there is a parent GEP, inherit the root array type and pointer, and
+  // merge the byte offsets. Otherwise, this GEP is itself the root of a GEP
+  // chain and we need to deterine the root array type
+  if (auto *PtrOpGEP = dyn_cast<GEPOperator>(PtrOperand)) {
+    assert(GEPChainInfoMap.contains(PtrOpGEP) &&
+           "Expected parent GEP to be visited before this GEP");
+    GEPInfo &PGEPInfo = GEPChainInfoMap[PtrOpGEP];
+    Info.RootFlattenedArrayType = PGEPInfo.RootFlattenedArrayType;
+    Info.RootPointerOperand = PGEPInfo.RootPointerOperand;
+    for (auto &VariableOffset : PGEPInfo.VariableOffsets)
+      Info.VariableOffsets.insert(VariableOffset);
+    Info.ConstantOffset += PGEPInfo.ConstantOffset;
+  } else {
+    Info.RootPointerOperand = PtrOperand;
+
+    // We should try to determine the type of the root from the pointer rather
+    // than the GEP's source element type because this could be a scalar GEP
+    // into an array-typed pointer from an Alloca or Global Variable.
+    Type *RootTy = GEP.getSourceElementType();
+    if (auto *GlobalVar = dyn_cast<GlobalVariable>(PtrOperand)) {
+      if (GlobalMap.contains(GlobalVar))
+        GlobalVar = GlobalMap[GlobalVar];
+      Info.RootPointerOperand = GlobalVar;
+      RootTy = GlobalVar->getValueType();
+    } else if (auto *Alloca = dyn_cast<AllocaInst>(PtrOperand)) {
+      RootTy = Alloca->getAllocatedType();
     }
-  }
-  // This case is just incase the gep chain doesn't end with a 1d array.
-  if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
-    GEPChainMap.insert(
-        {&CurrGEP,
-         {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
-          std::move(Dims), AllIndicesAreConstInt}});
-  }
-}
-
-bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain(
-    GetElementPtrInst &GEP) {
-  GEPData GEPInfo = GEPChainMap.at(&GEP);
-  return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
-}
-bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase(
-    GEPData &GEPInfo, GetElementPtrInst &GEP) {
-  IRBuilder<> Builder(&GEP);
-  Value *FlatIndex;
-  if (GEPInfo.AllIndicesAreConstInt)
-    FlatIndex = genConstFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
-  else
-    FlatIndex =
-        genInstructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
-
-  ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType;
-
-  // Don't append '.flat' to an empty string. If the SSA name isn't available
-  // it could conflict with the ParentOperand's name.
-  std::string FlatName = GEP.hasName() ? GEP.getName().str() + ".flat" : "";
-
-  Value *FlatGEP = Builder.CreateGEP(FlattenedArrayType, GEPInfo.ParentOperand,
-                                     {Builder.getInt32(0), FlatIndex}, FlatName,
-                                     GEP.getNoWrapFlags());
-
-  // Note: Old gep will become an invalid instruction after replaceAllUsesWith.
-  // Erase the old GEP in the map before to avoid invalid instructions
-  // and circular references.
-  GEPChainMap.erase(&GEP);
-
-  GEP.replaceAllUsesWith(FlatGEP);
-  GEP.eraseFromParent();
-  return true;
-}
+    assert(!isMultiDimensionalArray(RootTy) &&
+           "Expected root array type to be flattened");
 
-bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
-  auto It = GEPChainMap.find(&GEP);
-  if (It != GEPChainMap.end())
-    return visitGetElementPtrInstInGEPChain(GEP);
-  if (!isMultiDimensionalArray(GEP.getSourceElementType()))
-    return false;
+    // If the root type is not an array, we don't need to do any flattening
+    if (!isa<ArrayType>(RootTy))
+      return false;
 
-  ArrayType *ArrType = cast<ArrayType>(GEP.getSourceElementType());
-  IRBuilder<> Builder(&GEP);
-  auto [TotalElements, BaseType] = getElementCountAndType(ArrType);
-  ArrayType *FlattenedArrayType = ArrayType::get(BaseType, TotalElements);
+    Info.RootFlattenedArrayType = cast<ArrayType>(RootTy);
+  }
 
-  Value *PtrOperand = GEP.getPointerOperand();
+  // GEPs without users or GEPs with non-GEP users should be replaced such that
+  // the chain of GEPs they are a part of are collapsed to a single GEP into a
+  // flattened array.
+  bool ReplaceThisGEP = GEP.users().empty();
+  for (Value *User : GEP.users())
+    if (!isa<GetElementPtrInst>(User))
+      ReplaceThisGEP = true;
+
+  if (ReplaceThisGEP) {
+    unsigned BytesPerElem =
+        DL.getTypeAllocSize(Info.RootFlattenedArrayType->getArrayElementType());
+    assert(isPowerOf2_32(BytesPerElem) &&
+           "Bytes per element should be a power of 2");
+
+    // Compute the 32-bit index for this flattened GEP from the constant and
+    // variable byte offsets in the GEPInfo
+    IRBuilder<> Builder(&GEP);
+    Value *ZeroIndex = Builder.getInt32(0);
+    uint64_t ConstantOffset =
+        Info.ConstantOffset.udiv(BytesPerElem).getZExtValue();
+    assert(ConstantOffset < UINT32_MAX &&
+           "Constant byte offset for flat GEP index must fit within 32 bits");
+    Value *FlattenedIndex = Builder.getInt32(ConstantOffset);
+    for (auto [VarIndex, Multiplier] : Info.VariableOffsets) {
+      assert(Multiplier.getActiveBits() <= 32 &&
+             "The multiplier for a flat GEP index must fit within 32 bits");
+      assert(VarIndex->getType()->isIntegerTy(32) &&
+             "Expected i32-typed GEP indices");
+      Value *VI;
+      if (Multiplier.getZExtValue() % BytesPerElem != 0) {
+        // This can happen, e.g., with i8 GEPs. To handle this we just divide
+        // by BytesPerElem using an instruction after multiplying VarIndex by
+        // Multiplier.
+        VI = Builder.CreateMul(VarIndex,
+                               Builder.getInt32(Multiplier.getZExtValue()));
+        VI = Builder.CreateLShr(VI, Builder.getInt32(Log2_32(BytesPerElem)));
+      } else
+        VI = Builder.CreateMul(
+            VarIndex,
+            Builder.getInt32(Multiplier.getZExtValue() / BytesPerElem));
+      FlattenedIndex = Builder.CreateAdd(FlattenedIndex, VI);
+    }
 
-  unsigned GEPChainUseCount = 0;
-  recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount);
-
-  // NOTE: hasNUses(0) is not the same as GEPChainUseCount == 0.
-  // Here recursion is used to get the length of the GEP chain.
-  // Handle zero uses here because there won't be an update via
-  // a child in the chain later.
-  if (GEPChainUseCount == 0) {
-    SmallVector<Value *> Indices;
-    SmallVector<uint64_t> Dims;
-    bool AllIndicesAreConstInt = true;
-
-    // Collect indices and dimensions from the GEP
-    collectIndicesAndDimsFromGEP(GEP, Indices, Dims, AllIndicesAreConstInt);
-    GEPData GEPInfo{std::move(FlattenedArrayType), PtrOperand,
-                    std::move(Indices), std::move(Dims), AllIndicesAreConstInt};
-    return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
+    // Construct a new GEP for the flattened array to replace the current GEP
+    Value *NewGEP = Builder.CreateGEP(
+        Info.RootFlattenedArrayType, Info.RootPointerOperand,
+        {ZeroIndex, FlattenedIndex}, GEP.getName(), GEP.getNoWrapFlags());
+
+    // Replace the current GEP with the new GEP. Store GEPInfo into the map
+    // for later use in case this GEP was not the end of the chain
+    GEPChainInfoMap.insert({cast<GEPOperator>(NewGEP), std::move(Info)});
----------------
inbelic wrote:

nit: is there any merit to removing the old gep from the map at this point? I see that we will clear it above anyway

https://github.com/llvm/llvm-project/pull/146173


More information about the llvm-commits mailing list