[llvm] [DirectX] Scalarize Allocas as part of data scalarization (PR #140165)
Farzon Lotfi via llvm-commits
llvm-commits at lists.llvm.org
Thu May 22 22:03:12 PDT 2025
https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/140165
>From d33fd575f1fe81a99dabfdfce10aecca9a567be9 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Thu, 15 May 2025 18:26:28 -0400
Subject: [PATCH 1/2] [DirectX] Scalarize Allocas as part of data scalarization
- DXILDataScalarization should not just be limited to global data
- Add a scalarization for alloca
- Add ReversePostOrderTraversal of functions and iterate over basic
blocks and run DataScalarizerVisitor.
- fixes #140143
---
.../Target/DirectX/DXILDataScalarization.cpp | 83 ++++++++++++-------
.../test/CodeGen/DirectX/scalar-bug-117273.ll | 18 ++--
llvm/test/CodeGen/DirectX/scalarize-alloca.ll | 10 +++
3 files changed, 76 insertions(+), 35 deletions(-)
create mode 100644 llvm/test/CodeGen/DirectX/scalarize-alloca.ll
diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
index 1f2700ac55647..1209bcdfb2891 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -10,6 +10,7 @@
#include "DirectX.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstVisitor.h"
@@ -40,9 +41,10 @@ static bool findAndReplaceVectors(Module &M);
class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
public:
DataScalarizerVisitor() : GlobalMap() {}
- bool visit(Instruction &I);
+ bool visit(Function &F);
// InstVisitor methods. They return true if the instruction was scalarized,
// false if nothing changed.
+ bool visitAllocaInst(AllocaInst &AI);
bool visitInstruction(Instruction &I) { return false; }
bool visitSelectInst(SelectInst &SI) { return false; }
bool visitICmpInst(ICmpInst &ICI) { return false; }
@@ -65,11 +67,17 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
private:
GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);
DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
+ static bool isArrayOfVectors(Type *T);
};
-bool DataScalarizerVisitor::visit(Instruction &I) {
- assert(!GlobalMap.empty());
- return InstVisitor::visit(I);
+bool DataScalarizerVisitor::visit(Function &F) {
+ bool MadeChange = false;
+ ReversePostOrderTraversal<Function *> RPOT(&F);
+ for (BasicBlock *BB : make_early_inc_range(RPOT)) {
+ for (Instruction &I : make_early_inc_range(*BB))
+ MadeChange |= InstVisitor::visit(I);
+ }
+ return MadeChange;
}
GlobalVariable *
@@ -83,6 +91,42 @@ DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
return nullptr; // Not found
}
+// Recursively Creates and Array like version of the given vector like type.
+static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
+ if (auto *VecTy = dyn_cast<VectorType>(T))
+ return ArrayType::get(VecTy->getElementType(),
+ dyn_cast<FixedVectorType>(VecTy)->getNumElements());
+ if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
+ Type *NewElementType =
+ replaceVectorWithArray(ArrayTy->getElementType(), Ctx);
+ return ArrayType::get(NewElementType, ArrayTy->getNumElements());
+ }
+ // If it's not a vector or array, return the original type.
+ return T;
+}
+
+bool DataScalarizerVisitor::isArrayOfVectors(Type *T) {
+ if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
+ return isa<VectorType>(ArrType->getElementType());
+ return false;
+}
+
+bool DataScalarizerVisitor::visitAllocaInst(AllocaInst &AI) {
+ if (!isArrayOfVectors(AI.getAllocatedType()))
+ return false;
+
+ ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType());
+ IRBuilder<> Builder(&AI);
+ LLVMContext &Ctx = AI.getContext();
+ Type *NewType = replaceVectorWithArray(ArrType, Ctx);
+ AllocaInst *ArrAlloca =
+ Builder.CreateAlloca(NewType, nullptr, AI.getName() + ".scalarize");
+ ArrAlloca->setAlignment(AI.getAlign());
+ AI.replaceAllUsesWith(ArrAlloca);
+ AI.eraseFromParent();
+ return true;
+}
+
bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
unsigned NumOperands = LI.getNumOperands();
for (unsigned I = 0; I < NumOperands; ++I) {
@@ -154,20 +198,6 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
return true;
}
-// Recursively Creates and Array like version of the given vector like type.
-static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
- if (auto *VecTy = dyn_cast<VectorType>(T))
- return ArrayType::get(VecTy->getElementType(),
- dyn_cast<FixedVectorType>(VecTy)->getNumElements());
- if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
- Type *NewElementType =
- replaceVectorWithArray(ArrayTy->getElementType(), Ctx);
- return ArrayType::get(NewElementType, ArrayTy->getNumElements());
- }
- // If it's not a vector or array, return the original type.
- return T;
-}
-
Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
LLVMContext &Ctx) {
// Handle ConstantAggregateZero (zero-initialized constants)
@@ -253,20 +283,15 @@ static bool findAndReplaceVectors(Module &M) {
// Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes
// type equality. Instead we will use the visitor pattern.
Impl.GlobalMap[&G] = NewGlobal;
- for (User *U : make_early_inc_range(G.users())) {
- if (isa<ConstantExpr>(U) && isa<Operator>(U)) {
- ConstantExpr *CE = cast<ConstantExpr>(U);
- for (User *UCE : make_early_inc_range(CE->users())) {
- if (Instruction *Inst = dyn_cast<Instruction>(UCE))
- Impl.visit(*Inst);
- }
- }
- if (Instruction *Inst = dyn_cast<Instruction>(U))
- Impl.visit(*Inst);
- }
}
}
+ for (auto &F : make_early_inc_range(M.functions())) {
+ if (F.isDeclaration())
+ continue;
+ MadeChange |= Impl.visit(F);
+ }
+
// Remove the old globals after the iteration
for (auto &[Old, New] : Impl.GlobalMap) {
Old->eraseFromParent();
diff --git a/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll b/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll
index 25dc2c36b4e1f..2676abec1d8ae 100644
--- a/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll
+++ b/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll
@@ -8,12 +8,18 @@
define internal void @main() #1 {
; CHECK-LABEL: define internal void @main() {
; CHECK-NEXT: [[ENTRY:.*:]]
-; CHECK-NEXT: [[DOTI0:%.*]] = load float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 1), align 16
-; CHECK-NEXT: [[DOTI1:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 1), i32 1), align 4
-; CHECK-NEXT: [[DOTI2:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 1), i32 2), align 8
-; CHECK-NEXT: [[DOTI01:%.*]] = load float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 2), align 16
-; CHECK-NEXT: [[DOTI12:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 2), i32 1), align 4
-; CHECK-NEXT: [[DOTI23:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 2), i32 2), align 8
+; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds [24 x float], ptr @StaticArr.scalarized.1dim, i32 1
+; CHECK-NEXT: [[DOTI0:%.*]] = load float, ptr [[TMP0]], align 16
+; CHECK-NEXT: [[DOTI1:%.*]] = getelementptr float, ptr [[TMP0]], i32 1
+; CHECK-NEXT: [[DOTI11:%.*]] = load float, ptr [[DOTI1]], align 4
+; CHECK-NEXT: [[DOTI2:%.*]] = getelementptr float, ptr [[TMP0]], i32 2
+; CHECK-NEXT: [[DOTI22:%.*]] = load float, ptr [[DOTI2]], align 8
+; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds [24 x float], ptr @StaticArr.scalarized.1dim, i32 2
+; CHECK-NEXT: [[DOTI03:%.*]] = load float, ptr [[TMP1]], align 16
+; CHECK-NEXT: [[DOTI14:%.*]] = getelementptr float, ptr [[TMP1]], i32 1
+; CHECK-NEXT: [[DOTI15:%.*]] = load float, ptr [[DOTI14]], align 4
+; CHECK-NEXT: [[DOTI26:%.*]] = getelementptr float, ptr [[TMP1]], i32 2
+; CHECK-NEXT: [[DOTI27:%.*]] = load float, ptr [[DOTI26]], align 8
; CHECK-NEXT: ret void
;
entry:
diff --git a/llvm/test/CodeGen/DirectX/scalarize-alloca.ll b/llvm/test/CodeGen/DirectX/scalarize-alloca.ll
new file mode 100644
index 0000000000000..4829f3a31791f
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/scalarize-alloca.ll
@@ -0,0 +1,10 @@
+; RUN: opt -S -passes='dxil-data-scalarization' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefix=SCHECK
+; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefix=FCHECK
+
+; CHECK-LABEL: alloca_2d__vec_test
+define void @alloca_2d__vec_test() local_unnamed_addr #2 {
+ ; SCHECK: alloca [2 x [4 x i32]], align 16
+ ; FCHECK: alloca [8 x i32], align 16
+ %1 = alloca [2 x <4 x i32>], align 16
+ ret void
+}
>From ab82e4cdfc6582dc2fab525a5209d9f3056b4d8a Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Fri, 23 May 2025 01:02:49 -0400
Subject: [PATCH 2/2] address pr comments
---
llvm/lib/Target/DirectX/DXILDataScalarization.cpp | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
index 1209bcdfb2891..06708cec00cec 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -67,7 +67,6 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
private:
GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);
DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
- static bool isArrayOfVectors(Type *T);
};
bool DataScalarizerVisitor::visit(Function &F) {
@@ -91,7 +90,7 @@ DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
return nullptr; // Not found
}
-// Recursively Creates and Array like version of the given vector like type.
+// Recursively creates an array version of the given vector type.
static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
if (auto *VecTy = dyn_cast<VectorType>(T))
return ArrayType::get(VecTy->getElementType(),
@@ -105,7 +104,7 @@ static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
return T;
}
-bool DataScalarizerVisitor::isArrayOfVectors(Type *T) {
+static bool isArrayOfVectors(Type *T) {
if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
return isa<VectorType>(ArrType->getElementType());
return false;
More information about the llvm-commits
mailing list