[llvm] [DirectX] Data Scalarization of Vectors in Global Scope (PR #110029)
Farzon Lotfi via llvm-commits
llvm-commits at lists.llvm.org
Wed Sep 25 18:37:51 PDT 2024
https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/110029
>From ae9090aa72f8511bbecbd1f3690ef6e7f452e864 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Mon, 23 Sep 2024 22:36:12 -0400
Subject: [PATCH 1/5] [DirectX] Data Scalarization
---
llvm/lib/Target/DirectX/CMakeLists.txt | 1 +
.../Target/DirectX/DXILDataScalarization.cpp | 312 ++++++++++++++++++
.../Target/DirectX/DXILDataScalarization.h | 35 ++
llvm/lib/Target/DirectX/DirectX.h | 6 +
.../Target/DirectX/DirectXTargetMachine.cpp | 2 +
llvm/test/CodeGen/DirectX/scalar-load.ll | 40 +++
llvm/test/CodeGen/DirectX/scalar-store.ll | 34 +-
7 files changed, 418 insertions(+), 12 deletions(-)
create mode 100644 llvm/lib/Target/DirectX/DXILDataScalarization.cpp
create mode 100644 llvm/lib/Target/DirectX/DXILDataScalarization.h
create mode 100644 llvm/test/CodeGen/DirectX/scalar-load.ll
diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt
index 7e0f8a145505e0..c8ef0ef6f7e702 100644
--- a/llvm/lib/Target/DirectX/CMakeLists.txt
+++ b/llvm/lib/Target/DirectX/CMakeLists.txt
@@ -20,6 +20,7 @@ add_llvm_target(DirectXCodeGen
DirectXTargetMachine.cpp
DirectXTargetTransformInfo.cpp
DXContainerGlobals.cpp
+ DXILDataScalarization.cpp
DXILFinalizeLinkage.cpp
DXILIntrinsicExpansion.cpp
DXILOpBuilder.cpp
diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
new file mode 100644
index 00000000000000..e689c0b06fccfc
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -0,0 +1,312 @@
+//===- DXILDataScalarization.cpp - Prepare LLVM Module for DXIL Data
+//Legalization----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===--------------------------------------------------------------------------------===//
+
+#include "DXILDataScalarization.h"
+#include "DirectX.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstVisitor.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Operator.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/IR/ReplaceConstant.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Utils/Local.h"
+#include <utility>
+
+#define DEBUG_TYPE "dxil-data-scalarization"
+#define Max_VEC_SIZE 4
+
+using namespace llvm;
+
+static void findAndReplaceVectors(Module &M);
+
+class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
+public:
+ DataScalarizerVisitor() : GlobalMap() {}
+ bool visit(Function &F);
+ // InstVisitor methods. They return true if the instruction was scalarized,
+ // false if nothing changed.
+ bool visitInstruction(Instruction &I) { return false; }
+ bool visitSelectInst(SelectInst &SI) { return false; }
+ bool visitICmpInst(ICmpInst &ICI) { return false; }
+ bool visitFCmpInst(FCmpInst &FCI) { return false; }
+ bool visitUnaryOperator(UnaryOperator &UO) { return false; }
+ bool visitBinaryOperator(BinaryOperator &BO) { return false; }
+ bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
+ bool visitCastInst(CastInst &CI) { return false; }
+ bool visitBitCastInst(BitCastInst &BCI) { return false; }
+ bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
+ bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
+ bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
+ bool visitPHINode(PHINode &PHI) { return false; }
+ bool visitLoadInst(LoadInst &LI);
+ bool visitStoreInst(StoreInst &SI);
+ bool visitCallInst(CallInst &ICI) { return false; }
+ bool visitFreezeInst(FreezeInst &FI) { return false; }
+ friend void findAndReplaceVectors(llvm::Module &M);
+
+private:
+ GlobalVariable *getNewGlobalIfExists(Value *CurrOperand);
+ DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
+ SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
+ bool finish();
+};
+
+bool DataScalarizerVisitor::visit(Function &F) {
+ assert(!GlobalMap.empty());
+ ReversePostOrderTraversal<BasicBlock *> RPOT(&F.getEntryBlock());
+ for (BasicBlock *BB : RPOT) {
+ for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
+ Instruction *I = &*II;
+ bool Done = InstVisitor::visit(I);
+ ++II;
+ if (Done && I->getType()->isVoidTy())
+ I->eraseFromParent();
+ }
+ }
+ return finish();
+}
+
+bool DataScalarizerVisitor::finish() {
+ // TODO this should do cleanup
+ RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
+ return true;
+}
+
+GlobalVariable *
+DataScalarizerVisitor::getNewGlobalIfExists(Value *CurrOperand) {
+ if (GlobalVariable *OldGlobal = dyn_cast<GlobalVariable>(CurrOperand)) {
+ auto It = GlobalMap.find(OldGlobal);
+ if (It != GlobalMap.end()) {
+ return It->second; // Found, return the new global
+ }
+ }
+ return nullptr; // Not found
+}
+
+bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
+ for (unsigned I = 0; I < LI.getNumOperands(); ++I) {
+ Value *CurrOpperand = LI.getOperand(I);
+ GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand);
+ if (NewGlobal)
+ LI.setOperand(I, NewGlobal);
+ }
+ return false;
+}
+
+bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
+ bool ReplaceStore = false;
+ for (unsigned I = 0; I < SI.getNumOperands(); ++I) {
+ Value *CurrOpperand = SI.getOperand(I);
+ GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand);
+ if (NewGlobal) {
+ SI.setOperand(I, NewGlobal);
+ /*Value *StoredValue = SI.getValueOperand();
+ Type *StoredType = StoredValue->getType();
+ if (VectorType *VecTy = dyn_cast<VectorType>(StoredType)) {
+ unsigned NumElements = cast<FixedVectorType>(VecTy)->getNumElements();
+ ArrayType *ArrayTy = ArrayType::get(VecTy->getElementType(),
+ NumElements); std::vector<Constant *> ConstElements; if (ConstantVector
+ *ConstVec = dyn_cast<ConstantVector>(StoredValue)) { for(uint I = 0; I <
+ NumElements; I++) { ConstElements.push_back(ConstVec->getOperand(I));
+ }
+ }
+ Value *ArrayValue = ConstantArray::get(ArrayTy,ConstElements);
+ IRBuilder<> Builder(&SI);
+ Builder.CreateStore(ArrayValue, SI.getPointerOperand());
+ replaceStore = true;
+ }*/
+ }
+ }
+ return ReplaceStore;
+}
+
+bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
+ for (unsigned I = 0; I < GEPI.getNumOperands(); ++I) {
+ Value *CurrOpperand = GEPI.getOperand(I);
+ GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand);
+ if (NewGlobal) {
+ // Prepare to create a new GEP for the new global
+ IRBuilder<> Builder(&GEPI); // Create an IRBuilder at the position of GEPI
+
+ SmallVector<Value *, Max_VEC_SIZE> Indices;
+ for (auto &Index : GEPI.indices())
+ Indices.push_back(Index);
+
+ // Create a new GEP for the new global variable
+ Value *NewGEP =
+ Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices);
+
+ // Replace the old GEP with the new one
+ GEPI.replaceAllUsesWith(NewGEP);
+ PotentiallyDeadInstrs.emplace_back(&GEPI);
+ }
+ }
+ return true;
+}
+
+static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
+ if (auto *VecTy = dyn_cast<VectorType>(T))
+ return ArrayType::get(VecTy->getElementType(),
+ cast<FixedVectorType>(VecTy)->getNumElements());
+ if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
+ Type *NewElementType =
+ replaceVectorWithArray(ArrayTy->getElementType(), Ctx);
+ return ArrayType::get(NewElementType, ArrayTy->getNumElements());
+ }
+ // If it's not a vector or array, return the original type.
+ return T;
+}
+
+Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
+ LLVMContext &Ctx) {
+ // Handle ConstantAggregateZero (zero-initialized constants)
+ if (isa<ConstantAggregateZero>(Init)) {
+ return ConstantAggregateZero::get(NewType);
+ }
+
+ // Handle UndefValue (undefined constants)
+ if (isa<UndefValue>(Init)) {
+ return UndefValue::get(NewType);
+ }
+
+ // Handle vector to array transformation
+ if (isa<VectorType>(OrigType) && isa<ArrayType>(NewType)) {
+ // Convert vector initializer to array initializer
+ auto *VecInit = dyn_cast<ConstantVector>(Init);
+ if (!VecInit) {
+ llvm_unreachable("Expected a ConstantVector for vector initializer!");
+ }
+
+ SmallVector<Constant *, Max_VEC_SIZE> ArrayElements;
+ for (unsigned I = 0; I < VecInit->getNumOperands(); ++I) {
+ ArrayElements.push_back(VecInit->getOperand(I));
+ }
+
+ return ConstantArray::get(cast<ArrayType>(NewType), ArrayElements);
+ }
+
+ // Handle array of vectors transformation
+ if (auto *ArrayTy = dyn_cast<ArrayType>(OrigType)) {
+ // Recursively transform array elements
+ auto *ArrayInit = dyn_cast<ConstantArray>(Init);
+ if (!ArrayInit) {
+ llvm_unreachable("Expected a ConstantArray for array initializer!");
+ }
+
+ SmallVector<Constant *, Max_VEC_SIZE> NewArrayElements;
+ for (unsigned I = 0; I < ArrayTy->getNumElements(); ++I) {
+ Constant *NewElemInit = transformInitializer(
+ ArrayInit->getOperand(I), ArrayTy->getElementType(),
+ cast<ArrayType>(NewType)->getElementType(), Ctx);
+ NewArrayElements.push_back(NewElemInit);
+ }
+
+ return ConstantArray::get(cast<ArrayType>(NewType), NewArrayElements);
+ }
+
+ // If not a vector or array, return the original initializer
+ return Init;
+}
+
+static void findAndReplaceVectors(Module &M) {
+ LLVMContext &Ctx = M.getContext();
+ IRBuilder<> Builder(Ctx);
+ DataScalarizerVisitor Impl;
+ for (GlobalVariable &G : M.globals()) {
+ Type *OrigType = G.getValueType();
+ // Recursively replace vectors in the type
+ Type *NewType = replaceVectorWithArray(OrigType, Ctx);
+ if (OrigType != NewType) {
+ // Create a new global variable with the updated type
+ GlobalVariable *NewGlobal = new GlobalVariable(
+ M, NewType, G.isConstant(), G.getLinkage(),
+ // This is set via: transformInitializer
+ nullptr, G.getName() + ".scalarized", &G, G.getThreadLocalMode(),
+ G.getAddressSpace(), G.isExternallyInitialized());
+
+ // Copy relevant attributes
+ NewGlobal->setUnnamedAddr(G.getUnnamedAddr());
+ if (G.getAlignment() > 0) {
+ NewGlobal->setAlignment(Align(G.getAlignment()));
+ }
+
+ if (G.hasInitializer()) {
+ Constant *Init = G.getInitializer();
+ Constant *NewInit = transformInitializer(Init, OrigType, NewType, Ctx);
+ NewGlobal->setInitializer(NewInit);
+ }
+
+ // Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes
+ // type equality
+ // So instead we will use the visitor pattern
+ Impl.GlobalMap[&G] = NewGlobal;
+ for (User *U : G.users()) {
+ if (isa<ConstantExpr>(U) && isa<Operator>(U)) {
+ ConstantExpr *CE = cast<ConstantExpr>(U);
+ convertUsersOfConstantsToInstructions(CE,
+ /*RestrictToFunc=*/nullptr,
+ /*RemoveDeadConstants=*/false,
+ /*IncludeSelf=*/true);
+ }
+ }
+ // Uses should have grown
+ std::vector<User *> UsersToProcess;
+ // Collect all users first
+ // work around so I can delete
+ // in a loop body
+ for (User *U : G.users()) {
+ UsersToProcess.push_back(U);
+ }
+
+ // Now process each user
+ for (User *U : UsersToProcess) {
+ if (isa<Instruction>(U)) {
+ Instruction *Inst = cast<Instruction>(U);
+ Function *F = Inst->getFunction();
+ if (F)
+ Impl.visit(*F);
+ }
+ }
+ }
+ }
+
+ // Remove the old globals after the iteration
+ for (auto Pair : Impl.GlobalMap) {
+ GlobalVariable *OldG = Pair.getFirst();
+ OldG->eraseFromParent();
+ }
+}
+
+PreservedAnalyses DXILDataScalarization::run(Module &M,
+ ModuleAnalysisManager &) {
+ findAndReplaceVectors(M);
+ return PreservedAnalyses::none();
+}
+
+bool DXILDataScalarizationLegacy::runOnModule(Module &M) {
+ findAndReplaceVectors(M);
+ return true;
+}
+
+void DXILDataScalarizationLegacy::getAnalysisUsage(AnalysisUsage &AU) const {}
+
+char DXILDataScalarizationLegacy::ID = 0;
+
+INITIALIZE_PASS_BEGIN(DXILDataScalarizationLegacy, DEBUG_TYPE,
+ "DXIL Data Scalarization", false, false)
+INITIALIZE_PASS_END(DXILDataScalarizationLegacy, DEBUG_TYPE,
+ "DXIL Data Scalarization", false, false)
+
+ModulePass *llvm::createDXILDataScalarizationLegacyPass() {
+ return new DXILDataScalarizationLegacy();
+}
\ No newline at end of file
diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.h b/llvm/lib/Target/DirectX/DXILDataScalarization.h
new file mode 100644
index 00000000000000..d06119397ddb25
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.h
@@ -0,0 +1,35 @@
+//===- DXILDataScalarization.h - Prepare LLVM Module for DXIL Data
+//Legalization----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===------------------------------------------------------------------------------===//
+#ifndef LLVM_TARGET_DIRECTX_DXILDATASCALARIZATION_H
+#define LLVM_TARGET_DIRECTX_DXILDATASCALARIZATION_H
+
+#include "DXILResource.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Pass.h"
+
+namespace llvm {
+
+/// A pass thattransforms Vectors to Arrays
+class DXILDataScalarization : public PassInfoMixin<DXILDataScalarization> {
+public:
+ PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
+};
+
+class DXILDataScalarizationLegacy : public ModulePass {
+
+public:
+ bool runOnModule(Module &M) override;
+ DXILDataScalarizationLegacy() : ModulePass(ID) {}
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override;
+ static char ID; // Pass identification.
+};
+} // namespace llvm
+
+#endif // LLVM_TARGET_DIRECTX_DXILDATASCALARIZATION_H
diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h
index 60fc5094542b37..3221779be2f311 100644
--- a/llvm/lib/Target/DirectX/DirectX.h
+++ b/llvm/lib/Target/DirectX/DirectX.h
@@ -34,6 +34,12 @@ void initializeDXILIntrinsicExpansionLegacyPass(PassRegistry &);
/// Pass to expand intrinsic operations that lack DXIL opCodes
ModulePass *createDXILIntrinsicExpansionLegacyPass();
+/// Initializer for DXIL Data Scalarization Pass
+void initializeDXILDataScalarizationLegacyPass(PassRegistry &);
+
+/// Pass to scalarize llvm global data into a DXIL legal form
+ModulePass *createDXILDataScalarizationLegacyPass();
+
/// Initializer for DXILOpLowering
void initializeDXILOpLoweringLegacyPass(PassRegistry &);
diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
index 606022a9835f04..f358215ecf3735 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
@@ -46,6 +46,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
RegisterTargetMachine<DirectXTargetMachine> X(getTheDirectXTarget());
auto *PR = PassRegistry::getPassRegistry();
initializeDXILIntrinsicExpansionLegacyPass(*PR);
+ initializeDXILDataScalarizationLegacyPass(*PR);
initializeScalarizerLegacyPassPass(*PR);
initializeDXILPrepareModulePass(*PR);
initializeEmbedDXILPassPass(*PR);
@@ -86,6 +87,7 @@ class DirectXPassConfig : public TargetPassConfig {
FunctionPass *createTargetRegisterAllocator(bool) override { return nullptr; }
void addCodeGenPrepare() override {
addPass(createDXILIntrinsicExpansionLegacyPass());
+ addPass(createDXILDataScalarizationLegacyPass());
ScalarizerPassOptions DxilScalarOptions;
DxilScalarOptions.ScalarizeLoadStore = true;
addPass(createScalarizerPass(DxilScalarOptions));
diff --git a/llvm/test/CodeGen/DirectX/scalar-load.ll b/llvm/test/CodeGen/DirectX/scalar-load.ll
new file mode 100644
index 00000000000000..bd99b63883e9f9
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/scalar-load.ll
@@ -0,0 +1,40 @@
+; RUN: opt -S -dxil-data-scalarization -scalarizer -scalarize-load-store -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s
+@"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16
+@"vecData" = external addrspace(3) global <4 x i32>, align 4
+ at staticArrayOfVecData = internal global [3 x <4 x i32>] zeroinitializer, align 4
+
+; CHECK: @arrayofVecData.scalarized = local_unnamed_addr addrspace(3) global [2 x [3 x float]] zeroinitializer, align 16
+; CHECK: @vecData.scalarized = external addrspace(3) global [4 x i32], align 4
+; CHECK: @staticArrayOfVecData.scalarized = internal global [3 x [4 x i32]] zeroinitializer, align 4
+; CHECK-NOT: @arrayofVecData
+; CHECK-NOT: @vecData
+; CHECK-NOT: @staticArrayOfVecData
+
+; CHECK-LABEL: load_array_vec_test
+define <4 x i32> @load_array_vec_test() {
+ ; CHECK-COUNT-8: load i32, ptr addrspace(3) {{(.*@arrayofVecData.scalarized.*|%.*)}}, align 4
+ ; CHECK-NOT: load i32, ptr addrspace(3) {{.*}}, align 4
+ %1 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([2 x <4 x i32>], [2 x <4 x i32>] addrspace(3)* @"arrayofVecData", i32 0, i32 0), align 4
+ %2 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([2 x <4 x i32>], [2 x <4 x i32>] addrspace(3)* @"arrayofVecData", i32 0, i32 1), align 4
+ %3 = add <4 x i32> %1, %2
+ ret <4 x i32> %3
+}
+
+; CHECK-LABEL: load_vec_test
+define <4 x i32> @load_vec_test() {
+ ; CHECK-COUNT-4: load i32, ptr addrspace(3) {{(@vecData.scalarized|getelementptr \(i32, ptr addrspace\(3\) @vecData.scalarized, i32 .*\)|%.*)}}, align {{.*}}
+ ; CHECK-NOT: load i32, ptr addrspace(3) {{.*}}, align 4
+ %1 = load <4 x i32>, <4 x i32> addrspace(3)* @"vecData", align 4
+ ret <4 x i32> %1
+}
+
+; CHECK-LABEL: load_static_array_of_vec_test
+define <4 x i32> @load_static_array_of_vec_test(i32 %index) {
+ ; CHECK: getelementptr [3 x [4 x i32]], ptr @staticArrayOfVecData.scalarized, i32 0, i32 %index
+ ; CHECK-COUNT-4: load i32, ptr {{.*}}, align 4
+ ; CHECK-NOT: load i32, ptr {{.*}}, align 4
+ %3 = getelementptr inbounds [3 x <4 x i32>], [3 x <4 x i32>]* @staticArrayOfVecData, i32 0, i32 %index
+ %4 = load <4 x i32>, <4 x i32>* %3, align 4
+ ret <4 x i32> %4
+}
diff --git a/llvm/test/CodeGen/DirectX/scalar-store.ll b/llvm/test/CodeGen/DirectX/scalar-store.ll
index b970a2842e5a8b..767d2e47c3e8eb 100644
--- a/llvm/test/CodeGen/DirectX/scalar-store.ll
+++ b/llvm/test/CodeGen/DirectX/scalar-store.ll
@@ -1,17 +1,27 @@
-; RUN: opt -S -scalarizer -scalarize-load-store -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+; RUN: opt -S -dxil-data-scalarization -scalarizer -scalarize-load-store -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s
-@"sharedData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16
-; CHECK-LABEL: store_test
-define void @store_test () local_unnamed_addr {
- ; CHECK: store float 1.000000e+00, ptr addrspace(3) {{.*}}, align {{.*}}
- ; CHECK: store float 2.000000e+00, ptr addrspace(3) {{.*}}, align {{.*}}
- ; CHECK: store float 3.000000e+00, ptr addrspace(3) {{.*}}, align {{.*}}
- ; CHECK: store float 2.000000e+00, ptr addrspace(3) {{.*}}, align {{.*}}
- ; CHECK: store float 4.000000e+00, ptr addrspace(3) {{.*}}, align {{.*}}
- ; CHECK: store float 6.000000e+00, ptr addrspace(3) {{.*}}, align {{.*}}
+@"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16
+@"vecData" = external addrspace(3) global <4 x i32>, align 4
- store <3 x float> <float 1.000000e+00, float 2.000000e+00, float 3.000000e+00>, ptr addrspace(3) @"sharedData", align 16
- store <3 x float> <float 2.000000e+00, float 4.000000e+00, float 6.000000e+00>, ptr addrspace(3) getelementptr inbounds (i8, ptr addrspace(3) @"sharedData", i32 16), align 16
+; CHECK: @arrayofVecData.scalarized = local_unnamed_addr addrspace(3) global [2 x [3 x float]] zeroinitializer, align 16
+; CHECK: @vecData.scalarized = external addrspace(3) global [4 x i32], align 4
+; CHECK-NOT: @arrayofVecData
+; CHECK-NOT: @vecData
+
+; CHECK-LABEL: store_array_vec_test
+define void @store_array_vec_test () local_unnamed_addr {
+ ; CHECK-COUNT-6: store float {{1|2|3|4|6}}.000000e+00, ptr addrspace(3) {{(.*@arrayofVecData.scalarized.*|%.*)}}, align {{4|8|16}}
+ ; CHECK-NOT: store float {{1|2|3|4|6}}.000000e+00, ptr addrspace(3) {{(.*@arrayofVecData.scalarized.*|%.*)}}, align {{4|8|16}}
+ store <3 x float> <float 1.000000e+00, float 2.000000e+00, float 3.000000e+00>, ptr addrspace(3) @"arrayofVecData", align 16
+ store <3 x float> <float 2.000000e+00, float 4.000000e+00, float 6.000000e+00>, ptr addrspace(3) getelementptr inbounds (i8, ptr addrspace(3) @"arrayofVecData", i32 16), align 16
ret void
}
+
+; CHECK-LABEL: store_vec_test
+define void @store_vec_test(<4 x i32> %inputVec) {
+ ; CHECK-COUNT-4: store i32 %inputVec.{{.*}}, ptr addrspace(3) {{(@vecData.scalarized|getelementptr \(i32, ptr addrspace\(3\) @vecData.scalarized, i32 .*\)|%.*)}}, align 4
+ ; CHECK-NOT: store i32 %inputVec.{{.*}}, ptr addrspace(3)
+ store <4 x i32> %inputVec, <4 x i32> addrspace(3)* @"vecData", align 4
+ ret void
+}
\ No newline at end of file
>From 1e89daf5684508de52ba66ff54c8ae2d8342cd5a Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Tue, 24 Sep 2024 02:29:05 -0400
Subject: [PATCH 2/5] cleanup comments and dead code
---
.../Target/DirectX/DXILDataScalarization.cpp | 34 +++++--------------
.../Target/DirectX/DXILDataScalarization.h | 4 +--
2 files changed, 11 insertions(+), 27 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
index e689c0b06fccfc..b586e758c473f8 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -1,5 +1,5 @@
//===- DXILDataScalarization.cpp - Prepare LLVM Module for DXIL Data
-//Legalization----===//
+// Legalization----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -77,7 +77,6 @@ bool DataScalarizerVisitor::visit(Function &F) {
}
bool DataScalarizerVisitor::finish() {
- // TODO this should do cleanup
RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
return true;
}
@@ -104,30 +103,14 @@ bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
}
bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
- bool ReplaceStore = false;
for (unsigned I = 0; I < SI.getNumOperands(); ++I) {
Value *CurrOpperand = SI.getOperand(I);
GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand);
if (NewGlobal) {
SI.setOperand(I, NewGlobal);
- /*Value *StoredValue = SI.getValueOperand();
- Type *StoredType = StoredValue->getType();
- if (VectorType *VecTy = dyn_cast<VectorType>(StoredType)) {
- unsigned NumElements = cast<FixedVectorType>(VecTy)->getNumElements();
- ArrayType *ArrayTy = ArrayType::get(VecTy->getElementType(),
- NumElements); std::vector<Constant *> ConstElements; if (ConstantVector
- *ConstVec = dyn_cast<ConstantVector>(StoredValue)) { for(uint I = 0; I <
- NumElements; I++) { ConstElements.push_back(ConstVec->getOperand(I));
- }
- }
- Value *ArrayValue = ConstantArray::get(ArrayTy,ConstElements);
- IRBuilder<> Builder(&SI);
- Builder.CreateStore(ArrayValue, SI.getPointerOperand());
- replaceStore = true;
- }*/
}
}
- return ReplaceStore;
+ return false;
}
bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
@@ -135,18 +118,15 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
Value *CurrOpperand = GEPI.getOperand(I);
GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand);
if (NewGlobal) {
- // Prepare to create a new GEP for the new global
- IRBuilder<> Builder(&GEPI); // Create an IRBuilder at the position of GEPI
+ IRBuilder<> Builder(&GEPI);
SmallVector<Value *, Max_VEC_SIZE> Indices;
for (auto &Index : GEPI.indices())
Indices.push_back(Index);
- // Create a new GEP for the new global variable
Value *NewGEP =
Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices);
- // Replace the old GEP with the new one
GEPI.replaceAllUsesWith(NewGEP);
PotentiallyDeadInstrs.emplace_back(&GEPI);
}
@@ -154,6 +134,7 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
return true;
}
+// Recursively Creates and Array like version of the given vector like type.
static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
if (auto *VecTy = dyn_cast<VectorType>(T))
return ArrayType::get(VecTy->getElementType(),
@@ -197,7 +178,7 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
// Handle array of vectors transformation
if (auto *ArrayTy = dyn_cast<ArrayType>(OrigType)) {
- // Recursively transform array elements
+
auto *ArrayInit = dyn_cast<ConstantArray>(Init);
if (!ArrayInit) {
llvm_unreachable("Expected a ConstantArray for array initializer!");
@@ -205,6 +186,7 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
SmallVector<Constant *, Max_VEC_SIZE> NewArrayElements;
for (unsigned I = 0; I < ArrayTy->getNumElements(); ++I) {
+ // Recursively transform array elements
Constant *NewElemInit = transformInitializer(
ArrayInit->getOperand(I), ArrayTy->getElementType(),
cast<ArrayType>(NewType)->getElementType(), Ctx);
@@ -224,7 +206,7 @@ static void findAndReplaceVectors(Module &M) {
DataScalarizerVisitor Impl;
for (GlobalVariable &G : M.globals()) {
Type *OrigType = G.getValueType();
- // Recursively replace vectors in the type
+
Type *NewType = replaceVectorWithArray(OrigType, Ctx);
if (OrigType != NewType) {
// Create a new global variable with the updated type
@@ -251,6 +233,8 @@ static void findAndReplaceVectors(Module &M) {
// So instead we will use the visitor pattern
Impl.GlobalMap[&G] = NewGlobal;
for (User *U : G.users()) {
+ // Note: The GEPS are stored as constExprs
+ // This step flattens them out to instructions
if (isa<ConstantExpr>(U) && isa<Operator>(U)) {
ConstantExpr *CE = cast<ConstantExpr>(U);
convertUsersOfConstantsToInstructions(CE,
diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.h b/llvm/lib/Target/DirectX/DXILDataScalarization.h
index d06119397ddb25..b6c59f7b33fd40 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.h
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.h
@@ -1,5 +1,5 @@
//===- DXILDataScalarization.h - Prepare LLVM Module for DXIL Data
-//Legalization----===//
+// Legalization----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -15,7 +15,7 @@
namespace llvm {
-/// A pass thattransforms Vectors to Arrays
+/// A pass that transforms Vectors to Arrays
class DXILDataScalarization : public PassInfoMixin<DXILDataScalarization> {
public:
PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
>From 2e21e3d2ea8df7bf2f5cbfa817268acb1e6f8152 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Wed, 25 Sep 2024 03:43:26 -0400
Subject: [PATCH 3/5] check for ConstantDataVector
---
llvm/lib/Target/DirectX/DXILDataScalarization.cpp | 15 ++++++++-------
llvm/test/CodeGen/DirectX/scalar-load.ll | 7 +++++--
2 files changed, 13 insertions(+), 9 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
index b586e758c473f8..0557b21930b7fb 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -163,14 +163,15 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
// Handle vector to array transformation
if (isa<VectorType>(OrigType) && isa<ArrayType>(NewType)) {
// Convert vector initializer to array initializer
- auto *VecInit = dyn_cast<ConstantVector>(Init);
- if (!VecInit) {
- llvm_unreachable("Expected a ConstantVector for vector initializer!");
- }
-
SmallVector<Constant *, Max_VEC_SIZE> ArrayElements;
- for (unsigned I = 0; I < VecInit->getNumOperands(); ++I) {
- ArrayElements.push_back(VecInit->getOperand(I));
+ if( ConstantVector *ConstVecInit = dyn_cast<ConstantVector>(Init)) {
+ for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I)
+ ArrayElements.push_back(ConstVecInit->getOperand(I));
+ } else if (ConstantDataVector *ConstDataVecInit = llvm::dyn_cast<llvm::ConstantDataVector>(Init)) {
+ for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I)
+ ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I));
+ } else {
+ llvm_unreachable("Expected a ConstantVector or ConstantDataVector for vector initializer!");
}
return ConstantArray::get(cast<ArrayType>(NewType), ArrayElements);
diff --git a/llvm/test/CodeGen/DirectX/scalar-load.ll b/llvm/test/CodeGen/DirectX/scalar-load.ll
index bd99b63883e9f9..fa1f1290867b24 100644
--- a/llvm/test/CodeGen/DirectX/scalar-load.ll
+++ b/llvm/test/CodeGen/DirectX/scalar-load.ll
@@ -2,11 +2,14 @@
; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s
@"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16
@"vecData" = external addrspace(3) global <4 x i32>, align 4
- at staticArrayOfVecData = internal global [3 x <4 x i32>] zeroinitializer, align 4
+;@staticArrayOfVecData = internal global [3 x <4 x i32>] zeroinitializer, align 4
+ at staticArrayOfVecData = internal global [3 x <4 x i32>] [<4 x i32> <i32 1, i32 2, i32 3, i32 4>, <4 x i32> <i32 5, i32 6, i32 7, i32 8>, <4 x i32> <i32 9, i32 10, i32 11, i32 12>], align 4
+
; CHECK: @arrayofVecData.scalarized = local_unnamed_addr addrspace(3) global [2 x [3 x float]] zeroinitializer, align 16
; CHECK: @vecData.scalarized = external addrspace(3) global [4 x i32], align 4
-; CHECK: @staticArrayOfVecData.scalarized = internal global [3 x [4 x i32]] zeroinitializer, align 4
+; CHECK: @staticArrayOfVecData.scalarized = internal global [3 x [4 x i32]] {{\[}}[4 x i32] [i32 1, i32 2, i32 3, i32 4], [4 x i32] [i32 5, i32 6, i32 7, i32 8], [4 x i32] [i32 9, i32 10, i32 11, i32 12]], align 4
+
; CHECK-NOT: @arrayofVecData
; CHECK-NOT: @vecData
; CHECK-NOT: @staticArrayOfVecData
>From 6f6e42c881ba1be6387f953300b6488f2fa1f029 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Wed, 25 Sep 2024 15:08:27 -0400
Subject: [PATCH 4/5] fix iterator stability issue
---
.../Target/DirectX/DXILDataScalarization.cpp | 53 +++++++------------
llvm/test/CodeGen/DirectX/llc-pipeline.ll | 1 +
llvm/test/CodeGen/DirectX/scalar-load.ll | 5 +-
llvm/test/CodeGen/DirectX/scalar-store.ll | 2 +-
4 files changed, 25 insertions(+), 36 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
index 0557b21930b7fb..9fa39a5d71d86c 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -1,15 +1,15 @@
-//===- DXILDataScalarization.cpp - Prepare LLVM Module for DXIL Data
-// Legalization----===//
+//===- DXILDataScalarization.cpp - Perform DXIL Data Legalization----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-//===--------------------------------------------------------------------------------===//
+//===----------------------------------------------------------------===//
#include "DXILDataScalarization.h"
#include "DirectX.h"
#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstVisitor.h"
@@ -20,7 +20,6 @@
#include "llvm/IR/Type.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/Local.h"
-#include <utility>
#define DEBUG_TYPE "dxil-data-scalarization"
#define Max_VEC_SIZE 4
@@ -164,14 +163,16 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
if (isa<VectorType>(OrigType) && isa<ArrayType>(NewType)) {
// Convert vector initializer to array initializer
SmallVector<Constant *, Max_VEC_SIZE> ArrayElements;
- if( ConstantVector *ConstVecInit = dyn_cast<ConstantVector>(Init)) {
- for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I)
- ArrayElements.push_back(ConstVecInit->getOperand(I));
- } else if (ConstantDataVector *ConstDataVecInit = llvm::dyn_cast<llvm::ConstantDataVector>(Init)) {
- for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I)
- ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I));
- } else {
- llvm_unreachable("Expected a ConstantVector or ConstantDataVector for vector initializer!");
+ if (ConstantVector *ConstVecInit = dyn_cast<ConstantVector>(Init)) {
+ for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I)
+ ArrayElements.push_back(ConstVecInit->getOperand(I));
+ } else if (ConstantDataVector *ConstDataVecInit =
+ llvm::dyn_cast<llvm::ConstantDataVector>(Init)) {
+ for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I)
+ ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I));
+ } else {
+ llvm_unreachable("Expected a ConstantVector or ConstantDataVector for "
+ "vector initializer!");
}
return ConstantArray::get(cast<ArrayType>(NewType), ArrayElements);
@@ -213,9 +214,10 @@ static void findAndReplaceVectors(Module &M) {
// Create a new global variable with the updated type
GlobalVariable *NewGlobal = new GlobalVariable(
M, NewType, G.isConstant(), G.getLinkage(),
- // This is set via: transformInitializer
- nullptr, G.getName() + ".scalarized", &G, G.getThreadLocalMode(),
- G.getAddressSpace(), G.isExternallyInitialized());
+ // Initializer is set via transformInitializer
+ /*Initializer=*/nullptr, G.getName() + ".scalarized", &G,
+ G.getThreadLocalMode(), G.getAddressSpace(),
+ G.isExternallyInitialized());
// Copy relevant attributes
NewGlobal->setUnnamedAddr(G.getUnnamedAddr());
@@ -230,12 +232,9 @@ static void findAndReplaceVectors(Module &M) {
}
// Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes
- // type equality
- // So instead we will use the visitor pattern
+ // type equality. Instead we will use the visitor pattern.
Impl.GlobalMap[&G] = NewGlobal;
- for (User *U : G.users()) {
- // Note: The GEPS are stored as constExprs
- // This step flattens them out to instructions
+ for (User *U : make_early_inc_range(G.users())) {
if (isa<ConstantExpr>(U) && isa<Operator>(U)) {
ConstantExpr *CE = cast<ConstantExpr>(U);
convertUsersOfConstantsToInstructions(CE,
@@ -243,18 +242,6 @@ static void findAndReplaceVectors(Module &M) {
/*RemoveDeadConstants=*/false,
/*IncludeSelf=*/true);
}
- }
- // Uses should have grown
- std::vector<User *> UsersToProcess;
- // Collect all users first
- // work around so I can delete
- // in a loop body
- for (User *U : G.users()) {
- UsersToProcess.push_back(U);
- }
-
- // Now process each user
- for (User *U : UsersToProcess) {
if (isa<Instruction>(U)) {
Instruction *Inst = cast<Instruction>(U);
Function *F = Inst->getFunction();
@@ -294,4 +281,4 @@ INITIALIZE_PASS_END(DXILDataScalarizationLegacy, DEBUG_TYPE,
ModulePass *llvm::createDXILDataScalarizationLegacyPass() {
return new DXILDataScalarizationLegacy();
-}
\ No newline at end of file
+}
diff --git a/llvm/test/CodeGen/DirectX/llc-pipeline.ll b/llvm/test/CodeGen/DirectX/llc-pipeline.ll
index 46326d69175876..102748508b4ad7 100644
--- a/llvm/test/CodeGen/DirectX/llc-pipeline.ll
+++ b/llvm/test/CodeGen/DirectX/llc-pipeline.ll
@@ -8,6 +8,7 @@
; CHECK-NEXT: Target Transform Information
; CHECK-NEXT: ModulePass Manager
; CHECK-NEXT: DXIL Intrinsic Expansion
+; CHECK-NEXT: DXIL Data Scalarization
; CHECK-NEXT: FunctionPass Manager
; CHECK-NEXT: Dominator Tree Construction
; CHECK-NEXT: Scalarize vector operations
diff --git a/llvm/test/CodeGen/DirectX/scalar-load.ll b/llvm/test/CodeGen/DirectX/scalar-load.ll
index fa1f1290867b24..1f4834ebfd04f4 100644
--- a/llvm/test/CodeGen/DirectX/scalar-load.ll
+++ b/llvm/test/CodeGen/DirectX/scalar-load.ll
@@ -2,17 +2,18 @@
; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s
@"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16
@"vecData" = external addrspace(3) global <4 x i32>, align 4
-;@staticArrayOfVecData = internal global [3 x <4 x i32>] zeroinitializer, align 4
@staticArrayOfVecData = internal global [3 x <4 x i32>] [<4 x i32> <i32 1, i32 2, i32 3, i32 4>, <4 x i32> <i32 5, i32 6, i32 7, i32 8>, <4 x i32> <i32 9, i32 10, i32 11, i32 12>], align 4
-
+ at staticArray = internal global [4 x i32] [i32 1, i32 2, i32 3, i32 4], align 4
; CHECK: @arrayofVecData.scalarized = local_unnamed_addr addrspace(3) global [2 x [3 x float]] zeroinitializer, align 16
; CHECK: @vecData.scalarized = external addrspace(3) global [4 x i32], align 4
; CHECK: @staticArrayOfVecData.scalarized = internal global [3 x [4 x i32]] {{\[}}[4 x i32] [i32 1, i32 2, i32 3, i32 4], [4 x i32] [i32 5, i32 6, i32 7, i32 8], [4 x i32] [i32 9, i32 10, i32 11, i32 12]], align 4
+; Check @staticArray
; CHECK-NOT: @arrayofVecData
; CHECK-NOT: @vecData
; CHECK-NOT: @staticArrayOfVecData
+; CHECK-NOT: @staticArray.scalarized
; CHECK-LABEL: load_array_vec_test
define <4 x i32> @load_array_vec_test() {
diff --git a/llvm/test/CodeGen/DirectX/scalar-store.ll b/llvm/test/CodeGen/DirectX/scalar-store.ll
index 767d2e47c3e8eb..aac4711c3f97f3 100644
--- a/llvm/test/CodeGen/DirectX/scalar-store.ll
+++ b/llvm/test/CodeGen/DirectX/scalar-store.ll
@@ -24,4 +24,4 @@ define void @store_vec_test(<4 x i32> %inputVec) {
; CHECK-NOT: store i32 %inputVec.{{.*}}, ptr addrspace(3)
store <4 x i32> %inputVec, <4 x i32> addrspace(3)* @"vecData", align 4
ret void
-}
\ No newline at end of file
+}
>From 72baf7242e14fd83886021916cef5702a6b7a33d Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Wed, 25 Sep 2024 21:21:22 -0400
Subject: [PATCH 5/5] address pr comments
---
.../Target/DirectX/DXILDataScalarization.cpp | 106 ++++++++++--------
.../Target/DirectX/DXILDataScalarization.h | 16 +--
llvm/test/CodeGen/DirectX/scalar-data.ll | 12 ++
llvm/test/CodeGen/DirectX/scalar-load.ll | 20 +++-
llvm/test/CodeGen/DirectX/scalar-store.ll | 2 +
5 files changed, 95 insertions(+), 61 deletions(-)
create mode 100644 llvm/test/CodeGen/DirectX/scalar-data.ll
diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
index 9fa39a5d71d86c..0e6cf59e257508 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -1,15 +1,16 @@
-//===- DXILDataScalarization.cpp - Perform DXIL Data Legalization----===//
+//===- DXILDataScalarization.cpp - Perform DXIL Data Legalization ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-//===----------------------------------------------------------------===//
+//===---------------------------------------------------------------------===//
#include "DXILDataScalarization.h"
#include "DirectX.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/DXILResource.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstVisitor.h"
@@ -22,11 +23,21 @@
#include "llvm/Transforms/Utils/Local.h"
#define DEBUG_TYPE "dxil-data-scalarization"
-#define Max_VEC_SIZE 4
+static const int MaxVecSize = 4;
using namespace llvm;
-static void findAndReplaceVectors(Module &M);
+class DXILDataScalarizationLegacy : public ModulePass {
+
+public:
+ bool runOnModule(Module &M) override;
+ DXILDataScalarizationLegacy() : ModulePass(ID) {}
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override;
+ static char ID; // Pass identification.
+};
+
+static bool findAndReplaceVectors(Module &M);
class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
public:
@@ -51,10 +62,10 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
bool visitStoreInst(StoreInst &SI);
bool visitCallInst(CallInst &ICI) { return false; }
bool visitFreezeInst(FreezeInst &FI) { return false; }
- friend void findAndReplaceVectors(llvm::Module &M);
+ friend bool findAndReplaceVectors(llvm::Module &M);
private:
- GlobalVariable *getNewGlobalIfExists(Value *CurrOperand);
+ GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);
DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
bool finish();
@@ -81,7 +92,7 @@ bool DataScalarizerVisitor::finish() {
}
GlobalVariable *
-DataScalarizerVisitor::getNewGlobalIfExists(Value *CurrOperand) {
+DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
if (GlobalVariable *OldGlobal = dyn_cast<GlobalVariable>(CurrOperand)) {
auto It = GlobalMap.find(OldGlobal);
if (It != GlobalMap.end()) {
@@ -92,20 +103,20 @@ DataScalarizerVisitor::getNewGlobalIfExists(Value *CurrOperand) {
}
bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
- for (unsigned I = 0; I < LI.getNumOperands(); ++I) {
+ unsigned NumOperands = LI.getNumOperands();
+ for (unsigned I = 0; I < NumOperands; ++I) {
Value *CurrOpperand = LI.getOperand(I);
- GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand);
- if (NewGlobal)
+ if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
LI.setOperand(I, NewGlobal);
}
return false;
}
bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
- for (unsigned I = 0; I < SI.getNumOperands(); ++I) {
+ unsigned NumOperands = SI.getNumOperands();
+ for (unsigned I = 0; I < NumOperands; ++I) {
Value *CurrOpperand = SI.getOperand(I);
- GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand);
- if (NewGlobal) {
+ if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand)) {
SI.setOperand(I, NewGlobal);
}
}
@@ -113,22 +124,23 @@ bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
}
bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
- for (unsigned I = 0; I < GEPI.getNumOperands(); ++I) {
+ unsigned NumOperands = GEPI.getNumOperands();
+ for (unsigned I = 0; I < NumOperands; ++I) {
Value *CurrOpperand = GEPI.getOperand(I);
- GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand);
- if (NewGlobal) {
- IRBuilder<> Builder(&GEPI);
+ GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand);
+ if (!NewGlobal)
+ continue;
+ IRBuilder<> Builder(&GEPI);
- SmallVector<Value *, Max_VEC_SIZE> Indices;
- for (auto &Index : GEPI.indices())
- Indices.push_back(Index);
+ SmallVector<Value *, MaxVecSize> Indices;
+ for (auto &Index : GEPI.indices())
+ Indices.push_back(Index);
- Value *NewGEP =
- Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices);
+ Value *NewGEP =
+ Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices);
- GEPI.replaceAllUsesWith(NewGEP);
- PotentiallyDeadInstrs.emplace_back(&GEPI);
- }
+ GEPI.replaceAllUsesWith(NewGEP);
+ PotentiallyDeadInstrs.emplace_back(&GEPI);
}
return true;
}
@@ -137,7 +149,7 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
if (auto *VecTy = dyn_cast<VectorType>(T))
return ArrayType::get(VecTy->getElementType(),
- cast<FixedVectorType>(VecTy)->getNumElements());
+ dyn_cast<FixedVectorType>(VecTy)->getNumElements());
if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
Type *NewElementType =
replaceVectorWithArray(ArrayTy->getElementType(), Ctx);
@@ -162,7 +174,7 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
// Handle vector to array transformation
if (isa<VectorType>(OrigType) && isa<ArrayType>(NewType)) {
// Convert vector initializer to array initializer
- SmallVector<Constant *, Max_VEC_SIZE> ArrayElements;
+ SmallVector<Constant *, MaxVecSize> ArrayElements;
if (ConstantVector *ConstVecInit = dyn_cast<ConstantVector>(Init)) {
for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I)
ArrayElements.push_back(ConstVecInit->getOperand(I));
@@ -171,8 +183,8 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I)
ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I));
} else {
- llvm_unreachable("Expected a ConstantVector or ConstantDataVector for "
- "vector initializer!");
+ assert(false && "Expected a ConstantVector or ConstantDataVector for "
+ "vector initializer!");
}
return ConstantArray::get(cast<ArrayType>(NewType), ArrayElements);
@@ -180,13 +192,10 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
// Handle array of vectors transformation
if (auto *ArrayTy = dyn_cast<ArrayType>(OrigType)) {
-
auto *ArrayInit = dyn_cast<ConstantArray>(Init);
- if (!ArrayInit) {
- llvm_unreachable("Expected a ConstantArray for array initializer!");
- }
+ assert(ArrayInit && "Expected a ConstantArray for array initializer!");
- SmallVector<Constant *, Max_VEC_SIZE> NewArrayElements;
+ SmallVector<Constant *, MaxVecSize> NewArrayElements;
for (unsigned I = 0; I < ArrayTy->getNumElements(); ++I) {
// Recursively transform array elements
Constant *NewElemInit = transformInitializer(
@@ -202,7 +211,8 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
return Init;
}
-static void findAndReplaceVectors(Module &M) {
+static bool findAndReplaceVectors(Module &M) {
+ bool MadeChange = false;
LLVMContext &Ctx = M.getContext();
IRBuilder<> Builder(Ctx);
DataScalarizerVisitor Impl;
@@ -212,9 +222,9 @@ static void findAndReplaceVectors(Module &M) {
Type *NewType = replaceVectorWithArray(OrigType, Ctx);
if (OrigType != NewType) {
// Create a new global variable with the updated type
+ // Note: Initializer is set via transformInitializer
GlobalVariable *NewGlobal = new GlobalVariable(
M, NewType, G.isConstant(), G.getLinkage(),
- // Initializer is set via transformInitializer
/*Initializer=*/nullptr, G.getName() + ".scalarized", &G,
G.getThreadLocalMode(), G.getAddressSpace(),
G.isExternallyInitialized());
@@ -222,7 +232,7 @@ static void findAndReplaceVectors(Module &M) {
// Copy relevant attributes
NewGlobal->setUnnamedAddr(G.getUnnamedAddr());
if (G.getAlignment() > 0) {
- NewGlobal->setAlignment(Align(G.getAlignment()));
+ NewGlobal->setAlignment(G.getAlign());
}
if (G.hasInitializer()) {
@@ -253,24 +263,30 @@ static void findAndReplaceVectors(Module &M) {
}
// Remove the old globals after the iteration
- for (auto Pair : Impl.GlobalMap) {
- GlobalVariable *OldG = Pair.getFirst();
- OldG->eraseFromParent();
+ for (auto &[Old, New] : Impl.GlobalMap) {
+ Old->eraseFromParent();
+ MadeChange = true;
}
+ return MadeChange;
}
PreservedAnalyses DXILDataScalarization::run(Module &M,
ModuleAnalysisManager &) {
- findAndReplaceVectors(M);
- return PreservedAnalyses::none();
+ bool MadeChanges = findAndReplaceVectors(M);
+ if (!MadeChanges)
+ return PreservedAnalyses::all();
+ PreservedAnalyses PA;
+ PA.preserve<DXILResourceAnalysis>();
+ return PA;
}
bool DXILDataScalarizationLegacy::runOnModule(Module &M) {
- findAndReplaceVectors(M);
- return true;
+ return findAndReplaceVectors(M);
}
-void DXILDataScalarizationLegacy::getAnalysisUsage(AnalysisUsage &AU) const {}
+void DXILDataScalarizationLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.addPreserved<DXILResourceWrapperPass>();
+}
char DXILDataScalarizationLegacy::ID = 0;
diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.h b/llvm/lib/Target/DirectX/DXILDataScalarization.h
index b6c59f7b33fd40..560e061db96d08 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.h
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.h
@@ -1,11 +1,11 @@
-//===- DXILDataScalarization.h - Prepare LLVM Module for DXIL Data
-// Legalization----===//
+//===- DXILDataScalarization.h - Perform DXIL Data Legalization -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-//===------------------------------------------------------------------------------===//
+//===---------------------------------------------------------------------===//
+
#ifndef LLVM_TARGET_DIRECTX_DXILDATASCALARIZATION_H
#define LLVM_TARGET_DIRECTX_DXILDATASCALARIZATION_H
@@ -20,16 +20,6 @@ class DXILDataScalarization : public PassInfoMixin<DXILDataScalarization> {
public:
PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
};
-
-class DXILDataScalarizationLegacy : public ModulePass {
-
-public:
- bool runOnModule(Module &M) override;
- DXILDataScalarizationLegacy() : ModulePass(ID) {}
-
- void getAnalysisUsage(AnalysisUsage &AU) const override;
- static char ID; // Pass identification.
-};
} // namespace llvm
#endif // LLVM_TARGET_DIRECTX_DXILDATASCALARIZATION_H
diff --git a/llvm/test/CodeGen/DirectX/scalar-data.ll b/llvm/test/CodeGen/DirectX/scalar-data.ll
new file mode 100644
index 00000000000000..4438604a3a8797
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/scalar-data.ll
@@ -0,0 +1,12 @@
+; RUN: opt -S -dxil-data-scalarization -scalarizer -scalarize-load-store -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s
+
+; Make sure we don't touch arrays without vectors and that can recurse multiple-dimension arrays of vectors
+
+ at staticArray = internal global [4 x i32] [i32 1, i32 2, i32 3, i32 4], align 4
+@"groushared3dArrayofVectors" = local_unnamed_addr addrspace(3) global [3 x [3 x [3 x <4 x i32>]]] zeroinitializer, align 16
+
+; CHECK @staticArray
+; CHECK-NOT: @staticArray.scalarized
+; CHECK: @groushared3dArrayofVectors.scalarized = local_unnamed_addr addrspace(3) global [3 x [3 x [3 x [4 x i32]]]] zeroinitializer, align 16
+; CHECK-NOT: @groushared3dArrayofVectors
diff --git a/llvm/test/CodeGen/DirectX/scalar-load.ll b/llvm/test/CodeGen/DirectX/scalar-load.ll
index 1f4834ebfd04f4..11678f48a5e015 100644
--- a/llvm/test/CodeGen/DirectX/scalar-load.ll
+++ b/llvm/test/CodeGen/DirectX/scalar-load.ll
@@ -1,19 +1,23 @@
; RUN: opt -S -dxil-data-scalarization -scalarizer -scalarize-load-store -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s
+
+; Make sure we can load groupshared, static vectors and arrays of vectors
+
@"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16
@"vecData" = external addrspace(3) global <4 x i32>, align 4
@staticArrayOfVecData = internal global [3 x <4 x i32>] [<4 x i32> <i32 1, i32 2, i32 3, i32 4>, <4 x i32> <i32 5, i32 6, i32 7, i32 8>, <4 x i32> <i32 9, i32 10, i32 11, i32 12>], align 4
- at staticArray = internal global [4 x i32] [i32 1, i32 2, i32 3, i32 4], align 4
+@"groushared2dArrayofVectors" = local_unnamed_addr addrspace(3) global [3 x [ 3 x <4 x i32>]] zeroinitializer, align 16
; CHECK: @arrayofVecData.scalarized = local_unnamed_addr addrspace(3) global [2 x [3 x float]] zeroinitializer, align 16
; CHECK: @vecData.scalarized = external addrspace(3) global [4 x i32], align 4
; CHECK: @staticArrayOfVecData.scalarized = internal global [3 x [4 x i32]] {{\[}}[4 x i32] [i32 1, i32 2, i32 3, i32 4], [4 x i32] [i32 5, i32 6, i32 7, i32 8], [4 x i32] [i32 9, i32 10, i32 11, i32 12]], align 4
-; Check @staticArray
+; CHECK: @groushared2dArrayofVectors.scalarized = local_unnamed_addr addrspace(3) global [3 x [3 x [4 x i32]]] zeroinitializer, align 16
; CHECK-NOT: @arrayofVecData
; CHECK-NOT: @vecData
; CHECK-NOT: @staticArrayOfVecData
-; CHECK-NOT: @staticArray.scalarized
+; CHECK-NOT: @groushared2dArrayofVectors
+
; CHECK-LABEL: load_array_vec_test
define <4 x i32> @load_array_vec_test() {
@@ -42,3 +46,13 @@ define <4 x i32> @load_static_array_of_vec_test(i32 %index) {
%4 = load <4 x i32>, <4 x i32>* %3, align 4
ret <4 x i32> %4
}
+
+; CHECK-LABEL: multid_load_test
+define <4 x i32> @multid_load_test() {
+ ; CHECK-COUNT-8: load i32, ptr addrspace(3) {{(.*@groushared2dArrayofVectors.scalarized.*|%.*)}}, align 4
+ ; CHECK-NOT: load i32, ptr addrspace(3) {{.*}}, align 4
+ %1 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([3 x [3 x <4 x i32>]], [3 x [3 x <4 x i32>]] addrspace(3)* @"groushared2dArrayofVectors", i32 0, i32 0, i32 0), align 4
+ %2 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([3 x [3 x <4 x i32>]], [3 x [3 x <4 x i32>]] addrspace(3)* @"groushared2dArrayofVectors", i32 0, i32 1, i32 1), align 4
+ %3 = add <4 x i32> %1, %2
+ ret <4 x i32> %3
+}
diff --git a/llvm/test/CodeGen/DirectX/scalar-store.ll b/llvm/test/CodeGen/DirectX/scalar-store.ll
index aac4711c3f97f3..08d8a2c57c6c33 100644
--- a/llvm/test/CodeGen/DirectX/scalar-store.ll
+++ b/llvm/test/CodeGen/DirectX/scalar-store.ll
@@ -1,6 +1,8 @@
; RUN: opt -S -dxil-data-scalarization -scalarizer -scalarize-load-store -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s
+; Make sure we can store groupshared, static vectors and arrays of vectors
+
@"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16
@"vecData" = external addrspace(3) global <4 x i32>, align 4
More information about the llvm-commits
mailing list