[llvm] [DirectX] Flatten arrays (PR #114332)
Cooper Partin via llvm-commits
llvm-commits at lists.llvm.org
Thu Oct 31 09:47:26 PDT 2024
================
@@ -0,0 +1,458 @@
+//===- DXILFlattenArrays.cpp - Flattens DXIL Arrays-----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+///
+/// \file This file contains a pass to flatten arrays for the DirectX Backend.
+//
+//===----------------------------------------------------------------------===//
+
+#include "DXILFlattenArrays.h"
+#include "DirectX.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/DXILResource.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstVisitor.h"
+#include "llvm/IR/ReplaceConstant.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Transforms/Utils/Local.h"
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#define DEBUG_TYPE "dxil-flatten-arrays"
+
+using namespace llvm;
+
+class DXILFlattenArraysLegacy : public ModulePass {
+
+public:
+ bool runOnModule(Module &M) override;
+ DXILFlattenArraysLegacy() : ModulePass(ID) {}
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override;
+ static char ID; // Pass identification.
+};
+
+struct GEPData {
+ ArrayType *ParentArrayType;
+ Value *ParendOperand;
+ SmallVector<Value *> Indices;
+ SmallVector<uint64_t> Dims;
+ bool AllIndicesAreConstInt;
+};
+
+class DXILFlattenArraysVisitor
+ : public InstVisitor<DXILFlattenArraysVisitor, bool> {
+public:
+ DXILFlattenArraysVisitor() {}
+ bool visit(Function &F);
+ // InstVisitor methods. They return true if the instruction was scalarized,
+ // false if nothing changed.
+ bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
+ bool visitAllocaInst(AllocaInst &AI);
+ bool visitInstruction(Instruction &I) { return false; }
+ bool visitSelectInst(SelectInst &SI) { return false; }
+ bool visitICmpInst(ICmpInst &ICI) { return false; }
+ bool visitFCmpInst(FCmpInst &FCI) { return false; }
+ bool visitUnaryOperator(UnaryOperator &UO) { return false; }
+ bool visitBinaryOperator(BinaryOperator &BO) { return false; }
+ bool visitCastInst(CastInst &CI) { return false; }
+ bool visitBitCastInst(BitCastInst &BCI) { return false; }
+ bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
+ bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
+ bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
+ bool visitPHINode(PHINode &PHI) { return false; }
+ bool visitLoadInst(LoadInst &LI);
+ bool visitStoreInst(StoreInst &SI);
+ bool visitCallInst(CallInst &ICI) { return false; }
+ bool visitFreezeInst(FreezeInst &FI) { return false; }
+ static bool isMultiDimensionalArray(Type *T);
+ static unsigned getTotalElements(Type *ArrayTy);
+ static Type *getBaseElementType(Type *ArrayTy);
+
+private:
+ SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
+ DenseMap<GetElementPtrInst *, GEPData> GEPChainMap;
+ bool finish();
+ ConstantInt *constFlattenIndices(ArrayRef<Value *> Indices,
+ ArrayRef<uint64_t> Dims,
+ IRBuilder<> &Builder);
+ Value *instructionFlattenIndices(ArrayRef<Value *> Indices,
+ ArrayRef<uint64_t> Dims,
+ IRBuilder<> &Builder);
+ void
+ recursivelyCollectGEPs(GetElementPtrInst &CurrGEP,
+ ArrayType *FlattenedArrayType, Value *PtrOperand,
+ unsigned &GEPChainUseCount,
+ SmallVector<Value *> Indices = SmallVector<Value *>(),
+ SmallVector<uint64_t> Dims = SmallVector<uint64_t>(),
+ bool AllIndicesAreConstInt = true);
+ bool visitGetElementPtrInstInGEPChain(GetElementPtrInst &GEP);
+ bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo,
+ GetElementPtrInst &GEP);
+};
+
+bool DXILFlattenArraysVisitor::finish() {
+ RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
+ return true;
+}
+
+bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type *T) {
+ if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
+ return isa<ArrayType>(ArrType->getElementType());
+ return false;
+}
+
+unsigned DXILFlattenArraysVisitor::getTotalElements(Type *ArrayTy) {
+ unsigned TotalElements = 1;
+ Type *CurrArrayTy = ArrayTy;
+ while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
+ TotalElements *= InnerArrayTy->getNumElements();
----------------
coopp wrote:
I see.
https://github.com/llvm/llvm-project/pull/114332
More information about the llvm-commits
mailing list