[llvm] [DirectX] Simplify and correct the flattening of GEPs in DXILFlattenArrays (PR #146173)

Deric C. via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 30 14:52:36 PDT 2025


https://github.com/Icohedron updated https://github.com/llvm/llvm-project/pull/146173

>From b08bb11fd23d314360d8407fef2bc34b77d58e7c Mon Sep 17 00:00:00 2001
From: Icohedron <cheung.deric at gmail.com>
Date: Thu, 26 Jun 2025 23:16:56 +0000
Subject: [PATCH 1/6] Simplify flattening of GEP chains

This simplification also fixes instances of incorrect flat index
computations
---
 llvm/lib/Target/DirectX/DXILFlattenArrays.cpp | 271 +++++++++---------
 1 file changed, 132 insertions(+), 139 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
index 0b7cf2f970172..913a8dcb917f4 100644
--- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
+++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
@@ -40,18 +40,19 @@ class DXILFlattenArraysLegacy : public ModulePass {
   static char ID; // Pass identification.
 };
 
-struct GEPData {
-  ArrayType *ParentArrayType;
-  Value *ParentOperand;
-  SmallVector<Value *> Indices;
-  SmallVector<uint64_t> Dims;
-  bool AllIndicesAreConstInt;
+struct GEPInfo {
+  ArrayType *RootFlattenedArrayType;
+  Value *RootPointerOperand;
+  SmallMapVector<Value *, APInt, 4> VariableOffsets;
+  APInt ConstantOffset;
 };
 
 class DXILFlattenArraysVisitor
     : public InstVisitor<DXILFlattenArraysVisitor, bool> {
 public:
-  DXILFlattenArraysVisitor() {}
+  DXILFlattenArraysVisitor(
+      DenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap)
+      : GlobalMap(GlobalMap) {}
   bool visit(Function &F);
   // InstVisitor methods.  They return true if the instruction was scalarized,
   // false if nothing changed.
@@ -78,7 +79,8 @@ class DXILFlattenArraysVisitor
 
 private:
   SmallVector<WeakTrackingVH> PotentiallyDeadInstrs;
-  DenseMap<GetElementPtrInst *, GEPData> GEPChainMap;
+  DenseMap<GEPOperator *, GEPInfo> GEPChainInfoMap;
+  DenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap;
   bool finish();
   ConstantInt *genConstFlattenIndices(ArrayRef<Value *> Indices,
                                       ArrayRef<uint64_t> Dims,
@@ -86,23 +88,6 @@ class DXILFlattenArraysVisitor
   Value *genInstructionFlattenIndices(ArrayRef<Value *> Indices,
                                       ArrayRef<uint64_t> Dims,
                                       IRBuilder<> &Builder);
-
-  // Helper function to collect indices and dimensions from a GEP instruction
-  void collectIndicesAndDimsFromGEP(GetElementPtrInst &GEP,
-                                    SmallVectorImpl<Value *> &Indices,
-                                    SmallVectorImpl<uint64_t> &Dims,
-                                    bool &AllIndicesAreConstInt);
-
-  void
-  recursivelyCollectGEPs(GetElementPtrInst &CurrGEP,
-                         ArrayType *FlattenedArrayType, Value *PtrOperand,
-                         unsigned &GEPChainUseCount,
-                         SmallVector<Value *> Indices = SmallVector<Value *>(),
-                         SmallVector<uint64_t> Dims = SmallVector<uint64_t>(),
-                         bool AllIndicesAreConstInt = true);
-  bool visitGetElementPtrInstInGEPChain(GetElementPtrInst &GEP);
-  bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo,
-                                            GetElementPtrInst &GEP);
 };
 } // namespace
 
@@ -225,131 +210,139 @@ bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
   return true;
 }
 
-void DXILFlattenArraysVisitor::collectIndicesAndDimsFromGEP(
-    GetElementPtrInst &GEP, SmallVectorImpl<Value *> &Indices,
-    SmallVectorImpl<uint64_t> &Dims, bool &AllIndicesAreConstInt) {
-
-  Type *CurrentType = GEP.getSourceElementType();
+bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
+  // Do not visit GEPs more than once
+  if (GEPChainInfoMap.contains(cast<GEPOperator>(&GEP)))
+    return false;
 
-  // Note index 0 is the ptr index.
-  for (Value *Index : llvm::drop_begin(GEP.indices(), 1)) {
-    Indices.push_back(Index);
-    AllIndicesAreConstInt &= isa<ConstantInt>(Index);
+  // Construct GEPInfo for this GEP
+  GEPInfo Info;
 
-    if (auto *ArrayTy = dyn_cast<ArrayType>(CurrentType)) {
-      Dims.push_back(ArrayTy->getNumElements());
-      CurrentType = ArrayTy->getElementType();
-    } else {
-      assert(false && "Expected array type in GEP chain");
-    }
-  }
-}
+  // Obtain the variable and constant byte offsets computed by this GEP
+  const DataLayout &DL = GEP.getDataLayout();
+  unsigned BitWidth = DL.getIndexTypeSizeInBits(GEP.getType());
+  Info.ConstantOffset = {BitWidth, 0};
+  bool Success = GEP.collectOffset(DL, BitWidth, Info.VariableOffsets,
+                                   Info.ConstantOffset);
+  (void)Success;
+  assert(Success && "Failed to collect offsets for GEP");
 
-void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
-    GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType,
-    Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector<Value *> Indices,
-    SmallVector<uint64_t> Dims, bool AllIndicesAreConstInt) {
-  // Check if this GEP is already in the map to avoid circular references
-  if (GEPChainMap.count(&CurrGEP) > 0)
-    return;
+  Value *PtrOperand = GEP.getPointerOperand();
 
-  // Collect indices and dimensions from the current GEP
-  collectIndicesAndDimsFromGEP(CurrGEP, Indices, Dims, AllIndicesAreConstInt);
-  bool IsMultiDimArr = isMultiDimensionalArray(CurrGEP.getSourceElementType());
-  if (!IsMultiDimArr) {
-    assert(GEPChainUseCount < FlattenedArrayType->getNumElements());
-    GEPChainMap.insert(
-        {&CurrGEP,
-         {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
-          std::move(Dims), AllIndicesAreConstInt}});
-    return;
-  }
-  bool GepUses = false;
-  for (auto *User : CurrGEP.users()) {
-    if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) {
-      recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand,
-                             ++GEPChainUseCount, Indices, Dims,
-                             AllIndicesAreConstInt);
-      GepUses = true;
+  // Replace a GEP ConstantExpr pointer operand with a GEP instruction so that
+  // it can be visited
+  if (auto *PtrOpGEPCE = dyn_cast<ConstantExpr>(PtrOperand))
+    if (PtrOpGEPCE->getOpcode() == Instruction::GetElementPtr) {
+      GetElementPtrInst *OldGEPI =
+          cast<GetElementPtrInst>(PtrOpGEPCE->getAsInstruction());
+      OldGEPI->insertBefore(GEP.getIterator());
+
+      IRBuilder<> Builder(&GEP);
+      SmallVector<Value *> Indices(GEP.idx_begin(), GEP.idx_end());
+      Value *NewGEP =
+          Builder.CreateGEP(GEP.getSourceElementType(), OldGEPI, Indices,
+                            GEP.getName(), GEP.getNoWrapFlags());
+      assert(isa<GetElementPtrInst>(NewGEP) &&
+             "Expected newly-created GEP to not be a ConstantExpr");
+      GetElementPtrInst *NewGEPI = cast<GetElementPtrInst>(NewGEP);
+
+      GEP.replaceAllUsesWith(NewGEPI);
+      GEP.eraseFromParent();
+      visitGetElementPtrInst(*OldGEPI);
+      visitGetElementPtrInst(*NewGEPI);
+      return true;
     }
-  }
-  // This case is just incase the gep chain doesn't end with a 1d array.
-  if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
-    GEPChainMap.insert(
-        {&CurrGEP,
-         {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
-          std::move(Dims), AllIndicesAreConstInt}});
-  }
-}
 
-bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain(
-    GetElementPtrInst &GEP) {
-  GEPData GEPInfo = GEPChainMap.at(&GEP);
-  return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
-}
-bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase(
-    GEPData &GEPInfo, GetElementPtrInst &GEP) {
-  IRBuilder<> Builder(&GEP);
-  Value *FlatIndex;
-  if (GEPInfo.AllIndicesAreConstInt)
-    FlatIndex = genConstFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
-  else
-    FlatIndex =
-        genInstructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
-
-  ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType;
-
-  // Don't append '.flat' to an empty string. If the SSA name isn't available
-  // it could conflict with the ParentOperand's name.
-  std::string FlatName = GEP.hasName() ? GEP.getName().str() + ".flat" : "";
-
-  Value *FlatGEP = Builder.CreateGEP(FlattenedArrayType, GEPInfo.ParentOperand,
-                                     {Builder.getInt32(0), FlatIndex}, FlatName,
-                                     GEP.getNoWrapFlags());
-
-  // Note: Old gep will become an invalid instruction after replaceAllUsesWith.
-  // Erase the old GEP in the map before to avoid invalid instructions
-  // and circular references.
-  GEPChainMap.erase(&GEP);
-
-  GEP.replaceAllUsesWith(FlatGEP);
-  GEP.eraseFromParent();
-  return true;
-}
+  // If there is a parent GEP, inherit the root array type and pointer, and
+  // merge the byte offsets. Otherwise, this GEP is itself the root of a GEP
+  // chain and we need to deterine the root array type
+  if (auto *PtrOpGEP = dyn_cast<GEPOperator>(PtrOperand)) {
+    assert(GEPChainInfoMap.contains(PtrOpGEP) &&
+           "Expected parent GEP to be visited before this GEP");
+    GEPInfo &PGEPInfo = GEPChainInfoMap[PtrOpGEP];
+    Info.RootFlattenedArrayType = PGEPInfo.RootFlattenedArrayType;
+    Info.RootPointerOperand = PGEPInfo.RootPointerOperand;
+    for (auto &VariableOffset : PGEPInfo.VariableOffsets)
+      Info.VariableOffsets.insert(VariableOffset);
+    Info.ConstantOffset += PGEPInfo.ConstantOffset;
+  } else {
+    Info.RootPointerOperand = PtrOperand;
+
+    // We should try to determine the type of the root from the pointer rather
+    // than the GEP's source element type because this could be a scalar GEP
+    // into a multidimensional array-typed pointer from an Alloca or Global
+    // Variable.
+    Type *RootTy = GEP.getSourceElementType();
+    if (auto *GlobalVar = dyn_cast<GlobalVariable>(PtrOperand)) {
+      if (!GlobalMap.contains(GlobalVar))
+        return false;
+      GlobalVariable *NewGlobal = GlobalMap[GlobalVar];
+      Info.RootPointerOperand = NewGlobal;
+      RootTy = NewGlobal->getValueType();
+    } else if (auto *Alloca = dyn_cast<AllocaInst>(PtrOperand)) {
+      RootTy = Alloca->getAllocatedType();
+    }
+    assert(!isMultiDimensionalArray(RootTy) &&
+           "Expected root array type to be flattened");
 
-bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
-  auto It = GEPChainMap.find(&GEP);
-  if (It != GEPChainMap.end())
-    return visitGetElementPtrInstInGEPChain(GEP);
-  if (!isMultiDimensionalArray(GEP.getSourceElementType()))
-    return false;
+    // If the root type is not an array, we don't need to do any flattening
+    if (!isa<ArrayType>(RootTy))
+      return false;
 
-  ArrayType *ArrType = cast<ArrayType>(GEP.getSourceElementType());
-  IRBuilder<> Builder(&GEP);
-  auto [TotalElements, BaseType] = getElementCountAndType(ArrType);
-  ArrayType *FlattenedArrayType = ArrayType::get(BaseType, TotalElements);
+    Info.RootFlattenedArrayType = cast<ArrayType>(RootTy);
+  }
 
-  Value *PtrOperand = GEP.getPointerOperand();
+  // GEPs without users or GEPs with non-GEP users should be replaced such that
+  // the chain of GEPs they are a part of are collapsed to a single GEP into a
+  // flattened array.
+  bool ReplaceThisGEP = GEP.users().empty();
+  for (Value *User : GEP.users())
+    if (!isa<GetElementPtrInst>(User))
+      ReplaceThisGEP = true;
+
+  if (ReplaceThisGEP) {
+    // GEP.collectOffset returns the offset in bytes. So we need to divide its
+    // offsets by the size in bytes of the element type
+    unsigned BytesPerElem = Info.RootFlattenedArrayType->getArrayElementType()
+                                ->getPrimitiveSizeInBits() /
+                            8;
+
+    // Compute the 32-bit index for this flattened GEP from the constant and
+    // variable byte offsets in the GEPInfo
+    IRBuilder<> Builder(&GEP);
+    Value *ZeroIndex = Builder.getInt32(0);
+    uint64_t ConstantOffset =
+        Info.ConstantOffset.udiv(BytesPerElem).getZExtValue();
+    assert(ConstantOffset < UINT32_MAX &&
+           "Constant byte offset for flat GEP index must fit within 32 bits");
+    Value *FlattenedIndex = Builder.getInt32(ConstantOffset);
+    for (auto [VarIndex, Multiplier] : Info.VariableOffsets) {
+      uint64_t Mul = Multiplier.udiv(BytesPerElem).getZExtValue();
+      assert(Mul < UINT32_MAX &&
+             "Multiplier for flat GEP index must fit within 32 bits");
+      assert(VarIndex->getType()->isIntegerTy(32) &&
+             "Expected i32-typed GEP indices");
+      Value *ConstIntMul = Builder.getInt32(Mul);
+      Value *MulVarIndex = Builder.CreateMul(VarIndex, ConstIntMul);
+      FlattenedIndex = Builder.CreateAdd(FlattenedIndex, MulVarIndex);
+    }
 
-  unsigned GEPChainUseCount = 0;
-  recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount);
-
-  // NOTE: hasNUses(0) is not the same as GEPChainUseCount == 0.
-  // Here recursion is used to get the length of the GEP chain.
-  // Handle zero uses here because there won't be an update via
-  // a child in the chain later.
-  if (GEPChainUseCount == 0) {
-    SmallVector<Value *> Indices;
-    SmallVector<uint64_t> Dims;
-    bool AllIndicesAreConstInt = true;
-
-    // Collect indices and dimensions from the GEP
-    collectIndicesAndDimsFromGEP(GEP, Indices, Dims, AllIndicesAreConstInt);
-    GEPData GEPInfo{std::move(FlattenedArrayType), PtrOperand,
-                    std::move(Indices), std::move(Dims), AllIndicesAreConstInt};
-    return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
+    // Construct a new GEP for the flattened array to replace the current GEP
+    Value *NewGEP = Builder.CreateGEP(
+        Info.RootFlattenedArrayType, Info.RootPointerOperand,
+        {ZeroIndex, FlattenedIndex}, GEP.getName(), GEP.getNoWrapFlags());
+
+    // Replace the current GEP with the new GEP. Store GEPInfo into the map
+    // for later use in case this GEP was not the end of the chain
+    GEPChainInfoMap.insert({cast<GEPOperator>(NewGEP), std::move(Info)});
+    GEP.replaceAllUsesWith(NewGEP);
+    GEP.eraseFromParent();
+    return true;
   }
 
+  // This GEP is potentially dead at the end of the pass since it may not have
+  // any users anymore after GEP chains have been collapsed.
+  GEPChainInfoMap.insert({cast<GEPOperator>(&GEP), std::move(Info)});
   PotentiallyDeadInstrs.emplace_back(&GEP);
   return false;
 }
@@ -456,9 +449,9 @@ flattenGlobalArrays(Module &M,
 
 static bool flattenArrays(Module &M) {
   bool MadeChange = false;
-  DXILFlattenArraysVisitor Impl;
   DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
   flattenGlobalArrays(M, GlobalMap);
+  DXILFlattenArraysVisitor Impl(GlobalMap);
   for (auto &F : make_early_inc_range(M.functions())) {
     if (F.isDeclaration())
       continue;

>From e7b9d2131a327dd352664b0574dac7ddf2d7e6f9 Mon Sep 17 00:00:00 2001
From: Icohedron <cheung.deric at gmail.com>
Date: Fri, 27 Jun 2025 21:23:34 +0000
Subject: [PATCH 2/6] Fix tests with incorrect GEPs

A few tests had incorrect GEP indices or types.
This commit fixes these GEPs and array types.
---
 llvm/test/CodeGen/DirectX/flatten-array.ll    |  4 +-
 .../CodeGen/DirectX/flatten-bug-117273.ll     |  8 +--
 .../DirectX/llc-vector-load-scalarize.ll      | 60 +++++++++----------
 .../test/CodeGen/DirectX/scalar-bug-117273.ll |  4 +-
 llvm/test/CodeGen/DirectX/scalarize-alloca.ll |  4 +-
 5 files changed, 38 insertions(+), 42 deletions(-)

diff --git a/llvm/test/CodeGen/DirectX/flatten-array.ll b/llvm/test/CodeGen/DirectX/flatten-array.ll
index dc8c5f8421bfe..e256146bb74f4 100644
--- a/llvm/test/CodeGen/DirectX/flatten-array.ll
+++ b/llvm/test/CodeGen/DirectX/flatten-array.ll
@@ -159,9 +159,9 @@ define void @global_gep_load_index(i32 %row, i32 %col, i32 %timeIndex) {
 define void @global_incomplete_gep_chain(i32 %row, i32 %col) {
 ; CHECK-LABEL: define void @global_incomplete_gep_chain(
 ; CHECK-SAME: i32 [[ROW:%.*]], i32 [[COL:%.*]]) {
-; CHECK-NEXT:    [[TMP1:%.*]] = mul i32 [[COL]], 1
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i32 [[COL]], 4
 ; CHECK-NEXT:    [[TMP2:%.*]] = add i32 0, [[TMP1]]
-; CHECK-NEXT:    [[TMP3:%.*]] = mul i32 [[ROW]], 3
+; CHECK-NEXT:    [[TMP3:%.*]] = mul i32 [[ROW]], 12
 ; CHECK-NEXT:    [[TMP4:%.*]] = add i32 [[TMP2]], [[TMP3]]
 ; CHECK-NEXT:    [[DOTFLAT:%.*]] = getelementptr inbounds [24 x i32], ptr @a.1dim, i32 0, i32 [[TMP4]]
 ; CHECK-NOT: getelementptr inbounds [2 x [3 x [4 x i32]]]{{.*}}
diff --git a/llvm/test/CodeGen/DirectX/flatten-bug-117273.ll b/llvm/test/CodeGen/DirectX/flatten-bug-117273.ll
index c73e5017348d1..930805f0ddc90 100644
--- a/llvm/test/CodeGen/DirectX/flatten-bug-117273.ll
+++ b/llvm/test/CodeGen/DirectX/flatten-bug-117273.ll
@@ -8,16 +8,16 @@
 define internal void @main() {
 ; CHECK-LABEL: define internal void @main() {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr [24 x float], ptr @ZerroInitArr.1dim, i32 0, i32 1
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr [6 x float], ptr @ZerroInitArr.1dim, i32 0, i32 3
 ; CHECK-NEXT:    [[DOTI0:%.*]] = load float, ptr [[TMP0]], align 16
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr [24 x float], ptr @ZerroInitArr.1dim, i32 0, i32 2
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr [6 x float], ptr @ZerroInitArr.1dim, i32 0, i32 6
 ; CHECK-NEXT:    [[DOTI03:%.*]] = load float, ptr [[TMP1]], align 16
 ; CHECK-NEXT:    ret void
 ;
 entry:
-  %0 = getelementptr [8 x [3 x float]], ptr @ZerroInitArr, i32 0, i32 1
+  %0 = getelementptr [2 x [3 x float]], ptr @ZerroInitArr, i32 0, i32 1
   %.i0 = load float, ptr %0, align 16
-  %1 = getelementptr [8 x [3 x float]], ptr @ZerroInitArr, i32 0, i32 2
+  %1 = getelementptr [2 x [3 x float]], ptr @ZerroInitArr, i32 0, i32 2
   %.i03 = load float, ptr %1, align 16
   ret void
 }
diff --git a/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll b/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll
index d5797f6b51348..78550adbe424a 100644
--- a/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll
+++ b/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll
@@ -3,43 +3,35 @@
 
 ; Make sure we can load groupshared, static vectors and arrays of vectors
 
-@"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16
+@"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <4 x i32>] zeroinitializer, align 16
 @"vecData" = external addrspace(3) global <4 x i32>, align 4
 @staticArrayOfVecData = internal global [3 x <4 x i32>] [<4 x i32> <i32 1, i32 2, i32 3, i32 4>, <4 x i32> <i32 5, i32 6, i32 7, i32 8>, <4 x i32> <i32 9, i32 10, i32 11, i32 12>], align 4
-@"groushared2dArrayofVectors" = local_unnamed_addr addrspace(3) global [3 x [ 3 x <4 x i32>]] zeroinitializer, align 16
+@"groupshared2dArrayofVectors" = local_unnamed_addr addrspace(3) global [3 x [3 x <4 x i32>]] zeroinitializer, align 16
 
-; CHECK: @arrayofVecData.scalarized.1dim = local_unnamed_addr addrspace(3) global [6 x float] zeroinitializer, align 16
+; CHECK: @arrayofVecData.scalarized.1dim = local_unnamed_addr addrspace(3) global [8 x i32] zeroinitializer, align 16
 ; CHECK: @vecData.scalarized = external addrspace(3) global [4 x i32], align 4
 ; CHECK: @staticArrayOfVecData.scalarized.1dim = internal global [12 x i32] [i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12], align 4
-; CHECK: @groushared2dArrayofVectors.scalarized.1dim = local_unnamed_addr addrspace(3) global [36 x i32] zeroinitializer, align 16
+; CHECK: @groupshared2dArrayofVectors.scalarized.1dim = local_unnamed_addr addrspace(3) global [36 x i32] zeroinitializer, align 16
 
 ; CHECK-NOT: @arrayofVecData
 ; CHECK-NOT: @arrayofVecData.scalarized
 ; CHECK-NOT: @vecData
 ; CHECK-NOT: @staticArrayOfVecData
 ; CHECK-NOT: @staticArrayOfVecData.scalarized
-; CHECK-NOT: @groushared2dArrayofVectors
-; CHECK-NOT: @groushared2dArrayofVectors.scalarized
+; CHECK-NOT: @groupshared2dArrayofVectors
+; CHECK-NOT: @groupshared2dArrayofVectors.scalarized
 
 define <4 x i32> @load_array_vec_test() #0 {
 ; CHECK-LABEL: define <4 x i32> @load_array_vec_test(
 ; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
-; CHECK-NEXT:    [[TMP1:%.*]] = bitcast ptr addrspace(3) @arrayofVecData.scalarized.1dim to ptr addrspace(3)
-; CHECK-NEXT:    [[TMP2:%.*]] = load i32, ptr addrspace(3) [[TMP1]], align 4
-; CHECK-NEXT:    [[TMP3:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 1) to ptr addrspace(3)
-; CHECK-NEXT:    [[TMP4:%.*]] = load i32, ptr addrspace(3) [[TMP3]], align 4
-; CHECK-NEXT:    [[TMP5:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 2) to ptr addrspace(3)
-; CHECK-NEXT:    [[TMP6:%.*]] = load i32, ptr addrspace(3) [[TMP5]], align 4
-; CHECK-NEXT:    [[TMP7:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 3) to ptr addrspace(3)
-; CHECK-NEXT:    [[TMP8:%.*]] = load i32, ptr addrspace(3) [[TMP7]], align 4
-; CHECK-NEXT:    [[TMP9:%.*]] = bitcast ptr addrspace(3) getelementptr inbounds ([6 x float], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 0, i32 1) to ptr addrspace(3)
-; CHECK-NEXT:    [[TMP10:%.*]] = load i32, ptr addrspace(3) [[TMP9]], align 4
-; CHECK-NEXT:    [[TMP11:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([6 x float], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 0, i32 1), i32 1) to ptr addrspace(3)
-; CHECK-NEXT:    [[TMP12:%.*]] = load i32, ptr addrspace(3) [[TMP11]], align 4
-; CHECK-NEXT:    [[TMP13:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([6 x float], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 0, i32 1), i32 2) to ptr addrspace(3)
-; CHECK-NEXT:    [[TMP14:%.*]] = load i32, ptr addrspace(3) [[TMP13]], align 4
-; CHECK-NEXT:    [[TMP15:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([6 x float], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 0, i32 1), i32 3) to ptr addrspace(3)
-; CHECK-NEXT:    [[TMP16:%.*]] = load i32, ptr addrspace(3) [[TMP15]], align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = load i32, ptr addrspace(3) @arrayofVecData.scalarized.1dim, align 4
+; CHECK-NEXT:    [[TMP4:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 1), align 4
+; CHECK-NEXT:    [[TMP6:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 2), align 4
+; CHECK-NEXT:    [[TMP8:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 3), align 4
+; CHECK-NEXT:    [[TMP10:%.*]] = load i32, ptr addrspace(3) getelementptr inbounds ([8 x i32], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 0, i32 4), align 4
+; CHECK-NEXT:    [[TMP12:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([8 x i32], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 0, i32 4), i32 1), align 4
+; CHECK-NEXT:    [[TMP14:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([8 x i32], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 0, i32 4), i32 2), align 4
+; CHECK-NEXT:    [[TMP16:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([8 x i32], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 0, i32 4), i32 3), align 4
 ; CHECK-NEXT:    [[DOTI05:%.*]] = add i32 [[TMP2]], [[TMP10]]
 ; CHECK-NEXT:    [[DOTI16:%.*]] = add i32 [[TMP4]], [[TMP12]]
 ; CHECK-NEXT:    [[DOTI27:%.*]] = add i32 [[TMP6]], [[TMP14]]
@@ -77,7 +69,9 @@ define <4 x i32> @load_vec_test() #0 {
 define <4 x i32> @load_static_array_of_vec_test(i32 %index) #0 {
 ; CHECK-LABEL: define <4 x i32> @load_static_array_of_vec_test(
 ; CHECK-SAME: i32 [[INDEX:%.*]]) #[[ATTR0]] {
-; CHECK-NEXT:    [[DOTFLAT:%.*]] = getelementptr inbounds [12 x i32], ptr @staticArrayOfVecData.scalarized.1dim, i32 0, i32 [[INDEX]]
+; CHECK-NEXT:    [[TMP3:%.*]] = mul i32 [[INDEX]], 4
+; CHECK-NEXT:    [[TMP2:%.*]] = add i32 0, [[TMP3]]
+; CHECK-NEXT:    [[DOTFLAT:%.*]] = getelementptr inbounds [12 x i32], ptr @staticArrayOfVecData.scalarized.1dim, i32 0, i32 [[TMP2]]
 ; CHECK-NEXT:    [[DOTI0:%.*]] = load i32, ptr [[DOTFLAT]], align 4
 ; CHECK-NEXT:    [[DOTFLAT_I1:%.*]] = getelementptr i32, ptr [[DOTFLAT]], i32 1
 ; CHECK-NEXT:    [[DOTI1:%.*]] = load i32, ptr [[DOTFLAT_I1]], align 4
@@ -99,14 +93,14 @@ define <4 x i32> @load_static_array_of_vec_test(i32 %index) #0 {
 define <4 x i32> @multid_load_test() #0 {
 ; CHECK-LABEL: define <4 x i32> @multid_load_test(
 ; CHECK-SAME: ) #[[ATTR0]] {
-; CHECK-NEXT:    [[TMP1:%.*]] = load i32, ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, align 4
-; CHECK-NEXT:    [[TMP2:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 1), align 4
-; CHECK-NEXT:    [[TMP3:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 2), align 4
-; CHECK-NEXT:    [[TMP4:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 3), align 4
-; CHECK-NEXT:    [[TMP5:%.*]] = load i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 0, i32 4), align 4
-; CHECK-NEXT:    [[DOTI13:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 0, i32 4), i32 1), align 4
-; CHECK-NEXT:    [[DOTI25:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 0, i32 4), i32 2), align 4
-; CHECK-NEXT:    [[DOTI37:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 0, i32 4), i32 3), align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = load i32, ptr addrspace(3) @groupshared2dArrayofVectors.scalarized.1dim, align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @groupshared2dArrayofVectors.scalarized.1dim, i32 1), align 4
+; CHECK-NEXT:    [[TMP3:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @groupshared2dArrayofVectors.scalarized.1dim, i32 2), align 4
+; CHECK-NEXT:    [[TMP4:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @groupshared2dArrayofVectors.scalarized.1dim, i32 3), align 4
+; CHECK-NEXT:    [[TMP5:%.*]] = load i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groupshared2dArrayofVectors.scalarized.1dim, i32 0, i32 16), align 4
+; CHECK-NEXT:    [[DOTI13:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groupshared2dArrayofVectors.scalarized.1dim, i32 0, i32 16), i32 1), align 4
+; CHECK-NEXT:    [[DOTI25:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groupshared2dArrayofVectors.scalarized.1dim, i32 0, i32 16), i32 2), align 4
+; CHECK-NEXT:    [[DOTI37:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groupshared2dArrayofVectors.scalarized.1dim, i32 0, i32 16), i32 3), align 4
 ; CHECK-NEXT:    [[DOTI08:%.*]] = add i32 [[TMP1]], [[TMP5]]
 ; CHECK-NEXT:    [[DOTI19:%.*]] = add i32 [[TMP2]], [[DOTI13]]
 ; CHECK-NEXT:    [[DOTI210:%.*]] = add i32 [[TMP3]], [[DOTI25]]
@@ -117,8 +111,8 @@ define <4 x i32> @multid_load_test() #0 {
 ; CHECK-NEXT:    [[TMP6:%.*]] = insertelement <4 x i32> [[DOTUPTO217]], i32 [[DOTI311]], i32 3
 ; CHECK-NEXT:    ret <4 x i32> [[TMP6]]
 ;
-  %1 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([3 x [3 x <4 x i32>]], [3 x [3 x <4 x i32>]] addrspace(3)* @"groushared2dArrayofVectors", i32 0, i32 0, i32 0), align 4
-  %2 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([3 x [3 x <4 x i32>]], [3 x [3 x <4 x i32>]] addrspace(3)* @"groushared2dArrayofVectors", i32 0, i32 1, i32 1), align 4
+  %1 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([3 x [3 x <4 x i32>]], [3 x [3 x <4 x i32>]] addrspace(3)* @"groupshared2dArrayofVectors", i32 0, i32 0, i32 0), align 4
+  %2 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([3 x [3 x <4 x i32>]], [3 x [3 x <4 x i32>]] addrspace(3)* @"groupshared2dArrayofVectors", i32 0, i32 1, i32 1), align 4
   %3 = add <4 x i32> %1, %2
   ret <4 x i32> %3
 }
diff --git a/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll b/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll
index a07ce2c24f7ac..9ce2108a03831 100644
--- a/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll
+++ b/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll
@@ -8,13 +8,13 @@
 define internal void @main() #1 {
 ; CHECK-LABEL: define internal void @main() {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds [24 x float], ptr @StaticArr.scalarized.1dim, i32 0, i32 1
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds [24 x float], ptr @StaticArr.scalarized.1dim, i32 0, i32 3
 ; CHECK-NEXT:    [[DOTI0:%.*]] = load float, ptr [[TMP0]], align 16
 ; CHECK-NEXT:    [[DOTI1:%.*]] = getelementptr float, ptr [[TMP0]], i32 1
 ; CHECK-NEXT:    [[DOTI11:%.*]] = load float, ptr [[DOTI1]], align 4
 ; CHECK-NEXT:    [[DOTI2:%.*]] = getelementptr float, ptr [[TMP0]], i32 2
 ; CHECK-NEXT:    [[DOTI22:%.*]] = load float, ptr [[DOTI2]], align 8
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds [24 x float], ptr @StaticArr.scalarized.1dim, i32 0, i32 2
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds [24 x float], ptr @StaticArr.scalarized.1dim, i32 0, i32 6
 ; CHECK-NEXT:    [[DOTI03:%.*]] = load float, ptr [[TMP1]], align 16
 ; CHECK-NEXT:    [[DOTI14:%.*]] = getelementptr float, ptr [[TMP1]], i32 1
 ; CHECK-NEXT:    [[DOTI15:%.*]] = load float, ptr [[DOTI14]], align 4
diff --git a/llvm/test/CodeGen/DirectX/scalarize-alloca.ll b/llvm/test/CodeGen/DirectX/scalarize-alloca.ll
index 32e2c3ca2c302..a8557e47b0ea6 100644
--- a/llvm/test/CodeGen/DirectX/scalarize-alloca.ll
+++ b/llvm/test/CodeGen/DirectX/scalarize-alloca.ll
@@ -33,7 +33,9 @@ define void @alloca_2d_gep_test() {
   ; FCHECK:  [[alloca_val:%.*]] = alloca [4 x i32], align 16
   ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
   ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [2 x [2 x i32]], ptr [[alloca_val]], i32 0, i32 [[tid]]
-  ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [4 x i32], ptr [[alloca_val]], i32 0, i32 [[tid]]
+  ; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 2
+  ; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_mul]]
+  ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [4 x i32], ptr [[alloca_val]], i32 0, i32 [[flatidx]]
   ; CHECK: ret void
   %1 = alloca [2 x <2 x i32>], align 16
   %2 = tail call i32 @llvm.dx.thread.id(i32 0)

>From 1a09803d75b42e9b70c37e7e92c0ca2f3ddcf3f0 Mon Sep 17 00:00:00 2001
From: Icohedron <cheung.deric at gmail.com>
Date: Sat, 28 Jun 2025 01:39:17 +0000
Subject: [PATCH 3/6] Allow flattening GEPs for Global Variables not in the
 GlobalMap

---
 llvm/lib/Target/DirectX/DXILFlattenArrays.cpp | 41 +++++++++----------
 llvm/test/CodeGen/DirectX/flatten-array.ll    |  6 +--
 .../CodeGen/DirectX/flatten-bug-117273.ll     |  6 +--
 .../test/CodeGen/DirectX/scalar-bug-117273.ll | 18 +++-----
 4 files changed, 30 insertions(+), 41 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
index 913a8dcb917f4..e58cd829d96d8 100644
--- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
+++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
@@ -215,18 +215,6 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
   if (GEPChainInfoMap.contains(cast<GEPOperator>(&GEP)))
     return false;
 
-  // Construct GEPInfo for this GEP
-  GEPInfo Info;
-
-  // Obtain the variable and constant byte offsets computed by this GEP
-  const DataLayout &DL = GEP.getDataLayout();
-  unsigned BitWidth = DL.getIndexTypeSizeInBits(GEP.getType());
-  Info.ConstantOffset = {BitWidth, 0};
-  bool Success = GEP.collectOffset(DL, BitWidth, Info.VariableOffsets,
-                                   Info.ConstantOffset);
-  (void)Success;
-  assert(Success && "Failed to collect offsets for GEP");
-
   Value *PtrOperand = GEP.getPointerOperand();
 
   // Replace a GEP ConstantExpr pointer operand with a GEP instruction so that
@@ -243,7 +231,7 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
           Builder.CreateGEP(GEP.getSourceElementType(), OldGEPI, Indices,
                             GEP.getName(), GEP.getNoWrapFlags());
       assert(isa<GetElementPtrInst>(NewGEP) &&
-             "Expected newly-created GEP to not be a ConstantExpr");
+             "Expected newly-created GEP to be an instruction");
       GetElementPtrInst *NewGEPI = cast<GetElementPtrInst>(NewGEP);
 
       GEP.replaceAllUsesWith(NewGEPI);
@@ -253,6 +241,18 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
       return true;
     }
 
+  // Construct GEPInfo for this GEP
+  GEPInfo Info;
+
+  // Obtain the variable and constant byte offsets computed by this GEP
+  const DataLayout &DL = GEP.getDataLayout();
+  unsigned BitWidth = DL.getIndexTypeSizeInBits(GEP.getType());
+  Info.ConstantOffset = {BitWidth, 0};
+  bool Success = GEP.collectOffset(DL, BitWidth, Info.VariableOffsets,
+                                   Info.ConstantOffset);
+  (void)Success;
+  assert(Success && "Failed to collect offsets for GEP");
+
   // If there is a parent GEP, inherit the root array type and pointer, and
   // merge the byte offsets. Otherwise, this GEP is itself the root of a GEP
   // chain and we need to deterine the root array type
@@ -270,15 +270,13 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
 
     // We should try to determine the type of the root from the pointer rather
     // than the GEP's source element type because this could be a scalar GEP
-    // into a multidimensional array-typed pointer from an Alloca or Global
-    // Variable.
+    // into an array-typed pointer from an Alloca or Global Variable.
     Type *RootTy = GEP.getSourceElementType();
     if (auto *GlobalVar = dyn_cast<GlobalVariable>(PtrOperand)) {
-      if (!GlobalMap.contains(GlobalVar))
-        return false;
-      GlobalVariable *NewGlobal = GlobalMap[GlobalVar];
-      Info.RootPointerOperand = NewGlobal;
-      RootTy = NewGlobal->getValueType();
+      if (GlobalMap.contains(GlobalVar))
+        GlobalVar = GlobalMap[GlobalVar];
+      Info.RootPointerOperand = GlobalVar;
+      RootTy = GlobalVar->getValueType();
     } else if (auto *Alloca = dyn_cast<AllocaInst>(PtrOperand)) {
       RootTy = Alloca->getAllocatedType();
     }
@@ -341,7 +339,8 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
   }
 
   // This GEP is potentially dead at the end of the pass since it may not have
-  // any users anymore after GEP chains have been collapsed.
+  // any users anymore after GEP chains have been collapsed. We retain store
+  // GEPInfo for GEPs down the chain to use to compute their indices.
   GEPChainInfoMap.insert({cast<GEPOperator>(&GEP), std::move(Info)});
   PotentiallyDeadInstrs.emplace_back(&GEP);
   return false;
diff --git a/llvm/test/CodeGen/DirectX/flatten-array.ll b/llvm/test/CodeGen/DirectX/flatten-array.ll
index e256146bb74f4..dbb1d95df16f3 100644
--- a/llvm/test/CodeGen/DirectX/flatten-array.ll
+++ b/llvm/test/CodeGen/DirectX/flatten-array.ll
@@ -123,8 +123,7 @@ define void @gep_4d_test ()  {
 @b = internal global [2 x [3 x [4 x i32]]] zeroinitializer, align 16
 
 define void @global_gep_load() {
-  ; CHECK: [[GEP_PTR:%.*]] = getelementptr inbounds [24 x i32], ptr @a.1dim, i32 0, i32 6
-  ; CHECK-NEXT: load i32, ptr [[GEP_PTR]], align 4
+  ; CHECK: {{.*}} = load i32, ptr getelementptr inbounds ([24 x i32], ptr @a.1dim, i32 0, i32 6), align 4
   ; CHECK-NEXT:    ret void
   %1 = getelementptr inbounds [2 x [3 x [4 x i32]]], [2 x [3 x [4 x i32]]]* @a, i32 0, i32 0
   %2 = getelementptr inbounds [3 x [4 x i32]], [3 x [4 x i32]]* %1, i32 0, i32 1
@@ -177,8 +176,7 @@ define void @global_incomplete_gep_chain(i32 %row, i32 %col) {
 }
 
 define void @global_gep_store() {
-  ; CHECK: [[GEP_PTR:%.*]] = getelementptr inbounds [24 x i32], ptr @b.1dim, i32 0, i32 13
-  ; CHECK-NEXT: store i32 1, ptr [[GEP_PTR]], align 4
+  ; CHECK: store i32 1, ptr getelementptr inbounds ([24 x i32], ptr @b.1dim, i32 0, i32 13), align 4
   ; CHECK-NEXT:    ret void
   %1 = getelementptr inbounds [2 x [3 x [4 x i32]]], [2 x [3 x [4 x i32]]]* @b, i32 0, i32 1
   %2 = getelementptr inbounds [3 x [4 x i32]], [3 x [4 x i32]]* %1, i32 0, i32 0
diff --git a/llvm/test/CodeGen/DirectX/flatten-bug-117273.ll b/llvm/test/CodeGen/DirectX/flatten-bug-117273.ll
index 930805f0ddc90..78971b8954150 100644
--- a/llvm/test/CodeGen/DirectX/flatten-bug-117273.ll
+++ b/llvm/test/CodeGen/DirectX/flatten-bug-117273.ll
@@ -8,10 +8,8 @@
 define internal void @main() {
 ; CHECK-LABEL: define internal void @main() {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr [6 x float], ptr @ZerroInitArr.1dim, i32 0, i32 3
-; CHECK-NEXT:    [[DOTI0:%.*]] = load float, ptr [[TMP0]], align 16
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr [6 x float], ptr @ZerroInitArr.1dim, i32 0, i32 6
-; CHECK-NEXT:    [[DOTI03:%.*]] = load float, ptr [[TMP1]], align 16
+; CHECK-NEXT:    [[DOTI0:%.*]] = load float, ptr getelementptr ([6 x float], ptr @ZerroInitArr.1dim, i32 0, i32 3), align 16
+; CHECK-NEXT:    [[DOTI03:%.*]] = load float, ptr getelementptr ([6 x float], ptr @ZerroInitArr.1dim, i32 0, i32 6), align 16
 ; CHECK-NEXT:    ret void
 ;
 entry:
diff --git a/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll b/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll
index 9ce2108a03831..43bbe9249aac0 100644
--- a/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll
+++ b/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll
@@ -8,18 +8,12 @@
 define internal void @main() #1 {
 ; CHECK-LABEL: define internal void @main() {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds [24 x float], ptr @StaticArr.scalarized.1dim, i32 0, i32 3
-; CHECK-NEXT:    [[DOTI0:%.*]] = load float, ptr [[TMP0]], align 16
-; CHECK-NEXT:    [[DOTI1:%.*]] = getelementptr float, ptr [[TMP0]], i32 1
-; CHECK-NEXT:    [[DOTI11:%.*]] = load float, ptr [[DOTI1]], align 4
-; CHECK-NEXT:    [[DOTI2:%.*]] = getelementptr float, ptr [[TMP0]], i32 2
-; CHECK-NEXT:    [[DOTI22:%.*]] = load float, ptr [[DOTI2]], align 8
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds [24 x float], ptr @StaticArr.scalarized.1dim, i32 0, i32 6
-; CHECK-NEXT:    [[DOTI03:%.*]] = load float, ptr [[TMP1]], align 16
-; CHECK-NEXT:    [[DOTI14:%.*]] = getelementptr float, ptr [[TMP1]], i32 1
-; CHECK-NEXT:    [[DOTI15:%.*]] = load float, ptr [[DOTI14]], align 4
-; CHECK-NEXT:    [[DOTI26:%.*]] = getelementptr float, ptr [[TMP1]], i32 2
-; CHECK-NEXT:    [[DOTI27:%.*]] = load float, ptr [[DOTI26]], align 8
+; CHECK-NEXT:    [[DOTI0:%.*]] = load float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 0, i32 3), align 16
+; CHECK-NEXT:    [[DOTI11:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 0, i32 3), i32 1), align 4
+; CHECK-NEXT:    [[DOTI22:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 0, i32 3), i32 2), align 8
+; CHECK-NEXT:    [[DOTI03:%.*]] = load float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 0, i32 6), align 16
+; CHECK-NEXT:    [[DOTI15:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 0, i32 6), i32 1), align 4
+; CHECK-NEXT:    [[DOTI27:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 0, i32 6), i32 2), align 8
 ; CHECK-NEXT:    ret void
 ;
 entry:

>From dee77e9f5cf1bed2b13d6078d7255219cac70538 Mon Sep 17 00:00:00 2001
From: Icohedron <cheung.deric at gmail.com>
Date: Sat, 28 Jun 2025 05:33:51 +0000
Subject: [PATCH 4/6] Add test demonstrating the flattening of scalar GEPs,
 including i8

---
 llvm/test/CodeGen/DirectX/flatten-array.ll | 19 +++++++++++++++++++
 1 file changed, 19 insertions(+)

diff --git a/llvm/test/CodeGen/DirectX/flatten-array.ll b/llvm/test/CodeGen/DirectX/flatten-array.ll
index dbb1d95df16f3..bd83a3da24cca 100644
--- a/llvm/test/CodeGen/DirectX/flatten-array.ll
+++ b/llvm/test/CodeGen/DirectX/flatten-array.ll
@@ -255,5 +255,24 @@ define void @gep_4d_index_and_gep_chain_mixed() {
   ret void
 }
 
+; This test demonstrates that the collapsing of GEP chains occurs regardless of
+; the source element type given to the GEP. As long as the root pointer being
+; indexed to is an aggregate data structure, the GEP will be flattened.
+define void @gep_scalar_flatten() {
+  ; CHECK-LABEL: gep_scalar_flatten
+  ; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [24 x i32]
+  ; CHECK-NEXT: getelementptr inbounds nuw [24 x i32], ptr [[ALLOCA]], i32 0, i32 17
+  ; CHECK-NEXT: getelementptr inbounds nuw [24 x i32], ptr [[ALLOCA]], i32 0, i32 17
+  ; CHECK-NEXT: getelementptr inbounds nuw [24 x i32], ptr [[ALLOCA]], i32 0, i32 23
+  ; CHECK-NEXT: ret void
+  %a = alloca [2 x [3 x [4 x i32]]], align 4
+  %i8root = getelementptr inbounds nuw i8, [2 x [3 x [4 x i32]]]* %a, i32 68 ; %a[1][1][1]
+  %i32root = getelementptr inbounds nuw i32, [2 x [3 x [4 x i32]]]* %a, i32 17 ; %a[1][1][1]
+  %c0 = getelementptr inbounds nuw [2 x [3 x [4 x i32]]], [2 x [3 x [4 x i32]]]* %a, i32 0, i32 1 ; %a[1]
+  %c1 = getelementptr inbounds nuw i32, [3 x [4 x i32]]* %c0, i32 8 ; %a[1][2]
+  %c2 = getelementptr inbounds nuw i8, [4 x i32]* %c1, i32 12 ; %a[1][2][3]
+  ret void
+}
+
 ; Make sure we don't try to walk the body of a function declaration.
 declare void @opaque_function()

>From 1ccaa986e4e9e024d91ef6a67509af657c7c9ebc Mon Sep 17 00:00:00 2001
From: Icohedron <cheung.deric at gmail.com>
Date: Mon, 30 Jun 2025 21:25:57 +0000
Subject: [PATCH 5/6] Clear GEPChainInfoMap after visiting a function

It is possible for GEPOperator* to overlap between function visits, so
we have to clear the map or else there could be stale data leftover that
causes GEPs to be incorrectly generated or skipped over.
---
 llvm/lib/Target/DirectX/DXILFlattenArrays.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
index e58cd829d96d8..0f1c7673da2cf 100644
--- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
+++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
@@ -92,6 +92,7 @@ class DXILFlattenArraysVisitor
 } // namespace
 
 bool DXILFlattenArraysVisitor::finish() {
+  GEPChainInfoMap.clear();
   RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
   return true;
 }

>From 32806096da0117831839567b18def3b8cd1a474e Mon Sep 17 00:00:00 2001
From: Icohedron <cheung.deric at gmail.com>
Date: Mon, 30 Jun 2025 21:49:14 +0000
Subject: [PATCH 6/6] Fix check for constantexpr GEP instead of an instruction

---
 llvm/test/CodeGen/DirectX/flatten-array.ll | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/llvm/test/CodeGen/DirectX/flatten-array.ll b/llvm/test/CodeGen/DirectX/flatten-array.ll
index bd83a3da24cca..97e4d7a709260 100644
--- a/llvm/test/CodeGen/DirectX/flatten-array.ll
+++ b/llvm/test/CodeGen/DirectX/flatten-array.ll
@@ -202,8 +202,7 @@ define void @two_index_gep() {
 
 define void @two_index_gep_const() {
   ; CHECK-LABEL: define void @two_index_gep_const(
-  ; CHECK-NEXT: [[GEP_PTR:%.*]] = getelementptr inbounds nuw [4 x float], ptr addrspace(3) @g.1dim, i32 0, i32 3
-  ; CHECK-NEXT: load float, ptr addrspace(3) [[GEP_PTR]], align 4
+  ; CHECK-NEXT: load float, ptr addrspace(3) getelementptr inbounds nuw ([4 x float], ptr addrspace(3) @g.1dim, i32 0, i32 3), align 4
   ; CHECK-NEXT: ret void
   %1 = getelementptr inbounds nuw [2 x [2 x float]], ptr addrspace(3) @g, i32 0, i32 1, i32 1
   %3 = load float, ptr addrspace(3) %1, align 4



More information about the llvm-commits mailing list