[llvm] [DirectX] Flatten arrays (PR #114332)
Farzon Lotfi via llvm-commits
llvm-commits at lists.llvm.org
Wed Oct 30 17:03:22 PDT 2024
https://github.com/farzonl created https://github.com/llvm/llvm-project/pull/114332
- Relevant piece is `DXILFlattenArrays.cpp`
- Loads and Store Instruction visits are just for finding GetElementPtrConstantExpr and splitting them.
- Allocas needed to be replaced with flattened allocas.
- Global arrays were similar to allocas. Only interesting piece here is around initializers.
- Most of the work went into building correct GEP chains. The approach here was a recursive strategy via `recursivelyCollectGEPs`.
- All intermediary GEPs get marked for deletion and only the leaf GEPs get updated with the new index.
completes [89646](https://github.com/llvm/llvm-project/issues/89646)
>From 2c030f3e4ac04d2603dc3b60801ebc3d4b88160c Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Thu, 10 Oct 2024 04:59:08 -0400
Subject: [PATCH 1/7] [DirectX] A pass to flatten arrays
---
llvm/lib/Target/DirectX/CMakeLists.txt | 1 +
llvm/lib/Target/DirectX/DXILFlattenArrays.cpp | 282 ++++++++++++++++++
llvm/lib/Target/DirectX/DXILFlattenArrays.h | 25 ++
llvm/lib/Target/DirectX/DirectX.h | 6 +
.../Target/DirectX/DirectXTargetMachine.cpp | 3 +
llvm/test/CodeGen/DirectX/flatten-array.ll | 74 +++++
6 files changed, 391 insertions(+)
create mode 100644 llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
create mode 100644 llvm/lib/Target/DirectX/DXILFlattenArrays.h
create mode 100644 llvm/test/CodeGen/DirectX/flatten-array.ll
diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt
index 5d1dc50fdb0dde..a726071e0dcecd 100644
--- a/llvm/lib/Target/DirectX/CMakeLists.txt
+++ b/llvm/lib/Target/DirectX/CMakeLists.txt
@@ -22,6 +22,7 @@ add_llvm_target(DirectXCodeGen
DXContainerGlobals.cpp
DXILDataScalarization.cpp
DXILFinalizeLinkage.cpp
+ DXILFlattenArrays.cpp
DXILIntrinsicExpansion.cpp
DXILOpBuilder.cpp
DXILOpLowering.cpp
diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
new file mode 100644
index 00000000000000..126a27b3902e7c
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
@@ -0,0 +1,282 @@
+//===- DXILFlattenArrays.cpp - Flattens DXIL Arrays-----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+///
+/// \file This file contains a pass to flatten arrays for the DirectX Backend.
+//
+//===----------------------------------------------------------------------===//
+
+#include "DXILFlattenArrays.h"
+#include "DirectX.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/DXILResource.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstVisitor.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Transforms/Utils/Local.h"
+#include <cassert>
+#include <cstdint>
+
+#define DEBUG_TYPE "dxil-flatten-arrays"
+
+using namespace llvm;
+
+class DXILFlattenArraysLegacy : public ModulePass {
+
+public:
+ bool runOnModule(Module &M) override;
+ DXILFlattenArraysLegacy() : ModulePass(ID) {}
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override;
+ static char ID; // Pass identification.
+};
+
+struct GEPChainInfo {
+ DenseMap<GetElementPtrInst *, SmallVector<ConstantInt *>> GEPToIndicesMap;
+ DenseMap<GetElementPtrInst *, SmallVector<uint64_t>> GEPToDimmsMap;
+ DenseMap<GetElementPtrInst *, GetElementPtrInst *> GEPChildToNewParentMap;
+ DenseMap<Value *, GetElementPtrInst *> OperandToBaseGEPMap;
+};
+
+class DXILFlattenArraysVisitor
+ : public InstVisitor<DXILFlattenArraysVisitor, bool> {
+public:
+ DXILFlattenArraysVisitor() {}
+ bool visit(Function &F);
+ // InstVisitor methods. They return true if the instruction was scalarized,
+ // false if nothing changed.
+ bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
+ bool visitAllocaInst(AllocaInst &AI);
+ bool visitInstruction(Instruction &I) { return false; }
+ bool visitSelectInst(SelectInst &SI) { return false; }
+ bool visitICmpInst(ICmpInst &ICI) { return false; }
+ bool visitFCmpInst(FCmpInst &FCI) { return false; }
+ bool visitUnaryOperator(UnaryOperator &UO) { return false; }
+ bool visitBinaryOperator(BinaryOperator &BO) { return false; }
+ bool visitCastInst(CastInst &CI) { return false; }
+ bool visitBitCastInst(BitCastInst &BCI) { return false; }
+ bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
+ bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
+ bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
+ bool visitPHINode(PHINode &PHI) { return false; }
+ bool visitLoadInst(LoadInst &LI) { return false; }
+ bool visitStoreInst(StoreInst &SI) { return false; }
+ bool visitCallInst(CallInst &ICI) { return false; }
+ bool visitFreezeInst(FreezeInst &FI) { return false; }
+
+private:
+ SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
+ GEPChainInfo GEPChain;
+ bool finish();
+ bool isMultiDimensionalArray(Type *T);
+ ConstantInt *flattenIndices(ArrayRef<ConstantInt *> Indices,
+ ArrayRef<uint64_t> Dims, IRBuilder<> &Builder);
+ unsigned getTotalElements(Type *ArrayTy);
+ Type *getBaseElementType(Type *ArrayTy);
+ void recursivelyCollectGEPs(
+ GetElementPtrInst &CurrGEP, GetElementPtrInst &NewGEP,
+ GEPChainInfo &GEPChain,
+ SmallVector<ConstantInt *> Indices = SmallVector<ConstantInt *>(),
+ SmallVector<uint64_t> Dims = SmallVector<uint64_t>());
+ bool visitGetElementPtrInstInGEPChain(GetElementPtrInst &GEP);
+};
+
+bool DXILFlattenArraysVisitor::finish() {
+ RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
+ return true;
+}
+
+bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type *T) {
+ if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
+ return isa<ArrayType>(ArrType->getElementType());
+ return false;
+}
+
+unsigned DXILFlattenArraysVisitor::getTotalElements(Type *ArrayTy) {
+ unsigned TotalElements = 1;
+ Type *CurrArrayTy = ArrayTy;
+ while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
+ TotalElements *= InnerArrayTy->getNumElements();
+ CurrArrayTy = InnerArrayTy->getElementType();
+ }
+ return TotalElements;
+}
+
+Type *DXILFlattenArraysVisitor::getBaseElementType(Type *ArrayTy) {
+ Type *CurrArrayTy = ArrayTy;
+ while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
+ CurrArrayTy = InnerArrayTy->getElementType();
+ }
+ return CurrArrayTy;
+}
+
+ConstantInt *
+DXILFlattenArraysVisitor::flattenIndices(ArrayRef<ConstantInt *> Indices,
+ ArrayRef<uint64_t> Dims,
+ IRBuilder<> &Builder) {
+ assert(Indices.size() == Dims.size() &&
+ "Indicies and dimmensions should be the same");
+ unsigned FlatIndex = 0;
+ unsigned Multiplier = 1;
+
+ for (int I = Indices.size() - 1; I >= 0; --I) {
+ unsigned DimSize = Dims[I];
+ FlatIndex += Indices[I]->getZExtValue() * Multiplier;
+ Multiplier *= DimSize;
+ }
+ return Builder.getInt32(FlatIndex);
+}
+
+bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
+ if (!isMultiDimensionalArray(AI.getAllocatedType()))
+ return false;
+
+ ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType());
+ IRBuilder<> Builder(&AI);
+ unsigned TotalElements = getTotalElements(ArrType);
+
+ ArrayType *FattenedArrayType =
+ ArrayType::get(getBaseElementType(ArrType), TotalElements);
+ AllocaInst *FlatAlloca =
+ Builder.CreateAlloca(FattenedArrayType, nullptr, AI.getName() + ".flat");
+ FlatAlloca->setAlignment(AI.getAlign());
+ AI.replaceAllUsesWith(FlatAlloca);
+ AI.eraseFromParent();
+ return true;
+}
+
+void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
+ GetElementPtrInst &CurrGEP, GetElementPtrInst &NewGEP,
+ GEPChainInfo &GEPChain, SmallVector<ConstantInt *> Indices,
+ SmallVector<uint64_t> Dims) {
+ ConstantInt *LastIndex =
+ cast<ConstantInt>(CurrGEP.getOperand(CurrGEP.getNumOperands() - 1));
+
+ Indices.push_back(LastIndex);
+ assert(isa<ArrayType>(CurrGEP.getSourceElementType()));
+ Dims.push_back(
+ cast<ArrayType>(CurrGEP.getSourceElementType())->getNumElements());
+ if (!isMultiDimensionalArray(CurrGEP.getSourceElementType())) {
+ GEPChain.GEPToIndicesMap.insert({&CurrGEP, Indices});
+ GEPChain.GEPChildToNewParentMap.insert({&CurrGEP, &NewGEP});
+ GEPChain.GEPToDimmsMap.insert({&CurrGEP, Dims});
+ }
+ for (auto *User : CurrGEP.users()) {
+ if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) {
+ recursivelyCollectGEPs(*NestedGEP, NewGEP, GEPChain, Indices, Dims);
+ }
+ }
+ PotentiallyDeadInstrs.emplace_back(&CurrGEP);
+}
+
+bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain(
+ GetElementPtrInst &GEP) {
+ IRBuilder<> Builder(&GEP);
+ SmallVector<ConstantInt *> Indices = GEPChain.GEPToIndicesMap.at(&GEP);
+ GetElementPtrInst *Parent = GEPChain.GEPChildToNewParentMap.at(&GEP);
+ SmallVector<uint64_t> Dims = GEPChain.GEPToDimmsMap.at(&GEP);
+ ConstantInt *FlatIndex = flattenIndices(Indices, Dims, Builder);
+ if (!FlatIndex->isZero()) {
+ ArrayType *FlattenedArrayType =
+ cast<ArrayType>(Parent->getSourceElementType());
+ Value *FlatGEP =
+ Builder.CreateGEP(FlattenedArrayType, Parent->getPointerOperand(),
+ FlatIndex, GEP.getName() + ".flat", GEP.isInBounds());
+
+ GEP.replaceAllUsesWith(FlatGEP);
+ }
+ GEP.eraseFromParent();
+ return true;
+}
+
+bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
+ auto It = GEPChain.GEPToIndicesMap.find(&GEP);
+ if (It != GEPChain.GEPToIndicesMap.end())
+ return visitGetElementPtrInstInGEPChain(GEP);
+ if (!isMultiDimensionalArray(GEP.getSourceElementType()))
+ return false;
+
+ ArrayType *ArrType = cast<ArrayType>(GEP.getSourceElementType());
+ IRBuilder<> Builder(&GEP);
+ unsigned TotalElements = getTotalElements(ArrType);
+ ArrayType *FattenedArrayType =
+ ArrayType::get(getBaseElementType(ArrType), TotalElements);
+
+ ConstantInt *FlatIndex = Builder.getInt32(0);
+ Value *PtrOperand = GEP.getPointerOperand();
+ auto OpExists = GEPChain.OperandToBaseGEPMap.find(PtrOperand);
+
+ GetElementPtrInst *FlatGEP = nullptr;
+ if (OpExists == GEPChain.OperandToBaseGEPMap.end()) {
+ FlatGEP = cast<GetElementPtrInst>(
+ Builder.CreateGEP(FattenedArrayType, PtrOperand, FlatIndex,
+ GEP.getName() + ".flat", GEP.isInBounds()));
+ GEPChain.OperandToBaseGEPMap.insert({PtrOperand, FlatGEP});
+ } else
+ FlatGEP = OpExists->getSecond();
+ recursivelyCollectGEPs(GEP, *FlatGEP, GEPChain);
+ GEP.replaceAllUsesWith(FlatGEP);
+ GEP.eraseFromParent();
+ return true;
+}
+
+bool DXILFlattenArraysVisitor::visit(Function &F) {
+ bool MadeChange = false;
+ ////for (BasicBlock &BB : make_early_inc_range(F)) {
+ ReversePostOrderTraversal<Function *> RPOT(&F);
+ for (BasicBlock *BB : make_early_inc_range(RPOT)) {
+ for (Instruction &I : make_early_inc_range(*BB)) {
+ if (InstVisitor::visit(I) && I.getType()->isVoidTy()) {
+ I.eraseFromParent();
+ MadeChange = true;
+ }
+ }
+ }
+ return MadeChange;
+}
+
+static bool flattenArrays(Module &M) {
+ // TODO
+ bool MadeChange = false;
+ DXILFlattenArraysVisitor Impl;
+ for (auto &F : make_early_inc_range(M.functions())) {
+ MadeChange = Impl.visit(F);
+ }
+ return MadeChange;
+}
+
+PreservedAnalyses DXILFlattenArrays::run(Module &M, ModuleAnalysisManager &) {
+ bool MadeChanges = flattenArrays(M);
+ if (!MadeChanges)
+ return PreservedAnalyses::all();
+ PreservedAnalyses PA;
+ PA.preserve<DXILResourceAnalysis>();
+ return PA;
+}
+
+bool DXILFlattenArraysLegacy::runOnModule(Module &M) {
+ return flattenArrays(M);
+}
+
+void DXILFlattenArraysLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.addPreserved<DXILResourceWrapperPass>();
+}
+
+char DXILFlattenArraysLegacy::ID = 0;
+
+INITIALIZE_PASS_BEGIN(DXILFlattenArraysLegacy, DEBUG_TYPE,
+ "DXIL Array Flattener", false, false)
+INITIALIZE_PASS_END(DXILFlattenArraysLegacy, DEBUG_TYPE, "DXIL Array Flattener",
+ false, false)
+
+ModulePass *llvm::createDXILFlattenArraysLegacyPass() {
+ return new DXILFlattenArraysLegacy();
+}
\ No newline at end of file
diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.h b/llvm/lib/Target/DirectX/DXILFlattenArrays.h
new file mode 100644
index 00000000000000..409f8d198782c9
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.h
@@ -0,0 +1,25 @@
+//===- DXILFlattenArrays.h - Perform flattening of DXIL Arrays -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+#ifndef LLVM_TARGET_DIRECTX_DXILFLATTENARRAYS_H
+#define LLVM_TARGET_DIRECTX_DXILFLATTENARRAYS_H
+
+#include "DXILResource.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Pass.h"
+
+namespace llvm {
+
+/// A pass that transforms multidimensional arrays into one-dimensional arrays.
+class DXILFlattenArrays : public PassInfoMixin<DXILFlattenArrays> {
+public:
+ PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
+};
+} // namespace llvm
+
+#endif // LLVM_TARGET_DIRECTX_DXILFLATTENARRAYS_H
diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h
index 3221779be2f311..3454f16ecd5955 100644
--- a/llvm/lib/Target/DirectX/DirectX.h
+++ b/llvm/lib/Target/DirectX/DirectX.h
@@ -40,6 +40,12 @@ void initializeDXILDataScalarizationLegacyPass(PassRegistry &);
/// Pass to scalarize llvm global data into a DXIL legal form
ModulePass *createDXILDataScalarizationLegacyPass();
+/// Initializer for DXIL Array Flatten Pass
+void initializeDXILFlattenArraysLegacyPass(PassRegistry &);
+
+/// Pass to flatten arrays into a one dimensional DXIL legal form
+ModulePass *createDXILFlattenArraysLegacyPass();
+
/// Initializer for DXILOpLowering
void initializeDXILOpLoweringLegacyPass(PassRegistry &);
diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
index 18251ea3bd01d3..4e06eb8b713f59 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
@@ -13,6 +13,7 @@
#include "DirectXTargetMachine.h"
#include "DXILDataScalarization.h"
+#include "DXILFlattenArrays.h"
#include "DXILIntrinsicExpansion.h"
#include "DXILOpLowering.h"
#include "DXILPrettyPrinter.h"
@@ -48,6 +49,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
auto *PR = PassRegistry::getPassRegistry();
initializeDXILIntrinsicExpansionLegacyPass(*PR);
initializeDXILDataScalarizationLegacyPass(*PR);
+ initializeDXILFlattenArraysLegacyPass(*PR);
initializeScalarizerLegacyPassPass(*PR);
initializeDXILPrepareModulePass(*PR);
initializeEmbedDXILPassPass(*PR);
@@ -92,6 +94,7 @@ class DirectXPassConfig : public TargetPassConfig {
ScalarizerPassOptions DxilScalarOptions;
DxilScalarOptions.ScalarizeLoadStore = true;
addPass(createScalarizerPass(DxilScalarOptions));
+ addPass(createDXILFlattenArraysLegacyPass());
addPass(createDXILOpLoweringLegacyPass());
addPass(createDXILFinalizeLinkageLegacyPass());
addPass(createDXILTranslateMetadataLegacyPass());
diff --git a/llvm/test/CodeGen/DirectX/flatten-array.ll b/llvm/test/CodeGen/DirectX/flatten-array.ll
new file mode 100644
index 00000000000000..1b2d15428cb4a9
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/flatten-array.ll
@@ -0,0 +1,74 @@
+; RUN: opt -S -dxil-flatten-arrays %s | FileCheck %s
+
+; CHECK-LABEL: alloca_2d_test
+define void @alloca_2d_test () {
+ ; CHECK: alloca [9 x i32], align 4
+ ; CHECK-NOT: alloca [3 x [3 x i32]], align 4
+ %1 = alloca [3 x [3 x i32]], align 4
+ ret void
+}
+
+; CHECK-LABEL: alloca_3d_test
+define void @alloca_3d_test () {
+ ; CHECK: alloca [8 x i32], align 4
+ ; CHECK-NOT: alloca [2 x[2 x [2 x i32]]], align 4
+ %1 = alloca [2 x[2 x [2 x i32]]], align 4
+ ret void
+}
+
+; CHECK-LABEL: alloca_4d_test
+define void @alloca_4d_test () {
+ ; CHECK: alloca [16 x i32], align 4
+ ; CHECK-NOT: alloca [ 2x[2 x[2 x [2 x i32]]]], align 4
+ %1 = alloca [2x[2 x[2 x [2 x i32]]]], align 4
+ ret void
+}
+
+; CHECK-LABEL: gep_2d_test
+define void @gep_2d_test () {
+ ; CHECK: [[a:%.*]] = alloca [9 x i32], align 4
+ ; CHECK-COUNT-9: getelementptr inbounds [9 x i32], ptr [[a]], i32 {{[0-8]}}
+ ; CHECK-NOT: getelementptr inbounds [3 x [3 x i32]], ptr %1, i32 0, i32 {{.*}}, i32 {{.*}}
+ ; CHECK-NOT: getelementptr inbounds [3 x i32], [3 x i32]* {{.*}}, i32 {{.*}}, i32 {{.*}}
+ %1 = alloca [3 x [3 x i32]], align 4
+ %g2d0 = getelementptr inbounds [3 x [3 x i32]], [3 x [3 x i32]]* %1, i32 0, i32 0
+ %g1d0_1 = getelementptr inbounds [3 x i32], [3 x i32]* %g2d0, i32 0, i32 0
+ %g1d0_2 = getelementptr inbounds [3 x i32], [3 x i32]* %g2d0, i32 0, i32 1
+ %g1d0_3 = getelementptr inbounds [3 x i32], [3 x i32]* %g2d0, i32 0, i32 2
+ %g2d1 = getelementptr inbounds [3 x [3 x i32]], [3 x [3 x i32]]* %1, i32 0, i32 1
+ %g1d1_1 = getelementptr inbounds [3 x i32], [3 x i32]* %g2d1, i32 0, i32 0
+ %g1d1_2 = getelementptr inbounds [3 x i32], [3 x i32]* %g2d1, i32 0, i32 1
+ %g1d1_3 = getelementptr inbounds [3 x i32], [3 x i32]* %g2d1, i32 0, i32 2
+ %g2d2 = getelementptr inbounds [3 x [3 x i32]], [3 x [3 x i32]]* %1, i32 0, i32 2
+ %g1d2_1 = getelementptr inbounds [3 x i32], [3 x i32]* %g2d2, i32 0, i32 0
+ %g1d2_2 = getelementptr inbounds [3 x i32], [3 x i32]* %g2d2, i32 0, i32 1
+ %g1d2_3 = getelementptr inbounds [3 x i32], [3 x i32]* %g2d2, i32 0, i32 2
+
+ ret void
+}
+
+; CHECK-LABEL: gep_3d_test
+define void @gep_3d_test () {
+ %1 = alloca [2 x[2 x [2 x i32]]], align 4
+ %g3d0 = getelementptr inbounds [2 x[2 x [2 x i32]]], [2 x[2 x [2 x i32]]]* %1, i32 0, i32 0
+ %g2d0 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g3d0, i32 0, i32 0
+ %g1d0_1 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0, i32 0, i32 0
+ %g1d0_2 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0, i32 0, i32 1
+ %g2d1 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g3d0, i32 0, i32 1
+ %g1d1_1 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d1, i32 0, i32 0
+ %g1d1_2 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d1, i32 0, i32 1
+ %g3d1 = getelementptr inbounds [2 x[2 x [2 x i32]]], [2 x[2 x [2 x i32]]]* %1, i32 0, i32 1
+ %g2d2 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g3d1, i32 0, i32 0
+ %g1d2_1 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d2, i32 0, i32 0
+ %g1d2_2 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d2, i32 0, i32 1
+ %g2d3 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g3d1, i32 0, i32 1
+ %g1d3_1 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d3, i32 0, i32 0
+ %g1d3_2 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d3, i32 0, i32 1
+ ret void
+}
+
+; CHECK-LABEL: gep_4d_test
+define void @gep_4d_test () {
+ %1 = alloca [2x[2 x[2 x [2 x i32]]]], align 4
+ ret void
+}
\ No newline at end of file
>From 9cad0949ce67fc39b5b4473dc3c7bec984d4e6dd Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Thu, 10 Oct 2024 14:12:12 -0400
Subject: [PATCH 2/7] investigate how to remove a non leaf gep in the chain via
tracking it
---
llvm/lib/Target/DirectX/DXILFlattenArrays.cpp | 78 +++++++-----------
llvm/test/CodeGen/DirectX/flatten-array.ll | 81 +++++++++++++++++--
2 files changed, 104 insertions(+), 55 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
index 126a27b3902e7c..dfa6cd221f1733 100644
--- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
+++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
@@ -23,6 +23,7 @@
#include "llvm/Support/Casting.h"
#include "llvm/Transforms/Utils/Local.h"
#include <cassert>
+#include <cstddef>
#include <cstdint>
#define DEBUG_TYPE "dxil-flatten-arrays"
@@ -39,11 +40,11 @@ class DXILFlattenArraysLegacy : public ModulePass {
static char ID; // Pass identification.
};
-struct GEPChainInfo {
- DenseMap<GetElementPtrInst *, SmallVector<ConstantInt *>> GEPToIndicesMap;
- DenseMap<GetElementPtrInst *, SmallVector<uint64_t>> GEPToDimmsMap;
- DenseMap<GetElementPtrInst *, GetElementPtrInst *> GEPChildToNewParentMap;
- DenseMap<Value *, GetElementPtrInst *> OperandToBaseGEPMap;
+struct GEPData {
+ ArrayType* ParentArrayType;
+ Value *ParendOperand;
+ SmallVector<ConstantInt *> Indices;
+ SmallVector<uint64_t> Dims;
};
class DXILFlattenArraysVisitor
@@ -74,7 +75,7 @@ class DXILFlattenArraysVisitor
private:
SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
- GEPChainInfo GEPChain;
+ DenseMap<GetElementPtrInst *, GEPData> GEPChainMap;
bool finish();
bool isMultiDimensionalArray(Type *T);
ConstantInt *flattenIndices(ArrayRef<ConstantInt *> Indices,
@@ -82,11 +83,11 @@ class DXILFlattenArraysVisitor
unsigned getTotalElements(Type *ArrayTy);
Type *getBaseElementType(Type *ArrayTy);
void recursivelyCollectGEPs(
- GetElementPtrInst &CurrGEP, GetElementPtrInst &NewGEP,
- GEPChainInfo &GEPChain,
+ GetElementPtrInst &CurrGEP, ArrayType *FattenedArrayType, Value *PtrOperand,
SmallVector<ConstantInt *> Indices = SmallVector<ConstantInt *>(),
SmallVector<uint64_t> Dims = SmallVector<uint64_t>());
bool visitGetElementPtrInstInGEPChain(GetElementPtrInst &GEP);
+ void recursivelyDeleteGEPs(GetElementPtrInst *CurrGEP = nullptr);
};
bool DXILFlattenArraysVisitor::finish() {
@@ -154,9 +155,8 @@ bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
}
void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
- GetElementPtrInst &CurrGEP, GetElementPtrInst &NewGEP,
- GEPChainInfo &GEPChain, SmallVector<ConstantInt *> Indices,
- SmallVector<uint64_t> Dims) {
+ GetElementPtrInst &CurrGEP, ArrayType *FattenedArrayType, Value *PtrOperand,
+ SmallVector<ConstantInt *> Indices, SmallVector<uint64_t> Dims) {
ConstantInt *LastIndex =
cast<ConstantInt>(CurrGEP.getOperand(CurrGEP.getNumOperands() - 1));
@@ -165,41 +165,36 @@ void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
Dims.push_back(
cast<ArrayType>(CurrGEP.getSourceElementType())->getNumElements());
if (!isMultiDimensionalArray(CurrGEP.getSourceElementType())) {
- GEPChain.GEPToIndicesMap.insert({&CurrGEP, Indices});
- GEPChain.GEPChildToNewParentMap.insert({&CurrGEP, &NewGEP});
- GEPChain.GEPToDimmsMap.insert({&CurrGEP, Dims});
+ GEPChainMap.insert(
+ {&CurrGEP, {std::move(FattenedArrayType), PtrOperand, std::move(Indices), std::move(Dims)}});
}
for (auto *User : CurrGEP.users()) {
if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) {
- recursivelyCollectGEPs(*NestedGEP, NewGEP, GEPChain, Indices, Dims);
+ recursivelyCollectGEPs(*NestedGEP, FattenedArrayType, PtrOperand, Indices, Dims);
}
}
- PotentiallyDeadInstrs.emplace_back(&CurrGEP);
}
bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain(
GetElementPtrInst &GEP) {
IRBuilder<> Builder(&GEP);
- SmallVector<ConstantInt *> Indices = GEPChain.GEPToIndicesMap.at(&GEP);
- GetElementPtrInst *Parent = GEPChain.GEPChildToNewParentMap.at(&GEP);
- SmallVector<uint64_t> Dims = GEPChain.GEPToDimmsMap.at(&GEP);
- ConstantInt *FlatIndex = flattenIndices(Indices, Dims, Builder);
- if (!FlatIndex->isZero()) {
- ArrayType *FlattenedArrayType =
- cast<ArrayType>(Parent->getSourceElementType());
- Value *FlatGEP =
- Builder.CreateGEP(FlattenedArrayType, Parent->getPointerOperand(),
- FlatIndex, GEP.getName() + ".flat", GEP.isInBounds());
-
- GEP.replaceAllUsesWith(FlatGEP);
- }
+ GEPData GEPInfo = GEPChainMap.at(&GEP);
+ ConstantInt *FlatIndex =
+ flattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
+
+ ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType;
+ Value *FlatGEP = Builder.CreateGEP(
+ FlattenedArrayType, GEPInfo.ParendOperand, FlatIndex,
+ GEP.getName() + ".flat", GEP.isInBounds());
+
+ GEP.replaceAllUsesWith(FlatGEP);
GEP.eraseFromParent();
return true;
}
bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
- auto It = GEPChain.GEPToIndicesMap.find(&GEP);
- if (It != GEPChain.GEPToIndicesMap.end())
+ auto It = GEPChainMap.find(&GEP);
+ if (It != GEPChainMap.end())
return visitGetElementPtrInstInGEPChain(GEP);
if (!isMultiDimensionalArray(GEP.getSourceElementType()))
return false;
@@ -210,21 +205,10 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
ArrayType *FattenedArrayType =
ArrayType::get(getBaseElementType(ArrType), TotalElements);
- ConstantInt *FlatIndex = Builder.getInt32(0);
Value *PtrOperand = GEP.getPointerOperand();
- auto OpExists = GEPChain.OperandToBaseGEPMap.find(PtrOperand);
-
- GetElementPtrInst *FlatGEP = nullptr;
- if (OpExists == GEPChain.OperandToBaseGEPMap.end()) {
- FlatGEP = cast<GetElementPtrInst>(
- Builder.CreateGEP(FattenedArrayType, PtrOperand, FlatIndex,
- GEP.getName() + ".flat", GEP.isInBounds()));
- GEPChain.OperandToBaseGEPMap.insert({PtrOperand, FlatGEP});
- } else
- FlatGEP = OpExists->getSecond();
- recursivelyCollectGEPs(GEP, *FlatGEP, GEPChain);
- GEP.replaceAllUsesWith(FlatGEP);
- GEP.eraseFromParent();
+
+ recursivelyCollectGEPs(GEP, FattenedArrayType, PtrOperand);
+ PotentiallyDeadInstrs.emplace_back(&GEP);
return true;
}
@@ -240,11 +224,11 @@ bool DXILFlattenArraysVisitor::visit(Function &F) {
}
}
}
+ finish();
return MadeChange;
}
static bool flattenArrays(Module &M) {
- // TODO
bool MadeChange = false;
DXILFlattenArraysVisitor Impl;
for (auto &F : make_early_inc_range(M.functions())) {
@@ -279,4 +263,4 @@ INITIALIZE_PASS_END(DXILFlattenArraysLegacy, DEBUG_TYPE, "DXIL Array Flattener",
ModulePass *llvm::createDXILFlattenArraysLegacyPass() {
return new DXILFlattenArraysLegacy();
-}
\ No newline at end of file
+}
diff --git a/llvm/test/CodeGen/DirectX/flatten-array.ll b/llvm/test/CodeGen/DirectX/flatten-array.ll
index 1b2d15428cb4a9..09cfd63466d079 100644
--- a/llvm/test/CodeGen/DirectX/flatten-array.ll
+++ b/llvm/test/CodeGen/DirectX/flatten-array.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-flatten-arrays %s | FileCheck %s
+; RUN: opt -S -dxil-flatten-arrays %s | FileCheck %s
; CHECK-LABEL: alloca_2d_test
define void @alloca_2d_test () {
@@ -28,13 +28,13 @@ define void @alloca_4d_test () {
define void @gep_2d_test () {
; CHECK: [[a:%.*]] = alloca [9 x i32], align 4
; CHECK-COUNT-9: getelementptr inbounds [9 x i32], ptr [[a]], i32 {{[0-8]}}
- ; CHECK-NOT: getelementptr inbounds [3 x [3 x i32]], ptr %1, i32 0, i32 {{.*}}, i32 {{.*}}
- ; CHECK-NOT: getelementptr inbounds [3 x i32], [3 x i32]* {{.*}}, i32 {{.*}}, i32 {{.*}}
+ ; CHECK-NOT: getelementptr inbounds [3 x [3 x i32]], ptr %1, i32 0, i32 0, i32 {{[0-2]}}
+ ; CHECK-NOT: getelementptr inbounds [3 x i32], [3 x i32]* {{.*}}, i32 0, i32 {{[0-2]}}
%1 = alloca [3 x [3 x i32]], align 4
%g2d0 = getelementptr inbounds [3 x [3 x i32]], [3 x [3 x i32]]* %1, i32 0, i32 0
- %g1d0_1 = getelementptr inbounds [3 x i32], [3 x i32]* %g2d0, i32 0, i32 0
- %g1d0_2 = getelementptr inbounds [3 x i32], [3 x i32]* %g2d0, i32 0, i32 1
- %g1d0_3 = getelementptr inbounds [3 x i32], [3 x i32]* %g2d0, i32 0, i32 2
+ %g1d_1 = getelementptr inbounds [3 x i32], [3 x i32]* %g2d0, i32 0, i32 0
+ %g1d_2 = getelementptr inbounds [3 x i32], [3 x i32]* %g2d0, i32 0, i32 1
+ %g1d_3 = getelementptr inbounds [3 x i32], [3 x i32]* %g2d0, i32 0, i32 2
%g2d1 = getelementptr inbounds [3 x [3 x i32]], [3 x [3 x i32]]* %1, i32 0, i32 1
%g1d1_1 = getelementptr inbounds [3 x i32], [3 x i32]* %g2d1, i32 0, i32 0
%g1d1_2 = getelementptr inbounds [3 x i32], [3 x i32]* %g2d1, i32 0, i32 1
@@ -49,11 +49,16 @@ define void @gep_2d_test () {
; CHECK-LABEL: gep_3d_test
define void @gep_3d_test () {
+ ; CHECK: [[a:%.*]] = alloca [8 x i32], align 4
+ ; CHECK-COUNT-8: getelementptr inbounds [8 x i32], ptr [[a]], i32 {{[0-7]}}
+ ; CHECK-NOT: getelementptr inbounds [2 x[2 x [2 x i32]]], ptr %1, i32 0, i32 0, i32 {{[0-1]}}
+ ; CHECK-NOT: getelementptr inbounds [2 x [2 x i32]], ptr {{.*}}, i32 0, i32 0, i32 {{[0-1]}}
+ ; CHECK-NOT: getelementptr inbounds [2 x i32], [2 x i32]* {{.*}}, i32 0, i32 {{[0-1]}}
%1 = alloca [2 x[2 x [2 x i32]]], align 4
%g3d0 = getelementptr inbounds [2 x[2 x [2 x i32]]], [2 x[2 x [2 x i32]]]* %1, i32 0, i32 0
%g2d0 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g3d0, i32 0, i32 0
- %g1d0_1 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0, i32 0, i32 0
- %g1d0_2 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0, i32 0, i32 1
+ %g1d_1 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0, i32 0, i32 0
+ %g1d_2 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0, i32 0, i32 1
%g2d1 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g3d0, i32 0, i32 1
%g1d1_1 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d1, i32 0, i32 0
%g1d1_2 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d1, i32 0, i32 1
@@ -69,6 +74,66 @@ define void @gep_3d_test () {
; CHECK-LABEL: gep_4d_test
define void @gep_4d_test () {
+ ; CHECK: [[a:%.*]] = alloca [16 x i32], align 4
+ ; CHECK-COUNT-16: getelementptr inbounds [16 x i32], ptr [[a]], i32 {{[0-9]|1[0-5]}}
+ ; CHECK-NOT: getelementptr inbounds [2x[2 x[2 x [2 x i32]]]],, ptr %1, i32 0, i32 0, i32 {{[0-1]}}
+ ; CHECK-NOT: getelementptr inbounds [2 x[2 x [2 x i32]]], ptr {{.*}}, i32 0, i32 0, i32 {{[0-1]}}
+ ; CHECK-NOT: getelementptr inbounds [2 x [2 x i32]], ptr {{.*}}, i32 0, i32 0, i32 {{[0-1]}}
+ ; CHECK-NOT: getelementptr inbounds [2 x i32], [2 x i32]* {{.*}}, i32 0, i32 {{[0-1]}}
%1 = alloca [2x[2 x[2 x [2 x i32]]]], align 4
+ %g4d0 = getelementptr inbounds [2x[2 x[2 x [2 x i32]]]], [2x[2 x[2 x [2 x i32]]]]* %1, i32 0, i32 0
+ %g3d0 = getelementptr inbounds [2 x[2 x [2 x i32]]], [2 x[2 x [2 x i32]]]* %g4d0, i32 0, i32 0
+ %g2d0_0 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g3d0, i32 0, i32 0
+ %g1d_0 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0_0, i32 0, i32 0
+ %g1d_1 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0_0, i32 0, i32 1
+ %g2d0_1 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g3d0, i32 0, i32 1
+ %g1d_2 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0_1, i32 0, i32 0
+ %g1d_3 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0_1, i32 0, i32 1
+ %g3d1 = getelementptr inbounds [2 x[2 x [2 x i32]]], [2 x[2 x [2 x i32]]]* %g4d0, i32 0, i32 1
+ %g2d0_2 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g3d1, i32 0, i32 0
+ %g1d_4 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0_2, i32 0, i32 0
+ %g1d_5 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0_2, i32 0, i32 1
+ %g2d1_2 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g3d1, i32 0, i32 1
+ %g1d_6 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d1_2, i32 0, i32 0
+ %g1d_7 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d1_2, i32 0, i32 1
+ %g4d1 = getelementptr inbounds [2x[2 x[2 x [2 x i32]]]], [2x[2 x[2 x [2 x i32]]]]* %1, i32 0, i32 1
+ %g3d0_1 = getelementptr inbounds [2 x[2 x [2 x i32]]], [2 x[2 x [2 x i32]]]* %g4d1, i32 0, i32 0
+ %g2d0_3 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g3d0_1, i32 0, i32 0
+ %g1d_8 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0_3, i32 0, i32 0
+ %g1d_9 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0_3, i32 0, i32 1
+ %g2d0_4 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g3d0_1, i32 0, i32 1
+ %g1d_10 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0_4, i32 0, i32 0
+ %g1d_11 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0_4, i32 0, i32 1
+ %g3d1_1 = getelementptr inbounds [2 x[2 x [2 x i32]]], [2 x[2 x [2 x i32]]]* %g4d1, i32 0, i32 1
+ %g2d0_5 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g3d1_1, i32 0, i32 0
+ %g1d_12 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0_5, i32 0, i32 0
+ %g1d_13 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d0_5, i32 0, i32 1
+ %g2d1_3 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g3d1_1, i32 0, i32 1
+ %g1d_14 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d1_3, i32 0, i32 0
+ %g1d_15 = getelementptr inbounds [2 x i32], [2 x i32]* %g2d1_3, i32 0, i32 1
+ ret void
+}
+
+; CHECK-LABEL: bitcast_2d_test
+define void @bitcast_2d_test () {
+ ; CHECK: alloca [9 x i32], align 4
+ %1 = alloca [3 x [3 x i32]], align 4
+ bitcast [3 x [3 x i32]]* %1 to i8*
+ ret void
+}
+
+; CHECK-LABEL: bitcast_3d_test
+define void @bitcast_3d_test () {
+ ; CHECK: alloca [8 x i32], align 4
+ %1 = alloca [2 x[2 x [2 x i32]]], align 4
+ bitcast [2 x[2 x [2 x i32]]]* %1 to i8*
+ ret void
+}
+
+; CHECK-LABEL: bitcast_4d_test
+define void @bitcast_4d_test () {
+ ; CHECK: alloca [16 x i32], align 4
+ %1 = alloca [2x[2 x[2 x [2 x i32]]]], align 4
+ bitcast [2x[2 x[2 x [2 x i32]]]]* %1 to i8*
ret void
}
\ No newline at end of file
>From f6797ebf9c78976a232aa22695c28e4bc9519c3c Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Wed, 16 Oct 2024 18:56:26 -0400
Subject: [PATCH 3/7] fix up test cases. still need to resolve the ConstantExpr
case.
---
llvm/lib/Target/DirectX/DXILFlattenArrays.cpp | 175 ++++++++++++++++--
.../Target/DirectX/DirectXPassRegistry.def | 1 +
.../Target/DirectX/DirectXTargetMachine.cpp | 2 +-
llvm/test/CodeGen/DirectX/flatten-array.ll | 50 +++--
llvm/test/CodeGen/DirectX/llc-pipeline.ll | 1 +
llvm/test/CodeGen/DirectX/scalar-data.ll | 11 +-
llvm/test/CodeGen/DirectX/scalar-load.ll | 23 ++-
llvm/test/CodeGen/DirectX/scalar-store.ll | 11 +-
8 files changed, 218 insertions(+), 56 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
index dfa6cd221f1733..443c08610a508a 100644
--- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
+++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
@@ -41,7 +41,7 @@ class DXILFlattenArraysLegacy : public ModulePass {
};
struct GEPData {
- ArrayType* ParentArrayType;
+ ArrayType *ParentArrayType;
Value *ParendOperand;
SmallVector<ConstantInt *> Indices;
SmallVector<uint64_t> Dims;
@@ -72,20 +72,22 @@ class DXILFlattenArraysVisitor
bool visitStoreInst(StoreInst &SI) { return false; }
bool visitCallInst(CallInst &ICI) { return false; }
bool visitFreezeInst(FreezeInst &FI) { return false; }
+ static bool isMultiDimensionalArray(Type *T);
+ static unsigned getTotalElements(Type *ArrayTy);
+ static Type *getBaseElementType(Type *ArrayTy);
private:
SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
DenseMap<GetElementPtrInst *, GEPData> GEPChainMap;
bool finish();
- bool isMultiDimensionalArray(Type *T);
ConstantInt *flattenIndices(ArrayRef<ConstantInt *> Indices,
ArrayRef<uint64_t> Dims, IRBuilder<> &Builder);
- unsigned getTotalElements(Type *ArrayTy);
- Type *getBaseElementType(Type *ArrayTy);
void recursivelyCollectGEPs(
- GetElementPtrInst &CurrGEP, ArrayType *FattenedArrayType, Value *PtrOperand,
+ GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType,
+ Value *PtrOperand,
SmallVector<ConstantInt *> Indices = SmallVector<ConstantInt *>(),
SmallVector<uint64_t> Dims = SmallVector<uint64_t>());
+ ConstantInt *computeFlatIndex(GetElementPtrInst& GEP);
bool visitGetElementPtrInstInGEPChain(GetElementPtrInst &GEP);
void recursivelyDeleteGEPs(GetElementPtrInst *CurrGEP = nullptr);
};
@@ -154,23 +156,58 @@ bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
return true;
}
+ConstantInt *DXILFlattenArraysVisitor::computeFlatIndex(GetElementPtrInst& GEP) {
+ unsigned IndexAmount = GEP.getNumIndices();
+ assert(IndexAmount >= 1 && "Need At least one Index");
+ if(IndexAmount == 1)
+ return dyn_cast<ConstantInt>(GEP.getOperand(GEP.getNumOperands() - 1));
+
+ // Get the type of the base pointer.
+ Type *BaseType = GEP.getSourceElementType();
+
+ // Determine the dimensions of the multi-dimensional array.
+ SmallVector<int64_t> Dimensions;
+ while (auto *ArrType = dyn_cast<ArrayType>(BaseType)) {
+ Dimensions.push_back(ArrType->getNumElements());
+ BaseType = ArrType->getElementType();
+ }
+ unsigned FlatIndex = 0;
+ unsigned Multiplier = 1;
+ unsigned BitWidth = 32;
+ for (const Use &Index : GEP.indices()) {
+ ConstantInt *CurrentIndex = dyn_cast<ConstantInt>(Index);
+ BitWidth = CurrentIndex->getBitWidth();
+ if (!CurrentIndex)
+ return nullptr;
+ int64_t IndexValue = CurrentIndex->getSExtValue();
+ FlatIndex += IndexValue * Multiplier;
+
+ if (!Dimensions.empty()) {
+ Multiplier *= Dimensions.back(); // Use the last dimension size
+ Dimensions.pop_back(); // Remove the last dimension
+ }
+ }
+ return ConstantInt::get(GEP.getContext(), APInt(BitWidth, FlatIndex));
+}
+
void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
- GetElementPtrInst &CurrGEP, ArrayType *FattenedArrayType, Value *PtrOperand,
+ GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType, Value *PtrOperand,
SmallVector<ConstantInt *> Indices, SmallVector<uint64_t> Dims) {
- ConstantInt *LastIndex =
- cast<ConstantInt>(CurrGEP.getOperand(CurrGEP.getNumOperands() - 1));
-
+ ConstantInt *LastIndex = dyn_cast<ConstantInt>(CurrGEP.getOperand(CurrGEP .getNumOperands() - 1));
+ assert(LastIndex && "Flattening a GEP chain only work on constant indicies");
Indices.push_back(LastIndex);
assert(isa<ArrayType>(CurrGEP.getSourceElementType()));
Dims.push_back(
cast<ArrayType>(CurrGEP.getSourceElementType())->getNumElements());
if (!isMultiDimensionalArray(CurrGEP.getSourceElementType())) {
- GEPChainMap.insert(
- {&CurrGEP, {std::move(FattenedArrayType), PtrOperand, std::move(Indices), std::move(Dims)}});
+ GEPChainMap.insert({&CurrGEP,
+ {std::move(FlattenedArrayType), PtrOperand,
+ std::move(Indices), std::move(Dims)}});
}
for (auto *User : CurrGEP.users()) {
if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) {
- recursivelyCollectGEPs(*NestedGEP, FattenedArrayType, PtrOperand, Indices, Dims);
+ recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand, Indices,
+ Dims);
}
}
}
@@ -181,11 +218,11 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain(
GEPData GEPInfo = GEPChainMap.at(&GEP);
ConstantInt *FlatIndex =
flattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
-
+
ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType;
- Value *FlatGEP = Builder.CreateGEP(
- FlattenedArrayType, GEPInfo.ParendOperand, FlatIndex,
- GEP.getName() + ".flat", GEP.isInBounds());
+ Value *FlatGEP =
+ Builder.CreateGEP(FlattenedArrayType, GEPInfo.ParendOperand, FlatIndex,
+ GEP.getName() + ".flat", GEP.isInBounds());
GEP.replaceAllUsesWith(FlatGEP);
GEP.eraseFromParent();
@@ -202,13 +239,22 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
ArrayType *ArrType = cast<ArrayType>(GEP.getSourceElementType());
IRBuilder<> Builder(&GEP);
unsigned TotalElements = getTotalElements(ArrType);
- ArrayType *FattenedArrayType =
+ ArrayType *FlattenedArrayType =
ArrayType::get(getBaseElementType(ArrType), TotalElements);
Value *PtrOperand = GEP.getPointerOperand();
+ if(isa<ConstantInt>(GEP.getOperand(GEP.getNumOperands() - 1))) {
+ recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand);
+ PotentiallyDeadInstrs.emplace_back(&GEP);
+ } else {
+ SmallVector<Value *> Indices(GEP.idx_begin(),GEP.idx_end());
+ Value *FlatGEP =
+ Builder.CreateGEP(FlattenedArrayType, PtrOperand, Indices,
+ GEP.getName() + ".flat", GEP.isInBounds());
- recursivelyCollectGEPs(GEP, FattenedArrayType, PtrOperand);
- PotentiallyDeadInstrs.emplace_back(&GEP);
+ GEP.replaceAllUsesWith(FlatGEP);
+ GEP.eraseFromParent();
+ }
return true;
}
@@ -228,11 +274,100 @@ bool DXILFlattenArraysVisitor::visit(Function &F) {
return MadeChange;
}
+static void collectElements(Constant *Init, SmallVectorImpl<Constant *> &Elements) {
+ // Base case: If Init is not an array, add it directly to the vector.
+ if (!isa<ArrayType>(Init->getType())) {
+ Elements.push_back(Init);
+ return;
+ }
+
+ // Recursive case: Process each element in the array.
+ if (auto *ArrayConstant = dyn_cast<ConstantArray>(Init)) {
+ for (unsigned I = 0; I < ArrayConstant->getNumOperands(); ++I) {
+ collectElements(ArrayConstant->getOperand(I), Elements);
+ }
+ } else if (auto *DataArrayConstant = dyn_cast<ConstantDataArray>(Init)) {
+ for (unsigned I = 0; I < DataArrayConstant->getNumElements(); ++I) {
+ collectElements(DataArrayConstant->getElementAsConstant(I), Elements);
+ }
+ } else {
+ assert( false && "Expected a ConstantArray or ConstantDataArray for array initializer!");
+ }
+}
+
+static Constant *transformInitializer(Constant *Init, Type *OrigType,
+ ArrayType *FlattenedType,
+ LLVMContext &Ctx) {
+ // Handle ConstantAggregateZero (zero-initialized constants)
+ if (isa<ConstantAggregateZero>(Init))
+ return ConstantAggregateZero::get(FlattenedType);
+
+ // Handle UndefValue (undefined constants)
+ if (isa<UndefValue>(Init))
+ return UndefValue::get(FlattenedType);
+
+ if (!isa<ArrayType>(OrigType))
+ return Init;
+
+ SmallVector<Constant *> FlattenedElements;
+ collectElements(Init, FlattenedElements);
+ assert(FlattenedType->getNumElements() == FlattenedElements.size() && "The number of collected elements should match the FlattenedType");
+ return ConstantArray::get(FlattenedType, FlattenedElements);
+}
+
+static void
+flattenGlobalArrays(Module &M,
+ DenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap) {
+ LLVMContext &Ctx = M.getContext();
+ for (GlobalVariable &G : M.globals()) {
+ Type *OrigType = G.getValueType();
+ if (!DXILFlattenArraysVisitor::isMultiDimensionalArray(OrigType))
+ continue;
+
+ ArrayType *ArrType = cast<ArrayType>(OrigType);
+ unsigned TotalElements =
+ DXILFlattenArraysVisitor::getTotalElements(ArrType);
+ ArrayType *FattenedArrayType = ArrayType::get(
+ DXILFlattenArraysVisitor::getBaseElementType(ArrType), TotalElements);
+
+ // Create a new global variable with the updated type
+ // Note: Initializer is set via transformInitializer
+ GlobalVariable *NewGlobal =
+ new GlobalVariable(M, FattenedArrayType, G.isConstant(), G.getLinkage(),
+ /*Initializer=*/nullptr, G.getName() + ".1dim", &G,
+ G.getThreadLocalMode(), G.getAddressSpace(),
+ G.isExternallyInitialized());
+
+ // Copy relevant attributes
+ NewGlobal->setUnnamedAddr(G.getUnnamedAddr());
+ if (G.getAlignment() > 0) {
+ NewGlobal->setAlignment(G.getAlign());
+ }
+
+ if (G.hasInitializer()) {
+ Constant *Init = G.getInitializer();
+ Constant *NewInit =
+ transformInitializer(Init, OrigType, FattenedArrayType, Ctx);
+ NewGlobal->setInitializer(NewInit);
+ }
+ GlobalMap[&G] = NewGlobal;
+ }
+}
+
static bool flattenArrays(Module &M) {
bool MadeChange = false;
DXILFlattenArraysVisitor Impl;
+ DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
+ flattenGlobalArrays(M, GlobalMap);
for (auto &F : make_early_inc_range(M.functions())) {
- MadeChange = Impl.visit(F);
+ if(F.isIntrinsic())
+ continue;
+ MadeChange |= Impl.visit(F);
+ }
+ for (auto &[Old, New] : GlobalMap) {
+ Old->replaceAllUsesWith(New);
+ Old->eraseFromParent();
+ MadeChange |= true;
}
return MadeChange;
}
diff --git a/llvm/lib/Target/DirectX/DirectXPassRegistry.def b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
index ae729a1082b867..a0f864ed39375f 100644
--- a/llvm/lib/Target/DirectX/DirectXPassRegistry.def
+++ b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
@@ -24,6 +24,7 @@ MODULE_ANALYSIS("dxil-resource-md", DXILResourceMDAnalysis())
#define MODULE_PASS(NAME, CREATE_PASS)
#endif
MODULE_PASS("dxil-data-scalarization", DXILDataScalarization())
+MODULE_PASS("dxil-flatten-arrays", DXILFlattenArrays())
MODULE_PASS("dxil-intrinsic-expansion", DXILIntrinsicExpansion())
MODULE_PASS("dxil-op-lower", DXILOpLowering())
MODULE_PASS("dxil-pretty-printer", DXILPrettyPrinterPass(dbgs()))
diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
index 4e06eb8b713f59..59dbf053d6c222 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
@@ -93,8 +93,8 @@ class DirectXPassConfig : public TargetPassConfig {
addPass(createDXILDataScalarizationLegacyPass());
ScalarizerPassOptions DxilScalarOptions;
DxilScalarOptions.ScalarizeLoadStore = true;
- addPass(createScalarizerPass(DxilScalarOptions));
addPass(createDXILFlattenArraysLegacyPass());
+ addPass(createScalarizerPass(DxilScalarOptions));
addPass(createDXILOpLoweringLegacyPass());
addPass(createDXILFinalizeLinkageLegacyPass());
addPass(createDXILTranslateMetadataLegacyPass());
diff --git a/llvm/test/CodeGen/DirectX/flatten-array.ll b/llvm/test/CodeGen/DirectX/flatten-array.ll
index 09cfd63466d079..dcda200c8aadf2 100644
--- a/llvm/test/CodeGen/DirectX/flatten-array.ll
+++ b/llvm/test/CodeGen/DirectX/flatten-array.ll
@@ -76,7 +76,7 @@ define void @gep_3d_test () {
define void @gep_4d_test () {
; CHECK: [[a:%.*]] = alloca [16 x i32], align 4
; CHECK-COUNT-16: getelementptr inbounds [16 x i32], ptr [[a]], i32 {{[0-9]|1[0-5]}}
- ; CHECK-NOT: getelementptr inbounds [2x[2 x[2 x [2 x i32]]]],, ptr %1, i32 0, i32 0, i32 {{[0-1]}}
+ ; CHECK-NOT: getelementptr inbounds [2x[2 x[2 x [2 x i32]]]], ptr %1, i32 0, i32 0, i32 {{[0-1]}}
; CHECK-NOT: getelementptr inbounds [2 x[2 x [2 x i32]]], ptr {{.*}}, i32 0, i32 0, i32 {{[0-1]}}
; CHECK-NOT: getelementptr inbounds [2 x [2 x i32]], ptr {{.*}}, i32 0, i32 0, i32 {{[0-1]}}
; CHECK-NOT: getelementptr inbounds [2 x i32], [2 x i32]* {{.*}}, i32 0, i32 {{[0-1]}}
@@ -114,26 +114,36 @@ define void @gep_4d_test () {
ret void
}
-; CHECK-LABEL: bitcast_2d_test
-define void @bitcast_2d_test () {
- ; CHECK: alloca [9 x i32], align 4
- %1 = alloca [3 x [3 x i32]], align 4
- bitcast [3 x [3 x i32]]* %1 to i8*
- ret void
-}
-; CHECK-LABEL: bitcast_3d_test
-define void @bitcast_3d_test () {
- ; CHECK: alloca [8 x i32], align 4
- %1 = alloca [2 x[2 x [2 x i32]]], align 4
- bitcast [2 x[2 x [2 x i32]]]* %1 to i8*
- ret void
+ at a = internal global [2 x [3 x [4 x i32]]] [[3 x [4 x i32]] [[4 x i32] [i32 0, i32 1, i32 2, i32 3],
+ [4 x i32] [i32 4, i32 5, i32 6, i32 7],
+ [4 x i32] [i32 8, i32 9, i32 10, i32 11]],
+ [3 x [4 x i32]] [[4 x i32] [i32 12, i32 13, i32 14, i32 15],
+ [4 x i32] [i32 16, i32 17, i32 18, i32 19],
+ [4 x i32] [i32 20, i32 21, i32 22, i32 23]]], align 4
+
+ at b = internal global [2 x [3 x [4 x i32]]] zeroinitializer, align 16
+
+define void @global_gep_load() {
+ ; CHECK: load i32, ptr getelementptr inbounds ([24 x i32], ptr @a.1dim, i32 6), align 4
+ ; CHECK-NOT: getelementptr inbounds [2 x [3 x [4 x i32]]]{{.*}}
+ ; CHECK-NOT: getelementptr inbounds [3 x [4 x i32]]{{.*}}
+ ; CHECK-NOT: getelementptr inbounds [4 x i32]{{.*}}
+ %1 = getelementptr inbounds [2 x [3 x [4 x i32]]], [2 x [3 x [4 x i32]]]* @a, i32 0, i32 0
+ %2 = getelementptr inbounds [3 x [4 x i32]], [3 x [4 x i32]]* %1, i32 0, i32 1
+ %3 = getelementptr inbounds [4 x i32], [4 x i32]* %2, i32 0, i32 2
+ %4 = load i32, i32* %3, align 4
+ ret void
}
-; CHECK-LABEL: bitcast_4d_test
-define void @bitcast_4d_test () {
- ; CHECK: alloca [16 x i32], align 4
- %1 = alloca [2x[2 x[2 x [2 x i32]]]], align 4
- bitcast [2x[2 x[2 x [2 x i32]]]]* %1 to i8*
- ret void
+define void @global_gep_store() {
+ ; CHECK: store i32 1, ptr getelementptr inbounds ([24 x i32], ptr @b.1dim, i32 13), align 4
+ ; CHECK-NOT: getelementptr inbounds [2 x [3 x [4 x i32]]]{{.*}}
+ ; CHECK-NOT: getelementptr inbounds [3 x [4 x i32]]{{.*}}
+ ; CHECK-NOT: getelementptr inbounds [4 x i32]{{.*}}
+ %1 = getelementptr inbounds [2 x [3 x [4 x i32]]], [2 x [3 x [4 x i32]]]* @b, i32 0, i32 1
+ %2 = getelementptr inbounds [3 x [4 x i32]], [3 x [4 x i32]]* %1, i32 0, i32 0
+ %3 = getelementptr inbounds [4 x i32], [4 x i32]* %2, i32 0, i32 1
+ store i32 1, i32* %3, align 4
+ ret void
}
\ No newline at end of file
diff --git a/llvm/test/CodeGen/DirectX/llc-pipeline.ll b/llvm/test/CodeGen/DirectX/llc-pipeline.ll
index 224037cfe7fbe3..f0950df08eff5b 100644
--- a/llvm/test/CodeGen/DirectX/llc-pipeline.ll
+++ b/llvm/test/CodeGen/DirectX/llc-pipeline.ll
@@ -9,6 +9,7 @@
; CHECK-NEXT: ModulePass Manager
; CHECK-NEXT: DXIL Intrinsic Expansion
; CHECK-NEXT: DXIL Data Scalarization
+; CHECK-NEXT: DXIL Array Flattener
; CHECK-NEXT: FunctionPass Manager
; CHECK-NEXT: Dominator Tree Construction
; CHECK-NEXT: Scalarize vector operations
diff --git a/llvm/test/CodeGen/DirectX/scalar-data.ll b/llvm/test/CodeGen/DirectX/scalar-data.ll
index c436f1eae4425d..11b76205aa54c6 100644
--- a/llvm/test/CodeGen/DirectX/scalar-data.ll
+++ b/llvm/test/CodeGen/DirectX/scalar-data.ll
@@ -1,4 +1,5 @@
-; RUN: opt -S -passes='dxil-data-scalarization,function(scalarizer<load-store>),dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+; RUN: opt -S -passes='dxil-data-scalarization,function(scalarizer<load-store>),dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s -check-prefix DATACHECK
+; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays,function(scalarizer<load-store>),dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s
; Make sure we don't touch arrays without vectors and that can recurse multiple-dimension arrays of vectors
@@ -8,5 +9,9 @@
; CHECK @staticArray
; CHECK-NOT: @staticArray.scalarized
-; CHECK: @groushared3dArrayofVectors.scalarized = local_unnamed_addr addrspace(3) global [3 x [3 x [3 x [4 x i32]]]] zeroinitializer, align 16
-; CHECK-NOT: @groushared3dArrayofVectors
+; CHECK-NOT: @staticArray.scalarized.1dim
+; CHECK-NOT: @staticArray.1dim
+; DATACHECK: @groushared3dArrayofVectors.scalarized = local_unnamed_addr addrspace(3) global [3 x [3 x [3 x [4 x i32]]]] zeroinitializer, align 16
+; CHECK: @groushared3dArrayofVectors.scalarized.1dim = local_unnamed_addr addrspace(3) global [108 x i32] zeroinitializer, align 16
+; DATACHECK-NOT: @groushared3dArrayofVectors
+; CHECK-NOT: @groushared3dArrayofVectors.scalarized
diff --git a/llvm/test/CodeGen/DirectX/scalar-load.ll b/llvm/test/CodeGen/DirectX/scalar-load.ll
index b911a8f7855bb8..28a453da51f597 100644
--- a/llvm/test/CodeGen/DirectX/scalar-load.ll
+++ b/llvm/test/CodeGen/DirectX/scalar-load.ll
@@ -1,4 +1,5 @@
-; RUN: opt -S -passes='dxil-data-scalarization,function(scalarizer<load-store>),dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+; RUN: opt -S -passes='dxil-data-scalarization,function(scalarizer<load-store>),dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s -check-prefix DATACHECK
+; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays,function(scalarizer<load-store>),dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s
; Make sure we can load groupshared, static vectors and arrays of vectors
@@ -8,20 +9,25 @@
@staticArrayOfVecData = internal global [3 x <4 x i32>] [<4 x i32> <i32 1, i32 2, i32 3, i32 4>, <4 x i32> <i32 5, i32 6, i32 7, i32 8>, <4 x i32> <i32 9, i32 10, i32 11, i32 12>], align 4
@"groushared2dArrayofVectors" = local_unnamed_addr addrspace(3) global [3 x [ 3 x <4 x i32>]] zeroinitializer, align 16
-; CHECK: @arrayofVecData.scalarized = local_unnamed_addr addrspace(3) global [2 x [3 x float]] zeroinitializer, align 16
+; DATACHECK: @arrayofVecData.scalarized = local_unnamed_addr addrspace(3) global [2 x [3 x float]] zeroinitializer, align 16
+; CHECK: @arrayofVecData.scalarized.1dim = local_unnamed_addr addrspace(3) global [6 x float] zeroinitializer, align 16
; CHECK: @vecData.scalarized = external addrspace(3) global [4 x i32], align 4
-; CHECK: @staticArrayOfVecData.scalarized = internal global [3 x [4 x i32]] {{\[}}[4 x i32] [i32 1, i32 2, i32 3, i32 4], [4 x i32] [i32 5, i32 6, i32 7, i32 8], [4 x i32] [i32 9, i32 10, i32 11, i32 12]], align 4
-; CHECK: @groushared2dArrayofVectors.scalarized = local_unnamed_addr addrspace(3) global [3 x [3 x [4 x i32]]] zeroinitializer, align 16
+; DATACHECK: @staticArrayOfVecData.scalarized = internal global [3 x [4 x i32]] {{\[}}[4 x i32] [i32 1, i32 2, i32 3, i32 4], [4 x i32] [i32 5, i32 6, i32 7, i32 8], [4 x i32] [i32 9, i32 10, i32 11, i32 12]], align 4
+; CHECK: @staticArrayOfVecData.scalarized.1dim = internal global [12 x i32] [i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12], align 4
+; DATACHECK: @groushared2dArrayofVectors.scalarized = local_unnamed_addr addrspace(3) global [3 x [3 x [4 x i32]]] zeroinitializer, align 16
+; CHECK: @groushared2dArrayofVectors.scalarized.1dim = local_unnamed_addr addrspace(3) global [36 x i32] zeroinitializer, align 16
; CHECK-NOT: @arrayofVecData
+; CHECK-NOT: @arrayofVecData.scalarized
; CHECK-NOT: @vecData
; CHECK-NOT: @staticArrayOfVecData
+; CHECK-NOT: @staticArrayOfVecData.scalarized
; CHECK-NOT: @groushared2dArrayofVectors
-
+; CHECK-NOT: @groushared2dArrayofVectors.scalarized
; CHECK-LABEL: load_array_vec_test
define <4 x i32> @load_array_vec_test() #0 {
- ; CHECK-COUNT-8: load i32, ptr addrspace(3) {{(.*@arrayofVecData.scalarized.*|%.*)}}, align 4
+ ; CHECK-COUNT-8: load i32, ptr addrspace(3) {{(.*@arrayofVecData.scalarized.1dim.*|%.*)}}, align 4
; CHECK-NOT: load i32, ptr addrspace(3) {{.*}}, align 4
%1 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([2 x <4 x i32>], [2 x <4 x i32>] addrspace(3)* @"arrayofVecData", i32 0, i32 0), align 4
%2 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([2 x <4 x i32>], [2 x <4 x i32>] addrspace(3)* @"arrayofVecData", i32 0, i32 1), align 4
@@ -39,7 +45,8 @@ define <4 x i32> @load_vec_test() #0 {
; CHECK-LABEL: load_static_array_of_vec_test
define <4 x i32> @load_static_array_of_vec_test(i32 %index) #0 {
- ; CHECK: getelementptr [3 x [4 x i32]], ptr @staticArrayOfVecData.scalarized, i32 0, i32 %index
+ ; DATACHECK: getelementptr [3 x [4 x i32]], ptr @staticArrayOfVecData.scalarized, i32 0, i32 %index
+ ; CHECK: getelementptr [12 x i32], ptr @staticArrayOfVecData.scalarized.1dim, i32 0, i32 %index
; CHECK-COUNT-4: load i32, ptr {{.*}}, align 4
; CHECK-NOT: load i32, ptr {{.*}}, align 4
%3 = getelementptr inbounds [3 x <4 x i32>], [3 x <4 x i32>]* @staticArrayOfVecData, i32 0, i32 %index
@@ -49,7 +56,7 @@ define <4 x i32> @load_static_array_of_vec_test(i32 %index) #0 {
; CHECK-LABEL: multid_load_test
define <4 x i32> @multid_load_test() #0 {
- ; CHECK-COUNT-8: load i32, ptr addrspace(3) {{(.*@groushared2dArrayofVectors.scalarized.*|%.*)}}, align 4
+ ; CHECK-COUNT-8: load i32, ptr addrspace(3) {{(.*@groushared2dArrayofVectors.scalarized.1dim.*|%.*)}}, align 4
; CHECK-NOT: load i32, ptr addrspace(3) {{.*}}, align 4
%1 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([3 x [3 x <4 x i32>]], [3 x [3 x <4 x i32>]] addrspace(3)* @"groushared2dArrayofVectors", i32 0, i32 0, i32 0), align 4
%2 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([3 x [3 x <4 x i32>]], [3 x [3 x <4 x i32>]] addrspace(3)* @"groushared2dArrayofVectors", i32 0, i32 1, i32 1), align 4
diff --git a/llvm/test/CodeGen/DirectX/scalar-store.ll b/llvm/test/CodeGen/DirectX/scalar-store.ll
index c45481e8cae14f..90630ca4e92b84 100644
--- a/llvm/test/CodeGen/DirectX/scalar-store.ll
+++ b/llvm/test/CodeGen/DirectX/scalar-store.ll
@@ -1,4 +1,5 @@
-; RUN: opt -S -passes='dxil-data-scalarization,scalarizer<load-store>,dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+; RUN: opt -S -passes='dxil-data-scalarization,scalarizer<load-store>,dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s -check-prefix DATACHECK
+; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays,function(scalarizer<load-store>),dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s
; Make sure we can store groupshared, static vectors and arrays of vectors
@@ -6,15 +7,17 @@
@"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16
@"vecData" = external addrspace(3) global <4 x i32>, align 4
-; CHECK: @arrayofVecData.scalarized = local_unnamed_addr addrspace(3) global [2 x [3 x float]] zeroinitializer, align 16
+; DATACHECK: @arrayofVecData.scalarized = local_unnamed_addr addrspace(3) global [2 x [3 x float]] zeroinitializer, align 16
+; CHECK: @arrayofVecData.scalarized.1dim = local_unnamed_addr addrspace(3) global [6 x float] zeroinitializer, align 16
; CHECK: @vecData.scalarized = external addrspace(3) global [4 x i32], align 4
; CHECK-NOT: @arrayofVecData
+; CHECK-NOT: @arrayofVecData.scalarized
; CHECK-NOT: @vecData
; CHECK-LABEL: store_array_vec_test
define void @store_array_vec_test () local_unnamed_addr #0 {
- ; CHECK-COUNT-6: store float {{1|2|3|4|6}}.000000e+00, ptr addrspace(3) {{(.*@arrayofVecData.scalarized.*|%.*)}}, align {{4|8|16}}
- ; CHECK-NOT: store float {{1|2|3|4|6}}.000000e+00, ptr addrspace(3) {{(.*@arrayofVecData.scalarized.*|%.*)}}, align {{4|8|16}}
+ ; CHECK-COUNT-6: store float {{1|2|3|4|6}}.000000e+00, ptr addrspace(3) {{(.*@arrayofVecData.scalarized.1dim.*|%.*)}}, align {{4|8|16}}
+ ; CHECK-NOT: store float {{1|2|3|4|6}}.000000e+00, ptr addrspace(3) {{(.*@arrayofVecData.scalarized.1dim.*|%.*)}}, align {{4|8|16}}
store <3 x float> <float 1.000000e+00, float 2.000000e+00, float 3.000000e+00>, ptr addrspace(3) @"arrayofVecData", align 16
store <3 x float> <float 2.000000e+00, float 4.000000e+00, float 6.000000e+00>, ptr addrspace(3) getelementptr inbounds (i8, ptr addrspace(3) @"arrayofVecData", i32 16), align 16
ret void
>From 1dd694a0597b0ee588ef1dcb52df932ae97178d0 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Wed, 16 Oct 2024 20:38:58 -0400
Subject: [PATCH 4/7] TODO: find bug that prevents 1 dim for Value based index
---
llvm/lib/Target/DirectX/DXILFlattenArrays.cpp | 150 +++++++++++-------
1 file changed, 90 insertions(+), 60 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
index 443c08610a508a..b33e075342b597 100644
--- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
+++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
@@ -43,8 +43,9 @@ class DXILFlattenArraysLegacy : public ModulePass {
struct GEPData {
ArrayType *ParentArrayType;
Value *ParendOperand;
- SmallVector<ConstantInt *> Indices;
+ SmallVector<Value *> Indices;
SmallVector<uint64_t> Dims;
+ bool AllIndicesAreConstInt;
};
class DXILFlattenArraysVisitor
@@ -80,16 +81,20 @@ class DXILFlattenArraysVisitor
SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
DenseMap<GetElementPtrInst *, GEPData> GEPChainMap;
bool finish();
- ConstantInt *flattenIndices(ArrayRef<ConstantInt *> Indices,
- ArrayRef<uint64_t> Dims, IRBuilder<> &Builder);
- void recursivelyCollectGEPs(
- GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType,
- Value *PtrOperand,
- SmallVector<ConstantInt *> Indices = SmallVector<ConstantInt *>(),
- SmallVector<uint64_t> Dims = SmallVector<uint64_t>());
- ConstantInt *computeFlatIndex(GetElementPtrInst& GEP);
+ ConstantInt *constFlattenIndices(ArrayRef<Value *> Indices,
+ ArrayRef<uint64_t> Dims,
+ IRBuilder<> &Builder);
+ Value *instructionFlattenIndices(ArrayRef<Value *> Indices,
+ ArrayRef<uint64_t> Dims,
+ IRBuilder<> &Builder);
+ void
+ recursivelyCollectGEPs(GetElementPtrInst &CurrGEP,
+ ArrayType *FlattenedArrayType, Value *PtrOperand,
+ SmallVector<Value *> Indices = SmallVector<Value *>(),
+ SmallVector<uint64_t> Dims = SmallVector<uint64_t>(),
+ bool AllIndicesAreConstInt = true);
+ ConstantInt *computeFlatIndex(GetElementPtrInst &GEP);
bool visitGetElementPtrInstInGEPChain(GetElementPtrInst &GEP);
- void recursivelyDeleteGEPs(GetElementPtrInst *CurrGEP = nullptr);
};
bool DXILFlattenArraysVisitor::finish() {
@@ -121,10 +126,8 @@ Type *DXILFlattenArraysVisitor::getBaseElementType(Type *ArrayTy) {
return CurrArrayTy;
}
-ConstantInt *
-DXILFlattenArraysVisitor::flattenIndices(ArrayRef<ConstantInt *> Indices,
- ArrayRef<uint64_t> Dims,
- IRBuilder<> &Builder) {
+ConstantInt *DXILFlattenArraysVisitor::constFlattenIndices(
+ ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
assert(Indices.size() == Dims.size() &&
"Indicies and dimmensions should be the same");
unsigned FlatIndex = 0;
@@ -132,12 +135,29 @@ DXILFlattenArraysVisitor::flattenIndices(ArrayRef<ConstantInt *> Indices,
for (int I = Indices.size() - 1; I >= 0; --I) {
unsigned DimSize = Dims[I];
- FlatIndex += Indices[I]->getZExtValue() * Multiplier;
+ ConstantInt *CIndex = dyn_cast<ConstantInt>(Indices[I]);
+ assert(CIndex && "This function expects all indicies to be ConstantInt");
+ FlatIndex += CIndex->getZExtValue() * Multiplier;
Multiplier *= DimSize;
}
return Builder.getInt32(FlatIndex);
}
+Value *DXILFlattenArraysVisitor::instructionFlattenIndices(
+ ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
+ Value *FlatIndex = Builder.getInt32(0);
+ unsigned Multiplier = 1;
+
+ for (unsigned I = 0; I < Indices.size(); ++I) {
+ unsigned DimSize = Dims[I];
+ Value *VDimSize = Builder.getInt32(DimSize * Multiplier);
+ Value *ScaledIndex = Builder.CreateMul(Indices[I], VDimSize);
+ FlatIndex = Builder.CreateAdd(FlatIndex, ScaledIndex);
+ Multiplier *= DimSize;
+ }
+ return FlatIndex;
+}
+
bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
if (!isMultiDimensionalArray(AI.getAllocatedType()))
return false;
@@ -156,58 +176,61 @@ bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
return true;
}
-ConstantInt *DXILFlattenArraysVisitor::computeFlatIndex(GetElementPtrInst& GEP) {
+ConstantInt *
+DXILFlattenArraysVisitor::computeFlatIndex(GetElementPtrInst &GEP) {
unsigned IndexAmount = GEP.getNumIndices();
assert(IndexAmount >= 1 && "Need At least one Index");
- if(IndexAmount == 1)
+ if (IndexAmount == 1)
return dyn_cast<ConstantInt>(GEP.getOperand(GEP.getNumOperands() - 1));
// Get the type of the base pointer.
Type *BaseType = GEP.getSourceElementType();
- // Determine the dimensions of the multi-dimensional array.
- SmallVector<int64_t> Dimensions;
- while (auto *ArrType = dyn_cast<ArrayType>(BaseType)) {
- Dimensions.push_back(ArrType->getNumElements());
- BaseType = ArrType->getElementType();
- }
- unsigned FlatIndex = 0;
- unsigned Multiplier = 1;
- unsigned BitWidth = 32;
- for (const Use &Index : GEP.indices()) {
- ConstantInt *CurrentIndex = dyn_cast<ConstantInt>(Index);
- BitWidth = CurrentIndex->getBitWidth();
- if (!CurrentIndex)
- return nullptr;
- int64_t IndexValue = CurrentIndex->getSExtValue();
- FlatIndex += IndexValue * Multiplier;
-
- if (!Dimensions.empty()) {
- Multiplier *= Dimensions.back(); // Use the last dimension size
- Dimensions.pop_back(); // Remove the last dimension
- }
+ // Determine the dimensions of the multi-dimensional array.
+ SmallVector<int64_t> Dimensions;
+ while (auto *ArrType = dyn_cast<ArrayType>(BaseType)) {
+ Dimensions.push_back(ArrType->getNumElements());
+ BaseType = ArrType->getElementType();
+ }
+ unsigned FlatIndex = 0;
+ unsigned Multiplier = 1;
+ unsigned BitWidth = 32;
+ for (const Use &Index : GEP.indices()) {
+ ConstantInt *CurrentIndex = dyn_cast<ConstantInt>(Index);
+ BitWidth = CurrentIndex->getBitWidth();
+ if (!CurrentIndex)
+ return nullptr;
+ int64_t IndexValue = CurrentIndex->getSExtValue();
+ FlatIndex += IndexValue * Multiplier;
+
+ if (!Dimensions.empty()) {
+ Multiplier *= Dimensions.back(); // Use the last dimension size
+ Dimensions.pop_back(); // Remove the last dimension
}
- return ConstantInt::get(GEP.getContext(), APInt(BitWidth, FlatIndex));
+ }
+ return ConstantInt::get(GEP.getContext(), APInt(BitWidth, FlatIndex));
}
void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
- GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType, Value *PtrOperand,
- SmallVector<ConstantInt *> Indices, SmallVector<uint64_t> Dims) {
- ConstantInt *LastIndex = dyn_cast<ConstantInt>(CurrGEP.getOperand(CurrGEP .getNumOperands() - 1));
- assert(LastIndex && "Flattening a GEP chain only work on constant indicies");
+ GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType,
+ Value *PtrOperand, SmallVector<Value *> Indices, SmallVector<uint64_t> Dims,
+ bool AllIndicesAreConstInt) {
+ Value *LastIndex = CurrGEP.getOperand(CurrGEP.getNumOperands() - 1);
+ AllIndicesAreConstInt &= isa<ConstantInt>(LastIndex);
Indices.push_back(LastIndex);
assert(isa<ArrayType>(CurrGEP.getSourceElementType()));
Dims.push_back(
cast<ArrayType>(CurrGEP.getSourceElementType())->getNumElements());
if (!isMultiDimensionalArray(CurrGEP.getSourceElementType())) {
- GEPChainMap.insert({&CurrGEP,
- {std::move(FlattenedArrayType), PtrOperand,
- std::move(Indices), std::move(Dims)}});
+ GEPChainMap.insert(
+ {&CurrGEP,
+ {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
+ std::move(Dims), AllIndicesAreConstInt}});
}
for (auto *User : CurrGEP.users()) {
if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) {
- recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand, Indices,
- Dims);
+ recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand,
+ Indices, Dims, AllIndicesAreConstInt);
}
}
}
@@ -216,8 +239,12 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain(
GetElementPtrInst &GEP) {
IRBuilder<> Builder(&GEP);
GEPData GEPInfo = GEPChainMap.at(&GEP);
- ConstantInt *FlatIndex =
- flattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
+ Value *FlatIndex;
+ if (GEPInfo.AllIndicesAreConstInt)
+ FlatIndex = constFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
+ else
+ FlatIndex =
+ instructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType;
Value *FlatGEP =
@@ -243,18 +270,17 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
ArrayType::get(getBaseElementType(ArrType), TotalElements);
Value *PtrOperand = GEP.getPointerOperand();
- if(isa<ConstantInt>(GEP.getOperand(GEP.getNumOperands() - 1))) {
- recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand);
- PotentiallyDeadInstrs.emplace_back(&GEP);
- } else {
+ // if(isa<ConstantInt>(GEP.getOperand(GEP.getNumOperands() - 1))) {
+ recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand);
+ PotentiallyDeadInstrs.emplace_back(&GEP);
+ /*} else {
SmallVector<Value *> Indices(GEP.idx_begin(),GEP.idx_end());
Value *FlatGEP =
Builder.CreateGEP(FlattenedArrayType, PtrOperand, Indices,
GEP.getName() + ".flat", GEP.isInBounds());
-
GEP.replaceAllUsesWith(FlatGEP);
GEP.eraseFromParent();
- }
+ }*/
return true;
}
@@ -274,7 +300,8 @@ bool DXILFlattenArraysVisitor::visit(Function &F) {
return MadeChange;
}
-static void collectElements(Constant *Init, SmallVectorImpl<Constant *> &Elements) {
+static void collectElements(Constant *Init,
+ SmallVectorImpl<Constant *> &Elements) {
// Base case: If Init is not an array, add it directly to the vector.
if (!isa<ArrayType>(Init->getType())) {
Elements.push_back(Init);
@@ -291,7 +318,9 @@ static void collectElements(Constant *Init, SmallVectorImpl<Constant *> &Element
collectElements(DataArrayConstant->getElementAsConstant(I), Elements);
}
} else {
- assert( false && "Expected a ConstantArray or ConstantDataArray for array initializer!");
+ assert(
+ false &&
+ "Expected a ConstantArray or ConstantDataArray for array initializer!");
}
}
@@ -311,7 +340,8 @@ static Constant *transformInitializer(Constant *Init, Type *OrigType,
SmallVector<Constant *> FlattenedElements;
collectElements(Init, FlattenedElements);
- assert(FlattenedType->getNumElements() == FlattenedElements.size() && "The number of collected elements should match the FlattenedType");
+ assert(FlattenedType->getNumElements() == FlattenedElements.size() &&
+ "The number of collected elements should match the FlattenedType");
return ConstantArray::get(FlattenedType, FlattenedElements);
}
@@ -360,7 +390,7 @@ static bool flattenArrays(Module &M) {
DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
flattenGlobalArrays(M, GlobalMap);
for (auto &F : make_early_inc_range(M.functions())) {
- if(F.isIntrinsic())
+ if (F.isIntrinsic())
continue;
MadeChange |= Impl.visit(F);
}
>From 9aa36f214f6a3d6d36de2b135d9cc395e033feb0 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Wed, 16 Oct 2024 21:26:11 -0400
Subject: [PATCH 5/7] add instructions to compute index
---
llvm/lib/Target/DirectX/DXILFlattenArrays.cpp | 32 ++++++++++++++-----
llvm/test/CodeGen/DirectX/scalar-load.ll | 2 +-
2 files changed, 25 insertions(+), 9 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
index b33e075342b597..cf2d3865f5c196 100644
--- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
+++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
@@ -90,6 +90,7 @@ class DXILFlattenArraysVisitor
void
recursivelyCollectGEPs(GetElementPtrInst &CurrGEP,
ArrayType *FlattenedArrayType, Value *PtrOperand,
+ unsigned &UseCount,
SmallVector<Value *> Indices = SmallVector<Value *>(),
SmallVector<uint64_t> Dims = SmallVector<uint64_t>(),
bool AllIndicesAreConstInt = true);
@@ -145,13 +146,16 @@ ConstantInt *DXILFlattenArraysVisitor::constFlattenIndices(
Value *DXILFlattenArraysVisitor::instructionFlattenIndices(
ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
+ if (Indices.size() == 1)
+ return Indices[0];
+
Value *FlatIndex = Builder.getInt32(0);
unsigned Multiplier = 1;
- for (unsigned I = 0; I < Indices.size(); ++I) {
+ for (int I = Indices.size() - 1; I >= 0; --I) {
unsigned DimSize = Dims[I];
- Value *VDimSize = Builder.getInt32(DimSize * Multiplier);
- Value *ScaledIndex = Builder.CreateMul(Indices[I], VDimSize);
+ Value *VMultiplier = Builder.getInt32(Multiplier);
+ Value *ScaledIndex = Builder.CreateMul(Indices[I], VMultiplier);
FlatIndex = Builder.CreateAdd(FlatIndex, ScaledIndex);
Multiplier *= DimSize;
}
@@ -213,8 +217,8 @@ DXILFlattenArraysVisitor::computeFlatIndex(GetElementPtrInst &GEP) {
void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType,
- Value *PtrOperand, SmallVector<Value *> Indices, SmallVector<uint64_t> Dims,
- bool AllIndicesAreConstInt) {
+ Value *PtrOperand, unsigned &UseCount, SmallVector<Value *> Indices,
+ SmallVector<uint64_t> Dims, bool AllIndicesAreConstInt) {
Value *LastIndex = CurrGEP.getOperand(CurrGEP.getNumOperands() - 1);
AllIndicesAreConstInt &= isa<ConstantInt>(LastIndex);
Indices.push_back(LastIndex);
@@ -230,9 +234,17 @@ void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
for (auto *User : CurrGEP.users()) {
if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) {
recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand,
- Indices, Dims, AllIndicesAreConstInt);
+ ++UseCount, Indices, Dims, AllIndicesAreConstInt);
}
}
+ assert(Dims.size() == Indices.size());
+ // If the std::moves did not happen the gep chain is incomplete
+ // let save the last state.
+ if (!Dims.empty())
+ GEPChainMap.insert(
+ {&CurrGEP,
+ {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
+ std::move(Dims), AllIndicesAreConstInt}});
}
bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain(
@@ -271,8 +283,12 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
Value *PtrOperand = GEP.getPointerOperand();
// if(isa<ConstantInt>(GEP.getOperand(GEP.getNumOperands() - 1))) {
- recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand);
- PotentiallyDeadInstrs.emplace_back(&GEP);
+ unsigned UseCount = 0;
+ recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand, UseCount);
+ if (UseCount == 0)
+ visitGetElementPtrInstInGEPChain(GEP);
+ else
+ PotentiallyDeadInstrs.emplace_back(&GEP);
/*} else {
SmallVector<Value *> Indices(GEP.idx_begin(),GEP.idx_end());
Value *FlatGEP =
diff --git a/llvm/test/CodeGen/DirectX/scalar-load.ll b/llvm/test/CodeGen/DirectX/scalar-load.ll
index 28a453da51f597..7d01c0cfa7fa69 100644
--- a/llvm/test/CodeGen/DirectX/scalar-load.ll
+++ b/llvm/test/CodeGen/DirectX/scalar-load.ll
@@ -46,7 +46,7 @@ define <4 x i32> @load_vec_test() #0 {
; CHECK-LABEL: load_static_array_of_vec_test
define <4 x i32> @load_static_array_of_vec_test(i32 %index) #0 {
; DATACHECK: getelementptr [3 x [4 x i32]], ptr @staticArrayOfVecData.scalarized, i32 0, i32 %index
- ; CHECK: getelementptr [12 x i32], ptr @staticArrayOfVecData.scalarized.1dim, i32 0, i32 %index
+ ; CHECK: getelementptr [12 x i32], ptr @staticArrayOfVecData.scalarized.1dim, i32 %index
; CHECK-COUNT-4: load i32, ptr {{.*}}, align 4
; CHECK-NOT: load i32, ptr {{.*}}, align 4
%3 = getelementptr inbounds [3 x <4 x i32>], [3 x <4 x i32>]* @staticArrayOfVecData, i32 0, i32 %index
>From 9dfc97928adac10cbcb5e3f08053942ac1ec3c03 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Thu, 24 Oct 2024 15:40:33 -0400
Subject: [PATCH 6/7] save state
---
llvm/lib/Target/DirectX/DXILFlattenArrays.cpp | 63 +++++++++++--------
llvm/test/CodeGen/DirectX/flatten-array.ll | 43 +++++++++++++
2 files changed, 81 insertions(+), 25 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
index cf2d3865f5c196..e4660909c438ee 100644
--- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
+++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
@@ -90,12 +90,14 @@ class DXILFlattenArraysVisitor
void
recursivelyCollectGEPs(GetElementPtrInst &CurrGEP,
ArrayType *FlattenedArrayType, Value *PtrOperand,
- unsigned &UseCount,
+ unsigned &GEPChainUseCount,
SmallVector<Value *> Indices = SmallVector<Value *>(),
SmallVector<uint64_t> Dims = SmallVector<uint64_t>(),
bool AllIndicesAreConstInt = true);
ConstantInt *computeFlatIndex(GetElementPtrInst &GEP);
bool visitGetElementPtrInstInGEPChain(GetElementPtrInst &GEP);
+ bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo,
+ GetElementPtrInst &GEP);
};
bool DXILFlattenArraysVisitor::finish() {
@@ -217,7 +219,7 @@ DXILFlattenArraysVisitor::computeFlatIndex(GetElementPtrInst &GEP) {
void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType,
- Value *PtrOperand, unsigned &UseCount, SmallVector<Value *> Indices,
+ Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector<Value *> Indices,
SmallVector<uint64_t> Dims, bool AllIndicesAreConstInt) {
Value *LastIndex = CurrGEP.getOperand(CurrGEP.getNumOperands() - 1);
AllIndicesAreConstInt &= isa<ConstantInt>(LastIndex);
@@ -225,32 +227,40 @@ void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
assert(isa<ArrayType>(CurrGEP.getSourceElementType()));
Dims.push_back(
cast<ArrayType>(CurrGEP.getSourceElementType())->getNumElements());
- if (!isMultiDimensionalArray(CurrGEP.getSourceElementType())) {
+ bool IsMultiDimArr = isMultiDimensionalArray(CurrGEP.getSourceElementType());
+ if (!IsMultiDimArr) {
+ assert(GEPChainUseCount < FlattenedArrayType->getNumElements());
GEPChainMap.insert(
{&CurrGEP,
{std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
std::move(Dims), AllIndicesAreConstInt}});
+ return;
}
+ bool GepUses = false;
for (auto *User : CurrGEP.users()) {
if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) {
recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand,
- ++UseCount, Indices, Dims, AllIndicesAreConstInt);
+ ++GEPChainUseCount, Indices, Dims, AllIndicesAreConstInt);
+ GepUses = true;
}
}
- assert(Dims.size() == Indices.size());
- // If the std::moves did not happen the gep chain is incomplete
- // let save the last state.
- if (!Dims.empty())
+ // This case is just incase the gep chain doesn't end with a 1d array.
+ if(IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
GEPChainMap.insert(
{&CurrGEP,
{std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
std::move(Dims), AllIndicesAreConstInt}});
+ }
}
bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain(
GetElementPtrInst &GEP) {
- IRBuilder<> Builder(&GEP);
GEPData GEPInfo = GEPChainMap.at(&GEP);
+ return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
+}
+bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase(
+ GEPData &GEPInfo, GetElementPtrInst &GEP) {
+ IRBuilder<> Builder(&GEP);
Value *FlatIndex;
if (GEPInfo.AllIndicesAreConstInt)
FlatIndex = constFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
@@ -282,22 +292,25 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
ArrayType::get(getBaseElementType(ArrType), TotalElements);
Value *PtrOperand = GEP.getPointerOperand();
- // if(isa<ConstantInt>(GEP.getOperand(GEP.getNumOperands() - 1))) {
- unsigned UseCount = 0;
- recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand, UseCount);
- if (UseCount == 0)
- visitGetElementPtrInstInGEPChain(GEP);
- else
- PotentiallyDeadInstrs.emplace_back(&GEP);
- /*} else {
- SmallVector<Value *> Indices(GEP.idx_begin(),GEP.idx_end());
- Value *FlatGEP =
- Builder.CreateGEP(FlattenedArrayType, PtrOperand, Indices,
- GEP.getName() + ".flat", GEP.isInBounds());
- GEP.replaceAllUsesWith(FlatGEP);
- GEP.eraseFromParent();
- }*/
- return true;
+
+ unsigned GEPChainUseCount = 0;
+ recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount);
+
+ // NOTE: hasNUses(0) is not the same as GEPChainUseCount == 0.
+ // Here recursion is used to get the length of the GEP chain.
+ // Handle zero uses here because there won't be an update via
+ // a child in the chain later.
+ if (GEPChainUseCount == 0) {
+ SmallVector<Value *> Indices({GEP.getOperand(GEP.getNumOperands() - 1)});
+ SmallVector<uint64_t> Dims({ArrType->getNumElements()});
+ bool AllIndicesAreConstInt = isa<ConstantInt>(Indices[0]);
+ GEPData GEPInfo{std::move(FlattenedArrayType), PtrOperand,
+ std::move(Indices), std::move(Dims), AllIndicesAreConstInt};
+ return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
+ }
+
+ PotentiallyDeadInstrs.emplace_back(&GEP);
+ return false;
}
bool DXILFlattenArraysVisitor::visit(Function &F) {
diff --git a/llvm/test/CodeGen/DirectX/flatten-array.ll b/llvm/test/CodeGen/DirectX/flatten-array.ll
index dcda200c8aadf2..f3815b7f270717 100644
--- a/llvm/test/CodeGen/DirectX/flatten-array.ll
+++ b/llvm/test/CodeGen/DirectX/flatten-array.ll
@@ -136,6 +136,49 @@ define void @global_gep_load() {
ret void
}
+define void @global_gep_load_index(i32 %row, i32 %col, i32 %timeIndex) {
+; CHECK-LABEL: define void @global_gep_load_index(
+; CHECK-SAME: i32 [[ROW:%.*]], i32 [[COL:%.*]], i32 [[TIMEINDEX:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = mul i32 [[TIMEINDEX]], 1
+; CHECK-NEXT: [[TMP2:%.*]] = add i32 0, [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = mul i32 [[COL]], 4
+; CHECK-NEXT: [[TMP4:%.*]] = add i32 [[TMP2]], [[TMP3]]
+; CHECK-NEXT: [[TMP5:%.*]] = mul i32 [[ROW]], 12
+; CHECK-NEXT: [[TMP6:%.*]] = add i32 [[TMP4]], [[TMP5]]
+; CHECK-NEXT: [[DOTFLAT:%.*]] = getelementptr inbounds [24 x i32], ptr @a.1dim, i32 [[TMP6]]
+; CHECK-NOT: getelementptr inbounds [2 x [3 x [4 x i32]]]{{.*}}
+; CHECK-NOT: getelementptr inbounds [3 x [4 x i32]]{{.*}}
+; CHECK-NOT: getelementptr inbounds [4 x i32]{{.*}}
+; CHECK-NEXT: [[TMP7:%.*]] = load i32, ptr [[DOTFLAT]], align 4
+; CHECK-NEXT: ret void
+;
+ %1 = getelementptr inbounds [2 x [3 x [4 x i32]]], [2 x [3 x [4 x i32]]]* @a, i32 0, i32 %row
+ %2 = getelementptr inbounds [3 x [4 x i32]], [3 x [4 x i32]]* %1, i32 0, i32 %col
+ %3 = getelementptr inbounds [4 x i32], [4 x i32]* %2, i32 0, i32 %timeIndex
+ %4 = load i32, i32* %3, align 4
+ ret void
+}
+
+define void @global_incomplete_gep_chain(i32 %row, i32 %col) {
+; CHECK-LABEL: define void @global_incomplete_gep_chain(
+; CHECK-SAME: i32 [[ROW:%.*]], i32 [[COL:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = mul i32 [[COL]], 1
+; CHECK-NEXT: [[TMP2:%.*]] = add i32 0, [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = mul i32 [[ROW]], 3
+; CHECK-NEXT: [[TMP4:%.*]] = add i32 [[TMP2]], [[TMP3]]
+; CHECK-NEXT: [[DOTFLAT:%.*]] = getelementptr inbounds [24 x i32], ptr @a.1dim, i32 [[TMP4]]
+; CHECK-NOT: getelementptr inbounds [2 x [3 x [4 x i32]]]{{.*}}
+; CHECK-NOT: getelementptr inbounds [3 x [4 x i32]]{{.*}}
+; CHECK-NOT: getelementptr inbounds [4 x i32]{{.*}}
+; CHECK-NEXT: [[TMP5:%.*]] = load i32, ptr [[DOTFLAT]], align 4
+; CHECK-NEXT: ret void
+;
+ %1 = getelementptr inbounds [2 x [3 x [4 x i32]]], [2 x [3 x [4 x i32]]]* @a, i32 0, i32 %row
+ %2 = getelementptr inbounds [3 x [4 x i32]], [3 x [4 x i32]]* %1, i32 0, i32 %col
+ %4 = load i32, i32* %2, align 4
+ ret void
+}
+
define void @global_gep_store() {
; CHECK: store i32 1, ptr getelementptr inbounds ([24 x i32], ptr @b.1dim, i32 13), align 4
; CHECK-NOT: getelementptr inbounds [2 x [3 x [4 x i32]]]{{.*}}
>From ecc84280b2ac8f056f9584344d631be4242cfd23 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Wed, 30 Oct 2024 19:27:09 -0400
Subject: [PATCH 7/7] make load\store gep constexpr split
---
llvm/lib/Target/DirectX/DXILFlattenArrays.cpp | 86 +++++++++----------
llvm/test/CodeGen/DirectX/flatten-array.ll | 8 +-
2 files changed, 47 insertions(+), 47 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
index e4660909c438ee..20c7401e934e6c 100644
--- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
+++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
@@ -20,6 +20,7 @@
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstVisitor.h"
+#include "llvm/IR/ReplaceConstant.h"
#include "llvm/Support/Casting.h"
#include "llvm/Transforms/Utils/Local.h"
#include <cassert>
@@ -69,8 +70,8 @@ class DXILFlattenArraysVisitor
bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
bool visitPHINode(PHINode &PHI) { return false; }
- bool visitLoadInst(LoadInst &LI) { return false; }
- bool visitStoreInst(StoreInst &SI) { return false; }
+ bool visitLoadInst(LoadInst &LI);
+ bool visitStoreInst(StoreInst &SI);
bool visitCallInst(CallInst &ICI) { return false; }
bool visitFreezeInst(FreezeInst &FI) { return false; }
static bool isMultiDimensionalArray(Type *T);
@@ -94,7 +95,6 @@ class DXILFlattenArraysVisitor
SmallVector<Value *> Indices = SmallVector<Value *>(),
SmallVector<uint64_t> Dims = SmallVector<uint64_t>(),
bool AllIndicesAreConstInt = true);
- ConstantInt *computeFlatIndex(GetElementPtrInst &GEP);
bool visitGetElementPtrInstInGEPChain(GetElementPtrInst &GEP);
bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo,
GetElementPtrInst &GEP);
@@ -164,6 +164,38 @@ Value *DXILFlattenArraysVisitor::instructionFlattenIndices(
return FlatIndex;
}
+bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst &LI) {
+ unsigned NumOperands = LI.getNumOperands();
+ for (unsigned I = 0; I < NumOperands; ++I) {
+ Value *CurrOpperand = LI.getOperand(I);
+ ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
+ if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
+ convertUsersOfConstantsToInstructions(CE,
+ /*RestrictToFunc=*/nullptr,
+ /*RemoveDeadConstants=*/false,
+ /*IncludeSelf=*/true);
+ return false;
+ }
+ }
+ return false;
+}
+
+bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst &SI) {
+ unsigned NumOperands = SI.getNumOperands();
+ for (unsigned I = 0; I < NumOperands; ++I) {
+ Value *CurrOpperand = SI.getOperand(I);
+ ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
+ if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
+ convertUsersOfConstantsToInstructions(CE,
+ /*RestrictToFunc=*/nullptr,
+ /*RemoveDeadConstants=*/false,
+ /*IncludeSelf=*/true);
+ return false;
+ }
+ }
+ return false;
+}
+
bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
if (!isMultiDimensionalArray(AI.getAllocatedType()))
return false;
@@ -182,41 +214,6 @@ bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
return true;
}
-ConstantInt *
-DXILFlattenArraysVisitor::computeFlatIndex(GetElementPtrInst &GEP) {
- unsigned IndexAmount = GEP.getNumIndices();
- assert(IndexAmount >= 1 && "Need At least one Index");
- if (IndexAmount == 1)
- return dyn_cast<ConstantInt>(GEP.getOperand(GEP.getNumOperands() - 1));
-
- // Get the type of the base pointer.
- Type *BaseType = GEP.getSourceElementType();
-
- // Determine the dimensions of the multi-dimensional array.
- SmallVector<int64_t> Dimensions;
- while (auto *ArrType = dyn_cast<ArrayType>(BaseType)) {
- Dimensions.push_back(ArrType->getNumElements());
- BaseType = ArrType->getElementType();
- }
- unsigned FlatIndex = 0;
- unsigned Multiplier = 1;
- unsigned BitWidth = 32;
- for (const Use &Index : GEP.indices()) {
- ConstantInt *CurrentIndex = dyn_cast<ConstantInt>(Index);
- BitWidth = CurrentIndex->getBitWidth();
- if (!CurrentIndex)
- return nullptr;
- int64_t IndexValue = CurrentIndex->getSExtValue();
- FlatIndex += IndexValue * Multiplier;
-
- if (!Dimensions.empty()) {
- Multiplier *= Dimensions.back(); // Use the last dimension size
- Dimensions.pop_back(); // Remove the last dimension
- }
- }
- return ConstantInt::get(GEP.getContext(), APInt(BitWidth, FlatIndex));
-}
-
void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType,
Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector<Value *> Indices,
@@ -240,12 +237,13 @@ void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
for (auto *User : CurrGEP.users()) {
if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) {
recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand,
- ++GEPChainUseCount, Indices, Dims, AllIndicesAreConstInt);
+ ++GEPChainUseCount, Indices, Dims,
+ AllIndicesAreConstInt);
GepUses = true;
}
}
// This case is just incase the gep chain doesn't end with a 1d array.
- if(IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
+ if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
GEPChainMap.insert(
{&CurrGEP,
{std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
@@ -295,10 +293,10 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
unsigned GEPChainUseCount = 0;
recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount);
-
+
// NOTE: hasNUses(0) is not the same as GEPChainUseCount == 0.
// Here recursion is used to get the length of the GEP chain.
- // Handle zero uses here because there won't be an update via
+ // Handle zero uses here because there won't be an update via
// a child in the chain later.
if (GEPChainUseCount == 0) {
SmallVector<Value *> Indices({GEP.getOperand(GEP.getNumOperands() - 1)});
@@ -308,7 +306,7 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
std::move(Indices), std::move(Dims), AllIndicesAreConstInt};
return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
}
-
+
PotentiallyDeadInstrs.emplace_back(&GEP);
return false;
}
@@ -426,7 +424,7 @@ static bool flattenArrays(Module &M) {
for (auto &[Old, New] : GlobalMap) {
Old->replaceAllUsesWith(New);
Old->eraseFromParent();
- MadeChange |= true;
+ MadeChange = true;
}
return MadeChange;
}
diff --git a/llvm/test/CodeGen/DirectX/flatten-array.ll b/llvm/test/CodeGen/DirectX/flatten-array.ll
index f3815b7f270717..2dd7cd988837b5 100644
--- a/llvm/test/CodeGen/DirectX/flatten-array.ll
+++ b/llvm/test/CodeGen/DirectX/flatten-array.ll
@@ -125,7 +125,8 @@ define void @gep_4d_test () {
@b = internal global [2 x [3 x [4 x i32]]] zeroinitializer, align 16
define void @global_gep_load() {
- ; CHECK: load i32, ptr getelementptr inbounds ([24 x i32], ptr @a.1dim, i32 6), align 4
+ ; CHECK: [[GEP_PTR:%.*]] = getelementptr inbounds [24 x i32], ptr @a.1dim, i32 6
+ ; CHECK: load i32, ptr [[GEP_PTR]], align 4
; CHECK-NOT: getelementptr inbounds [2 x [3 x [4 x i32]]]{{.*}}
; CHECK-NOT: getelementptr inbounds [3 x [4 x i32]]{{.*}}
; CHECK-NOT: getelementptr inbounds [4 x i32]{{.*}}
@@ -180,7 +181,8 @@ define void @global_incomplete_gep_chain(i32 %row, i32 %col) {
}
define void @global_gep_store() {
- ; CHECK: store i32 1, ptr getelementptr inbounds ([24 x i32], ptr @b.1dim, i32 13), align 4
+ ; CHECK: [[GEP_PTR:%.*]] = getelementptr inbounds [24 x i32], ptr @b.1dim, i32 13
+ ; CHECK: store i32 1, ptr [[GEP_PTR]], align 4
; CHECK-NOT: getelementptr inbounds [2 x [3 x [4 x i32]]]{{.*}}
; CHECK-NOT: getelementptr inbounds [3 x [4 x i32]]{{.*}}
; CHECK-NOT: getelementptr inbounds [4 x i32]{{.*}}
@@ -189,4 +191,4 @@ define void @global_gep_store() {
%3 = getelementptr inbounds [4 x i32], [4 x i32]* %2, i32 0, i32 1
store i32 1, i32* %3, align 4
ret void
-}
\ No newline at end of file
+}
More information about the llvm-commits
mailing list