[llvm] [DirectX] Scalarize Allocas as part of data scalarization (PR #140165)

Farzon Lotfi via llvm-commits llvm-commits at lists.llvm.org
Thu May 22 22:03:12 PDT 2025


https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/140165

>From d33fd575f1fe81a99dabfdfce10aecca9a567be9 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Thu, 15 May 2025 18:26:28 -0400
Subject: [PATCH 1/2] [DirectX] Scalarize Allocas as part of data scalarization

- DXILDataScalarization should not just be limited to global data
- Add a scalarization for alloca
- Add ReversePostOrderTraversal of functions and iterate over basic
  blocks and run DataScalarizerVisitor.
- fixes #140143
---
 .../Target/DirectX/DXILDataScalarization.cpp  | 83 ++++++++++++-------
 .../test/CodeGen/DirectX/scalar-bug-117273.ll | 18 ++--
 llvm/test/CodeGen/DirectX/scalarize-alloca.ll | 10 +++
 3 files changed, 76 insertions(+), 35 deletions(-)
 create mode 100644 llvm/test/CodeGen/DirectX/scalarize-alloca.ll

diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
index 1f2700ac55647..1209bcdfb2891 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -10,6 +10,7 @@
 #include "DirectX.h"
 #include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/InstVisitor.h"
@@ -40,9 +41,10 @@ static bool findAndReplaceVectors(Module &M);
 class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
 public:
   DataScalarizerVisitor() : GlobalMap() {}
-  bool visit(Instruction &I);
+  bool visit(Function &F);
   // InstVisitor methods.  They return true if the instruction was scalarized,
   // false if nothing changed.
+  bool visitAllocaInst(AllocaInst &AI);
   bool visitInstruction(Instruction &I) { return false; }
   bool visitSelectInst(SelectInst &SI) { return false; }
   bool visitICmpInst(ICmpInst &ICI) { return false; }
@@ -65,11 +67,17 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
 private:
   GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);
   DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
+  static bool isArrayOfVectors(Type *T);
 };
 
-bool DataScalarizerVisitor::visit(Instruction &I) {
-  assert(!GlobalMap.empty());
-  return InstVisitor::visit(I);
+bool DataScalarizerVisitor::visit(Function &F) {
+  bool MadeChange = false;
+  ReversePostOrderTraversal<Function *> RPOT(&F);
+  for (BasicBlock *BB : make_early_inc_range(RPOT)) {
+    for (Instruction &I : make_early_inc_range(*BB))
+      MadeChange |= InstVisitor::visit(I);
+  }
+  return MadeChange;
 }
 
 GlobalVariable *
@@ -83,6 +91,42 @@ DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
   return nullptr; // Not found
 }
 
+// Recursively Creates and Array like version of the given vector like type.
+static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
+  if (auto *VecTy = dyn_cast<VectorType>(T))
+    return ArrayType::get(VecTy->getElementType(),
+                          dyn_cast<FixedVectorType>(VecTy)->getNumElements());
+  if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
+    Type *NewElementType =
+        replaceVectorWithArray(ArrayTy->getElementType(), Ctx);
+    return ArrayType::get(NewElementType, ArrayTy->getNumElements());
+  }
+  // If it's not a vector or array, return the original type.
+  return T;
+}
+
+bool DataScalarizerVisitor::isArrayOfVectors(Type *T) {
+  if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
+    return isa<VectorType>(ArrType->getElementType());
+  return false;
+}
+
+bool DataScalarizerVisitor::visitAllocaInst(AllocaInst &AI) {
+  if (!isArrayOfVectors(AI.getAllocatedType()))
+    return false;
+
+  ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType());
+  IRBuilder<> Builder(&AI);
+  LLVMContext &Ctx = AI.getContext();
+  Type *NewType = replaceVectorWithArray(ArrType, Ctx);
+  AllocaInst *ArrAlloca =
+      Builder.CreateAlloca(NewType, nullptr, AI.getName() + ".scalarize");
+  ArrAlloca->setAlignment(AI.getAlign());
+  AI.replaceAllUsesWith(ArrAlloca);
+  AI.eraseFromParent();
+  return true;
+}
+
 bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
   unsigned NumOperands = LI.getNumOperands();
   for (unsigned I = 0; I < NumOperands; ++I) {
@@ -154,20 +198,6 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
   return true;
 }
 
-// Recursively Creates and Array like version of the given vector like type.
-static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
-  if (auto *VecTy = dyn_cast<VectorType>(T))
-    return ArrayType::get(VecTy->getElementType(),
-                          dyn_cast<FixedVectorType>(VecTy)->getNumElements());
-  if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
-    Type *NewElementType =
-        replaceVectorWithArray(ArrayTy->getElementType(), Ctx);
-    return ArrayType::get(NewElementType, ArrayTy->getNumElements());
-  }
-  // If it's not a vector or array, return the original type.
-  return T;
-}
-
 Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
                                LLVMContext &Ctx) {
   // Handle ConstantAggregateZero (zero-initialized constants)
@@ -253,20 +283,15 @@ static bool findAndReplaceVectors(Module &M) {
       // Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes
       // type equality. Instead we will use the visitor pattern.
       Impl.GlobalMap[&G] = NewGlobal;
-      for (User *U : make_early_inc_range(G.users())) {
-        if (isa<ConstantExpr>(U) && isa<Operator>(U)) {
-          ConstantExpr *CE = cast<ConstantExpr>(U);
-          for (User *UCE : make_early_inc_range(CE->users())) {
-            if (Instruction *Inst = dyn_cast<Instruction>(UCE))
-              Impl.visit(*Inst);
-          }
-        }
-        if (Instruction *Inst = dyn_cast<Instruction>(U))
-          Impl.visit(*Inst);
-      }
     }
   }
 
+  for (auto &F : make_early_inc_range(M.functions())) {
+    if (F.isDeclaration())
+      continue;
+    MadeChange |= Impl.visit(F);
+  }
+
   // Remove the old globals after the iteration
   for (auto &[Old, New] : Impl.GlobalMap) {
     Old->eraseFromParent();
diff --git a/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll b/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll
index 25dc2c36b4e1f..2676abec1d8ae 100644
--- a/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll
+++ b/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll
@@ -8,12 +8,18 @@
 define internal void @main() #1 {
 ; CHECK-LABEL: define internal void @main() {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
-; CHECK-NEXT:    [[DOTI0:%.*]] = load float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 1), align 16
-; CHECK-NEXT:    [[DOTI1:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 1), i32 1), align 4
-; CHECK-NEXT:    [[DOTI2:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 1), i32 2), align 8
-; CHECK-NEXT:    [[DOTI01:%.*]] = load float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 2), align 16
-; CHECK-NEXT:    [[DOTI12:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 2), i32 1), align 4
-; CHECK-NEXT:    [[DOTI23:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 2), i32 2), align 8
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds [24 x float], ptr @StaticArr.scalarized.1dim, i32 1
+; CHECK-NEXT:    [[DOTI0:%.*]] = load float, ptr [[TMP0]], align 16
+; CHECK-NEXT:    [[DOTI1:%.*]] = getelementptr float, ptr [[TMP0]], i32 1
+; CHECK-NEXT:    [[DOTI11:%.*]] = load float, ptr [[DOTI1]], align 4
+; CHECK-NEXT:    [[DOTI2:%.*]] = getelementptr float, ptr [[TMP0]], i32 2
+; CHECK-NEXT:    [[DOTI22:%.*]] = load float, ptr [[DOTI2]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds [24 x float], ptr @StaticArr.scalarized.1dim, i32 2
+; CHECK-NEXT:    [[DOTI03:%.*]] = load float, ptr [[TMP1]], align 16
+; CHECK-NEXT:    [[DOTI14:%.*]] = getelementptr float, ptr [[TMP1]], i32 1
+; CHECK-NEXT:    [[DOTI15:%.*]] = load float, ptr [[DOTI14]], align 4
+; CHECK-NEXT:    [[DOTI26:%.*]] = getelementptr float, ptr [[TMP1]], i32 2
+; CHECK-NEXT:    [[DOTI27:%.*]] = load float, ptr [[DOTI26]], align 8
 ; CHECK-NEXT:    ret void
 ;
 entry:
diff --git a/llvm/test/CodeGen/DirectX/scalarize-alloca.ll b/llvm/test/CodeGen/DirectX/scalarize-alloca.ll
new file mode 100644
index 0000000000000..4829f3a31791f
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/scalarize-alloca.ll
@@ -0,0 +1,10 @@
+; RUN: opt -S -passes='dxil-data-scalarization' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefix=SCHECK
+; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefix=FCHECK
+
+; CHECK-LABEL: alloca_2d__vec_test
+define void @alloca_2d__vec_test() local_unnamed_addr #2 {
+  ; SCHECK:  alloca [2 x [4 x i32]], align 16
+  ; FCHECK:  alloca [8 x i32], align 16
+  %1 = alloca [2 x <4 x i32>], align 16
+  ret void
+}

>From ab82e4cdfc6582dc2fab525a5209d9f3056b4d8a Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Fri, 23 May 2025 01:02:49 -0400
Subject: [PATCH 2/2] address pr comments

---
 llvm/lib/Target/DirectX/DXILDataScalarization.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
index 1209bcdfb2891..06708cec00cec 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -67,7 +67,6 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
 private:
   GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);
   DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
-  static bool isArrayOfVectors(Type *T);
 };
 
 bool DataScalarizerVisitor::visit(Function &F) {
@@ -91,7 +90,7 @@ DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
   return nullptr; // Not found
 }
 
-// Recursively Creates and Array like version of the given vector like type.
+// Recursively creates an array version of the given vector type.
 static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
   if (auto *VecTy = dyn_cast<VectorType>(T))
     return ArrayType::get(VecTy->getElementType(),
@@ -105,7 +104,7 @@ static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
   return T;
 }
 
-bool DataScalarizerVisitor::isArrayOfVectors(Type *T) {
+static bool isArrayOfVectors(Type *T) {
   if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
     return isa<VectorType>(ArrType->getElementType());
   return false;



More information about the llvm-commits mailing list