[llvm] [DirectX] Flatten arrays (PR #114332)

Justin Bogner via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 4 09:25:45 PST 2024


================
@@ -0,0 +1,458 @@
+//===- DXILFlattenArrays.cpp - Flattens DXIL Arrays-----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+///
+/// \file This file contains a pass to flatten arrays for the DirectX Backend.
+//
+//===----------------------------------------------------------------------===//
+
+#include "DXILFlattenArrays.h"
+#include "DirectX.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/DXILResource.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstVisitor.h"
+#include "llvm/IR/ReplaceConstant.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Transforms/Utils/Local.h"
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#define DEBUG_TYPE "dxil-flatten-arrays"
+
+using namespace llvm;
+
+class DXILFlattenArraysLegacy : public ModulePass {
+
+public:
+  bool runOnModule(Module &M) override;
+  DXILFlattenArraysLegacy() : ModulePass(ID) {}
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override;
+  static char ID; // Pass identification.
+};
+
+struct GEPData {
+  ArrayType *ParentArrayType;
+  Value *ParendOperand;
+  SmallVector<Value *> Indices;
+  SmallVector<uint64_t> Dims;
+  bool AllIndicesAreConstInt;
+};
+
+class DXILFlattenArraysVisitor
+    : public InstVisitor<DXILFlattenArraysVisitor, bool> {
+public:
+  DXILFlattenArraysVisitor() {}
+  bool visit(Function &F);
+  // InstVisitor methods.  They return true if the instruction was scalarized,
+  // false if nothing changed.
+  bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
+  bool visitAllocaInst(AllocaInst &AI);
+  bool visitInstruction(Instruction &I) { return false; }
+  bool visitSelectInst(SelectInst &SI) { return false; }
+  bool visitICmpInst(ICmpInst &ICI) { return false; }
+  bool visitFCmpInst(FCmpInst &FCI) { return false; }
+  bool visitUnaryOperator(UnaryOperator &UO) { return false; }
+  bool visitBinaryOperator(BinaryOperator &BO) { return false; }
+  bool visitCastInst(CastInst &CI) { return false; }
+  bool visitBitCastInst(BitCastInst &BCI) { return false; }
+  bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
+  bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
+  bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
+  bool visitPHINode(PHINode &PHI) { return false; }
+  bool visitLoadInst(LoadInst &LI);
+  bool visitStoreInst(StoreInst &SI);
+  bool visitCallInst(CallInst &ICI) { return false; }
+  bool visitFreezeInst(FreezeInst &FI) { return false; }
+  static bool isMultiDimensionalArray(Type *T);
+  static unsigned getTotalElements(Type *ArrayTy);
+  static Type *getBaseElementType(Type *ArrayTy);
+
+private:
+  SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
+  DenseMap<GetElementPtrInst *, GEPData> GEPChainMap;
+  bool finish();
+  ConstantInt *constFlattenIndices(ArrayRef<Value *> Indices,
+                                   ArrayRef<uint64_t> Dims,
+                                   IRBuilder<> &Builder);
+  Value *instructionFlattenIndices(ArrayRef<Value *> Indices,
+                                   ArrayRef<uint64_t> Dims,
+                                   IRBuilder<> &Builder);
+  void
+  recursivelyCollectGEPs(GetElementPtrInst &CurrGEP,
+                         ArrayType *FlattenedArrayType, Value *PtrOperand,
+                         unsigned &GEPChainUseCount,
+                         SmallVector<Value *> Indices = SmallVector<Value *>(),
+                         SmallVector<uint64_t> Dims = SmallVector<uint64_t>(),
+                         bool AllIndicesAreConstInt = true);
+  bool visitGetElementPtrInstInGEPChain(GetElementPtrInst &GEP);
+  bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo,
+                                            GetElementPtrInst &GEP);
+};
+
+bool DXILFlattenArraysVisitor::finish() {
+  RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
+  return true;
+}
+
+bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type *T) {
+  if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
+    return isa<ArrayType>(ArrType->getElementType());
+  return false;
+}
+
+unsigned DXILFlattenArraysVisitor::getTotalElements(Type *ArrayTy) {
+  unsigned TotalElements = 1;
+  Type *CurrArrayTy = ArrayTy;
+  while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
+    TotalElements *= InnerArrayTy->getNumElements();
+    CurrArrayTy = InnerArrayTy->getElementType();
+  }
+  return TotalElements;
+}
+
+Type *DXILFlattenArraysVisitor::getBaseElementType(Type *ArrayTy) {
+  Type *CurrArrayTy = ArrayTy;
+  while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
+    CurrArrayTy = InnerArrayTy->getElementType();
+  }
+  return CurrArrayTy;
+}
+
+ConstantInt *DXILFlattenArraysVisitor::constFlattenIndices(
+    ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
+  assert(Indices.size() == Dims.size() &&
+         "Indicies and dimmensions should be the same");
+  unsigned FlatIndex = 0;
+  unsigned Multiplier = 1;
+
+  for (int I = Indices.size() - 1; I >= 0; --I) {
+    unsigned DimSize = Dims[I];
+    ConstantInt *CIndex = dyn_cast<ConstantInt>(Indices[I]);
+    assert(CIndex && "This function expects all indicies to be ConstantInt");
+    FlatIndex += CIndex->getZExtValue() * Multiplier;
+    Multiplier *= DimSize;
+  }
+  return Builder.getInt32(FlatIndex);
+}
+
+Value *DXILFlattenArraysVisitor::instructionFlattenIndices(
+    ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
+  if (Indices.size() == 1)
+    return Indices[0];
+
+  Value *FlatIndex = Builder.getInt32(0);
+  unsigned Multiplier = 1;
+
+  for (int I = Indices.size() - 1; I >= 0; --I) {
+    unsigned DimSize = Dims[I];
+    Value *VMultiplier = Builder.getInt32(Multiplier);
+    Value *ScaledIndex = Builder.CreateMul(Indices[I], VMultiplier);
+    FlatIndex = Builder.CreateAdd(FlatIndex, ScaledIndex);
+    Multiplier *= DimSize;
+  }
+  return FlatIndex;
+}
+
+bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst &LI) {
+  unsigned NumOperands = LI.getNumOperands();
+  for (unsigned I = 0; I < NumOperands; ++I) {
+    Value *CurrOpperand = LI.getOperand(I);
+    ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
+    if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
+      convertUsersOfConstantsToInstructions(CE,
+                                            /*RestrictToFunc=*/nullptr,
+                                            /*RemoveDeadConstants=*/false,
+                                            /*IncludeSelf=*/true);
+      return false;
+    }
+  }
+  return false;
+}
+
+bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst &SI) {
+  unsigned NumOperands = SI.getNumOperands();
+  for (unsigned I = 0; I < NumOperands; ++I) {
+    Value *CurrOpperand = SI.getOperand(I);
+    ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
+    if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
+      convertUsersOfConstantsToInstructions(CE,
+                                            /*RestrictToFunc=*/nullptr,
+                                            /*RemoveDeadConstants=*/false,
+                                            /*IncludeSelf=*/true);
+      return false;
+    }
+  }
+  return false;
+}
+
+bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
+  if (!isMultiDimensionalArray(AI.getAllocatedType()))
+    return false;
+
+  ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType());
+  IRBuilder<> Builder(&AI);
+  unsigned TotalElements = getTotalElements(ArrType);
+
+  ArrayType *FattenedArrayType =
+      ArrayType::get(getBaseElementType(ArrType), TotalElements);
+  AllocaInst *FlatAlloca =
+      Builder.CreateAlloca(FattenedArrayType, nullptr, AI.getName() + ".flat");
+  FlatAlloca->setAlignment(AI.getAlign());
+  AI.replaceAllUsesWith(FlatAlloca);
+  AI.eraseFromParent();
+  return true;
+}
+
+void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
+    GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType,
+    Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector<Value *> Indices,
+    SmallVector<uint64_t> Dims, bool AllIndicesAreConstInt) {
+  Value *LastIndex = CurrGEP.getOperand(CurrGEP.getNumOperands() - 1);
+  AllIndicesAreConstInt &= isa<ConstantInt>(LastIndex);
+  Indices.push_back(LastIndex);
+  assert(isa<ArrayType>(CurrGEP.getSourceElementType()));
+  Dims.push_back(
+      cast<ArrayType>(CurrGEP.getSourceElementType())->getNumElements());
+  bool IsMultiDimArr = isMultiDimensionalArray(CurrGEP.getSourceElementType());
+  if (!IsMultiDimArr) {
+    assert(GEPChainUseCount < FlattenedArrayType->getNumElements());
+    GEPChainMap.insert(
+        {&CurrGEP,
+         {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
+          std::move(Dims), AllIndicesAreConstInt}});
+    return;
+  }
+  bool GepUses = false;
+  for (auto *User : CurrGEP.users()) {
+    if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) {
+      recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand,
+                             ++GEPChainUseCount, Indices, Dims,
+                             AllIndicesAreConstInt);
+      GepUses = true;
+    }
+  }
+  // This case is just incase the gep chain doesn't end with a 1d array.
+  if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
+    GEPChainMap.insert(
+        {&CurrGEP,
+         {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
+          std::move(Dims), AllIndicesAreConstInt}});
+  }
+}
+
+bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain(
+    GetElementPtrInst &GEP) {
+  GEPData GEPInfo = GEPChainMap.at(&GEP);
+  return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
+}
+bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase(
+    GEPData &GEPInfo, GetElementPtrInst &GEP) {
+  IRBuilder<> Builder(&GEP);
+  Value *FlatIndex;
+  if (GEPInfo.AllIndicesAreConstInt)
+    FlatIndex = constFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
+  else
+    FlatIndex =
+        instructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
+
+  ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType;
+  Value *FlatGEP =
+      Builder.CreateGEP(FlattenedArrayType, GEPInfo.ParendOperand, FlatIndex,
+                        GEP.getName() + ".flat", GEP.isInBounds());
+
+  GEP.replaceAllUsesWith(FlatGEP);
+  GEP.eraseFromParent();
+  return true;
+}
+
+bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
+  auto It = GEPChainMap.find(&GEP);
+  if (It != GEPChainMap.end())
+    return visitGetElementPtrInstInGEPChain(GEP);
+  if (!isMultiDimensionalArray(GEP.getSourceElementType()))
+    return false;
+
+  ArrayType *ArrType = cast<ArrayType>(GEP.getSourceElementType());
+  IRBuilder<> Builder(&GEP);
+  unsigned TotalElements = getTotalElements(ArrType);
+  ArrayType *FlattenedArrayType =
+      ArrayType::get(getBaseElementType(ArrType), TotalElements);
+
+  Value *PtrOperand = GEP.getPointerOperand();
+
+  unsigned GEPChainUseCount = 0;
+  recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount);
+
+  // NOTE: hasNUses(0) is not the same as GEPChainUseCount == 0.
+  // Here recursion is used to get the length of the GEP chain.
+  // Handle zero uses here because there won't be an update via
+  // a child in the chain later.
+  if (GEPChainUseCount == 0) {
+    SmallVector<Value *> Indices({GEP.getOperand(GEP.getNumOperands() - 1)});
+    SmallVector<uint64_t> Dims({ArrType->getNumElements()});
+    bool AllIndicesAreConstInt = isa<ConstantInt>(Indices[0]);
+    GEPData GEPInfo{std::move(FlattenedArrayType), PtrOperand,
+                    std::move(Indices), std::move(Dims), AllIndicesAreConstInt};
+    return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
+  }
+
+  PotentiallyDeadInstrs.emplace_back(&GEP);
+  return false;
+}
+
+bool DXILFlattenArraysVisitor::visit(Function &F) {
+  bool MadeChange = false;
+  ////for (BasicBlock &BB : make_early_inc_range(F)) {
----------------
bogner wrote:

leftover commented code

https://github.com/llvm/llvm-project/pull/114332


More information about the llvm-commits mailing list