[llvm] [DirectX] Data Scalarization of Vectors in Global Scope (PR #110029)

Farzon Lotfi via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 25 18:22:57 PDT 2024


https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/110029

>From ae9090aa72f8511bbecbd1f3690ef6e7f452e864 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Mon, 23 Sep 2024 22:36:12 -0400
Subject: [PATCH 1/5] [DirectX] Data Scalarization

---
 llvm/lib/Target/DirectX/CMakeLists.txt        |   1 +
 .../Target/DirectX/DXILDataScalarization.cpp  | 312 ++++++++++++++++++
 .../Target/DirectX/DXILDataScalarization.h    |  35 ++
 llvm/lib/Target/DirectX/DirectX.h             |   6 +
 .../Target/DirectX/DirectXTargetMachine.cpp   |   2 +
 llvm/test/CodeGen/DirectX/scalar-load.ll      |  40 +++
 llvm/test/CodeGen/DirectX/scalar-store.ll     |  34 +-
 7 files changed, 418 insertions(+), 12 deletions(-)
 create mode 100644 llvm/lib/Target/DirectX/DXILDataScalarization.cpp
 create mode 100644 llvm/lib/Target/DirectX/DXILDataScalarization.h
 create mode 100644 llvm/test/CodeGen/DirectX/scalar-load.ll

diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt
index 7e0f8a145505e0..c8ef0ef6f7e702 100644
--- a/llvm/lib/Target/DirectX/CMakeLists.txt
+++ b/llvm/lib/Target/DirectX/CMakeLists.txt
@@ -20,6 +20,7 @@ add_llvm_target(DirectXCodeGen
   DirectXTargetMachine.cpp
   DirectXTargetTransformInfo.cpp
   DXContainerGlobals.cpp
+  DXILDataScalarization.cpp
   DXILFinalizeLinkage.cpp
   DXILIntrinsicExpansion.cpp
   DXILOpBuilder.cpp
diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
new file mode 100644
index 00000000000000..e689c0b06fccfc
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -0,0 +1,312 @@
+//===- DXILDataScalarization.cpp - Prepare LLVM Module for DXIL Data
+//Legalization----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===--------------------------------------------------------------------------------===//
+
+#include "DXILDataScalarization.h"
+#include "DirectX.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstVisitor.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Operator.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/IR/ReplaceConstant.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Utils/Local.h"
+#include <utility>
+
+#define DEBUG_TYPE "dxil-data-scalarization"
+#define Max_VEC_SIZE 4
+
+using namespace llvm;
+
+static void findAndReplaceVectors(Module &M);
+
+class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
+public:
+  DataScalarizerVisitor() : GlobalMap() {}
+  bool visit(Function &F);
+  // InstVisitor methods.  They return true if the instruction was scalarized,
+  // false if nothing changed.
+  bool visitInstruction(Instruction &I) { return false; }
+  bool visitSelectInst(SelectInst &SI) { return false; }
+  bool visitICmpInst(ICmpInst &ICI) { return false; }
+  bool visitFCmpInst(FCmpInst &FCI) { return false; }
+  bool visitUnaryOperator(UnaryOperator &UO) { return false; }
+  bool visitBinaryOperator(BinaryOperator &BO) { return false; }
+  bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
+  bool visitCastInst(CastInst &CI) { return false; }
+  bool visitBitCastInst(BitCastInst &BCI) { return false; }
+  bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
+  bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
+  bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
+  bool visitPHINode(PHINode &PHI) { return false; }
+  bool visitLoadInst(LoadInst &LI);
+  bool visitStoreInst(StoreInst &SI);
+  bool visitCallInst(CallInst &ICI) { return false; }
+  bool visitFreezeInst(FreezeInst &FI) { return false; }
+  friend void findAndReplaceVectors(llvm::Module &M);
+
+private:
+  GlobalVariable *getNewGlobalIfExists(Value *CurrOperand);
+  DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
+  SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
+  bool finish();
+};
+
+bool DataScalarizerVisitor::visit(Function &F) {
+  assert(!GlobalMap.empty());
+  ReversePostOrderTraversal<BasicBlock *> RPOT(&F.getEntryBlock());
+  for (BasicBlock *BB : RPOT) {
+    for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
+      Instruction *I = &*II;
+      bool Done = InstVisitor::visit(I);
+      ++II;
+      if (Done && I->getType()->isVoidTy())
+        I->eraseFromParent();
+    }
+  }
+  return finish();
+}
+
+bool DataScalarizerVisitor::finish() {
+  // TODO this should do cleanup
+  RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
+  return true;
+}
+
+GlobalVariable *
+DataScalarizerVisitor::getNewGlobalIfExists(Value *CurrOperand) {
+  if (GlobalVariable *OldGlobal = dyn_cast<GlobalVariable>(CurrOperand)) {
+    auto It = GlobalMap.find(OldGlobal);
+    if (It != GlobalMap.end()) {
+      return It->second; // Found, return the new global
+    }
+  }
+  return nullptr; // Not found
+}
+
+bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
+  for (unsigned I = 0; I < LI.getNumOperands(); ++I) {
+    Value *CurrOpperand = LI.getOperand(I);
+    GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand);
+    if (NewGlobal)
+      LI.setOperand(I, NewGlobal);
+  }
+  return false;
+}
+
+bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
+  bool ReplaceStore = false;
+  for (unsigned I = 0; I < SI.getNumOperands(); ++I) {
+    Value *CurrOpperand = SI.getOperand(I);
+    GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand);
+    if (NewGlobal) {
+      SI.setOperand(I, NewGlobal);
+      /*Value *StoredValue = SI.getValueOperand();
+      Type *StoredType = StoredValue->getType();
+      if (VectorType *VecTy = dyn_cast<VectorType>(StoredType)) {
+          unsigned NumElements = cast<FixedVectorType>(VecTy)->getNumElements();
+          ArrayType *ArrayTy = ArrayType::get(VecTy->getElementType(),
+      NumElements); std::vector<Constant *> ConstElements; if (ConstantVector
+      *ConstVec = dyn_cast<ConstantVector>(StoredValue)) { for(uint I = 0; I <
+      NumElements; I++) { ConstElements.push_back(ConstVec->getOperand(I));
+              }
+          }
+          Value *ArrayValue = ConstantArray::get(ArrayTy,ConstElements);
+          IRBuilder<> Builder(&SI);
+          Builder.CreateStore(ArrayValue, SI.getPointerOperand());
+          replaceStore = true;
+      }*/
+    }
+  }
+  return ReplaceStore;
+}
+
+bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
+  for (unsigned I = 0; I < GEPI.getNumOperands(); ++I) {
+    Value *CurrOpperand = GEPI.getOperand(I);
+    GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand);
+    if (NewGlobal) {
+      // Prepare to create a new GEP for the new global
+      IRBuilder<> Builder(&GEPI); // Create an IRBuilder at the position of GEPI
+
+      SmallVector<Value *, Max_VEC_SIZE> Indices;
+      for (auto &Index : GEPI.indices())
+        Indices.push_back(Index);
+
+      // Create a new GEP for the new global variable
+      Value *NewGEP =
+          Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices);
+
+      // Replace the old GEP with the new one
+      GEPI.replaceAllUsesWith(NewGEP);
+      PotentiallyDeadInstrs.emplace_back(&GEPI);
+    }
+  }
+  return true;
+}
+
+static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
+  if (auto *VecTy = dyn_cast<VectorType>(T))
+    return ArrayType::get(VecTy->getElementType(),
+                          cast<FixedVectorType>(VecTy)->getNumElements());
+  if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
+    Type *NewElementType =
+        replaceVectorWithArray(ArrayTy->getElementType(), Ctx);
+    return ArrayType::get(NewElementType, ArrayTy->getNumElements());
+  }
+  // If it's not a vector or array, return the original type.
+  return T;
+}
+
+Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
+                               LLVMContext &Ctx) {
+  // Handle ConstantAggregateZero (zero-initialized constants)
+  if (isa<ConstantAggregateZero>(Init)) {
+    return ConstantAggregateZero::get(NewType);
+  }
+
+  // Handle UndefValue (undefined constants)
+  if (isa<UndefValue>(Init)) {
+    return UndefValue::get(NewType);
+  }
+
+  // Handle vector to array transformation
+  if (isa<VectorType>(OrigType) && isa<ArrayType>(NewType)) {
+    // Convert vector initializer to array initializer
+    auto *VecInit = dyn_cast<ConstantVector>(Init);
+    if (!VecInit) {
+      llvm_unreachable("Expected a ConstantVector for vector initializer!");
+    }
+
+    SmallVector<Constant *, Max_VEC_SIZE> ArrayElements;
+    for (unsigned I = 0; I < VecInit->getNumOperands(); ++I) {
+      ArrayElements.push_back(VecInit->getOperand(I));
+    }
+
+    return ConstantArray::get(cast<ArrayType>(NewType), ArrayElements);
+  }
+
+  // Handle array of vectors transformation
+  if (auto *ArrayTy = dyn_cast<ArrayType>(OrigType)) {
+    // Recursively transform array elements
+    auto *ArrayInit = dyn_cast<ConstantArray>(Init);
+    if (!ArrayInit) {
+      llvm_unreachable("Expected a ConstantArray for array initializer!");
+    }
+
+    SmallVector<Constant *, Max_VEC_SIZE> NewArrayElements;
+    for (unsigned I = 0; I < ArrayTy->getNumElements(); ++I) {
+      Constant *NewElemInit = transformInitializer(
+          ArrayInit->getOperand(I), ArrayTy->getElementType(),
+          cast<ArrayType>(NewType)->getElementType(), Ctx);
+      NewArrayElements.push_back(NewElemInit);
+    }
+
+    return ConstantArray::get(cast<ArrayType>(NewType), NewArrayElements);
+  }
+
+  // If not a vector or array, return the original initializer
+  return Init;
+}
+
+static void findAndReplaceVectors(Module &M) {
+  LLVMContext &Ctx = M.getContext();
+  IRBuilder<> Builder(Ctx);
+  DataScalarizerVisitor Impl;
+  for (GlobalVariable &G : M.globals()) {
+    Type *OrigType = G.getValueType();
+    // Recursively replace vectors in the type
+    Type *NewType = replaceVectorWithArray(OrigType, Ctx);
+    if (OrigType != NewType) {
+      // Create a new global variable with the updated type
+      GlobalVariable *NewGlobal = new GlobalVariable(
+          M, NewType, G.isConstant(), G.getLinkage(),
+          // This is set via: transformInitializer
+          nullptr, G.getName() + ".scalarized", &G, G.getThreadLocalMode(),
+          G.getAddressSpace(), G.isExternallyInitialized());
+
+      // Copy relevant attributes
+      NewGlobal->setUnnamedAddr(G.getUnnamedAddr());
+      if (G.getAlignment() > 0) {
+        NewGlobal->setAlignment(Align(G.getAlignment()));
+      }
+
+      if (G.hasInitializer()) {
+        Constant *Init = G.getInitializer();
+        Constant *NewInit = transformInitializer(Init, OrigType, NewType, Ctx);
+        NewGlobal->setInitializer(NewInit);
+      }
+
+      // Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes
+      // type equality
+      //  So instead we will use the visitor pattern
+      Impl.GlobalMap[&G] = NewGlobal;
+      for (User *U : G.users()) {
+        if (isa<ConstantExpr>(U) && isa<Operator>(U)) {
+          ConstantExpr *CE = cast<ConstantExpr>(U);
+          convertUsersOfConstantsToInstructions(CE,
+                                                /*RestrictToFunc=*/nullptr,
+                                                /*RemoveDeadConstants=*/false,
+                                                /*IncludeSelf=*/true);
+        }
+      }
+      // Uses should have grown
+      std::vector<User *> UsersToProcess;
+      // Collect all users first
+      // work around so I can delete
+      // in a loop body
+      for (User *U : G.users()) {
+        UsersToProcess.push_back(U);
+      }
+
+      // Now process each user
+      for (User *U : UsersToProcess) {
+        if (isa<Instruction>(U)) {
+          Instruction *Inst = cast<Instruction>(U);
+          Function *F = Inst->getFunction();
+          if (F)
+            Impl.visit(*F);
+        }
+      }
+    }
+  }
+
+  // Remove the old globals after the iteration
+  for (auto Pair : Impl.GlobalMap) {
+    GlobalVariable *OldG = Pair.getFirst();
+    OldG->eraseFromParent();
+  }
+}
+
+PreservedAnalyses DXILDataScalarization::run(Module &M,
+                                             ModuleAnalysisManager &) {
+  findAndReplaceVectors(M);
+  return PreservedAnalyses::none();
+}
+
+bool DXILDataScalarizationLegacy::runOnModule(Module &M) {
+  findAndReplaceVectors(M);
+  return true;
+}
+
+void DXILDataScalarizationLegacy::getAnalysisUsage(AnalysisUsage &AU) const {}
+
+char DXILDataScalarizationLegacy::ID = 0;
+
+INITIALIZE_PASS_BEGIN(DXILDataScalarizationLegacy, DEBUG_TYPE,
+                      "DXIL Data Scalarization", false, false)
+INITIALIZE_PASS_END(DXILDataScalarizationLegacy, DEBUG_TYPE,
+                    "DXIL Data Scalarization", false, false)
+
+ModulePass *llvm::createDXILDataScalarizationLegacyPass() {
+  return new DXILDataScalarizationLegacy();
+}
\ No newline at end of file
diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.h b/llvm/lib/Target/DirectX/DXILDataScalarization.h
new file mode 100644
index 00000000000000..d06119397ddb25
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.h
@@ -0,0 +1,35 @@
+//===- DXILDataScalarization.h - Prepare LLVM Module for DXIL Data
+//Legalization----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===------------------------------------------------------------------------------===//
+#ifndef LLVM_TARGET_DIRECTX_DXILDATASCALARIZATION_H
+#define LLVM_TARGET_DIRECTX_DXILDATASCALARIZATION_H
+
+#include "DXILResource.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Pass.h"
+
+namespace llvm {
+
+/// A pass thattransforms Vectors to Arrays
+class DXILDataScalarization : public PassInfoMixin<DXILDataScalarization> {
+public:
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
+};
+
+class DXILDataScalarizationLegacy : public ModulePass {
+
+public:
+  bool runOnModule(Module &M) override;
+  DXILDataScalarizationLegacy() : ModulePass(ID) {}
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override;
+  static char ID; // Pass identification.
+};
+} // namespace llvm
+
+#endif // LLVM_TARGET_DIRECTX_DXILDATASCALARIZATION_H
diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h
index 60fc5094542b37..3221779be2f311 100644
--- a/llvm/lib/Target/DirectX/DirectX.h
+++ b/llvm/lib/Target/DirectX/DirectX.h
@@ -34,6 +34,12 @@ void initializeDXILIntrinsicExpansionLegacyPass(PassRegistry &);
 /// Pass to expand intrinsic operations that lack DXIL opCodes
 ModulePass *createDXILIntrinsicExpansionLegacyPass();
 
+/// Initializer for DXIL Data Scalarization Pass
+void initializeDXILDataScalarizationLegacyPass(PassRegistry &);
+
+/// Pass to scalarize llvm global data into a DXIL legal form
+ModulePass *createDXILDataScalarizationLegacyPass();
+
 /// Initializer for DXILOpLowering
 void initializeDXILOpLoweringLegacyPass(PassRegistry &);
 
diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
index 606022a9835f04..f358215ecf3735 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
@@ -46,6 +46,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
   RegisterTargetMachine<DirectXTargetMachine> X(getTheDirectXTarget());
   auto *PR = PassRegistry::getPassRegistry();
   initializeDXILIntrinsicExpansionLegacyPass(*PR);
+  initializeDXILDataScalarizationLegacyPass(*PR);
   initializeScalarizerLegacyPassPass(*PR);
   initializeDXILPrepareModulePass(*PR);
   initializeEmbedDXILPassPass(*PR);
@@ -86,6 +87,7 @@ class DirectXPassConfig : public TargetPassConfig {
   FunctionPass *createTargetRegisterAllocator(bool) override { return nullptr; }
   void addCodeGenPrepare() override {
     addPass(createDXILIntrinsicExpansionLegacyPass());
+    addPass(createDXILDataScalarizationLegacyPass());
     ScalarizerPassOptions DxilScalarOptions;
     DxilScalarOptions.ScalarizeLoadStore = true;
     addPass(createScalarizerPass(DxilScalarOptions));
diff --git a/llvm/test/CodeGen/DirectX/scalar-load.ll b/llvm/test/CodeGen/DirectX/scalar-load.ll
new file mode 100644
index 00000000000000..bd99b63883e9f9
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/scalar-load.ll
@@ -0,0 +1,40 @@
+; RUN: opt -S -dxil-data-scalarization -scalarizer -scalarize-load-store -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s
+@"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16
+@"vecData" = external addrspace(3) global <4 x i32>, align 4
+ at staticArrayOfVecData = internal global [3 x <4 x i32>] zeroinitializer, align 4
+
+; CHECK: @arrayofVecData.scalarized = local_unnamed_addr addrspace(3) global [2 x [3 x float]] zeroinitializer, align 16
+; CHECK: @vecData.scalarized = external addrspace(3) global [4 x i32], align 4
+; CHECK: @staticArrayOfVecData.scalarized = internal global [3 x [4 x i32]] zeroinitializer, align 4
+; CHECK-NOT: @arrayofVecData
+; CHECK-NOT: @vecData
+; CHECK-NOT: @staticArrayOfVecData
+
+; CHECK-LABEL: load_array_vec_test
+define <4 x i32> @load_array_vec_test() {
+  ; CHECK-COUNT-8: load i32, ptr addrspace(3) {{(.*@arrayofVecData.scalarized.*|%.*)}}, align 4
+  ; CHECK-NOT: load i32, ptr addrspace(3) {{.*}}, align 4
+  %1 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([2 x <4 x i32>], [2 x <4 x i32>] addrspace(3)* @"arrayofVecData", i32 0, i32 0), align 4
+  %2 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([2 x <4 x i32>], [2 x <4 x i32>] addrspace(3)* @"arrayofVecData", i32 0, i32 1), align 4
+  %3 = add <4 x i32> %1, %2
+  ret <4 x i32> %3
+}
+
+; CHECK-LABEL: load_vec_test
+define <4 x i32> @load_vec_test() {
+  ; CHECK-COUNT-4: load i32, ptr addrspace(3) {{(@vecData.scalarized|getelementptr \(i32, ptr addrspace\(3\) @vecData.scalarized, i32 .*\)|%.*)}}, align {{.*}}
+  ; CHECK-NOT: load i32, ptr addrspace(3) {{.*}}, align 4 
+  %1 = load <4 x i32>, <4 x i32> addrspace(3)* @"vecData", align 4
+  ret <4 x i32> %1
+}
+
+; CHECK-LABEL: load_static_array_of_vec_test
+define <4 x i32> @load_static_array_of_vec_test(i32 %index) {
+  ; CHECK: getelementptr [3 x [4 x i32]], ptr @staticArrayOfVecData.scalarized, i32 0, i32 %index
+  ; CHECK-COUNT-4: load i32, ptr {{.*}}, align 4
+  ; CHECK-NOT: load i32, ptr {{.*}}, align 4
+  %3 = getelementptr inbounds [3 x <4 x i32>], [3 x <4 x i32>]* @staticArrayOfVecData, i32 0, i32 %index
+  %4 = load <4 x i32>, <4 x i32>* %3, align 4
+  ret <4 x i32> %4
+}
diff --git a/llvm/test/CodeGen/DirectX/scalar-store.ll b/llvm/test/CodeGen/DirectX/scalar-store.ll
index b970a2842e5a8b..767d2e47c3e8eb 100644
--- a/llvm/test/CodeGen/DirectX/scalar-store.ll
+++ b/llvm/test/CodeGen/DirectX/scalar-store.ll
@@ -1,17 +1,27 @@
-; RUN: opt -S -scalarizer -scalarize-load-store -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+; RUN: opt -S -dxil-data-scalarization -scalarizer -scalarize-load-store -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 ; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s
 
-@"sharedData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16 
-; CHECK-LABEL: store_test
-define void @store_test () local_unnamed_addr {
-    ; CHECK: store float 1.000000e+00, ptr addrspace(3) {{.*}}, align {{.*}} 
-    ; CHECK: store float 2.000000e+00, ptr addrspace(3) {{.*}}, align {{.*}}
-    ; CHECK: store float 3.000000e+00, ptr addrspace(3) {{.*}}, align {{.*}} 
-    ; CHECK: store float 2.000000e+00, ptr addrspace(3) {{.*}}, align {{.*}} 
-    ; CHECK: store float 4.000000e+00, ptr addrspace(3) {{.*}}, align {{.*}} 
-    ; CHECK: store float 6.000000e+00, ptr addrspace(3) {{.*}}, align {{.*}} 
+@"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16
+@"vecData" = external addrspace(3) global <4 x i32>, align 4
 
-    store <3 x float> <float 1.000000e+00, float 2.000000e+00, float 3.000000e+00>, ptr addrspace(3) @"sharedData", align 16 
-    store <3 x float> <float 2.000000e+00, float 4.000000e+00, float 6.000000e+00>, ptr addrspace(3)   getelementptr inbounds (i8, ptr addrspace(3) @"sharedData", i32 16), align 16 
+; CHECK: @arrayofVecData.scalarized = local_unnamed_addr addrspace(3) global [2 x [3 x float]] zeroinitializer, align 16
+; CHECK: @vecData.scalarized = external addrspace(3) global [4 x i32], align 4
+; CHECK-NOT: @arrayofVecData
+; CHECK-NOT: @vecData
+
+; CHECK-LABEL: store_array_vec_test
+define void @store_array_vec_test () local_unnamed_addr {
+    ; CHECK-COUNT-6: store float {{1|2|3|4|6}}.000000e+00, ptr addrspace(3) {{(.*@arrayofVecData.scalarized.*|%.*)}}, align {{4|8|16}}
+    ; CHECK-NOT: store float {{1|2|3|4|6}}.000000e+00, ptr addrspace(3) {{(.*@arrayofVecData.scalarized.*|%.*)}}, align {{4|8|16}}
+    store <3 x float> <float 1.000000e+00, float 2.000000e+00, float 3.000000e+00>, ptr addrspace(3) @"arrayofVecData", align 16 
+    store <3 x float> <float 2.000000e+00, float 4.000000e+00, float 6.000000e+00>, ptr addrspace(3)   getelementptr inbounds (i8, ptr addrspace(3) @"arrayofVecData", i32 16), align 16 
     ret void
  } 
+
+; CHECK-LABEL: store_vec_test
+define void @store_vec_test(<4 x i32> %inputVec) {
+  ; CHECK-COUNT-4: store i32 %inputVec.{{.*}}, ptr addrspace(3) {{(@vecData.scalarized|getelementptr \(i32, ptr addrspace\(3\) @vecData.scalarized, i32 .*\)|%.*)}}, align 4 
+  ; CHECK-NOT: store i32 %inputVec.{{.*}}, ptr addrspace(3)
+  store <4 x i32> %inputVec, <4 x i32> addrspace(3)* @"vecData", align 4
+  ret void
+}
\ No newline at end of file

>From 1e89daf5684508de52ba66ff54c8ae2d8342cd5a Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Tue, 24 Sep 2024 02:29:05 -0400
Subject: [PATCH 2/5] cleanup comments and dead code

---
 .../Target/DirectX/DXILDataScalarization.cpp  | 34 +++++--------------
 .../Target/DirectX/DXILDataScalarization.h    |  4 +--
 2 files changed, 11 insertions(+), 27 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
index e689c0b06fccfc..b586e758c473f8 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -1,5 +1,5 @@
 //===- DXILDataScalarization.cpp - Prepare LLVM Module for DXIL Data
-//Legalization----===//
+// Legalization----===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -77,7 +77,6 @@ bool DataScalarizerVisitor::visit(Function &F) {
 }
 
 bool DataScalarizerVisitor::finish() {
-  // TODO this should do cleanup
   RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
   return true;
 }
@@ -104,30 +103,14 @@ bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
 }
 
 bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
-  bool ReplaceStore = false;
   for (unsigned I = 0; I < SI.getNumOperands(); ++I) {
     Value *CurrOpperand = SI.getOperand(I);
     GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand);
     if (NewGlobal) {
       SI.setOperand(I, NewGlobal);
-      /*Value *StoredValue = SI.getValueOperand();
-      Type *StoredType = StoredValue->getType();
-      if (VectorType *VecTy = dyn_cast<VectorType>(StoredType)) {
-          unsigned NumElements = cast<FixedVectorType>(VecTy)->getNumElements();
-          ArrayType *ArrayTy = ArrayType::get(VecTy->getElementType(),
-      NumElements); std::vector<Constant *> ConstElements; if (ConstantVector
-      *ConstVec = dyn_cast<ConstantVector>(StoredValue)) { for(uint I = 0; I <
-      NumElements; I++) { ConstElements.push_back(ConstVec->getOperand(I));
-              }
-          }
-          Value *ArrayValue = ConstantArray::get(ArrayTy,ConstElements);
-          IRBuilder<> Builder(&SI);
-          Builder.CreateStore(ArrayValue, SI.getPointerOperand());
-          replaceStore = true;
-      }*/
     }
   }
-  return ReplaceStore;
+  return false;
 }
 
 bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
@@ -135,18 +118,15 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
     Value *CurrOpperand = GEPI.getOperand(I);
     GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand);
     if (NewGlobal) {
-      // Prepare to create a new GEP for the new global
-      IRBuilder<> Builder(&GEPI); // Create an IRBuilder at the position of GEPI
+      IRBuilder<> Builder(&GEPI);
 
       SmallVector<Value *, Max_VEC_SIZE> Indices;
       for (auto &Index : GEPI.indices())
         Indices.push_back(Index);
 
-      // Create a new GEP for the new global variable
       Value *NewGEP =
           Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices);
 
-      // Replace the old GEP with the new one
       GEPI.replaceAllUsesWith(NewGEP);
       PotentiallyDeadInstrs.emplace_back(&GEPI);
     }
@@ -154,6 +134,7 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
   return true;
 }
 
+// Recursively Creates and Array like version of the given vector like type.
 static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
   if (auto *VecTy = dyn_cast<VectorType>(T))
     return ArrayType::get(VecTy->getElementType(),
@@ -197,7 +178,7 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
 
   // Handle array of vectors transformation
   if (auto *ArrayTy = dyn_cast<ArrayType>(OrigType)) {
-    // Recursively transform array elements
+
     auto *ArrayInit = dyn_cast<ConstantArray>(Init);
     if (!ArrayInit) {
       llvm_unreachable("Expected a ConstantArray for array initializer!");
@@ -205,6 +186,7 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
 
     SmallVector<Constant *, Max_VEC_SIZE> NewArrayElements;
     for (unsigned I = 0; I < ArrayTy->getNumElements(); ++I) {
+      // Recursively transform array elements
       Constant *NewElemInit = transformInitializer(
           ArrayInit->getOperand(I), ArrayTy->getElementType(),
           cast<ArrayType>(NewType)->getElementType(), Ctx);
@@ -224,7 +206,7 @@ static void findAndReplaceVectors(Module &M) {
   DataScalarizerVisitor Impl;
   for (GlobalVariable &G : M.globals()) {
     Type *OrigType = G.getValueType();
-    // Recursively replace vectors in the type
+
     Type *NewType = replaceVectorWithArray(OrigType, Ctx);
     if (OrigType != NewType) {
       // Create a new global variable with the updated type
@@ -251,6 +233,8 @@ static void findAndReplaceVectors(Module &M) {
       //  So instead we will use the visitor pattern
       Impl.GlobalMap[&G] = NewGlobal;
       for (User *U : G.users()) {
+        // Note: The GEPS are stored as constExprs
+        // This step flattens them out to instructions
         if (isa<ConstantExpr>(U) && isa<Operator>(U)) {
           ConstantExpr *CE = cast<ConstantExpr>(U);
           convertUsersOfConstantsToInstructions(CE,
diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.h b/llvm/lib/Target/DirectX/DXILDataScalarization.h
index d06119397ddb25..b6c59f7b33fd40 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.h
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.h
@@ -1,5 +1,5 @@
 //===- DXILDataScalarization.h - Prepare LLVM Module for DXIL Data
-//Legalization----===//
+// Legalization----===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -15,7 +15,7 @@
 
 namespace llvm {
 
-/// A pass thattransforms Vectors to Arrays
+/// A pass that transforms Vectors to Arrays
 class DXILDataScalarization : public PassInfoMixin<DXILDataScalarization> {
 public:
   PreservedAnalyses run(Module &M, ModuleAnalysisManager &);

>From 2e21e3d2ea8df7bf2f5cbfa817268acb1e6f8152 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Wed, 25 Sep 2024 03:43:26 -0400
Subject: [PATCH 3/5] check for ConstantDataVector

---
 llvm/lib/Target/DirectX/DXILDataScalarization.cpp | 15 ++++++++-------
 llvm/test/CodeGen/DirectX/scalar-load.ll          |  7 +++++--
 2 files changed, 13 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
index b586e758c473f8..0557b21930b7fb 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -163,14 +163,15 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
   // Handle vector to array transformation
   if (isa<VectorType>(OrigType) && isa<ArrayType>(NewType)) {
     // Convert vector initializer to array initializer
-    auto *VecInit = dyn_cast<ConstantVector>(Init);
-    if (!VecInit) {
-      llvm_unreachable("Expected a ConstantVector for vector initializer!");
-    }
-
     SmallVector<Constant *, Max_VEC_SIZE> ArrayElements;
-    for (unsigned I = 0; I < VecInit->getNumOperands(); ++I) {
-      ArrayElements.push_back(VecInit->getOperand(I));
+    if( ConstantVector *ConstVecInit = dyn_cast<ConstantVector>(Init)) {
+        for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I) 
+            ArrayElements.push_back(ConstVecInit->getOperand(I));
+    } else if (ConstantDataVector *ConstDataVecInit = llvm::dyn_cast<llvm::ConstantDataVector>(Init)) {
+        for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I) 
+            ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I));
+    }  else {
+      llvm_unreachable("Expected a ConstantVector or ConstantDataVector for vector initializer!");
     }
 
     return ConstantArray::get(cast<ArrayType>(NewType), ArrayElements);
diff --git a/llvm/test/CodeGen/DirectX/scalar-load.ll b/llvm/test/CodeGen/DirectX/scalar-load.ll
index bd99b63883e9f9..fa1f1290867b24 100644
--- a/llvm/test/CodeGen/DirectX/scalar-load.ll
+++ b/llvm/test/CodeGen/DirectX/scalar-load.ll
@@ -2,11 +2,14 @@
 ; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s
 @"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16
 @"vecData" = external addrspace(3) global <4 x i32>, align 4
- at staticArrayOfVecData = internal global [3 x <4 x i32>] zeroinitializer, align 4
+;@staticArrayOfVecData = internal global [3 x <4 x i32>] zeroinitializer, align 4
+ at staticArrayOfVecData = internal global [3 x <4 x i32>] [<4 x i32> <i32 1, i32 2, i32 3, i32 4>, <4 x i32> <i32 5, i32 6, i32 7, i32 8>, <4 x i32> <i32 9, i32 10, i32 11, i32 12>], align 4
+
 
 ; CHECK: @arrayofVecData.scalarized = local_unnamed_addr addrspace(3) global [2 x [3 x float]] zeroinitializer, align 16
 ; CHECK: @vecData.scalarized = external addrspace(3) global [4 x i32], align 4
-; CHECK: @staticArrayOfVecData.scalarized = internal global [3 x [4 x i32]] zeroinitializer, align 4
+; CHECK: @staticArrayOfVecData.scalarized = internal global [3 x [4 x i32]] {{\[}}[4 x i32] [i32 1, i32 2, i32 3, i32 4], [4 x i32] [i32 5, i32 6, i32 7, i32 8], [4 x i32] [i32 9, i32 10, i32 11, i32 12]], align 4
+
 ; CHECK-NOT: @arrayofVecData
 ; CHECK-NOT: @vecData
 ; CHECK-NOT: @staticArrayOfVecData

>From 6f6e42c881ba1be6387f953300b6488f2fa1f029 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Wed, 25 Sep 2024 15:08:27 -0400
Subject: [PATCH 4/5] fix iterator stability issue

---
 .../Target/DirectX/DXILDataScalarization.cpp  | 53 +++++++------------
 llvm/test/CodeGen/DirectX/llc-pipeline.ll     |  1 +
 llvm/test/CodeGen/DirectX/scalar-load.ll      |  5 +-
 llvm/test/CodeGen/DirectX/scalar-store.ll     |  2 +-
 4 files changed, 25 insertions(+), 36 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
index 0557b21930b7fb..9fa39a5d71d86c 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -1,15 +1,15 @@
-//===- DXILDataScalarization.cpp - Prepare LLVM Module for DXIL Data
-// Legalization----===//
+//===- DXILDataScalarization.cpp - Perform DXIL Data Legalization----===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
-//===--------------------------------------------------------------------------------===//
+//===----------------------------------------------------------------===//
 
 #include "DXILDataScalarization.h"
 #include "DirectX.h"
 #include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/InstVisitor.h"
@@ -20,7 +20,6 @@
 #include "llvm/IR/Type.h"
 #include "llvm/Transforms/Utils/Cloning.h"
 #include "llvm/Transforms/Utils/Local.h"
-#include <utility>
 
 #define DEBUG_TYPE "dxil-data-scalarization"
 #define Max_VEC_SIZE 4
@@ -164,14 +163,16 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
   if (isa<VectorType>(OrigType) && isa<ArrayType>(NewType)) {
     // Convert vector initializer to array initializer
     SmallVector<Constant *, Max_VEC_SIZE> ArrayElements;
-    if( ConstantVector *ConstVecInit = dyn_cast<ConstantVector>(Init)) {
-        for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I) 
-            ArrayElements.push_back(ConstVecInit->getOperand(I));
-    } else if (ConstantDataVector *ConstDataVecInit = llvm::dyn_cast<llvm::ConstantDataVector>(Init)) {
-        for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I) 
-            ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I));
-    }  else {
-      llvm_unreachable("Expected a ConstantVector or ConstantDataVector for vector initializer!");
+    if (ConstantVector *ConstVecInit = dyn_cast<ConstantVector>(Init)) {
+      for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I)
+        ArrayElements.push_back(ConstVecInit->getOperand(I));
+    } else if (ConstantDataVector *ConstDataVecInit =
+                   llvm::dyn_cast<llvm::ConstantDataVector>(Init)) {
+      for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I)
+        ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I));
+    } else {
+      llvm_unreachable("Expected a ConstantVector or ConstantDataVector for "
+                       "vector initializer!");
     }
 
     return ConstantArray::get(cast<ArrayType>(NewType), ArrayElements);
@@ -213,9 +214,10 @@ static void findAndReplaceVectors(Module &M) {
       // Create a new global variable with the updated type
       GlobalVariable *NewGlobal = new GlobalVariable(
           M, NewType, G.isConstant(), G.getLinkage(),
-          // This is set via: transformInitializer
-          nullptr, G.getName() + ".scalarized", &G, G.getThreadLocalMode(),
-          G.getAddressSpace(), G.isExternallyInitialized());
+          // Initializer is set via transformInitializer
+          /*Initializer=*/nullptr, G.getName() + ".scalarized", &G,
+          G.getThreadLocalMode(), G.getAddressSpace(),
+          G.isExternallyInitialized());
 
       // Copy relevant attributes
       NewGlobal->setUnnamedAddr(G.getUnnamedAddr());
@@ -230,12 +232,9 @@ static void findAndReplaceVectors(Module &M) {
       }
 
       // Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes
-      // type equality
-      //  So instead we will use the visitor pattern
+      // type equality. Instead we will use the visitor pattern.
       Impl.GlobalMap[&G] = NewGlobal;
-      for (User *U : G.users()) {
-        // Note: The GEPS are stored as constExprs
-        // This step flattens them out to instructions
+      for (User *U : make_early_inc_range(G.users())) {
         if (isa<ConstantExpr>(U) && isa<Operator>(U)) {
           ConstantExpr *CE = cast<ConstantExpr>(U);
           convertUsersOfConstantsToInstructions(CE,
@@ -243,18 +242,6 @@ static void findAndReplaceVectors(Module &M) {
                                                 /*RemoveDeadConstants=*/false,
                                                 /*IncludeSelf=*/true);
         }
-      }
-      // Uses should have grown
-      std::vector<User *> UsersToProcess;
-      // Collect all users first
-      // work around so I can delete
-      // in a loop body
-      for (User *U : G.users()) {
-        UsersToProcess.push_back(U);
-      }
-
-      // Now process each user
-      for (User *U : UsersToProcess) {
         if (isa<Instruction>(U)) {
           Instruction *Inst = cast<Instruction>(U);
           Function *F = Inst->getFunction();
@@ -294,4 +281,4 @@ INITIALIZE_PASS_END(DXILDataScalarizationLegacy, DEBUG_TYPE,
 
 ModulePass *llvm::createDXILDataScalarizationLegacyPass() {
   return new DXILDataScalarizationLegacy();
-}
\ No newline at end of file
+}
diff --git a/llvm/test/CodeGen/DirectX/llc-pipeline.ll b/llvm/test/CodeGen/DirectX/llc-pipeline.ll
index 46326d69175876..102748508b4ad7 100644
--- a/llvm/test/CodeGen/DirectX/llc-pipeline.ll
+++ b/llvm/test/CodeGen/DirectX/llc-pipeline.ll
@@ -8,6 +8,7 @@
 ; CHECK-NEXT: Target Transform Information
 ; CHECK-NEXT: ModulePass Manager
 ; CHECK-NEXT:   DXIL Intrinsic Expansion
+; CHECK-NEXT:   DXIL Data Scalarization
 ; CHECK-NEXT:   FunctionPass Manager
 ; CHECK-NEXT:     Dominator Tree Construction
 ; CHECK-NEXT:     Scalarize vector operations
diff --git a/llvm/test/CodeGen/DirectX/scalar-load.ll b/llvm/test/CodeGen/DirectX/scalar-load.ll
index fa1f1290867b24..1f4834ebfd04f4 100644
--- a/llvm/test/CodeGen/DirectX/scalar-load.ll
+++ b/llvm/test/CodeGen/DirectX/scalar-load.ll
@@ -2,17 +2,18 @@
 ; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s
 @"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16
 @"vecData" = external addrspace(3) global <4 x i32>, align 4
-;@staticArrayOfVecData = internal global [3 x <4 x i32>] zeroinitializer, align 4
 @staticArrayOfVecData = internal global [3 x <4 x i32>] [<4 x i32> <i32 1, i32 2, i32 3, i32 4>, <4 x i32> <i32 5, i32 6, i32 7, i32 8>, <4 x i32> <i32 9, i32 10, i32 11, i32 12>], align 4
-
+ at staticArray = internal global [4 x i32] [i32 1, i32 2, i32 3, i32 4], align 4
 
 ; CHECK: @arrayofVecData.scalarized = local_unnamed_addr addrspace(3) global [2 x [3 x float]] zeroinitializer, align 16
 ; CHECK: @vecData.scalarized = external addrspace(3) global [4 x i32], align 4
 ; CHECK: @staticArrayOfVecData.scalarized = internal global [3 x [4 x i32]] {{\[}}[4 x i32] [i32 1, i32 2, i32 3, i32 4], [4 x i32] [i32 5, i32 6, i32 7, i32 8], [4 x i32] [i32 9, i32 10, i32 11, i32 12]], align 4
+; Check @staticArray
 
 ; CHECK-NOT: @arrayofVecData
 ; CHECK-NOT: @vecData
 ; CHECK-NOT: @staticArrayOfVecData
+; CHECK-NOT: @staticArray.scalarized
 
 ; CHECK-LABEL: load_array_vec_test
 define <4 x i32> @load_array_vec_test() {
diff --git a/llvm/test/CodeGen/DirectX/scalar-store.ll b/llvm/test/CodeGen/DirectX/scalar-store.ll
index 767d2e47c3e8eb..aac4711c3f97f3 100644
--- a/llvm/test/CodeGen/DirectX/scalar-store.ll
+++ b/llvm/test/CodeGen/DirectX/scalar-store.ll
@@ -24,4 +24,4 @@ define void @store_vec_test(<4 x i32> %inputVec) {
   ; CHECK-NOT: store i32 %inputVec.{{.*}}, ptr addrspace(3)
   store <4 x i32> %inputVec, <4 x i32> addrspace(3)* @"vecData", align 4
   ret void
-}
\ No newline at end of file
+}

>From 3910ec672d382f44a796cf61cb6a973e029cce52 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Wed, 25 Sep 2024 21:21:22 -0400
Subject: [PATCH 5/5] address pr comments

---
 .../Target/DirectX/DXILDataScalarization.cpp  | 106 ++++++++++--------
 .../Target/DirectX/DXILDataScalarization.h    |  16 +--
 llvm/test/CodeGen/DirectX/scalar-data.ll      |  12 ++
 llvm/test/CodeGen/DirectX/scalar-load.ll      |  20 +++-
 llvm/test/CodeGen/DirectX/scalar-store.ll     |   2 +
 5 files changed, 95 insertions(+), 61 deletions(-)
 create mode 100644 llvm/test/CodeGen/DirectX/scalar-data.ll

diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
index 9fa39a5d71d86c..0e6cf59e257508 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -1,15 +1,16 @@
-//===- DXILDataScalarization.cpp - Perform DXIL Data Legalization----===//
+//===- DXILDataScalarization.cpp - Perform DXIL Data Legalization ---------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
-//===----------------------------------------------------------------===//
+//===---------------------------------------------------------------------===//
 
 #include "DXILDataScalarization.h"
 #include "DirectX.h"
 #include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/DXILResource.h"
 #include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/InstVisitor.h"
@@ -22,11 +23,21 @@
 #include "llvm/Transforms/Utils/Local.h"
 
 #define DEBUG_TYPE "dxil-data-scalarization"
-#define Max_VEC_SIZE 4
+static const int MaxVecSize = 4;
 
 using namespace llvm;
 
-static void findAndReplaceVectors(Module &M);
+class DXILDataScalarizationLegacy : public ModulePass {
+
+public:
+  bool runOnModule(Module &M) override;
+  DXILDataScalarizationLegacy() : ModulePass(ID) {}
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override;
+  static char ID; // Pass identification.
+};
+
+static bool findAndReplaceVectors(Module &M);
 
 class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
 public:
@@ -51,10 +62,10 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
   bool visitStoreInst(StoreInst &SI);
   bool visitCallInst(CallInst &ICI) { return false; }
   bool visitFreezeInst(FreezeInst &FI) { return false; }
-  friend void findAndReplaceVectors(llvm::Module &M);
+  friend bool findAndReplaceVectors(llvm::Module &M);
 
 private:
-  GlobalVariable *getNewGlobalIfExists(Value *CurrOperand);
+  GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);
   DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
   SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
   bool finish();
@@ -81,7 +92,7 @@ bool DataScalarizerVisitor::finish() {
 }
 
 GlobalVariable *
-DataScalarizerVisitor::getNewGlobalIfExists(Value *CurrOperand) {
+DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
   if (GlobalVariable *OldGlobal = dyn_cast<GlobalVariable>(CurrOperand)) {
     auto It = GlobalMap.find(OldGlobal);
     if (It != GlobalMap.end()) {
@@ -92,20 +103,20 @@ DataScalarizerVisitor::getNewGlobalIfExists(Value *CurrOperand) {
 }
 
 bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
-  for (unsigned I = 0; I < LI.getNumOperands(); ++I) {
+  unsigned NumOperands = LI.getNumOperands();
+  for (unsigned I = 0; I < NumOperands; ++I) {
     Value *CurrOpperand = LI.getOperand(I);
-    GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand);
-    if (NewGlobal)
+    if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
       LI.setOperand(I, NewGlobal);
   }
   return false;
 }
 
 bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
-  for (unsigned I = 0; I < SI.getNumOperands(); ++I) {
+  unsigned NumOperands = SI.getNumOperands();
+  for (unsigned I = 0; I < NumOperands; ++I) {
     Value *CurrOpperand = SI.getOperand(I);
-    GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand);
-    if (NewGlobal) {
+    if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand)) {
       SI.setOperand(I, NewGlobal);
     }
   }
@@ -113,22 +124,23 @@ bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
 }
 
 bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
-  for (unsigned I = 0; I < GEPI.getNumOperands(); ++I) {
+  unsigned NumOperands = GEPI.getNumOperands();
+  for (unsigned I = 0; I < NumOperands; ++I) {
     Value *CurrOpperand = GEPI.getOperand(I);
-    GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand);
-    if (NewGlobal) {
-      IRBuilder<> Builder(&GEPI);
+    GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand);
+    if (!NewGlobal)
+      continue;
+    IRBuilder<> Builder(&GEPI);
 
-      SmallVector<Value *, Max_VEC_SIZE> Indices;
-      for (auto &Index : GEPI.indices())
-        Indices.push_back(Index);
+    SmallVector<Value *, MaxVecSize> Indices;
+    for (auto &Index : GEPI.indices())
+      Indices.push_back(Index);
 
-      Value *NewGEP =
-          Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices);
+    Value *NewGEP =
+        Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices);
 
-      GEPI.replaceAllUsesWith(NewGEP);
-      PotentiallyDeadInstrs.emplace_back(&GEPI);
-    }
+    GEPI.replaceAllUsesWith(NewGEP);
+    PotentiallyDeadInstrs.emplace_back(&GEPI);
   }
   return true;
 }
@@ -137,7 +149,7 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
 static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
   if (auto *VecTy = dyn_cast<VectorType>(T))
     return ArrayType::get(VecTy->getElementType(),
-                          cast<FixedVectorType>(VecTy)->getNumElements());
+                          dyn_cast<FixedVectorType>(VecTy)->getNumElements());
   if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
     Type *NewElementType =
         replaceVectorWithArray(ArrayTy->getElementType(), Ctx);
@@ -162,7 +174,7 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
   // Handle vector to array transformation
   if (isa<VectorType>(OrigType) && isa<ArrayType>(NewType)) {
     // Convert vector initializer to array initializer
-    SmallVector<Constant *, Max_VEC_SIZE> ArrayElements;
+    SmallVector<Constant *, MaxVecSize> ArrayElements;
     if (ConstantVector *ConstVecInit = dyn_cast<ConstantVector>(Init)) {
       for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I)
         ArrayElements.push_back(ConstVecInit->getOperand(I));
@@ -171,8 +183,8 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
       for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I)
         ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I));
     } else {
-      llvm_unreachable("Expected a ConstantVector or ConstantDataVector for "
-                       "vector initializer!");
+      assert(false && "Expected a ConstantVector or ConstantDataVector for "
+                      "vector initializer!");
     }
 
     return ConstantArray::get(cast<ArrayType>(NewType), ArrayElements);
@@ -180,13 +192,10 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
 
   // Handle array of vectors transformation
   if (auto *ArrayTy = dyn_cast<ArrayType>(OrigType)) {
-
     auto *ArrayInit = dyn_cast<ConstantArray>(Init);
-    if (!ArrayInit) {
-      llvm_unreachable("Expected a ConstantArray for array initializer!");
-    }
+    assert(ArrayInit && "Expected a ConstantArray for array initializer!");
 
-    SmallVector<Constant *, Max_VEC_SIZE> NewArrayElements;
+    SmallVector<Constant *, MaxVecSize> NewArrayElements;
     for (unsigned I = 0; I < ArrayTy->getNumElements(); ++I) {
       // Recursively transform array elements
       Constant *NewElemInit = transformInitializer(
@@ -202,7 +211,8 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
   return Init;
 }
 
-static void findAndReplaceVectors(Module &M) {
+static bool findAndReplaceVectors(Module &M) {
+  bool MadeChange = false;
   LLVMContext &Ctx = M.getContext();
   IRBuilder<> Builder(Ctx);
   DataScalarizerVisitor Impl;
@@ -212,9 +222,9 @@ static void findAndReplaceVectors(Module &M) {
     Type *NewType = replaceVectorWithArray(OrigType, Ctx);
     if (OrigType != NewType) {
       // Create a new global variable with the updated type
+      // Note: Initializer is set via transformInitializer
       GlobalVariable *NewGlobal = new GlobalVariable(
           M, NewType, G.isConstant(), G.getLinkage(),
-          // Initializer is set via transformInitializer
           /*Initializer=*/nullptr, G.getName() + ".scalarized", &G,
           G.getThreadLocalMode(), G.getAddressSpace(),
           G.isExternallyInitialized());
@@ -222,7 +232,7 @@ static void findAndReplaceVectors(Module &M) {
       // Copy relevant attributes
       NewGlobal->setUnnamedAddr(G.getUnnamedAddr());
       if (G.getAlignment() > 0) {
-        NewGlobal->setAlignment(Align(G.getAlignment()));
+        NewGlobal->setAlignment(G.getAlign());
       }
 
       if (G.hasInitializer()) {
@@ -253,24 +263,30 @@ static void findAndReplaceVectors(Module &M) {
   }
 
   // Remove the old globals after the iteration
-  for (auto Pair : Impl.GlobalMap) {
-    GlobalVariable *OldG = Pair.getFirst();
-    OldG->eraseFromParent();
+  for (auto &[Old, New] : Impl.GlobalMap) {
+    Old->eraseFromParent();
+    MadeChange = true;
   }
+  return MadeChange;
 }
 
 PreservedAnalyses DXILDataScalarization::run(Module &M,
                                              ModuleAnalysisManager &) {
-  findAndReplaceVectors(M);
-  return PreservedAnalyses::none();
+  bool MadeChanges = findAndReplaceVectors(M);
+  if (!MadeChanges)
+    return PreservedAnalyses::all();
+  PreservedAnalyses PA;
+  PA.preserve<DXILResourceAnalysis>();
+  return PA;
 }
 
 bool DXILDataScalarizationLegacy::runOnModule(Module &M) {
-  findAndReplaceVectors(M);
-  return true;
+  return findAndReplaceVectors(M);
 }
 
-void DXILDataScalarizationLegacy::getAnalysisUsage(AnalysisUsage &AU) const {}
+void DXILDataScalarizationLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
+  AU.addPreserved<DXILResourceWrapperPass>();
+}
 
 char DXILDataScalarizationLegacy::ID = 0;
 
diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.h b/llvm/lib/Target/DirectX/DXILDataScalarization.h
index b6c59f7b33fd40..560e061db96d08 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.h
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.h
@@ -1,11 +1,11 @@
-//===- DXILDataScalarization.h - Prepare LLVM Module for DXIL Data
-// Legalization----===//
+//===- DXILDataScalarization.h - Perform DXIL Data Legalization -*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
-//===------------------------------------------------------------------------------===//
+//===---------------------------------------------------------------------===//
+
 #ifndef LLVM_TARGET_DIRECTX_DXILDATASCALARIZATION_H
 #define LLVM_TARGET_DIRECTX_DXILDATASCALARIZATION_H
 
@@ -20,16 +20,6 @@ class DXILDataScalarization : public PassInfoMixin<DXILDataScalarization> {
 public:
   PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
 };
-
-class DXILDataScalarizationLegacy : public ModulePass {
-
-public:
-  bool runOnModule(Module &M) override;
-  DXILDataScalarizationLegacy() : ModulePass(ID) {}
-
-  void getAnalysisUsage(AnalysisUsage &AU) const override;
-  static char ID; // Pass identification.
-};
 } // namespace llvm
 
 #endif // LLVM_TARGET_DIRECTX_DXILDATASCALARIZATION_H
diff --git a/llvm/test/CodeGen/DirectX/scalar-data.ll b/llvm/test/CodeGen/DirectX/scalar-data.ll
new file mode 100644
index 00000000000000..4438604a3a8797
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/scalar-data.ll
@@ -0,0 +1,12 @@
+; RUN: opt -S -dxil-data-scalarization -scalarizer -scalarize-load-store -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s
+
+; Make sure we don't touch arrays without vectors and that can recurse multiple-dimension arrays of vectors
+
+ at staticArray = internal global [4 x i32] [i32 1, i32 2, i32 3, i32 4], align 4
+@"groushared3dArrayofVectors" = local_unnamed_addr addrspace(3) global [3 x [3 x [3 x <4 x i32>]]] zeroinitializer, align 16
+
+; CHECK @staticArray
+; CHECK-NOT: @staticArray.scalarized
+; CHECK: @groushared3dArrayofVectors.scalarized = local_unnamed_addr addrspace(3) global [3 x [3 x [3 x [4 x i32]]]] zeroinitializer, align 16
+; CHECK-NOT: @groushared3dArrayofVectors
diff --git a/llvm/test/CodeGen/DirectX/scalar-load.ll b/llvm/test/CodeGen/DirectX/scalar-load.ll
index 1f4834ebfd04f4..e6303d19ada7d0 100644
--- a/llvm/test/CodeGen/DirectX/scalar-load.ll
+++ b/llvm/test/CodeGen/DirectX/scalar-load.ll
@@ -1,19 +1,23 @@
 ; RUN: opt -S -dxil-data-scalarization -scalarizer -scalarize-load-store -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 ; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s
+
+; Make sure we can load groupshared, static vectors and arrays of vectors
+
 @"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16
 @"vecData" = external addrspace(3) global <4 x i32>, align 4
 @staticArrayOfVecData = internal global [3 x <4 x i32>] [<4 x i32> <i32 1, i32 2, i32 3, i32 4>, <4 x i32> <i32 5, i32 6, i32 7, i32 8>, <4 x i32> <i32 9, i32 10, i32 11, i32 12>], align 4
- at staticArray = internal global [4 x i32] [i32 1, i32 2, i32 3, i32 4], align 4
+@"groushared2dArrayofVectors" = local_unnamed_addr addrspace(3) global [3 x [ 3 x <4 x i32>]] zeroinitializer, align 16
 
 ; CHECK: @arrayofVecData.scalarized = local_unnamed_addr addrspace(3) global [2 x [3 x float]] zeroinitializer, align 16
 ; CHECK: @vecData.scalarized = external addrspace(3) global [4 x i32], align 4
 ; CHECK: @staticArrayOfVecData.scalarized = internal global [3 x [4 x i32]] {{\[}}[4 x i32] [i32 1, i32 2, i32 3, i32 4], [4 x i32] [i32 5, i32 6, i32 7, i32 8], [4 x i32] [i32 9, i32 10, i32 11, i32 12]], align 4
-; Check @staticArray
+; CHECK: @groushared2dArrayofVectors.scalarized = local_unnamed_addr addrspace(3) global [3 x [3 x [4 x i32]]] zeroinitializer, align 16
 
 ; CHECK-NOT: @arrayofVecData
 ; CHECK-NOT: @vecData
 ; CHECK-NOT: @staticArrayOfVecData
-; CHECK-NOT: @staticArray.scalarized
+; CHECK-NOT: @groushared2dArrayofVectors
+
 
 ; CHECK-LABEL: load_array_vec_test
 define <4 x i32> @load_array_vec_test() {
@@ -42,3 +46,13 @@ define <4 x i32> @load_static_array_of_vec_test(i32 %index) {
   %4 = load <4 x i32>, <4 x i32>* %3, align 4
   ret <4 x i32> %4
 }
+
+; CHECK-LABEL: multid_load_test
+define <4 x i32> @multid_load_test()  {
+  ; CHECK-COUNT-8: load i32, ptr addrspace(3) {{(.*@groushared2dArrayofVectors.scalarized.*|%.*)}}, align 4
+  ; CHECK-NOT: load i32, ptr addrspace(3) {{.*}}, align 4
+  %1 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([3 x [3 x <4 x i32>]], [3 x [3 x <4 x i32>]] addrspace(3)* @"groushared2dArrayofVectors", i32 0, i32 0, i32 0), align 4
+  %2 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([3 x [3 x <4 x i32>]], [3 x [3 x <4 x i32>]] addrspace(3)* @"groushared2dArrayofVectors", i32 0, i32 1, i32 1), align 4
+  %3 = add <4 x i32> %1, %2
+  ret <4 x i32> %3
+}
\ No newline at end of file
diff --git a/llvm/test/CodeGen/DirectX/scalar-store.ll b/llvm/test/CodeGen/DirectX/scalar-store.ll
index aac4711c3f97f3..08d8a2c57c6c33 100644
--- a/llvm/test/CodeGen/DirectX/scalar-store.ll
+++ b/llvm/test/CodeGen/DirectX/scalar-store.ll
@@ -1,6 +1,8 @@
 ; RUN: opt -S -dxil-data-scalarization -scalarizer -scalarize-load-store -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 ; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s
 
+; Make sure we can store groupshared, static vectors and arrays of vectors
+
 @"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16
 @"vecData" = external addrspace(3) global <4 x i32>, align 4
 



More information about the llvm-commits mailing list