[llvm] [DirectX] Flatten arrays (PR #114332)

Farzon Lotfi via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 4 11:15:17 PST 2024


https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/114332

>From 76914836eede9661100cbf220335e424719f92f4 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Thu, 10 Oct 2024 04:59:08 -0400
Subject: [PATCH 1/3] [DirectX] A pass to flatten arrays

---
 llvm/lib/Target/DirectX/CMakeLists.txt        |   1 +
 llvm/lib/Target/DirectX/DXILFlattenArrays.cpp | 460 ++++++++++++++++++
 llvm/lib/Target/DirectX/DXILFlattenArrays.h   |  25 +
 llvm/lib/Target/DirectX/DirectX.h             |   6 +
 .../Target/DirectX/DirectXPassRegistry.def    |   1 +
 .../Target/DirectX/DirectXTargetMachine.cpp   |   3 +
 llvm/test/CodeGen/DirectX/flatten-array.ll    | 192 ++++++++
 llvm/test/CodeGen/DirectX/llc-pipeline.ll     |   1 +
 llvm/test/CodeGen/DirectX/scalar-data.ll      |  11 +-
 llvm/test/CodeGen/DirectX/scalar-load.ll      |  23 +-
 llvm/test/CodeGen/DirectX/scalar-store.ll     |  11 +-
 11 files changed, 719 insertions(+), 15 deletions(-)
 create mode 100644 llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
 create mode 100644 llvm/lib/Target/DirectX/DXILFlattenArrays.h
 create mode 100644 llvm/test/CodeGen/DirectX/flatten-array.ll

diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt
index 5d1dc50fdb0dde..a726071e0dcecd 100644
--- a/llvm/lib/Target/DirectX/CMakeLists.txt
+++ b/llvm/lib/Target/DirectX/CMakeLists.txt
@@ -22,6 +22,7 @@ add_llvm_target(DirectXCodeGen
   DXContainerGlobals.cpp
   DXILDataScalarization.cpp
   DXILFinalizeLinkage.cpp
+  DXILFlattenArrays.cpp
   DXILIntrinsicExpansion.cpp
   DXILOpBuilder.cpp
   DXILOpLowering.cpp
diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
new file mode 100644
index 00000000000000..e4660909c438ee
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
@@ -0,0 +1,460 @@
+//===- DXILFlattenArrays.cpp - Flattens DXIL Arrays-----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+///
+/// \file This file contains a pass to flatten arrays for the DirectX Backend.
+//
+//===----------------------------------------------------------------------===//
+
+#include "DXILFlattenArrays.h"
+#include "DirectX.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/DXILResource.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstVisitor.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Transforms/Utils/Local.h"
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#define DEBUG_TYPE "dxil-flatten-arrays"
+
+using namespace llvm;
+
+class DXILFlattenArraysLegacy : public ModulePass {
+
+public:
+  bool runOnModule(Module &M) override;
+  DXILFlattenArraysLegacy() : ModulePass(ID) {}
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override;
+  static char ID; // Pass identification.
+};
+
+struct GEPData {
+  ArrayType *ParentArrayType;
+  Value *ParendOperand;
+  SmallVector<Value *> Indices;
+  SmallVector<uint64_t> Dims;
+  bool AllIndicesAreConstInt;
+};
+
+class DXILFlattenArraysVisitor
+    : public InstVisitor<DXILFlattenArraysVisitor, bool> {
+public:
+  DXILFlattenArraysVisitor() {}
+  bool visit(Function &F);
+  // InstVisitor methods.  They return true if the instruction was scalarized,
+  // false if nothing changed.
+  bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
+  bool visitAllocaInst(AllocaInst &AI);
+  bool visitInstruction(Instruction &I) { return false; }
+  bool visitSelectInst(SelectInst &SI) { return false; }
+  bool visitICmpInst(ICmpInst &ICI) { return false; }
+  bool visitFCmpInst(FCmpInst &FCI) { return false; }
+  bool visitUnaryOperator(UnaryOperator &UO) { return false; }
+  bool visitBinaryOperator(BinaryOperator &BO) { return false; }
+  bool visitCastInst(CastInst &CI) { return false; }
+  bool visitBitCastInst(BitCastInst &BCI) { return false; }
+  bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
+  bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
+  bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
+  bool visitPHINode(PHINode &PHI) { return false; }
+  bool visitLoadInst(LoadInst &LI) { return false; }
+  bool visitStoreInst(StoreInst &SI) { return false; }
+  bool visitCallInst(CallInst &ICI) { return false; }
+  bool visitFreezeInst(FreezeInst &FI) { return false; }
+  static bool isMultiDimensionalArray(Type *T);
+  static unsigned getTotalElements(Type *ArrayTy);
+  static Type *getBaseElementType(Type *ArrayTy);
+
+private:
+  SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
+  DenseMap<GetElementPtrInst *, GEPData> GEPChainMap;
+  bool finish();
+  ConstantInt *constFlattenIndices(ArrayRef<Value *> Indices,
+                                   ArrayRef<uint64_t> Dims,
+                                   IRBuilder<> &Builder);
+  Value *instructionFlattenIndices(ArrayRef<Value *> Indices,
+                                   ArrayRef<uint64_t> Dims,
+                                   IRBuilder<> &Builder);
+  void
+  recursivelyCollectGEPs(GetElementPtrInst &CurrGEP,
+                         ArrayType *FlattenedArrayType, Value *PtrOperand,
+                         unsigned &GEPChainUseCount,
+                         SmallVector<Value *> Indices = SmallVector<Value *>(),
+                         SmallVector<uint64_t> Dims = SmallVector<uint64_t>(),
+                         bool AllIndicesAreConstInt = true);
+  ConstantInt *computeFlatIndex(GetElementPtrInst &GEP);
+  bool visitGetElementPtrInstInGEPChain(GetElementPtrInst &GEP);
+  bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo,
+                                            GetElementPtrInst &GEP);
+};
+
+bool DXILFlattenArraysVisitor::finish() {
+  RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
+  return true;
+}
+
+bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type *T) {
+  if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
+    return isa<ArrayType>(ArrType->getElementType());
+  return false;
+}
+
+unsigned DXILFlattenArraysVisitor::getTotalElements(Type *ArrayTy) {
+  unsigned TotalElements = 1;
+  Type *CurrArrayTy = ArrayTy;
+  while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
+    TotalElements *= InnerArrayTy->getNumElements();
+    CurrArrayTy = InnerArrayTy->getElementType();
+  }
+  return TotalElements;
+}
+
+Type *DXILFlattenArraysVisitor::getBaseElementType(Type *ArrayTy) {
+  Type *CurrArrayTy = ArrayTy;
+  while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
+    CurrArrayTy = InnerArrayTy->getElementType();
+  }
+  return CurrArrayTy;
+}
+
+ConstantInt *DXILFlattenArraysVisitor::constFlattenIndices(
+    ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
+  assert(Indices.size() == Dims.size() &&
+         "Indicies and dimmensions should be the same");
+  unsigned FlatIndex = 0;
+  unsigned Multiplier = 1;
+
+  for (int I = Indices.size() - 1; I >= 0; --I) {
+    unsigned DimSize = Dims[I];
+    ConstantInt *CIndex = dyn_cast<ConstantInt>(Indices[I]);
+    assert(CIndex && "This function expects all indicies to be ConstantInt");
+    FlatIndex += CIndex->getZExtValue() * Multiplier;
+    Multiplier *= DimSize;
+  }
+  return Builder.getInt32(FlatIndex);
+}
+
+Value *DXILFlattenArraysVisitor::instructionFlattenIndices(
+    ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
+  if (Indices.size() == 1)
+    return Indices[0];
+
+  Value *FlatIndex = Builder.getInt32(0);
+  unsigned Multiplier = 1;
+
+  for (int I = Indices.size() - 1; I >= 0; --I) {
+    unsigned DimSize = Dims[I];
+    Value *VMultiplier = Builder.getInt32(Multiplier);
+    Value *ScaledIndex = Builder.CreateMul(Indices[I], VMultiplier);
+    FlatIndex = Builder.CreateAdd(FlatIndex, ScaledIndex);
+    Multiplier *= DimSize;
+  }
+  return FlatIndex;
+}
+
+bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
+  if (!isMultiDimensionalArray(AI.getAllocatedType()))
+    return false;
+
+  ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType());
+  IRBuilder<> Builder(&AI);
+  unsigned TotalElements = getTotalElements(ArrType);
+
+  ArrayType *FattenedArrayType =
+      ArrayType::get(getBaseElementType(ArrType), TotalElements);
+  AllocaInst *FlatAlloca =
+      Builder.CreateAlloca(FattenedArrayType, nullptr, AI.getName() + ".flat");
+  FlatAlloca->setAlignment(AI.getAlign());
+  AI.replaceAllUsesWith(FlatAlloca);
+  AI.eraseFromParent();
+  return true;
+}
+
+ConstantInt *
+DXILFlattenArraysVisitor::computeFlatIndex(GetElementPtrInst &GEP) {
+  unsigned IndexAmount = GEP.getNumIndices();
+  assert(IndexAmount >= 1 && "Need At least one Index");
+  if (IndexAmount == 1)
+    return dyn_cast<ConstantInt>(GEP.getOperand(GEP.getNumOperands() - 1));
+
+  // Get the type of the base pointer.
+  Type *BaseType = GEP.getSourceElementType();
+
+  // Determine the dimensions of the multi-dimensional array.
+  SmallVector<int64_t> Dimensions;
+  while (auto *ArrType = dyn_cast<ArrayType>(BaseType)) {
+    Dimensions.push_back(ArrType->getNumElements());
+    BaseType = ArrType->getElementType();
+  }
+  unsigned FlatIndex = 0;
+  unsigned Multiplier = 1;
+  unsigned BitWidth = 32;
+  for (const Use &Index : GEP.indices()) {
+    ConstantInt *CurrentIndex = dyn_cast<ConstantInt>(Index);
+    BitWidth = CurrentIndex->getBitWidth();
+    if (!CurrentIndex)
+      return nullptr;
+    int64_t IndexValue = CurrentIndex->getSExtValue();
+    FlatIndex += IndexValue * Multiplier;
+
+    if (!Dimensions.empty()) {
+      Multiplier *= Dimensions.back(); // Use the last dimension size
+      Dimensions.pop_back();           // Remove the last dimension
+    }
+  }
+  return ConstantInt::get(GEP.getContext(), APInt(BitWidth, FlatIndex));
+}
+
+void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
+    GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType,
+    Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector<Value *> Indices,
+    SmallVector<uint64_t> Dims, bool AllIndicesAreConstInt) {
+  Value *LastIndex = CurrGEP.getOperand(CurrGEP.getNumOperands() - 1);
+  AllIndicesAreConstInt &= isa<ConstantInt>(LastIndex);
+  Indices.push_back(LastIndex);
+  assert(isa<ArrayType>(CurrGEP.getSourceElementType()));
+  Dims.push_back(
+      cast<ArrayType>(CurrGEP.getSourceElementType())->getNumElements());
+  bool IsMultiDimArr = isMultiDimensionalArray(CurrGEP.getSourceElementType());
+  if (!IsMultiDimArr) {
+    assert(GEPChainUseCount < FlattenedArrayType->getNumElements());
+    GEPChainMap.insert(
+        {&CurrGEP,
+         {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
+          std::move(Dims), AllIndicesAreConstInt}});
+    return;
+  }
+  bool GepUses = false;
+  for (auto *User : CurrGEP.users()) {
+    if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) {
+      recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand,
+                             ++GEPChainUseCount, Indices, Dims, AllIndicesAreConstInt);
+      GepUses = true;
+    }
+  }
+  // This case is just incase the gep chain doesn't end with a 1d array.
+  if(IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
+    GEPChainMap.insert(
+        {&CurrGEP,
+         {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
+          std::move(Dims), AllIndicesAreConstInt}});
+  }
+}
+
+bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain(
+    GetElementPtrInst &GEP) {
+  GEPData GEPInfo = GEPChainMap.at(&GEP);
+  return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
+}
+bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase(
+    GEPData &GEPInfo, GetElementPtrInst &GEP) {
+  IRBuilder<> Builder(&GEP);
+  Value *FlatIndex;
+  if (GEPInfo.AllIndicesAreConstInt)
+    FlatIndex = constFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
+  else
+    FlatIndex =
+        instructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
+
+  ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType;
+  Value *FlatGEP =
+      Builder.CreateGEP(FlattenedArrayType, GEPInfo.ParendOperand, FlatIndex,
+                        GEP.getName() + ".flat", GEP.isInBounds());
+
+  GEP.replaceAllUsesWith(FlatGEP);
+  GEP.eraseFromParent();
+  return true;
+}
+
+bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
+  auto It = GEPChainMap.find(&GEP);
+  if (It != GEPChainMap.end())
+    return visitGetElementPtrInstInGEPChain(GEP);
+  if (!isMultiDimensionalArray(GEP.getSourceElementType()))
+    return false;
+
+  ArrayType *ArrType = cast<ArrayType>(GEP.getSourceElementType());
+  IRBuilder<> Builder(&GEP);
+  unsigned TotalElements = getTotalElements(ArrType);
+  ArrayType *FlattenedArrayType =
+      ArrayType::get(getBaseElementType(ArrType), TotalElements);
+
+  Value *PtrOperand = GEP.getPointerOperand();
+
+  unsigned GEPChainUseCount = 0;
+  recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount);
+  
+  // NOTE: hasNUses(0) is not the same as GEPChainUseCount == 0.
+  // Here recursion is used to get the length of the GEP chain.
+  // Handle zero uses here because there won't be an update via 
+  // a child in the chain later.
+  if (GEPChainUseCount == 0) {
+    SmallVector<Value *> Indices({GEP.getOperand(GEP.getNumOperands() - 1)});
+    SmallVector<uint64_t> Dims({ArrType->getNumElements()});
+    bool AllIndicesAreConstInt = isa<ConstantInt>(Indices[0]);
+    GEPData GEPInfo{std::move(FlattenedArrayType), PtrOperand,
+                    std::move(Indices), std::move(Dims), AllIndicesAreConstInt};
+    return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
+  }
+  
+  PotentiallyDeadInstrs.emplace_back(&GEP);
+  return false;
+}
+
+bool DXILFlattenArraysVisitor::visit(Function &F) {
+  bool MadeChange = false;
+  ////for (BasicBlock &BB : make_early_inc_range(F)) {
+  ReversePostOrderTraversal<Function *> RPOT(&F);
+  for (BasicBlock *BB : make_early_inc_range(RPOT)) {
+    for (Instruction &I : make_early_inc_range(*BB)) {
+      if (InstVisitor::visit(I) && I.getType()->isVoidTy()) {
+        I.eraseFromParent();
+        MadeChange = true;
+      }
+    }
+  }
+  finish();
+  return MadeChange;
+}
+
+static void collectElements(Constant *Init,
+                            SmallVectorImpl<Constant *> &Elements) {
+  // Base case: If Init is not an array, add it directly to the vector.
+  if (!isa<ArrayType>(Init->getType())) {
+    Elements.push_back(Init);
+    return;
+  }
+
+  // Recursive case: Process each element in the array.
+  if (auto *ArrayConstant = dyn_cast<ConstantArray>(Init)) {
+    for (unsigned I = 0; I < ArrayConstant->getNumOperands(); ++I) {
+      collectElements(ArrayConstant->getOperand(I), Elements);
+    }
+  } else if (auto *DataArrayConstant = dyn_cast<ConstantDataArray>(Init)) {
+    for (unsigned I = 0; I < DataArrayConstant->getNumElements(); ++I) {
+      collectElements(DataArrayConstant->getElementAsConstant(I), Elements);
+    }
+  } else {
+    assert(
+        false &&
+        "Expected a ConstantArray or ConstantDataArray for array initializer!");
+  }
+}
+
+static Constant *transformInitializer(Constant *Init, Type *OrigType,
+                                      ArrayType *FlattenedType,
+                                      LLVMContext &Ctx) {
+  // Handle ConstantAggregateZero (zero-initialized constants)
+  if (isa<ConstantAggregateZero>(Init))
+    return ConstantAggregateZero::get(FlattenedType);
+
+  // Handle UndefValue (undefined constants)
+  if (isa<UndefValue>(Init))
+    return UndefValue::get(FlattenedType);
+
+  if (!isa<ArrayType>(OrigType))
+    return Init;
+
+  SmallVector<Constant *> FlattenedElements;
+  collectElements(Init, FlattenedElements);
+  assert(FlattenedType->getNumElements() == FlattenedElements.size() &&
+         "The number of collected elements should match the FlattenedType");
+  return ConstantArray::get(FlattenedType, FlattenedElements);
+}
+
+static void
+flattenGlobalArrays(Module &M,
+                    DenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap) {
+  LLVMContext &Ctx = M.getContext();
+  for (GlobalVariable &G : M.globals()) {
+    Type *OrigType = G.getValueType();
+    if (!DXILFlattenArraysVisitor::isMultiDimensionalArray(OrigType))
+      continue;
+
+    ArrayType *ArrType = cast<ArrayType>(OrigType);
+    unsigned TotalElements =
+        DXILFlattenArraysVisitor::getTotalElements(ArrType);
+    ArrayType *FattenedArrayType = ArrayType::get(
+        DXILFlattenArraysVisitor::getBaseElementType(ArrType), TotalElements);
+
+    // Create a new global variable with the updated type
+    // Note: Initializer is set via transformInitializer
+    GlobalVariable *NewGlobal =
+        new GlobalVariable(M, FattenedArrayType, G.isConstant(), G.getLinkage(),
+                           /*Initializer=*/nullptr, G.getName() + ".1dim", &G,
+                           G.getThreadLocalMode(), G.getAddressSpace(),
+                           G.isExternallyInitialized());
+
+    // Copy relevant attributes
+    NewGlobal->setUnnamedAddr(G.getUnnamedAddr());
+    if (G.getAlignment() > 0) {
+      NewGlobal->setAlignment(G.getAlign());
+    }
+
+    if (G.hasInitializer()) {
+      Constant *Init = G.getInitializer();
+      Constant *NewInit =
+          transformInitializer(Init, OrigType, FattenedArrayType, Ctx);
+      NewGlobal->setInitializer(NewInit);
+    }
+    GlobalMap[&G] = NewGlobal;
+  }
+}
+
+static bool flattenArrays(Module &M) {
+  bool MadeChange = false;
+  DXILFlattenArraysVisitor Impl;
+  DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
+  flattenGlobalArrays(M, GlobalMap);
+  for (auto &F : make_early_inc_range(M.functions())) {
+    if (F.isIntrinsic())
+      continue;
+    MadeChange |= Impl.visit(F);
+  }
+  for (auto &[Old, New] : GlobalMap) {
+    Old->replaceAllUsesWith(New);
+    Old->eraseFromParent();
+    MadeChange |= true;
+  }
+  return MadeChange;
+}
+
+PreservedAnalyses DXILFlattenArrays::run(Module &M, ModuleAnalysisManager &) {
+  bool MadeChanges = flattenArrays(M);
+  if (!MadeChanges)
+    return PreservedAnalyses::all();
+  PreservedAnalyses PA;
+  PA.preserve<DXILResourceAnalysis>();
+  return PA;
+}
+
+bool DXILFlattenArraysLegacy::runOnModule(Module &M) {
+  return flattenArrays(M);
+}
+
+void DXILFlattenArraysLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
+  AU.addPreserved<DXILResourceWrapperPass>();
+}
+
+char DXILFlattenArraysLegacy::ID = 0;
+
+INITIALIZE_PASS_BEGIN(DXILFlattenArraysLegacy, DEBUG_TYPE,
+                      "DXIL Array Flattener", false, false)
+INITIALIZE_PASS_END(DXILFlattenArraysLegacy, DEBUG_TYPE, "DXIL Array Flattener",
+                    false, false)
+
+ModulePass *llvm::createDXILFlattenArraysLegacyPass() {
+  return new DXILFlattenArraysLegacy();
+}
diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.h b/llvm/lib/Target/DirectX/DXILFlattenArrays.h
new file mode 100644
index 00000000000000..409f8d198782c9
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.h
@@ -0,0 +1,25 @@
+//===- DXILFlattenArrays.h - Perform flattening of DXIL Arrays -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+#ifndef LLVM_TARGET_DIRECTX_DXILFLATTENARRAYS_H
+#define LLVM_TARGET_DIRECTX_DXILFLATTENARRAYS_H
+
+#include "DXILResource.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Pass.h"
+
+namespace llvm {
+
+/// A pass that transforms multidimensional arrays into one-dimensional arrays.
+class DXILFlattenArrays : public PassInfoMixin<DXILFlattenArrays> {
+public:
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
+};
+} // namespace llvm
+
+#endif // LLVM_TARGET_DIRECTX_DXILFLATTENARRAYS_H
diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h
index 3221779be2f311..3454f16ecd5955 100644
--- a/llvm/lib/Target/DirectX/DirectX.h
+++ b/llvm/lib/Target/DirectX/DirectX.h
@@ -40,6 +40,12 @@ void initializeDXILDataScalarizationLegacyPass(PassRegistry &);
 /// Pass to scalarize llvm global data into a DXIL legal form
 ModulePass *createDXILDataScalarizationLegacyPass();
 
+/// Initializer for DXIL Array Flatten Pass
+void initializeDXILFlattenArraysLegacyPass(PassRegistry &);
+
+/// Pass to flatten arrays into a one dimensional DXIL legal form
+ModulePass *createDXILFlattenArraysLegacyPass();
+
 /// Initializer for DXILOpLowering
 void initializeDXILOpLoweringLegacyPass(PassRegistry &);
 
diff --git a/llvm/lib/Target/DirectX/DirectXPassRegistry.def b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
index ae729a1082b867..a0f864ed39375f 100644
--- a/llvm/lib/Target/DirectX/DirectXPassRegistry.def
+++ b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
@@ -24,6 +24,7 @@ MODULE_ANALYSIS("dxil-resource-md", DXILResourceMDAnalysis())
 #define MODULE_PASS(NAME, CREATE_PASS)
 #endif
 MODULE_PASS("dxil-data-scalarization", DXILDataScalarization())
+MODULE_PASS("dxil-flatten-arrays", DXILFlattenArrays())
 MODULE_PASS("dxil-intrinsic-expansion", DXILIntrinsicExpansion())
 MODULE_PASS("dxil-op-lower", DXILOpLowering())
 MODULE_PASS("dxil-pretty-printer", DXILPrettyPrinterPass(dbgs()))
diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
index 18251ea3bd01d3..59dbf053d6c222 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
@@ -13,6 +13,7 @@
 
 #include "DirectXTargetMachine.h"
 #include "DXILDataScalarization.h"
+#include "DXILFlattenArrays.h"
 #include "DXILIntrinsicExpansion.h"
 #include "DXILOpLowering.h"
 #include "DXILPrettyPrinter.h"
@@ -48,6 +49,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
   auto *PR = PassRegistry::getPassRegistry();
   initializeDXILIntrinsicExpansionLegacyPass(*PR);
   initializeDXILDataScalarizationLegacyPass(*PR);
+  initializeDXILFlattenArraysLegacyPass(*PR);
   initializeScalarizerLegacyPassPass(*PR);
   initializeDXILPrepareModulePass(*PR);
   initializeEmbedDXILPassPass(*PR);
@@ -91,6 +93,7 @@ class DirectXPassConfig : public TargetPassConfig {
     addPass(createDXILDataScalarizationLegacyPass());
     ScalarizerPassOptions DxilScalarOptions;
     DxilScalarOptions.ScalarizeLoadStore = true;
+    addPass(createDXILFlattenArraysLegacyPass());
     addPass(createScalarizerPass(DxilScalarOptions));
     addPass(createDXILOpLoweringLegacyPass());
     addPass(createDXILFinalizeLinkageLegacyPass());
diff --git a/llvm/test/CodeGen/DirectX/flatten-array.ll b/llvm/test/CodeGen/DirectX/flatten-array.ll
new file mode 100644
index 00000000000000..f3815b7f270717
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/flatten-array.ll
@@ -0,0 +1,192 @@
+; RUN: opt -S -dxil-flatten-arrays  %s | FileCheck %s
+
+; CHECK-LABEL: alloca_2d_test
+define void @alloca_2d_test ()  {
+    ; CHECK: alloca [9 x i32], align 4
+    ; CHECK-NOT: alloca [3 x [3 x i32]], align 4
+    %1 = alloca [3 x [3 x i32]], align 4
+    ret void
+}
+
+; CHECK-LABEL: alloca_3d_test
+define void @alloca_3d_test ()  {
+    ; CHECK: alloca [8 x i32], align 4
+    ; CHECK-NOT: alloca [2 x[2 x [2 x i32]]], align 4
+    %1 = alloca [2 x[2 x [2 x i32]]], align 4
+    ret void
+}
+
+; CHECK-LABEL: alloca_4d_test
+define void @alloca_4d_test ()  {
+    ; CHECK: alloca [16 x i32], align 4
+    ; CHECK-NOT: alloca [ 2x[2 x[2 x [2 x i32]]]], align 4
+    %1 = alloca [2x[2 x[2 x [2 x i32]]]], align 4
+    ret void
+}
+
+; CHECK-LABEL: gep_2d_test
+define void @gep_2d_test ()  {
+    ; CHECK: [[a:%.*]] = alloca [9 x i32], align 4
+    ; CHECK-COUNT-9: getelementptr inbounds [9 x i32], ptr [[a]], i32 {{[0-8]}}
+    ; CHECK-NOT: getelementptr inbounds [3 x [3 x i32]], ptr %1, i32 0, i32 0, i32 {{[0-2]}}
+    ; CHECK-NOT: getelementptr inbounds [3 x i32], [3 x i32]* {{.*}}, i32 0, i32 {{[0-2]}}
+    %1 = alloca [3 x [3 x i32]], align 4
+    %g2d0 = getelementptr inbounds [3 x [3 x i32]], [3 x [3 x i32]]* %1, i32 0, i32 0
+    %g1d_1 = getelementptr inbounds [3 x i32], [3 x i32]* %g2d0, i32 0, i32 0
+    %g1d_2 = getelementptr inbounds [3 x i32], [3 x i32]* %g2d0, i32 0, i32 1
+    %g1d_3 = getelementptr inbounds [3 x i32], [3 x i32]* %g2d0, i32 0, i32 2
+    %g2d1 = getelementptr inbounds [3 x [3 x i32]], [3 x [3 x i32]]* %1, i32 0, i32 1
+    %g1d1_1 = getelementptr inbounds [3 x i32], [3 x i32]* %g2d1, i32 0, i32 0
+    %g1d1_2 = getelementptr inbounds [3 x i32], [3 x i32]* %g2d1, i32 0, i32 1
+    %g1d1_3 = getelementptr inbounds [3 x i32], [3 x i32]* %g2d1, i32 0, i32 2
+    %g2d2 = getelementptr inbounds [3 x [3 x i32]], [3 x [3 x i32]]* %1, i32 0, i32 2
+    %g1d2_1 = getelementptr inbounds [3 x i32], [3 x i32]* %g2d2, i32 0, i32 0
+    %g1d2_2 = getelementptr inbounds [3 x i32], [3 x i32]* %g2d2, i32 0, i32 1
+    %g1d2_3 = getelementptr inbounds [3 x i32], [3 x i32]* %g2d2, i32 0, i32 2
+    
+    ret void
+}
+
+; CHECK-LABEL: gep_3d_test
+define void @gep_3d_test ()  {
+    ; CHECK: [[a:%.*]] = alloca [8 x i32], align 4
+    ; CHECK-COUNT-8: getelementptr inbounds [8 x i32], ptr [[a]], i32 {{[0-7]}}
+    ; CHECK-NOT: getelementptr inbounds [2 x[2 x [2 x i32]]], ptr %1, i32 0, i32 0, i32 {{[0-1]}}
+    ; CHECK-NOT: getelementptr inbounds [2 x [2 x i32]], ptr {{.*}}, i32 0, i32 0, i32 {{[0-1]}}
+    ; CHECK-NOT: getelementptr inbounds [2 x i32], [2 x i32]* {{.*}}, i32 0, i32 {{[0-1]}}
+    %1 = alloca [2 x[2 x [2 x i32]]], align 4
+    %g3d0 = getelementptr inbounds [2 x[2 x [2 x i32]]], [2 x[2 x [2 x i32]]]* %1, i32 0, i32 0
+    %g2d0 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g3d0, i32 0, i32 0
+    %g1d_1 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0, i32 0, i32 0
+    %g1d_2 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0, i32 0, i32 1
+    %g2d1 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g3d0, i32 0, i32 1
+    %g1d1_1 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d1, i32 0, i32 0
+    %g1d1_2 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d1, i32 0, i32 1
+    %g3d1 = getelementptr inbounds [2 x[2 x [2 x i32]]], [2 x[2 x [2 x i32]]]* %1, i32 0, i32 1
+    %g2d2 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g3d1, i32 0, i32 0
+    %g1d2_1 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d2, i32 0, i32 0
+    %g1d2_2 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d2, i32 0, i32 1
+    %g2d3 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g3d1, i32 0, i32 1
+    %g1d3_1 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d3, i32 0, i32 0
+    %g1d3_2 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d3, i32 0, i32 1
+    ret void
+}
+
+; CHECK-LABEL: gep_4d_test
+define void @gep_4d_test ()  {
+    ; CHECK: [[a:%.*]] = alloca [16 x i32], align 4
+    ; CHECK-COUNT-16: getelementptr inbounds [16 x i32], ptr [[a]], i32 {{[0-9]|1[0-5]}}
+    ; CHECK-NOT: getelementptr inbounds [2x[2 x[2 x [2 x i32]]]], ptr %1, i32 0, i32 0, i32 {{[0-1]}}
+    ; CHECK-NOT: getelementptr inbounds [2 x[2 x [2 x i32]]], ptr {{.*}}, i32 0, i32 0, i32 {{[0-1]}}
+    ; CHECK-NOT: getelementptr inbounds [2 x [2 x i32]], ptr {{.*}}, i32 0, i32 0, i32 {{[0-1]}}
+    ; CHECK-NOT: getelementptr inbounds [2 x i32], [2 x i32]* {{.*}}, i32 0, i32 {{[0-1]}}
+    %1 = alloca [2x[2 x[2 x [2 x i32]]]], align 4
+    %g4d0 = getelementptr inbounds [2x[2 x[2 x [2 x i32]]]], [2x[2 x[2 x [2 x i32]]]]* %1, i32 0, i32 0
+    %g3d0 = getelementptr inbounds [2 x[2 x [2 x i32]]], [2 x[2 x [2 x i32]]]* %g4d0, i32 0, i32 0
+    %g2d0_0 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g3d0, i32 0, i32 0
+    %g1d_0 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0_0, i32 0, i32 0
+    %g1d_1 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0_0, i32 0, i32 1
+    %g2d0_1 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g3d0, i32 0, i32 1
+    %g1d_2 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0_1, i32 0, i32 0
+    %g1d_3 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0_1, i32 0, i32 1
+    %g3d1 = getelementptr inbounds [2 x[2 x [2 x i32]]], [2 x[2 x [2 x i32]]]* %g4d0, i32 0, i32 1
+    %g2d0_2 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g3d1, i32 0, i32 0
+    %g1d_4 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0_2, i32 0, i32 0
+    %g1d_5 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0_2, i32 0, i32 1
+    %g2d1_2 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g3d1, i32 0, i32 1
+    %g1d_6 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d1_2, i32 0, i32 0
+    %g1d_7 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d1_2, i32 0, i32 1
+    %g4d1 = getelementptr inbounds [2x[2 x[2 x [2 x i32]]]], [2x[2 x[2 x [2 x i32]]]]* %1, i32 0, i32 1
+    %g3d0_1 = getelementptr inbounds [2 x[2 x [2 x i32]]], [2 x[2 x [2 x i32]]]* %g4d1, i32 0, i32 0
+    %g2d0_3 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g3d0_1, i32 0, i32 0
+    %g1d_8 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0_3, i32 0, i32 0
+    %g1d_9 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0_3, i32 0, i32 1
+    %g2d0_4 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g3d0_1, i32 0, i32 1
+    %g1d_10 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0_4, i32 0, i32 0
+    %g1d_11 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0_4, i32 0, i32 1
+    %g3d1_1 = getelementptr inbounds [2 x[2 x [2 x i32]]], [2 x[2 x [2 x i32]]]* %g4d1, i32 0, i32 1
+    %g2d0_5 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g3d1_1, i32 0, i32 0
+    %g1d_12 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0_5, i32 0, i32 0
+    %g1d_13 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0_5, i32 0, i32 1
+    %g2d1_3 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g3d1_1, i32 0, i32 1
+    %g1d_14 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d1_3, i32 0, i32 0
+    %g1d_15 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d1_3, i32 0, i32 1
+    ret void
+}
+
+
+ at a = internal global [2 x [3 x [4 x i32]]] [[3 x [4 x i32]] [[4 x i32] [i32 0, i32 1, i32 2, i32 3], 
+                                                             [4 x i32] [i32 4, i32 5, i32 6, i32 7], 
+                                                             [4 x i32] [i32 8, i32 9, i32 10, i32 11]], 
+                                            [3 x [4 x i32]] [[4 x i32] [i32 12, i32 13, i32 14, i32 15], 
+                                                             [4 x i32] [i32 16, i32 17, i32 18, i32 19], 
+                                                             [4 x i32] [i32 20, i32 21, i32 22, i32 23]]], align 4
+
+ at b = internal global [2 x [3 x [4 x i32]]] zeroinitializer, align 16
+
+define void @global_gep_load() {
+  ; CHECK: load i32, ptr getelementptr inbounds ([24 x i32], ptr @a.1dim, i32 6), align 4
+  ; CHECK-NOT: getelementptr inbounds [2 x [3 x [4 x i32]]]{{.*}}
+  ; CHECK-NOT: getelementptr inbounds [3 x [4 x i32]]{{.*}}
+  ; CHECK-NOT: getelementptr inbounds [4 x i32]{{.*}}
+  %1 = getelementptr inbounds [2 x [3 x [4 x i32]]], [2 x [3 x [4 x i32]]]* @a, i32 0, i32 0
+  %2 = getelementptr inbounds [3 x [4 x i32]], [3 x [4 x i32]]* %1, i32 0, i32 1
+  %3 = getelementptr inbounds [4 x i32], [4 x i32]* %2, i32 0, i32 2
+  %4 = load i32, i32* %3, align 4
+  ret void
+}
+
+define void @global_gep_load_index(i32 %row, i32 %col, i32 %timeIndex) {
+; CHECK-LABEL: define void @global_gep_load_index(
+; CHECK-SAME: i32 [[ROW:%.*]], i32 [[COL:%.*]], i32 [[TIMEINDEX:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i32 [[TIMEINDEX]], 1
+; CHECK-NEXT:    [[TMP2:%.*]] = add i32 0, [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = mul i32 [[COL]], 4
+; CHECK-NEXT:    [[TMP4:%.*]] = add i32 [[TMP2]], [[TMP3]]
+; CHECK-NEXT:    [[TMP5:%.*]] = mul i32 [[ROW]], 12
+; CHECK-NEXT:    [[TMP6:%.*]] = add i32 [[TMP4]], [[TMP5]]
+; CHECK-NEXT:    [[DOTFLAT:%.*]] = getelementptr inbounds [24 x i32], ptr @a.1dim, i32 [[TMP6]]
+; CHECK-NOT: getelementptr inbounds [2 x [3 x [4 x i32]]]{{.*}}
+; CHECK-NOT: getelementptr inbounds [3 x [4 x i32]]{{.*}}
+; CHECK-NOT: getelementptr inbounds [4 x i32]{{.*}}
+; CHECK-NEXT:    [[TMP7:%.*]] = load i32, ptr [[DOTFLAT]], align 4
+; CHECK-NEXT:    ret void
+;
+  %1 = getelementptr inbounds [2 x [3 x [4 x i32]]], [2 x [3 x [4 x i32]]]* @a, i32 0, i32 %row
+  %2 = getelementptr inbounds [3 x [4 x i32]], [3 x [4 x i32]]* %1, i32 0, i32 %col
+  %3 = getelementptr inbounds [4 x i32], [4 x i32]* %2, i32 0, i32 %timeIndex
+  %4 = load i32, i32* %3, align 4
+  ret void
+}
+
+define void @global_incomplete_gep_chain(i32 %row, i32 %col) {
+; CHECK-LABEL: define void @global_incomplete_gep_chain(
+; CHECK-SAME: i32 [[ROW:%.*]], i32 [[COL:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i32 [[COL]], 1
+; CHECK-NEXT:    [[TMP2:%.*]] = add i32 0, [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = mul i32 [[ROW]], 3
+; CHECK-NEXT:    [[TMP4:%.*]] = add i32 [[TMP2]], [[TMP3]]
+; CHECK-NEXT:    [[DOTFLAT:%.*]] = getelementptr inbounds [24 x i32], ptr @a.1dim, i32 [[TMP4]]
+; CHECK-NOT: getelementptr inbounds [2 x [3 x [4 x i32]]]{{.*}}
+; CHECK-NOT: getelementptr inbounds [3 x [4 x i32]]{{.*}}
+; CHECK-NOT: getelementptr inbounds [4 x i32]{{.*}}
+; CHECK-NEXT:    [[TMP5:%.*]] = load i32, ptr [[DOTFLAT]], align 4
+; CHECK-NEXT:    ret void
+;
+  %1 = getelementptr inbounds [2 x [3 x [4 x i32]]], [2 x [3 x [4 x i32]]]* @a, i32 0, i32 %row
+  %2 = getelementptr inbounds [3 x [4 x i32]], [3 x [4 x i32]]* %1, i32 0, i32 %col
+  %4 = load i32, i32* %2, align 4
+  ret void
+}
+
+define void @global_gep_store() {
+  ; CHECK: store i32 1, ptr getelementptr inbounds ([24 x i32], ptr @b.1dim, i32 13), align 4
+  ; CHECK-NOT: getelementptr inbounds [2 x [3 x [4 x i32]]]{{.*}}
+  ; CHECK-NOT: getelementptr inbounds [3 x [4 x i32]]{{.*}}
+  ; CHECK-NOT: getelementptr inbounds [4 x i32]{{.*}}
+  %1 = getelementptr inbounds [2 x [3 x [4 x i32]]], [2 x [3 x [4 x i32]]]* @b, i32 0, i32 1
+  %2 = getelementptr inbounds [3 x [4 x i32]], [3 x [4 x i32]]* %1, i32 0, i32 0
+  %3 = getelementptr inbounds [4 x i32], [4 x i32]* %2, i32 0, i32 1
+  store i32 1, i32* %3, align 4
+  ret void
+}
\ No newline at end of file
diff --git a/llvm/test/CodeGen/DirectX/llc-pipeline.ll b/llvm/test/CodeGen/DirectX/llc-pipeline.ll
index 224037cfe7fbe3..f0950df08eff5b 100644
--- a/llvm/test/CodeGen/DirectX/llc-pipeline.ll
+++ b/llvm/test/CodeGen/DirectX/llc-pipeline.ll
@@ -9,6 +9,7 @@
 ; CHECK-NEXT: ModulePass Manager
 ; CHECK-NEXT:   DXIL Intrinsic Expansion
 ; CHECK-NEXT:   DXIL Data Scalarization
+; CHECK-NEXT:   DXIL Array Flattener
 ; CHECK-NEXT:   FunctionPass Manager
 ; CHECK-NEXT:     Dominator Tree Construction
 ; CHECK-NEXT:     Scalarize vector operations
diff --git a/llvm/test/CodeGen/DirectX/scalar-data.ll b/llvm/test/CodeGen/DirectX/scalar-data.ll
index c436f1eae4425d..11b76205aa54c6 100644
--- a/llvm/test/CodeGen/DirectX/scalar-data.ll
+++ b/llvm/test/CodeGen/DirectX/scalar-data.ll
@@ -1,4 +1,5 @@
-; RUN: opt -S -passes='dxil-data-scalarization,function(scalarizer<load-store>),dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+; RUN: opt -S -passes='dxil-data-scalarization,function(scalarizer<load-store>),dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s -check-prefix DATACHECK
+; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays,function(scalarizer<load-store>),dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 ; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s
 
 ; Make sure we don't touch arrays without vectors and that can recurse multiple-dimension arrays of vectors
@@ -8,5 +9,9 @@
 
 ; CHECK @staticArray
 ; CHECK-NOT: @staticArray.scalarized
-; CHECK: @groushared3dArrayofVectors.scalarized = local_unnamed_addr addrspace(3) global [3 x [3 x [3 x [4 x i32]]]] zeroinitializer, align 16
-; CHECK-NOT: @groushared3dArrayofVectors
+; CHECK-NOT: @staticArray.scalarized.1dim
+; CHECK-NOT: @staticArray.1dim
+; DATACHECK: @groushared3dArrayofVectors.scalarized = local_unnamed_addr addrspace(3) global [3 x [3 x [3 x [4 x i32]]]] zeroinitializer, align 16
+; CHECK: @groushared3dArrayofVectors.scalarized.1dim = local_unnamed_addr addrspace(3) global [108 x i32] zeroinitializer, align 16
+; DATACHECK-NOT: @groushared3dArrayofVectors
+; CHECK-NOT: @groushared3dArrayofVectors.scalarized
diff --git a/llvm/test/CodeGen/DirectX/scalar-load.ll b/llvm/test/CodeGen/DirectX/scalar-load.ll
index b911a8f7855bb8..7d01c0cfa7fa69 100644
--- a/llvm/test/CodeGen/DirectX/scalar-load.ll
+++ b/llvm/test/CodeGen/DirectX/scalar-load.ll
@@ -1,4 +1,5 @@
-; RUN: opt -S -passes='dxil-data-scalarization,function(scalarizer<load-store>),dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+; RUN: opt -S -passes='dxil-data-scalarization,function(scalarizer<load-store>),dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s  -check-prefix DATACHECK
+; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays,function(scalarizer<load-store>),dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 ; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s
 
 ; Make sure we can load groupshared, static vectors and arrays of vectors
@@ -8,20 +9,25 @@
 @staticArrayOfVecData = internal global [3 x <4 x i32>] [<4 x i32> <i32 1, i32 2, i32 3, i32 4>, <4 x i32> <i32 5, i32 6, i32 7, i32 8>, <4 x i32> <i32 9, i32 10, i32 11, i32 12>], align 4
 @"groushared2dArrayofVectors" = local_unnamed_addr addrspace(3) global [3 x [ 3 x <4 x i32>]] zeroinitializer, align 16
 
-; CHECK: @arrayofVecData.scalarized = local_unnamed_addr addrspace(3) global [2 x [3 x float]] zeroinitializer, align 16
+; DATACHECK: @arrayofVecData.scalarized = local_unnamed_addr addrspace(3) global [2 x [3 x float]] zeroinitializer, align 16
+; CHECK: @arrayofVecData.scalarized.1dim = local_unnamed_addr addrspace(3) global [6 x float] zeroinitializer, align 16
 ; CHECK: @vecData.scalarized = external addrspace(3) global [4 x i32], align 4
-; CHECK: @staticArrayOfVecData.scalarized = internal global [3 x [4 x i32]] {{\[}}[4 x i32] [i32 1, i32 2, i32 3, i32 4], [4 x i32] [i32 5, i32 6, i32 7, i32 8], [4 x i32] [i32 9, i32 10, i32 11, i32 12]], align 4
-; CHECK: @groushared2dArrayofVectors.scalarized = local_unnamed_addr addrspace(3) global [3 x [3 x [4 x i32]]] zeroinitializer, align 16
+; DATACHECK: @staticArrayOfVecData.scalarized = internal global [3 x [4 x i32]] {{\[}}[4 x i32] [i32 1, i32 2, i32 3, i32 4], [4 x i32] [i32 5, i32 6, i32 7, i32 8], [4 x i32] [i32 9, i32 10, i32 11, i32 12]], align 4
+; CHECK: @staticArrayOfVecData.scalarized.1dim = internal global [12 x i32] [i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12], align 4
+; DATACHECK: @groushared2dArrayofVectors.scalarized = local_unnamed_addr addrspace(3) global [3 x [3 x [4 x i32]]] zeroinitializer, align 16
+; CHECK: @groushared2dArrayofVectors.scalarized.1dim = local_unnamed_addr addrspace(3) global [36 x i32] zeroinitializer, align 16
 
 ; CHECK-NOT: @arrayofVecData
+; CHECK-NOT: @arrayofVecData.scalarized
 ; CHECK-NOT: @vecData
 ; CHECK-NOT: @staticArrayOfVecData
+; CHECK-NOT: @staticArrayOfVecData.scalarized
 ; CHECK-NOT: @groushared2dArrayofVectors
-
+; CHECK-NOT: @groushared2dArrayofVectors.scalarized
 
 ; CHECK-LABEL: load_array_vec_test
 define <4 x i32> @load_array_vec_test() #0 {
-  ; CHECK-COUNT-8: load i32, ptr addrspace(3) {{(.*@arrayofVecData.scalarized.*|%.*)}}, align 4
+  ; CHECK-COUNT-8: load i32, ptr addrspace(3) {{(.*@arrayofVecData.scalarized.1dim.*|%.*)}}, align 4
   ; CHECK-NOT: load i32, ptr addrspace(3) {{.*}}, align 4
   %1 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([2 x <4 x i32>], [2 x <4 x i32>] addrspace(3)* @"arrayofVecData", i32 0, i32 0), align 4
   %2 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([2 x <4 x i32>], [2 x <4 x i32>] addrspace(3)* @"arrayofVecData", i32 0, i32 1), align 4
@@ -39,7 +45,8 @@ define <4 x i32> @load_vec_test() #0 {
 
 ; CHECK-LABEL: load_static_array_of_vec_test
 define <4 x i32> @load_static_array_of_vec_test(i32 %index) #0 {
-  ; CHECK: getelementptr [3 x [4 x i32]], ptr @staticArrayOfVecData.scalarized, i32 0, i32 %index
+  ; DATACHECK: getelementptr [3 x [4 x i32]], ptr @staticArrayOfVecData.scalarized, i32 0, i32 %index
+  ; CHECK: getelementptr [12 x i32], ptr @staticArrayOfVecData.scalarized.1dim, i32 %index
   ; CHECK-COUNT-4: load i32, ptr {{.*}}, align 4
   ; CHECK-NOT: load i32, ptr {{.*}}, align 4
   %3 = getelementptr inbounds [3 x <4 x i32>], [3 x <4 x i32>]* @staticArrayOfVecData, i32 0, i32 %index
@@ -49,7 +56,7 @@ define <4 x i32> @load_static_array_of_vec_test(i32 %index) #0 {
 
 ; CHECK-LABEL: multid_load_test
 define <4 x i32> @multid_load_test() #0 {
-  ; CHECK-COUNT-8: load i32, ptr addrspace(3) {{(.*@groushared2dArrayofVectors.scalarized.*|%.*)}}, align 4
+  ; CHECK-COUNT-8: load i32, ptr addrspace(3) {{(.*@groushared2dArrayofVectors.scalarized.1dim.*|%.*)}}, align 4
   ; CHECK-NOT: load i32, ptr addrspace(3) {{.*}}, align 4
   %1 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([3 x [3 x <4 x i32>]], [3 x [3 x <4 x i32>]] addrspace(3)* @"groushared2dArrayofVectors", i32 0, i32 0, i32 0), align 4
   %2 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([3 x [3 x <4 x i32>]], [3 x [3 x <4 x i32>]] addrspace(3)* @"groushared2dArrayofVectors", i32 0, i32 1, i32 1), align 4
diff --git a/llvm/test/CodeGen/DirectX/scalar-store.ll b/llvm/test/CodeGen/DirectX/scalar-store.ll
index c45481e8cae14f..90630ca4e92b84 100644
--- a/llvm/test/CodeGen/DirectX/scalar-store.ll
+++ b/llvm/test/CodeGen/DirectX/scalar-store.ll
@@ -1,4 +1,5 @@
-; RUN: opt -S -passes='dxil-data-scalarization,scalarizer<load-store>,dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+; RUN: opt -S -passes='dxil-data-scalarization,scalarizer<load-store>,dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s -check-prefix DATACHECK
+; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays,function(scalarizer<load-store>),dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 ; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s
 
 ; Make sure we can store groupshared, static vectors and arrays of vectors
@@ -6,15 +7,17 @@
 @"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16
 @"vecData" = external addrspace(3) global <4 x i32>, align 4
 
-; CHECK: @arrayofVecData.scalarized = local_unnamed_addr addrspace(3) global [2 x [3 x float]] zeroinitializer, align 16
+; DATACHECK: @arrayofVecData.scalarized = local_unnamed_addr addrspace(3) global [2 x [3 x float]] zeroinitializer, align 16
+; CHECK: @arrayofVecData.scalarized.1dim = local_unnamed_addr addrspace(3) global [6 x float] zeroinitializer, align 16
 ; CHECK: @vecData.scalarized = external addrspace(3) global [4 x i32], align 4
 ; CHECK-NOT: @arrayofVecData
+; CHECK-NOT: @arrayofVecData.scalarized
 ; CHECK-NOT: @vecData
 
 ; CHECK-LABEL: store_array_vec_test
 define void @store_array_vec_test () local_unnamed_addr #0 {
-    ; CHECK-COUNT-6: store float {{1|2|3|4|6}}.000000e+00, ptr addrspace(3) {{(.*@arrayofVecData.scalarized.*|%.*)}}, align {{4|8|16}}
-    ; CHECK-NOT: store float {{1|2|3|4|6}}.000000e+00, ptr addrspace(3) {{(.*@arrayofVecData.scalarized.*|%.*)}}, align {{4|8|16}}
+    ; CHECK-COUNT-6: store float {{1|2|3|4|6}}.000000e+00, ptr addrspace(3) {{(.*@arrayofVecData.scalarized.1dim.*|%.*)}}, align {{4|8|16}}
+    ; CHECK-NOT: store float {{1|2|3|4|6}}.000000e+00, ptr addrspace(3) {{(.*@arrayofVecData.scalarized.1dim.*|%.*)}}, align {{4|8|16}}
     store <3 x float> <float 1.000000e+00, float 2.000000e+00, float 3.000000e+00>, ptr addrspace(3) @"arrayofVecData", align 16 
     store <3 x float> <float 2.000000e+00, float 4.000000e+00, float 6.000000e+00>, ptr addrspace(3)   getelementptr inbounds (i8, ptr addrspace(3) @"arrayofVecData", i32 16), align 16 
     ret void

>From c9f0d39b06f9c18560437b7d6bcd64540a5871f9 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Wed, 30 Oct 2024 19:27:09 -0400
Subject: [PATCH 2/3] make load\store gep constexpr split

---
 llvm/lib/Target/DirectX/DXILFlattenArrays.cpp | 86 +++++++++----------
 llvm/test/CodeGen/DirectX/flatten-array.ll    |  8 +-
 2 files changed, 47 insertions(+), 47 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
index e4660909c438ee..20c7401e934e6c 100644
--- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
+++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
@@ -20,6 +20,7 @@
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/InstVisitor.h"
+#include "llvm/IR/ReplaceConstant.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Transforms/Utils/Local.h"
 #include <cassert>
@@ -69,8 +70,8 @@ class DXILFlattenArraysVisitor
   bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
   bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
   bool visitPHINode(PHINode &PHI) { return false; }
-  bool visitLoadInst(LoadInst &LI) { return false; }
-  bool visitStoreInst(StoreInst &SI) { return false; }
+  bool visitLoadInst(LoadInst &LI);
+  bool visitStoreInst(StoreInst &SI);
   bool visitCallInst(CallInst &ICI) { return false; }
   bool visitFreezeInst(FreezeInst &FI) { return false; }
   static bool isMultiDimensionalArray(Type *T);
@@ -94,7 +95,6 @@ class DXILFlattenArraysVisitor
                          SmallVector<Value *> Indices = SmallVector<Value *>(),
                          SmallVector<uint64_t> Dims = SmallVector<uint64_t>(),
                          bool AllIndicesAreConstInt = true);
-  ConstantInt *computeFlatIndex(GetElementPtrInst &GEP);
   bool visitGetElementPtrInstInGEPChain(GetElementPtrInst &GEP);
   bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo,
                                             GetElementPtrInst &GEP);
@@ -164,6 +164,38 @@ Value *DXILFlattenArraysVisitor::instructionFlattenIndices(
   return FlatIndex;
 }
 
+bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst &LI) {
+  unsigned NumOperands = LI.getNumOperands();
+  for (unsigned I = 0; I < NumOperands; ++I) {
+    Value *CurrOpperand = LI.getOperand(I);
+    ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
+    if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
+      convertUsersOfConstantsToInstructions(CE,
+                                            /*RestrictToFunc=*/nullptr,
+                                            /*RemoveDeadConstants=*/false,
+                                            /*IncludeSelf=*/true);
+      return false;
+    }
+  }
+  return false;
+}
+
+bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst &SI) {
+  unsigned NumOperands = SI.getNumOperands();
+  for (unsigned I = 0; I < NumOperands; ++I) {
+    Value *CurrOpperand = SI.getOperand(I);
+    ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
+    if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
+      convertUsersOfConstantsToInstructions(CE,
+                                            /*RestrictToFunc=*/nullptr,
+                                            /*RemoveDeadConstants=*/false,
+                                            /*IncludeSelf=*/true);
+      return false;
+    }
+  }
+  return false;
+}
+
 bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
   if (!isMultiDimensionalArray(AI.getAllocatedType()))
     return false;
@@ -182,41 +214,6 @@ bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
   return true;
 }
 
-ConstantInt *
-DXILFlattenArraysVisitor::computeFlatIndex(GetElementPtrInst &GEP) {
-  unsigned IndexAmount = GEP.getNumIndices();
-  assert(IndexAmount >= 1 && "Need At least one Index");
-  if (IndexAmount == 1)
-    return dyn_cast<ConstantInt>(GEP.getOperand(GEP.getNumOperands() - 1));
-
-  // Get the type of the base pointer.
-  Type *BaseType = GEP.getSourceElementType();
-
-  // Determine the dimensions of the multi-dimensional array.
-  SmallVector<int64_t> Dimensions;
-  while (auto *ArrType = dyn_cast<ArrayType>(BaseType)) {
-    Dimensions.push_back(ArrType->getNumElements());
-    BaseType = ArrType->getElementType();
-  }
-  unsigned FlatIndex = 0;
-  unsigned Multiplier = 1;
-  unsigned BitWidth = 32;
-  for (const Use &Index : GEP.indices()) {
-    ConstantInt *CurrentIndex = dyn_cast<ConstantInt>(Index);
-    BitWidth = CurrentIndex->getBitWidth();
-    if (!CurrentIndex)
-      return nullptr;
-    int64_t IndexValue = CurrentIndex->getSExtValue();
-    FlatIndex += IndexValue * Multiplier;
-
-    if (!Dimensions.empty()) {
-      Multiplier *= Dimensions.back(); // Use the last dimension size
-      Dimensions.pop_back();           // Remove the last dimension
-    }
-  }
-  return ConstantInt::get(GEP.getContext(), APInt(BitWidth, FlatIndex));
-}
-
 void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
     GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType,
     Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector<Value *> Indices,
@@ -240,12 +237,13 @@ void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
   for (auto *User : CurrGEP.users()) {
     if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) {
       recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand,
-                             ++GEPChainUseCount, Indices, Dims, AllIndicesAreConstInt);
+                             ++GEPChainUseCount, Indices, Dims,
+                             AllIndicesAreConstInt);
       GepUses = true;
     }
   }
   // This case is just incase the gep chain doesn't end with a 1d array.
-  if(IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
+  if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
     GEPChainMap.insert(
         {&CurrGEP,
          {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
@@ -295,10 +293,10 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
 
   unsigned GEPChainUseCount = 0;
   recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount);
-  
+
   // NOTE: hasNUses(0) is not the same as GEPChainUseCount == 0.
   // Here recursion is used to get the length of the GEP chain.
-  // Handle zero uses here because there won't be an update via 
+  // Handle zero uses here because there won't be an update via
   // a child in the chain later.
   if (GEPChainUseCount == 0) {
     SmallVector<Value *> Indices({GEP.getOperand(GEP.getNumOperands() - 1)});
@@ -308,7 +306,7 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
                     std::move(Indices), std::move(Dims), AllIndicesAreConstInt};
     return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
   }
-  
+
   PotentiallyDeadInstrs.emplace_back(&GEP);
   return false;
 }
@@ -426,7 +424,7 @@ static bool flattenArrays(Module &M) {
   for (auto &[Old, New] : GlobalMap) {
     Old->replaceAllUsesWith(New);
     Old->eraseFromParent();
-    MadeChange |= true;
+    MadeChange = true;
   }
   return MadeChange;
 }
diff --git a/llvm/test/CodeGen/DirectX/flatten-array.ll b/llvm/test/CodeGen/DirectX/flatten-array.ll
index f3815b7f270717..2dd7cd988837b5 100644
--- a/llvm/test/CodeGen/DirectX/flatten-array.ll
+++ b/llvm/test/CodeGen/DirectX/flatten-array.ll
@@ -125,7 +125,8 @@ define void @gep_4d_test ()  {
 @b = internal global [2 x [3 x [4 x i32]]] zeroinitializer, align 16
 
 define void @global_gep_load() {
-  ; CHECK: load i32, ptr getelementptr inbounds ([24 x i32], ptr @a.1dim, i32 6), align 4
+  ; CHECK: [[GEP_PTR:%.*]] = getelementptr inbounds [24 x i32], ptr @a.1dim, i32 6
+  ; CHECK: load i32, ptr [[GEP_PTR]], align 4
   ; CHECK-NOT: getelementptr inbounds [2 x [3 x [4 x i32]]]{{.*}}
   ; CHECK-NOT: getelementptr inbounds [3 x [4 x i32]]{{.*}}
   ; CHECK-NOT: getelementptr inbounds [4 x i32]{{.*}}
@@ -180,7 +181,8 @@ define void @global_incomplete_gep_chain(i32 %row, i32 %col) {
 }
 
 define void @global_gep_store() {
-  ; CHECK: store i32 1, ptr getelementptr inbounds ([24 x i32], ptr @b.1dim, i32 13), align 4
+  ; CHECK: [[GEP_PTR:%.*]] = getelementptr inbounds [24 x i32], ptr @b.1dim, i32 13
+  ; CHECK:  store i32 1, ptr [[GEP_PTR]], align 4
   ; CHECK-NOT: getelementptr inbounds [2 x [3 x [4 x i32]]]{{.*}}
   ; CHECK-NOT: getelementptr inbounds [3 x [4 x i32]]{{.*}}
   ; CHECK-NOT: getelementptr inbounds [4 x i32]{{.*}}
@@ -189,4 +191,4 @@ define void @global_gep_store() {
   %3 = getelementptr inbounds [4 x i32], [4 x i32]* %2, i32 0, i32 1
   store i32 1, i32* %3, align 4
   ret void
-}
\ No newline at end of file
+}

>From 9d36d73404a1a9f9bca9bcb276147f4542119337 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Mon, 4 Nov 2024 14:15:00 -0500
Subject: [PATCH 3/3] address pr feedback

---
 llvm/lib/Target/DirectX/DXILFlattenArrays.cpp | 52 ++++++++-----------
 llvm/lib/Target/DirectX/DXILFlattenArrays.h   |  2 -
 2 files changed, 21 insertions(+), 33 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
index 20c7401e934e6c..65b5c2a2764c6e 100644
--- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
+++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
@@ -5,10 +5,9 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===---------------------------------------------------------------------===//
-
 ///
 /// \file This file contains a pass to flatten arrays for the DirectX Backend.
-//
+///
 //===----------------------------------------------------------------------===//
 
 #include "DXILFlattenArrays.h"
@@ -26,10 +25,12 @@
 #include <cassert>
 #include <cstddef>
 #include <cstdint>
+#include <utility>
 
 #define DEBUG_TYPE "dxil-flatten-arrays"
 
 using namespace llvm;
+namespace {
 
 class DXILFlattenArraysLegacy : public ModulePass {
 
@@ -75,17 +76,16 @@ class DXILFlattenArraysVisitor
   bool visitCallInst(CallInst &ICI) { return false; }
   bool visitFreezeInst(FreezeInst &FI) { return false; }
   static bool isMultiDimensionalArray(Type *T);
-  static unsigned getTotalElements(Type *ArrayTy);
-  static Type *getBaseElementType(Type *ArrayTy);
+  static std::pair<unsigned, Type *> getElementCountAndType(Type *ArrayTy);
 
 private:
-  SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
+  SmallVector<WeakTrackingVH> PotentiallyDeadInstrs;
   DenseMap<GetElementPtrInst *, GEPData> GEPChainMap;
   bool finish();
-  ConstantInt *constFlattenIndices(ArrayRef<Value *> Indices,
+  ConstantInt *genConstFlattenIndices(ArrayRef<Value *> Indices,
                                    ArrayRef<uint64_t> Dims,
                                    IRBuilder<> &Builder);
-  Value *instructionFlattenIndices(ArrayRef<Value *> Indices,
+  Value *genInstructionFlattenIndices(ArrayRef<Value *> Indices,
                                    ArrayRef<uint64_t> Dims,
                                    IRBuilder<> &Builder);
   void
@@ -99,6 +99,7 @@ class DXILFlattenArraysVisitor
   bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo,
                                             GetElementPtrInst &GEP);
 };
+} // namespace
 
 bool DXILFlattenArraysVisitor::finish() {
   RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
@@ -111,25 +112,17 @@ bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type *T) {
   return false;
 }
 
-unsigned DXILFlattenArraysVisitor::getTotalElements(Type *ArrayTy) {
+std::pair<unsigned, Type *> DXILFlattenArraysVisitor::getElementCountAndType(Type *ArrayTy) {
   unsigned TotalElements = 1;
   Type *CurrArrayTy = ArrayTy;
   while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
     TotalElements *= InnerArrayTy->getNumElements();
     CurrArrayTy = InnerArrayTy->getElementType();
   }
-  return TotalElements;
-}
-
-Type *DXILFlattenArraysVisitor::getBaseElementType(Type *ArrayTy) {
-  Type *CurrArrayTy = ArrayTy;
-  while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
-    CurrArrayTy = InnerArrayTy->getElementType();
-  }
-  return CurrArrayTy;
+  return std::make_pair(TotalElements, CurrArrayTy);
 }
 
-ConstantInt *DXILFlattenArraysVisitor::constFlattenIndices(
+ConstantInt *DXILFlattenArraysVisitor::genConstFlattenIndices(
     ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
   assert(Indices.size() == Dims.size() &&
          "Indicies and dimmensions should be the same");
@@ -146,7 +139,7 @@ ConstantInt *DXILFlattenArraysVisitor::constFlattenIndices(
   return Builder.getInt32(FlatIndex);
 }
 
-Value *DXILFlattenArraysVisitor::instructionFlattenIndices(
+Value *DXILFlattenArraysVisitor::genInstructionFlattenIndices(
     ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
   if (Indices.size() == 1)
     return Indices[0];
@@ -202,10 +195,10 @@ bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
 
   ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType());
   IRBuilder<> Builder(&AI);
-  unsigned TotalElements = getTotalElements(ArrType);
+  auto [TotalElements, BaseType] = getElementCountAndType(ArrType);
 
   ArrayType *FattenedArrayType =
-      ArrayType::get(getBaseElementType(ArrType), TotalElements);
+      ArrayType::get(BaseType, TotalElements);
   AllocaInst *FlatAlloca =
       Builder.CreateAlloca(FattenedArrayType, nullptr, AI.getName() + ".flat");
   FlatAlloca->setAlignment(AI.getAlign());
@@ -261,10 +254,10 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase(
   IRBuilder<> Builder(&GEP);
   Value *FlatIndex;
   if (GEPInfo.AllIndicesAreConstInt)
-    FlatIndex = constFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
+    FlatIndex = genConstFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
   else
     FlatIndex =
-        instructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
+        genInstructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
 
   ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType;
   Value *FlatGEP =
@@ -285,9 +278,9 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
 
   ArrayType *ArrType = cast<ArrayType>(GEP.getSourceElementType());
   IRBuilder<> Builder(&GEP);
-  unsigned TotalElements = getTotalElements(ArrType);
+  auto [TotalElements, BaseType] = getElementCountAndType(ArrType);
   ArrayType *FlattenedArrayType =
-      ArrayType::get(getBaseElementType(ArrType), TotalElements);
+      ArrayType::get(BaseType, TotalElements);
 
   Value *PtrOperand = GEP.getPointerOperand();
 
@@ -313,7 +306,6 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
 
 bool DXILFlattenArraysVisitor::visit(Function &F) {
   bool MadeChange = false;
-  ////for (BasicBlock &BB : make_early_inc_range(F)) {
   ReversePostOrderTraversal<Function *> RPOT(&F);
   for (BasicBlock *BB : make_early_inc_range(RPOT)) {
     for (Instruction &I : make_early_inc_range(*BB)) {
@@ -345,8 +337,7 @@ static void collectElements(Constant *Init,
       collectElements(DataArrayConstant->getElementAsConstant(I), Elements);
     }
   } else {
-    assert(
-        false &&
+    llvm_unreachable (
         "Expected a ConstantArray or ConstantDataArray for array initializer!");
   }
 }
@@ -382,10 +373,9 @@ flattenGlobalArrays(Module &M,
       continue;
 
     ArrayType *ArrType = cast<ArrayType>(OrigType);
-    unsigned TotalElements =
-        DXILFlattenArraysVisitor::getTotalElements(ArrType);
+    auto [TotalElements, BaseType] = DXILFlattenArraysVisitor::getElementCountAndType(ArrType);
     ArrayType *FattenedArrayType = ArrayType::get(
-        DXILFlattenArraysVisitor::getBaseElementType(ArrType), TotalElements);
+        BaseType, TotalElements);
 
     // Create a new global variable with the updated type
     // Note: Initializer is set via transformInitializer
diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.h b/llvm/lib/Target/DirectX/DXILFlattenArrays.h
index 409f8d198782c9..aae68496af620a 100644
--- a/llvm/lib/Target/DirectX/DXILFlattenArrays.h
+++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.h
@@ -9,9 +9,7 @@
 #ifndef LLVM_TARGET_DIRECTX_DXILFLATTENARRAYS_H
 #define LLVM_TARGET_DIRECTX_DXILFLATTENARRAYS_H
 
-#include "DXILResource.h"
 #include "llvm/IR/PassManager.h"
-#include "llvm/Pass.h"
 
 namespace llvm {
 



More information about the llvm-commits mailing list