[llvm] [DirectX] Flatten arrays (PR #114332)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Oct 30 17:03:54 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-directx
Author: Farzon Lotfi (farzonl)
<details>
<summary>Changes</summary>
- Relevant piece is `DXILFlattenArrays.cpp`
- Loads and Store Instruction visits are just for finding GetElementPtrConstantExpr and splitting them.
- Allocas needed to be replaced with flattened allocas.
- Global arrays were similar to allocas. Only interesting piece here is around initializers.
- Most of the work went into building correct GEP chains. The approach here was a recursive strategy via `recursivelyCollectGEPs`.
- All intermediary GEPs get marked for deletion and only the leaf GEPs get updated with the new index.
completes [89646](https://github.com/llvm/llvm-project/issues/89646)
---
Patch is 42.60 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/114332.diff
11 Files Affected:
- (modified) llvm/lib/Target/DirectX/CMakeLists.txt (+1)
- (added) llvm/lib/Target/DirectX/DXILFlattenArrays.cpp (+458)
- (added) llvm/lib/Target/DirectX/DXILFlattenArrays.h (+25)
- (modified) llvm/lib/Target/DirectX/DirectX.h (+6)
- (modified) llvm/lib/Target/DirectX/DirectXPassRegistry.def (+1)
- (modified) llvm/lib/Target/DirectX/DirectXTargetMachine.cpp (+3)
- (added) llvm/test/CodeGen/DirectX/flatten-array.ll (+194)
- (modified) llvm/test/CodeGen/DirectX/llc-pipeline.ll (+1)
- (modified) llvm/test/CodeGen/DirectX/scalar-data.ll (+8-3)
- (modified) llvm/test/CodeGen/DirectX/scalar-load.ll (+15-8)
- (modified) llvm/test/CodeGen/DirectX/scalar-store.ll (+7-4)
``````````diff
diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt
index 5d1dc50fdb0dde..a726071e0dcecd 100644
--- a/llvm/lib/Target/DirectX/CMakeLists.txt
+++ b/llvm/lib/Target/DirectX/CMakeLists.txt
@@ -22,6 +22,7 @@ add_llvm_target(DirectXCodeGen
DXContainerGlobals.cpp
DXILDataScalarization.cpp
DXILFinalizeLinkage.cpp
+ DXILFlattenArrays.cpp
DXILIntrinsicExpansion.cpp
DXILOpBuilder.cpp
DXILOpLowering.cpp
diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
new file mode 100644
index 00000000000000..20c7401e934e6c
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
@@ -0,0 +1,458 @@
+//===- DXILFlattenArrays.cpp - Flattens DXIL Arrays-----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+///
+/// \file This file contains a pass to flatten arrays for the DirectX Backend.
+//
+//===----------------------------------------------------------------------===//
+
+#include "DXILFlattenArrays.h"
+#include "DirectX.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/DXILResource.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstVisitor.h"
+#include "llvm/IR/ReplaceConstant.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Transforms/Utils/Local.h"
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#define DEBUG_TYPE "dxil-flatten-arrays"
+
+using namespace llvm;
+
+class DXILFlattenArraysLegacy : public ModulePass {
+
+public:
+ bool runOnModule(Module &M) override;
+ DXILFlattenArraysLegacy() : ModulePass(ID) {}
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override;
+ static char ID; // Pass identification.
+};
+
+struct GEPData {
+ ArrayType *ParentArrayType;
+ Value *ParendOperand;
+ SmallVector<Value *> Indices;
+ SmallVector<uint64_t> Dims;
+ bool AllIndicesAreConstInt;
+};
+
+class DXILFlattenArraysVisitor
+ : public InstVisitor<DXILFlattenArraysVisitor, bool> {
+public:
+ DXILFlattenArraysVisitor() {}
+ bool visit(Function &F);
+ // InstVisitor methods. They return true if the instruction was scalarized,
+ // false if nothing changed.
+ bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
+ bool visitAllocaInst(AllocaInst &AI);
+ bool visitInstruction(Instruction &I) { return false; }
+ bool visitSelectInst(SelectInst &SI) { return false; }
+ bool visitICmpInst(ICmpInst &ICI) { return false; }
+ bool visitFCmpInst(FCmpInst &FCI) { return false; }
+ bool visitUnaryOperator(UnaryOperator &UO) { return false; }
+ bool visitBinaryOperator(BinaryOperator &BO) { return false; }
+ bool visitCastInst(CastInst &CI) { return false; }
+ bool visitBitCastInst(BitCastInst &BCI) { return false; }
+ bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
+ bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
+ bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
+ bool visitPHINode(PHINode &PHI) { return false; }
+ bool visitLoadInst(LoadInst &LI);
+ bool visitStoreInst(StoreInst &SI);
+ bool visitCallInst(CallInst &ICI) { return false; }
+ bool visitFreezeInst(FreezeInst &FI) { return false; }
+ static bool isMultiDimensionalArray(Type *T);
+ static unsigned getTotalElements(Type *ArrayTy);
+ static Type *getBaseElementType(Type *ArrayTy);
+
+private:
+ SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
+ DenseMap<GetElementPtrInst *, GEPData> GEPChainMap;
+ bool finish();
+ ConstantInt *constFlattenIndices(ArrayRef<Value *> Indices,
+ ArrayRef<uint64_t> Dims,
+ IRBuilder<> &Builder);
+ Value *instructionFlattenIndices(ArrayRef<Value *> Indices,
+ ArrayRef<uint64_t> Dims,
+ IRBuilder<> &Builder);
+ void
+ recursivelyCollectGEPs(GetElementPtrInst &CurrGEP,
+ ArrayType *FlattenedArrayType, Value *PtrOperand,
+ unsigned &GEPChainUseCount,
+ SmallVector<Value *> Indices = SmallVector<Value *>(),
+ SmallVector<uint64_t> Dims = SmallVector<uint64_t>(),
+ bool AllIndicesAreConstInt = true);
+ bool visitGetElementPtrInstInGEPChain(GetElementPtrInst &GEP);
+ bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo,
+ GetElementPtrInst &GEP);
+};
+
+bool DXILFlattenArraysVisitor::finish() {
+ RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
+ return true;
+}
+
+bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type *T) {
+ if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
+ return isa<ArrayType>(ArrType->getElementType());
+ return false;
+}
+
+unsigned DXILFlattenArraysVisitor::getTotalElements(Type *ArrayTy) {
+ unsigned TotalElements = 1;
+ Type *CurrArrayTy = ArrayTy;
+ while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
+ TotalElements *= InnerArrayTy->getNumElements();
+ CurrArrayTy = InnerArrayTy->getElementType();
+ }
+ return TotalElements;
+}
+
+Type *DXILFlattenArraysVisitor::getBaseElementType(Type *ArrayTy) {
+ Type *CurrArrayTy = ArrayTy;
+ while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
+ CurrArrayTy = InnerArrayTy->getElementType();
+ }
+ return CurrArrayTy;
+}
+
+ConstantInt *DXILFlattenArraysVisitor::constFlattenIndices(
+ ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
+ assert(Indices.size() == Dims.size() &&
+ "Indicies and dimmensions should be the same");
+ unsigned FlatIndex = 0;
+ unsigned Multiplier = 1;
+
+ for (int I = Indices.size() - 1; I >= 0; --I) {
+ unsigned DimSize = Dims[I];
+ ConstantInt *CIndex = dyn_cast<ConstantInt>(Indices[I]);
+ assert(CIndex && "This function expects all indicies to be ConstantInt");
+ FlatIndex += CIndex->getZExtValue() * Multiplier;
+ Multiplier *= DimSize;
+ }
+ return Builder.getInt32(FlatIndex);
+}
+
+Value *DXILFlattenArraysVisitor::instructionFlattenIndices(
+ ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
+ if (Indices.size() == 1)
+ return Indices[0];
+
+ Value *FlatIndex = Builder.getInt32(0);
+ unsigned Multiplier = 1;
+
+ for (int I = Indices.size() - 1; I >= 0; --I) {
+ unsigned DimSize = Dims[I];
+ Value *VMultiplier = Builder.getInt32(Multiplier);
+ Value *ScaledIndex = Builder.CreateMul(Indices[I], VMultiplier);
+ FlatIndex = Builder.CreateAdd(FlatIndex, ScaledIndex);
+ Multiplier *= DimSize;
+ }
+ return FlatIndex;
+}
+
+bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst &LI) {
+ unsigned NumOperands = LI.getNumOperands();
+ for (unsigned I = 0; I < NumOperands; ++I) {
+ Value *CurrOpperand = LI.getOperand(I);
+ ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
+ if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
+ convertUsersOfConstantsToInstructions(CE,
+ /*RestrictToFunc=*/nullptr,
+ /*RemoveDeadConstants=*/false,
+ /*IncludeSelf=*/true);
+ return false;
+ }
+ }
+ return false;
+}
+
+bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst &SI) {
+ unsigned NumOperands = SI.getNumOperands();
+ for (unsigned I = 0; I < NumOperands; ++I) {
+ Value *CurrOpperand = SI.getOperand(I);
+ ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
+ if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
+ convertUsersOfConstantsToInstructions(CE,
+ /*RestrictToFunc=*/nullptr,
+ /*RemoveDeadConstants=*/false,
+ /*IncludeSelf=*/true);
+ return false;
+ }
+ }
+ return false;
+}
+
+bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
+ if (!isMultiDimensionalArray(AI.getAllocatedType()))
+ return false;
+
+ ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType());
+ IRBuilder<> Builder(&AI);
+ unsigned TotalElements = getTotalElements(ArrType);
+
+ ArrayType *FattenedArrayType =
+ ArrayType::get(getBaseElementType(ArrType), TotalElements);
+ AllocaInst *FlatAlloca =
+ Builder.CreateAlloca(FattenedArrayType, nullptr, AI.getName() + ".flat");
+ FlatAlloca->setAlignment(AI.getAlign());
+ AI.replaceAllUsesWith(FlatAlloca);
+ AI.eraseFromParent();
+ return true;
+}
+
+void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
+ GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType,
+ Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector<Value *> Indices,
+ SmallVector<uint64_t> Dims, bool AllIndicesAreConstInt) {
+ Value *LastIndex = CurrGEP.getOperand(CurrGEP.getNumOperands() - 1);
+ AllIndicesAreConstInt &= isa<ConstantInt>(LastIndex);
+ Indices.push_back(LastIndex);
+ assert(isa<ArrayType>(CurrGEP.getSourceElementType()));
+ Dims.push_back(
+ cast<ArrayType>(CurrGEP.getSourceElementType())->getNumElements());
+ bool IsMultiDimArr = isMultiDimensionalArray(CurrGEP.getSourceElementType());
+ if (!IsMultiDimArr) {
+ assert(GEPChainUseCount < FlattenedArrayType->getNumElements());
+ GEPChainMap.insert(
+ {&CurrGEP,
+ {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
+ std::move(Dims), AllIndicesAreConstInt}});
+ return;
+ }
+ bool GepUses = false;
+ for (auto *User : CurrGEP.users()) {
+ if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) {
+ recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand,
+ ++GEPChainUseCount, Indices, Dims,
+ AllIndicesAreConstInt);
+ GepUses = true;
+ }
+ }
+ // This case is just incase the gep chain doesn't end with a 1d array.
+ if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
+ GEPChainMap.insert(
+ {&CurrGEP,
+ {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
+ std::move(Dims), AllIndicesAreConstInt}});
+ }
+}
+
+bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain(
+ GetElementPtrInst &GEP) {
+ GEPData GEPInfo = GEPChainMap.at(&GEP);
+ return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
+}
+bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase(
+ GEPData &GEPInfo, GetElementPtrInst &GEP) {
+ IRBuilder<> Builder(&GEP);
+ Value *FlatIndex;
+ if (GEPInfo.AllIndicesAreConstInt)
+ FlatIndex = constFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
+ else
+ FlatIndex =
+ instructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
+
+ ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType;
+ Value *FlatGEP =
+ Builder.CreateGEP(FlattenedArrayType, GEPInfo.ParendOperand, FlatIndex,
+ GEP.getName() + ".flat", GEP.isInBounds());
+
+ GEP.replaceAllUsesWith(FlatGEP);
+ GEP.eraseFromParent();
+ return true;
+}
+
+bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
+ auto It = GEPChainMap.find(&GEP);
+ if (It != GEPChainMap.end())
+ return visitGetElementPtrInstInGEPChain(GEP);
+ if (!isMultiDimensionalArray(GEP.getSourceElementType()))
+ return false;
+
+ ArrayType *ArrType = cast<ArrayType>(GEP.getSourceElementType());
+ IRBuilder<> Builder(&GEP);
+ unsigned TotalElements = getTotalElements(ArrType);
+ ArrayType *FlattenedArrayType =
+ ArrayType::get(getBaseElementType(ArrType), TotalElements);
+
+ Value *PtrOperand = GEP.getPointerOperand();
+
+ unsigned GEPChainUseCount = 0;
+ recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount);
+
+ // NOTE: hasNUses(0) is not the same as GEPChainUseCount == 0.
+ // Here recursion is used to get the length of the GEP chain.
+ // Handle zero uses here because there won't be an update via
+ // a child in the chain later.
+ if (GEPChainUseCount == 0) {
+ SmallVector<Value *> Indices({GEP.getOperand(GEP.getNumOperands() - 1)});
+ SmallVector<uint64_t> Dims({ArrType->getNumElements()});
+ bool AllIndicesAreConstInt = isa<ConstantInt>(Indices[0]);
+ GEPData GEPInfo{std::move(FlattenedArrayType), PtrOperand,
+ std::move(Indices), std::move(Dims), AllIndicesAreConstInt};
+ return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
+ }
+
+ PotentiallyDeadInstrs.emplace_back(&GEP);
+ return false;
+}
+
+bool DXILFlattenArraysVisitor::visit(Function &F) {
+ bool MadeChange = false;
+ ////for (BasicBlock &BB : make_early_inc_range(F)) {
+ ReversePostOrderTraversal<Function *> RPOT(&F);
+ for (BasicBlock *BB : make_early_inc_range(RPOT)) {
+ for (Instruction &I : make_early_inc_range(*BB)) {
+ if (InstVisitor::visit(I) && I.getType()->isVoidTy()) {
+ I.eraseFromParent();
+ MadeChange = true;
+ }
+ }
+ }
+ finish();
+ return MadeChange;
+}
+
+static void collectElements(Constant *Init,
+ SmallVectorImpl<Constant *> &Elements) {
+ // Base case: If Init is not an array, add it directly to the vector.
+ if (!isa<ArrayType>(Init->getType())) {
+ Elements.push_back(Init);
+ return;
+ }
+
+ // Recursive case: Process each element in the array.
+ if (auto *ArrayConstant = dyn_cast<ConstantArray>(Init)) {
+ for (unsigned I = 0; I < ArrayConstant->getNumOperands(); ++I) {
+ collectElements(ArrayConstant->getOperand(I), Elements);
+ }
+ } else if (auto *DataArrayConstant = dyn_cast<ConstantDataArray>(Init)) {
+ for (unsigned I = 0; I < DataArrayConstant->getNumElements(); ++I) {
+ collectElements(DataArrayConstant->getElementAsConstant(I), Elements);
+ }
+ } else {
+ assert(
+ false &&
+ "Expected a ConstantArray or ConstantDataArray for array initializer!");
+ }
+}
+
+static Constant *transformInitializer(Constant *Init, Type *OrigType,
+ ArrayType *FlattenedType,
+ LLVMContext &Ctx) {
+ // Handle ConstantAggregateZero (zero-initialized constants)
+ if (isa<ConstantAggregateZero>(Init))
+ return ConstantAggregateZero::get(FlattenedType);
+
+ // Handle UndefValue (undefined constants)
+ if (isa<UndefValue>(Init))
+ return UndefValue::get(FlattenedType);
+
+ if (!isa<ArrayType>(OrigType))
+ return Init;
+
+ SmallVector<Constant *> FlattenedElements;
+ collectElements(Init, FlattenedElements);
+ assert(FlattenedType->getNumElements() == FlattenedElements.size() &&
+ "The number of collected elements should match the FlattenedType");
+ return ConstantArray::get(FlattenedType, FlattenedElements);
+}
+
+static void
+flattenGlobalArrays(Module &M,
+ DenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap) {
+ LLVMContext &Ctx = M.getContext();
+ for (GlobalVariable &G : M.globals()) {
+ Type *OrigType = G.getValueType();
+ if (!DXILFlattenArraysVisitor::isMultiDimensionalArray(OrigType))
+ continue;
+
+ ArrayType *ArrType = cast<ArrayType>(OrigType);
+ unsigned TotalElements =
+ DXILFlattenArraysVisitor::getTotalElements(ArrType);
+ ArrayType *FattenedArrayType = ArrayType::get(
+ DXILFlattenArraysVisitor::getBaseElementType(ArrType), TotalElements);
+
+ // Create a new global variable with the updated type
+ // Note: Initializer is set via transformInitializer
+ GlobalVariable *NewGlobal =
+ new GlobalVariable(M, FattenedArrayType, G.isConstant(), G.getLinkage(),
+ /*Initializer=*/nullptr, G.getName() + ".1dim", &G,
+ G.getThreadLocalMode(), G.getAddressSpace(),
+ G.isExternallyInitialized());
+
+ // Copy relevant attributes
+ NewGlobal->setUnnamedAddr(G.getUnnamedAddr());
+ if (G.getAlignment() > 0) {
+ NewGlobal->setAlignment(G.getAlign());
+ }
+
+ if (G.hasInitializer()) {
+ Constant *Init = G.getInitializer();
+ Constant *NewInit =
+ transformInitializer(Init, OrigType, FattenedArrayType, Ctx);
+ NewGlobal->setInitializer(NewInit);
+ }
+ GlobalMap[&G] = NewGlobal;
+ }
+}
+
+static bool flattenArrays(Module &M) {
+ bool MadeChange = false;
+ DXILFlattenArraysVisitor Impl;
+ DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
+ flattenGlobalArrays(M, GlobalMap);
+ for (auto &F : make_early_inc_range(M.functions())) {
+ if (F.isIntrinsic())
+ continue;
+ MadeChange |= Impl.visit(F);
+ }
+ for (auto &[Old, New] : GlobalMap) {
+ Old->replaceAllUsesWith(New);
+ Old->eraseFromParent();
+ MadeChange = true;
+ }
+ return MadeChange;
+}
+
+PreservedAnalyses DXILFlattenArrays::run(Module &M, ModuleAnalysisManager &) {
+ bool MadeChanges = flattenArrays(M);
+ if (!MadeChanges)
+ return PreservedAnalyses::all();
+ PreservedAnalyses PA;
+ PA.preserve<DXILResourceAnalysis>();
+ return PA;
+}
+
+bool DXILFlattenArraysLegacy::runOnModule(Module &M) {
+ return flattenArrays(M);
+}
+
+void DXILFlattenArraysLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.addPreserved<DXILResourceWrapperPass>();
+}
+
+char DXILFlattenArraysLegacy::ID = 0;
+
+INITIALIZE_PASS_BEGIN(DXILFlattenArraysLegacy, DEBUG_TYPE,
+ "DXIL Array Flattener", false, false)
+INITIALIZE_PASS_END(DXILFlattenArraysLegacy, DEBUG_TYPE, "DXIL Array Flattener",
+ false, false)
+
+ModulePass *llvm::createDXILFlattenArraysLegacyPass() {
+ return new DXILFlattenArraysLegacy();
+}
diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.h b/llvm/lib/Target/DirectX/DXILFlattenArrays.h
new file mode 100644
index 00000000000000..409f8d198782c9
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.h
@@ -0,0 +1,25 @@
+//===- DXILFlattenArrays.h - Perform flattening of DXIL Arrays -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+#ifndef LLVM_TARGET_DIRECTX_DXILFLATTENARRAYS_H
+#define LLVM_TARGET_DIRECTX_DXILFLATTENARRAYS_H
+
+#include "DXILResource.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Pass.h"
+
+namespace llvm {
+
+/// A pass that transforms multidimensional arrays into one-dimensional arrays.
+class DXILFlattenArrays : public PassInfoMixin<DXILFlattenArrays> {
+public:
+ PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
+};
+} // namespace llvm
+
+#endif // LLVM_TARGET_DIRECTX_DXILFLATTENARRAYS_H
diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h
index 3221779be2f311..3454f16ecd5955 100644
--- a/llvm/lib/Target/DirectX/DirectX.h
+++ b/llvm/lib/Target/DirectX/DirectX.h
@@ -40,6 +40,12 @@ void initializeDXILDataScalarizationLegacyPass(PassRegistry &);
/// Pass to scalarize llvm global data into a DXIL legal form
ModulePass *createDXILDataScalarizationLegacyPass();
+/// Initializer for DXIL Array Flatten Pass
+void initializeDXILFlattenArraysLegacyPass(PassRegistry &);
+
+/// Pass to flatten arrays into a one dimensional DXIL legal form
+ModulePass *createDXILFlattenArraysLegacyPass();
+
/// Initializer for DXILOpLowering
void initializeDXILOpLoweringLegacyPass(PassRegistry &);
diff --git a/llvm/lib/Target/DirectX/DirectXPassRegistry.def b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
index ae729a1082b867..a0f864ed39375f 100644
--- a/llvm/lib/Target/DirectX/DirectXPassRegistry.def
+++ b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
@@ -24,6 +24,7 @@ MODULE_ANALYSIS("dxil-resource-md", DXILResourceMDAnalysis())
#define MODULE_PASS(NAME, CREATE_PASS)
#endif
MODULE_PASS("dxil-data-scalarization", DXILDataScalarization())
+MODULE_PASS("dxil-flatten-arrays", DXILFlattenArrays())
MODULE_PASS("dxil-intrinsic-expansion", DXILIntrinsicExpansion(...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/114332
More information about the llvm-commits
mailing list