[llvm] [DirectX] Bug fix for Data Scalarization crash (PR #118426)

Farzon Lotfi via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 16 11:28:49 PST 2024


https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/118426

>From 3fa78a615dce3bb0da93f69714ed3146de6e4685 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Mon, 2 Dec 2024 19:09:15 -0500
Subject: [PATCH 1/2] [DirectX] Bug fix for Data Scalarization crash

---
 .../Target/DirectX/DXILDataScalarization.cpp  | 28 +++++-----------
 llvm/lib/Target/DirectX/DXILFlattenArrays.cpp | 12 +++++++
 .../CodeGen/DirectX/flatten-bug-117273.ll     | 25 +++++++++++++++
 .../test/CodeGen/DirectX/scalar-bug-117273.ll | 32 +++++++++++++++++++
 4 files changed, 77 insertions(+), 20 deletions(-)
 create mode 100644 llvm/test/CodeGen/DirectX/flatten-bug-117273.ll
 create mode 100644 llvm/test/CodeGen/DirectX/scalar-bug-117273.ll

diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
index 0e6cf59e257508..42e40d146ff5c1 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -67,28 +67,17 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
 private:
   GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);
   DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
-  SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
-  bool finish();
 };
 
 bool DataScalarizerVisitor::visit(Function &F) {
   assert(!GlobalMap.empty());
-  ReversePostOrderTraversal<BasicBlock *> RPOT(&F.getEntryBlock());
-  for (BasicBlock *BB : RPOT) {
-    for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
-      Instruction *I = &*II;
-      bool Done = InstVisitor::visit(I);
-      ++II;
-      if (Done && I->getType()->isVoidTy())
-        I->eraseFromParent();
-    }
+  bool MadeChange = false;
+  ReversePostOrderTraversal<Function *> RPOT(&F);
+  for (BasicBlock *BB : make_early_inc_range(RPOT)) {
+    for (Instruction &I : make_early_inc_range(*BB))
+      MadeChange |= InstVisitor::visit(I);
   }
-  return finish();
-}
-
-bool DataScalarizerVisitor::finish() {
-  RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
-  return true;
+  return MadeChange;
 }
 
 GlobalVariable *
@@ -140,7 +129,7 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
         Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices);
 
     GEPI.replaceAllUsesWith(NewGEP);
-    PotentiallyDeadInstrs.emplace_back(&GEPI);
+    GEPI.eraseFromParent();
   }
   return true;
 }
@@ -252,8 +241,7 @@ static bool findAndReplaceVectors(Module &M) {
                                                 /*RemoveDeadConstants=*/false,
                                                 /*IncludeSelf=*/true);
         }
-        if (isa<Instruction>(U)) {
-          Instruction *Inst = cast<Instruction>(U);
+        if (Instruction *Inst = dyn_cast<Instruction>(U)) {
           Function *F = Inst->getFunction();
           if (F)
             Impl.visit(*F);
diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
index e4a3bc76eeacd2..fbf0e4bc35d6c6 100644
--- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
+++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
@@ -321,6 +321,18 @@ static void collectElements(Constant *Init,
     Elements.push_back(Init);
     return;
   }
+  auto *LLVMArrayType = dyn_cast<ArrayType>(Init->getType());
+  if (isa<ConstantAggregateZero>(Init)) {
+    for (unsigned I = 0; I < LLVMArrayType->getNumElements(); ++I)
+      Elements.push_back(
+          Constant::getNullValue(LLVMArrayType->getElementType()));
+    return;
+  }
+  if (isa<UndefValue>(Init)) {
+    for (unsigned I = 0; I < LLVMArrayType->getNumElements(); ++I)
+      Elements.push_back(UndefValue::get(LLVMArrayType->getElementType()));
+    return;
+  }
 
   // Recursive case: Process each element in the array.
   if (auto *ArrayConstant = dyn_cast<ConstantArray>(Init)) {
diff --git a/llvm/test/CodeGen/DirectX/flatten-bug-117273.ll b/llvm/test/CodeGen/DirectX/flatten-bug-117273.ll
new file mode 100644
index 00000000000000..1e0f4fb3d382a1
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/flatten-bug-117273.ll
@@ -0,0 +1,25 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes='dxil-flatten-arrays,dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+
+ at _ZL7Palette = internal constant [2 x [3 x float]] [[3 x float] zeroinitializer, [3 x float] undef], align 16
+
+; CHECK: @_ZL7Palette.1dim = internal constant [6 x float] [float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float undef, float undef, float undef], align 16
+
+define internal void @_Z4mainDv3_j(<3 x i32> noundef %DID) {
+; CHECK-LABEL: define internal void @_Z4mainDv3_j(
+; CHECK-SAME: <3 x i32> noundef [[DID:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr [24 x float], ptr @_ZL7Palette.1dim, i32 1
+; CHECK-NEXT:    [[DOTI0:%.*]] = load float, ptr [[TMP0]], align 16
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr [24 x float], ptr @_ZL7Palette.1dim, i32 2
+; CHECK-NEXT:    [[DOTI03:%.*]] = load float, ptr [[TMP1]], align 16
+; CHECK-NEXT:    ret void
+;
+entry:
+  %0 = getelementptr [8 x [3 x float]], ptr @_ZL7Palette, i32 0, i32 1
+  %.i0 = load float, ptr %0, align 16
+  %1 = getelementptr [8 x [3 x float]], ptr @_ZL7Palette, i32 0, i32 2
+  %.i03 = load float, ptr %1, align 16
+  ret void
+}
diff --git a/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll b/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll
new file mode 100644
index 00000000000000..aa09fa401a6492
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll
@@ -0,0 +1,32 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays,function(scalarizer<load-store>),dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+
+ at _ZL7Palette = internal constant [8 x <3 x float>] [<3 x float> zeroinitializer, <3 x float> splat (float 5.000000e-01), <3 x float> <float 1.000000e+00, float 5.000000e-01, float 5.000000e-01>, <3 x float> <float 5.000000e-01, float 1.000000e+00, float 5.000000e-01>, <3 x float> <float 5.000000e-01, float 5.000000e-01, float 1.000000e+00>, <3 x float> <float 5.000000e-01, float 1.000000e+00, float 1.000000e+00>, <3 x float> <float 1.000000e+00, float 5.000000e-01, float 1.000000e+00>, <3 x float> <float 1.000000e+00, float 1.000000e+00, float 5.000000e-01>], align 16
+
+; Function Attrs: alwaysinline convergent mustprogress norecurse nounwind
+define internal void @_Z4mainDv3_j(<3 x i32> noundef %DID) #1 {
+; CHECK-LABEL: define internal void @_Z4mainDv3_j(
+; CHECK-SAME: <3 x i32> noundef [[DID:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr [24 x float], ptr @_ZL7Palette.scalarized.1dim, i32 1
+; CHECK-NEXT:    [[DOTI0:%.*]] = load float, ptr [[TMP0]], align 16
+; CHECK-NEXT:    [[DOTI1:%.*]] = getelementptr float, ptr [[TMP0]], i32 1
+; CHECK-NEXT:    [[DOTI11:%.*]] = load float, ptr [[DOTI1]], align 4
+; CHECK-NEXT:    [[DOTI2:%.*]] = getelementptr float, ptr [[TMP0]], i32 2
+; CHECK-NEXT:    [[DOTI22:%.*]] = load float, ptr [[DOTI2]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr [24 x float], ptr @_ZL7Palette.scalarized.1dim, i32 2
+; CHECK-NEXT:    [[DOTI03:%.*]] = load float, ptr [[TMP1]], align 16
+; CHECK-NEXT:    [[DOTI14:%.*]] = getelementptr float, ptr [[TMP1]], i32 1
+; CHECK-NEXT:    [[DOTI15:%.*]] = load float, ptr [[DOTI14]], align 4
+; CHECK-NEXT:    [[DOTI26:%.*]] = getelementptr float, ptr [[TMP1]], i32 2
+; CHECK-NEXT:    [[DOTI27:%.*]] = load float, ptr [[DOTI26]], align 8
+; CHECK-NEXT:    ret void
+;
+entry:
+  %arrayidx = getelementptr inbounds [8 x <3 x float>], ptr @_ZL7Palette, i32 0, i32 1
+  %2 = load <3 x float>, ptr %arrayidx, align 16
+  %arrayidx2 = getelementptr inbounds [8 x <3 x float>], ptr @_ZL7Palette, i32 0, i32 2
+  %3 = load <3 x float>, ptr %arrayidx2, align 16
+  ret void
+}

>From 58abdecb57552a87da716ab887528e596208f446 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Mon, 16 Dec 2024 12:14:29 -0500
Subject: [PATCH 2/2] 1. address PR comments 2. remove use of
 convertUsersOfConstantsToInstructions (was leaving use    chains in unstable
 state). 3. remove RPOT for DXILDataScalarization.cpp. we were visiting twice.

---
 .../Target/DirectX/DXILDataScalarization.cpp  | 108 ++++++++++++------
 llvm/lib/Target/DirectX/DXILFlattenArrays.cpp |  47 +++++---
 .../CodeGen/DirectX/flatten-bug-117273.ll     |  16 ++-
 .../DirectX/llc-vector-load-scalarize.ll      |  29 ++---
 .../test/CodeGen/DirectX/scalar-bug-117273.ll |  29 ++---
 llvm/test/CodeGen/DirectX/scalar-load.ll      |  28 +++--
 6 files changed, 150 insertions(+), 107 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
index 42e40d146ff5c1..fbb736a858aaae 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -42,7 +42,7 @@ static bool findAndReplaceVectors(Module &M);
 class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
 public:
   DataScalarizerVisitor() : GlobalMap() {}
-  bool visit(Function &F);
+  bool visit(Instruction &I);
   // InstVisitor methods.  They return true if the instruction was scalarized,
   // false if nothing changed.
   bool visitInstruction(Instruction &I) { return false; }
@@ -65,19 +65,14 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
   friend bool findAndReplaceVectors(llvm::Module &M);
 
 private:
+  Value *createNewGetElementPtr(GetElementPtrInst &GEPI);
   GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);
   DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
 };
 
-bool DataScalarizerVisitor::visit(Function &F) {
+bool DataScalarizerVisitor::visit(Instruction &I) {
   assert(!GlobalMap.empty());
-  bool MadeChange = false;
-  ReversePostOrderTraversal<Function *> RPOT(&F);
-  for (BasicBlock *BB : make_early_inc_range(RPOT)) {
-    for (Instruction &I : make_early_inc_range(*BB))
-      MadeChange |= InstVisitor::visit(I);
-  }
-  return MadeChange;
+  return InstVisitor::visit(I);
 }
 
 GlobalVariable *
@@ -95,6 +90,21 @@ bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
   unsigned NumOperands = LI.getNumOperands();
   for (unsigned I = 0; I < NumOperands; ++I) {
     Value *CurrOpperand = LI.getOperand(I);
+    ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
+    if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
+      GetElementPtrInst *OldGEP =
+          cast<GetElementPtrInst>(CE->getAsInstruction());
+      OldGEP->insertBefore(&LI);
+      Value *NewGEP = createNewGetElementPtr(*OldGEP);
+      IRBuilder<> Builder(&LI);
+      LoadInst *NewLoad =
+          Builder.CreateLoad(LI.getType(), NewGEP, LI.getName());
+      NewLoad->setAlignment(LI.getAlign());
+      LI.replaceAllUsesWith(NewLoad);
+      LI.eraseFromParent();
+      OldGEP->eraseFromParent();
+      return true;
+    }
     if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
       LI.setOperand(I, NewGlobal);
   }
@@ -105,32 +115,69 @@ bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
   unsigned NumOperands = SI.getNumOperands();
   for (unsigned I = 0; I < NumOperands; ++I) {
     Value *CurrOpperand = SI.getOperand(I);
-    if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand)) {
-      SI.setOperand(I, NewGlobal);
+    ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
+    if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
+      GetElementPtrInst *OldGEP =
+          cast<GetElementPtrInst>(CE->getAsInstruction());
+      OldGEP->insertBefore(&SI);
+      Value *NewGEP = createNewGetElementPtr(*OldGEP);
+      IRBuilder<> Builder(&SI);
+      StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), NewGEP);
+      NewStore->setAlignment(SI.getAlign());
+      SI.replaceAllUsesWith(NewStore);
+      SI.eraseFromParent();
+      OldGEP->eraseFromParent();
+      return true;
     }
+    if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
+      SI.setOperand(I, NewGlobal);
   }
   return false;
 }
 
-bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
+Value *DataScalarizerVisitor::createNewGetElementPtr(GetElementPtrInst &GEPI) {
   unsigned NumOperands = GEPI.getNumOperands();
+  GlobalVariable *NewGlobal = nullptr;
   for (unsigned I = 0; I < NumOperands; ++I) {
     Value *CurrOpperand = GEPI.getOperand(I);
-    GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand);
-    if (!NewGlobal)
-      continue;
-    IRBuilder<> Builder(&GEPI);
+    NewGlobal = lookupReplacementGlobal(CurrOpperand);
+    if (NewGlobal)
+      break;
+  }
+  if (!NewGlobal)
+    return nullptr;
 
-    SmallVector<Value *, MaxVecSize> Indices;
-    for (auto &Index : GEPI.indices())
-      Indices.push_back(Index);
+  IRBuilder<> Builder(&GEPI);
+  SmallVector<Value *, MaxVecSize> Indices;
+  for (auto &Index : GEPI.indices())
+    Indices.push_back(Index);
 
-    Value *NewGEP =
-        Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices);
+  return Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices,
+                           GEPI.getName(), GEPI.getNoWrapFlags());
+}
 
-    GEPI.replaceAllUsesWith(NewGEP);
-    GEPI.eraseFromParent();
+bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
+  unsigned NumOperands = GEPI.getNumOperands();
+  GlobalVariable *NewGlobal = nullptr;
+  for (unsigned I = 0; I < NumOperands; ++I) {
+    Value *CurrOpperand = GEPI.getOperand(I);
+    NewGlobal = lookupReplacementGlobal(CurrOpperand);
+    if (NewGlobal)
+      break;
   }
+
+  if (!NewGlobal)
+    return false;
+
+  IRBuilder<> Builder(&GEPI);
+  SmallVector<Value *, MaxVecSize> Indices;
+  for (auto &Index : GEPI.indices())
+    Indices.push_back(Index);
+  Value *NewGEP =
+      Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices,
+                        GEPI.getName(), GEPI.getNoWrapFlags());
+  GEPI.replaceAllUsesWith(NewGEP);
+  GEPI.eraseFromParent();
   return true;
 }
 
@@ -236,16 +283,13 @@ static bool findAndReplaceVectors(Module &M) {
       for (User *U : make_early_inc_range(G.users())) {
         if (isa<ConstantExpr>(U) && isa<Operator>(U)) {
           ConstantExpr *CE = cast<ConstantExpr>(U);
-          convertUsersOfConstantsToInstructions(CE,
-                                                /*RestrictToFunc=*/nullptr,
-                                                /*RemoveDeadConstants=*/false,
-                                                /*IncludeSelf=*/true);
-        }
-        if (Instruction *Inst = dyn_cast<Instruction>(U)) {
-          Function *F = Inst->getFunction();
-          if (F)
-            Impl.visit(*F);
+          for (User *UCE : make_early_inc_range(CE->users())) {
+            if (Instruction *Inst = dyn_cast<Instruction>(UCE))
+              Impl.visit(*Inst);
+          }
         }
+        if (Instruction *Inst = dyn_cast<Instruction>(U))
+          Impl.visit(*Inst);
       }
     }
   }
diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
index fbf0e4bc35d6c6..2a453ca13c6fc2 100644
--- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
+++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
@@ -164,11 +164,18 @@ bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst &LI) {
     Value *CurrOpperand = LI.getOperand(I);
     ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
     if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
-      convertUsersOfConstantsToInstructions(CE,
-                                            /*RestrictToFunc=*/nullptr,
-                                            /*RemoveDeadConstants=*/false,
-                                            /*IncludeSelf=*/true);
-      return false;
+      GetElementPtrInst *OldGEP =
+          cast<GetElementPtrInst>(CE->getAsInstruction());
+      OldGEP->insertBefore(&LI);
+
+      IRBuilder<> Builder(&LI);
+      LoadInst *NewLoad =
+          Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
+      NewLoad->setAlignment(LI.getAlign());
+      LI.replaceAllUsesWith(NewLoad);
+      LI.eraseFromParent();
+      visitGetElementPtrInst(*OldGEP);
+      return true;
     }
   }
   return false;
@@ -180,11 +187,17 @@ bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst &SI) {
     Value *CurrOpperand = SI.getOperand(I);
     ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
     if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
-      convertUsersOfConstantsToInstructions(CE,
-                                            /*RestrictToFunc=*/nullptr,
-                                            /*RemoveDeadConstants=*/false,
-                                            /*IncludeSelf=*/true);
-      return false;
+      GetElementPtrInst *OldGEP =
+          cast<GetElementPtrInst>(CE->getAsInstruction());
+      OldGEP->insertBefore(&SI);
+
+      IRBuilder<> Builder(&SI);
+      StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
+      NewStore->setAlignment(SI.getAlign());
+      SI.replaceAllUsesWith(NewStore);
+      SI.eraseFromParent();
+      visitGetElementPtrInst(*OldGEP);
+      return true;
     }
   }
   return false;
@@ -317,20 +330,20 @@ bool DXILFlattenArraysVisitor::visit(Function &F) {
 static void collectElements(Constant *Init,
                             SmallVectorImpl<Constant *> &Elements) {
   // Base case: If Init is not an array, add it directly to the vector.
-  if (!isa<ArrayType>(Init->getType())) {
+  auto *ArrayTy = dyn_cast<ArrayType>(Init->getType());
+  if (!ArrayTy) {
     Elements.push_back(Init);
     return;
   }
-  auto *LLVMArrayType = dyn_cast<ArrayType>(Init->getType());
+  unsigned ArrSize = ArrayTy->getNumElements();
   if (isa<ConstantAggregateZero>(Init)) {
-    for (unsigned I = 0; I < LLVMArrayType->getNumElements(); ++I)
-      Elements.push_back(
-          Constant::getNullValue(LLVMArrayType->getElementType()));
+    for (unsigned I = 0; I < ArrSize; ++I)
+      Elements.push_back(Constant::getNullValue(ArrayTy->getElementType()));
     return;
   }
   if (isa<UndefValue>(Init)) {
-    for (unsigned I = 0; I < LLVMArrayType->getNumElements(); ++I)
-      Elements.push_back(UndefValue::get(LLVMArrayType->getElementType()));
+    for (unsigned I = 0; I < ArrSize; ++I)
+      Elements.push_back(UndefValue::get(ArrayTy->getElementType()));
     return;
   }
 
diff --git a/llvm/test/CodeGen/DirectX/flatten-bug-117273.ll b/llvm/test/CodeGen/DirectX/flatten-bug-117273.ll
index 1e0f4fb3d382a1..096868e7d9e0c3 100644
--- a/llvm/test/CodeGen/DirectX/flatten-bug-117273.ll
+++ b/llvm/test/CodeGen/DirectX/flatten-bug-117273.ll
@@ -2,24 +2,22 @@
 ; RUN: opt -S -passes='dxil-flatten-arrays,dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 
- at _ZL7Palette = internal constant [2 x [3 x float]] [[3 x float] zeroinitializer, [3 x float] undef], align 16
+ at ZerroInitAndUndefArr = internal constant [2 x [3 x float]] [[3 x float] zeroinitializer, [3 x float] undef], align 16
 
-; CHECK: @_ZL7Palette.1dim = internal constant [6 x float] [float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float undef, float undef, float undef], align 16
 
-define internal void @_Z4mainDv3_j(<3 x i32> noundef %DID) {
-; CHECK-LABEL: define internal void @_Z4mainDv3_j(
-; CHECK-SAME: <3 x i32> noundef [[DID:%.*]]) {
+define internal void @main() {
+; CHECK-LABEL: define internal void @main() {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr [24 x float], ptr @_ZL7Palette.1dim, i32 1
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr [24 x float], ptr @ZerroInitAndUndefArr.1dim, i32 1
 ; CHECK-NEXT:    [[DOTI0:%.*]] = load float, ptr [[TMP0]], align 16
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr [24 x float], ptr @_ZL7Palette.1dim, i32 2
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr [24 x float], ptr @ZerroInitAndUndefArr.1dim, i32 2
 ; CHECK-NEXT:    [[DOTI03:%.*]] = load float, ptr [[TMP1]], align 16
 ; CHECK-NEXT:    ret void
 ;
 entry:
-  %0 = getelementptr [8 x [3 x float]], ptr @_ZL7Palette, i32 0, i32 1
+  %0 = getelementptr [8 x [3 x float]], ptr @ZerroInitAndUndefArr, i32 0, i32 1
   %.i0 = load float, ptr %0, align 16
-  %1 = getelementptr [8 x [3 x float]], ptr @_ZL7Palette, i32 0, i32 2
+  %1 = getelementptr [8 x [3 x float]], ptr @ZerroInitAndUndefArr, i32 0, i32 2
   %.i03 = load float, ptr %1, align 16
   ret void
 }
diff --git a/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll b/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll
index 59725203836506..4e522c6ef5da7e 100644
--- a/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll
+++ b/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll
@@ -21,7 +21,6 @@
 ; CHECK-NOT: @groushared2dArrayofVectors
 ; CHECK-NOT: @groushared2dArrayofVectors.scalarized
 
-
 define <4 x i32> @load_array_vec_test() #0 {
 ; CHECK-LABEL: define <4 x i32> @load_array_vec_test(
 ; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
@@ -33,18 +32,13 @@ define <4 x i32> @load_array_vec_test() #0 {
 ; CHECK-NEXT:    [[TMP6:%.*]] = load i32, ptr addrspace(3) [[TMP5]], align 4
 ; CHECK-NEXT:    [[TMP7:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 3) to ptr addrspace(3)
 ; CHECK-NEXT:    [[TMP8:%.*]] = load i32, ptr addrspace(3) [[TMP7]], align 4
-; CHECK-NEXT:    [[TMP9:%.*]] = bitcast ptr addrspace(3) @arrayofVecData.scalarized.1dim to ptr addrspace(3)
-; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr [2 x [3 x float]], ptr addrspace(3) [[TMP9]], i32 0, i32 1
-; CHECK-NEXT:    [[TMP11:%.*]] = bitcast ptr addrspace(3) [[TMP10]] to ptr addrspace(3)
+; CHECK-NEXT:    [[TMP11:%.*]] = bitcast ptr addrspace(3) getelementptr inbounds ([6 x float], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 1) to ptr addrspace(3)
 ; CHECK-NEXT:    [[TMP12:%.*]] = load i32, ptr addrspace(3) [[TMP11]], align 4
-; CHECK-NEXT:    [[TMP13:%.*]] = bitcast ptr addrspace(3) [[TMP10]] to ptr addrspace(3)
-; CHECK-NEXT:    [[DOTI12:%.*]] = getelementptr i32, ptr addrspace(3) [[TMP13]], i32 1
+; CHECK-NEXT:    [[DOTI12:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([6 x float], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 1), i32 1) to ptr addrspace(3)
 ; CHECK-NEXT:    [[DOTI13:%.*]] = load i32, ptr addrspace(3) [[DOTI12]], align 4
-; CHECK-NEXT:    [[TMP14:%.*]] = bitcast ptr addrspace(3) [[TMP10]] to ptr addrspace(3)
-; CHECK-NEXT:    [[DOTI24:%.*]] = getelementptr i32, ptr addrspace(3) [[TMP14]], i32 2
+; CHECK-NEXT:    [[DOTI24:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([6 x float], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 1), i32 2) to ptr addrspace(3)
 ; CHECK-NEXT:    [[DOTI25:%.*]] = load i32, ptr addrspace(3) [[DOTI24]], align 4
-; CHECK-NEXT:    [[TMP15:%.*]] = bitcast ptr addrspace(3) [[TMP10]] to ptr addrspace(3)
-; CHECK-NEXT:    [[DOTI36:%.*]] = getelementptr i32, ptr addrspace(3) [[TMP15]], i32 3
+; CHECK-NEXT:    [[DOTI36:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([6 x float], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 1), i32 3) to ptr addrspace(3)
 ; CHECK-NEXT:    [[DOTI37:%.*]] = load i32, ptr addrspace(3) [[DOTI36]], align 4
 ; CHECK-NEXT:    [[DOTI08:%.*]] = add i32 [[TMP2]], [[TMP12]]
 ; CHECK-NEXT:    [[DOTI19:%.*]] = add i32 [[TMP4]], [[DOTI13]]
@@ -87,7 +81,7 @@ define <4 x i32> @load_vec_test() #0 {
 define <4 x i32> @load_static_array_of_vec_test(i32 %index) #0 {
 ; CHECK-LABEL: define <4 x i32> @load_static_array_of_vec_test(
 ; CHECK-SAME: i32 [[INDEX:%.*]]) #[[ATTR0]] {
-; CHECK-NEXT:    [[DOTFLAT:%.*]] = getelementptr [12 x i32], ptr @staticArrayOfVecData.scalarized.1dim, i32 [[INDEX]]
+; CHECK-NEXT:    [[DOTFLAT:%.*]] = getelementptr inbounds [12 x i32], ptr @staticArrayOfVecData.scalarized.1dim, i32 [[INDEX]]
 ; CHECK-NEXT:    [[TMP1:%.*]] = bitcast ptr [[DOTFLAT]] to ptr
 ; CHECK-NEXT:    [[TMP2:%.*]] = load i32, ptr [[TMP1]], align 4
 ; CHECK-NEXT:    [[TMP3:%.*]] = bitcast ptr [[DOTFLAT]] to ptr
@@ -121,18 +115,13 @@ define <4 x i32> @multid_load_test() #0 {
 ; CHECK-NEXT:    [[TMP6:%.*]] = load i32, ptr addrspace(3) [[TMP5]], align 4
 ; CHECK-NEXT:    [[TMP7:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 3) to ptr addrspace(3)
 ; CHECK-NEXT:    [[TMP8:%.*]] = load i32, ptr addrspace(3) [[TMP7]], align 4
-; CHECK-NEXT:    [[TMP9:%.*]] = bitcast ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim to ptr addrspace(3)
-; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr [3 x [3 x [4 x i32]]], ptr addrspace(3) [[TMP9]], i32 0, i32 1, i32 1
-; CHECK-NEXT:    [[TMP11:%.*]] = bitcast ptr addrspace(3) [[TMP10]] to ptr addrspace(3)
+; CHECK-NEXT:    [[TMP11:%.*]] = bitcast ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 1) to ptr addrspace(3)
 ; CHECK-NEXT:    [[TMP12:%.*]] = load i32, ptr addrspace(3) [[TMP11]], align 4
-; CHECK-NEXT:    [[TMP13:%.*]] = bitcast ptr addrspace(3) [[TMP10]] to ptr addrspace(3)
-; CHECK-NEXT:    [[DOTI12:%.*]] = getelementptr i32, ptr addrspace(3) [[TMP13]], i32 1
+; CHECK-NEXT:    [[DOTI12:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 1), i32 1) to ptr addrspace(3)
 ; CHECK-NEXT:    [[DOTI13:%.*]] = load i32, ptr addrspace(3) [[DOTI12]], align 4
-; CHECK-NEXT:    [[TMP14:%.*]] = bitcast ptr addrspace(3) [[TMP10]] to ptr addrspace(3)
-; CHECK-NEXT:    [[DOTI24:%.*]] = getelementptr i32, ptr addrspace(3) [[TMP14]], i32 2
+; CHECK-NEXT:    [[DOTI24:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 1), i32 2) to ptr addrspace(3)
 ; CHECK-NEXT:    [[DOTI25:%.*]] = load i32, ptr addrspace(3) [[DOTI24]], align 4
-; CHECK-NEXT:    [[TMP15:%.*]] = bitcast ptr addrspace(3) [[TMP10]] to ptr addrspace(3)
-; CHECK-NEXT:    [[DOTI36:%.*]] = getelementptr i32, ptr addrspace(3) [[TMP15]], i32 3
+; CHECK-NEXT:    [[DOTI36:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 1), i32 3) to ptr addrspace(3)
 ; CHECK-NEXT:    [[DOTI37:%.*]] = load i32, ptr addrspace(3) [[DOTI36]], align 4
 ; CHECK-NEXT:    [[DOTI08:%.*]] = add i32 [[TMP2]], [[TMP12]]
 ; CHECK-NEXT:    [[DOTI19:%.*]] = add i32 [[TMP4]], [[DOTI13]]
diff --git a/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll b/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll
index aa09fa401a6492..25dc2c36b4e1f3 100644
--- a/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll
+++ b/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll
@@ -2,31 +2,24 @@
 ; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays,function(scalarizer<load-store>),dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 
- at _ZL7Palette = internal constant [8 x <3 x float>] [<3 x float> zeroinitializer, <3 x float> splat (float 5.000000e-01), <3 x float> <float 1.000000e+00, float 5.000000e-01, float 5.000000e-01>, <3 x float> <float 5.000000e-01, float 1.000000e+00, float 5.000000e-01>, <3 x float> <float 5.000000e-01, float 5.000000e-01, float 1.000000e+00>, <3 x float> <float 5.000000e-01, float 1.000000e+00, float 1.000000e+00>, <3 x float> <float 1.000000e+00, float 5.000000e-01, float 1.000000e+00>, <3 x float> <float 1.000000e+00, float 1.000000e+00, float 5.000000e-01>], align 16
+ at StaticArr = internal constant [8 x <3 x float>] [<3 x float> zeroinitializer, <3 x float> splat (float 5.000000e-01), <3 x float> <float 1.000000e+00, float 5.000000e-01, float 5.000000e-01>, <3 x float> <float 5.000000e-01, float 1.000000e+00, float 5.000000e-01>, <3 x float> <float 5.000000e-01, float 5.000000e-01, float 1.000000e+00>, <3 x float> <float 5.000000e-01, float 1.000000e+00, float 1.000000e+00>, <3 x float> <float 1.000000e+00, float 5.000000e-01, float 1.000000e+00>, <3 x float> <float 1.000000e+00, float 1.000000e+00, float 5.000000e-01>], align 16
 
 ; Function Attrs: alwaysinline convergent mustprogress norecurse nounwind
-define internal void @_Z4mainDv3_j(<3 x i32> noundef %DID) #1 {
-; CHECK-LABEL: define internal void @_Z4mainDv3_j(
-; CHECK-SAME: <3 x i32> noundef [[DID:%.*]]) {
+define internal void @main() #1 {
+; CHECK-LABEL: define internal void @main() {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr [24 x float], ptr @_ZL7Palette.scalarized.1dim, i32 1
-; CHECK-NEXT:    [[DOTI0:%.*]] = load float, ptr [[TMP0]], align 16
-; CHECK-NEXT:    [[DOTI1:%.*]] = getelementptr float, ptr [[TMP0]], i32 1
-; CHECK-NEXT:    [[DOTI11:%.*]] = load float, ptr [[DOTI1]], align 4
-; CHECK-NEXT:    [[DOTI2:%.*]] = getelementptr float, ptr [[TMP0]], i32 2
-; CHECK-NEXT:    [[DOTI22:%.*]] = load float, ptr [[DOTI2]], align 8
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr [24 x float], ptr @_ZL7Palette.scalarized.1dim, i32 2
-; CHECK-NEXT:    [[DOTI03:%.*]] = load float, ptr [[TMP1]], align 16
-; CHECK-NEXT:    [[DOTI14:%.*]] = getelementptr float, ptr [[TMP1]], i32 1
-; CHECK-NEXT:    [[DOTI15:%.*]] = load float, ptr [[DOTI14]], align 4
-; CHECK-NEXT:    [[DOTI26:%.*]] = getelementptr float, ptr [[TMP1]], i32 2
-; CHECK-NEXT:    [[DOTI27:%.*]] = load float, ptr [[DOTI26]], align 8
+; CHECK-NEXT:    [[DOTI0:%.*]] = load float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 1), align 16
+; CHECK-NEXT:    [[DOTI1:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 1), i32 1), align 4
+; CHECK-NEXT:    [[DOTI2:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 1), i32 2), align 8
+; CHECK-NEXT:    [[DOTI01:%.*]] = load float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 2), align 16
+; CHECK-NEXT:    [[DOTI12:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 2), i32 1), align 4
+; CHECK-NEXT:    [[DOTI23:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 2), i32 2), align 8
 ; CHECK-NEXT:    ret void
 ;
 entry:
-  %arrayidx = getelementptr inbounds [8 x <3 x float>], ptr @_ZL7Palette, i32 0, i32 1
+  %arrayidx = getelementptr inbounds [8 x <3 x float>], ptr @StaticArr, i32 0, i32 1
   %2 = load <3 x float>, ptr %arrayidx, align 16
-  %arrayidx2 = getelementptr inbounds [8 x <3 x float>], ptr @_ZL7Palette, i32 0, i32 2
+  %arrayidx2 = getelementptr inbounds [8 x <3 x float>], ptr @StaticArr, i32 0, i32 2
   %3 = load <3 x float>, ptr %arrayidx2, align 16
   ret void
 }
diff --git a/llvm/test/CodeGen/DirectX/scalar-load.ll b/llvm/test/CodeGen/DirectX/scalar-load.ll
index a32db8b8e39951..ed1e9109b7b182 100644
--- a/llvm/test/CodeGen/DirectX/scalar-load.ll
+++ b/llvm/test/CodeGen/DirectX/scalar-load.ll
@@ -2,10 +2,10 @@
 
 ; Make sure we can load groupshared, static vectors and arrays of vectors
 
-@"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16
-@"vecData" = external addrspace(3) global <4 x i32>, align 4
+ at arrayofVecData = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16
+ at vecData = external addrspace(3) global <4 x i32>, align 4
 @staticArrayOfVecData = internal global [3 x <4 x i32>] [<4 x i32> <i32 1, i32 2, i32 3, i32 4>, <4 x i32> <i32 5, i32 6, i32 7, i32 8>, <4 x i32> <i32 9, i32 10, i32 11, i32 12>], align 4
-@"groushared2dArrayofVectors" = local_unnamed_addr addrspace(3) global [3 x [ 3 x <4 x i32>]] zeroinitializer, align 16
+ at groushared2dArrayofVectors = local_unnamed_addr addrspace(3) global [3 x [ 3 x <4 x i32>]] zeroinitializer, align 16
 
 ; CHECK: @arrayofVecData.scalarized = local_unnamed_addr addrspace(3) global [2 x [3 x float]] zeroinitializer, align 16
 ; CHECK: @vecData.scalarized = external addrspace(3) global [4 x i32], align 4
@@ -19,12 +19,12 @@
 
 
 ; CHECK-LABEL: load_array_vec_test
-define <4 x i32> @load_array_vec_test() #0 {
-  ; CHECK-COUNT-8: load i32, ptr addrspace(3) {{(.*@arrayofVecData.scalarized.*|%.*)}}, align 4
-  %1 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([2 x <4 x i32>], [2 x <4 x i32>] addrspace(3)* @"arrayofVecData", i32 0, i32 0), align 4
-  %2 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([2 x <4 x i32>], [2 x <4 x i32>] addrspace(3)* @"arrayofVecData", i32 0, i32 1), align 4
-  %3 = add <4 x i32> %1, %2
-  ret <4 x i32> %3
+define <3 x float> @load_array_vec_test() #0 {
+  ; CHECK-COUNT-6: load float, ptr addrspace(3) {{(.*@arrayofVecData.scalarized.*|%.*)}}, align 4
+  %1 = load <3 x float>, <3 x float> addrspace(3)* getelementptr inbounds ([2 x <3 x float>], [2 x <3 x float>] addrspace(3)* @"arrayofVecData", i32 0, i32 0), align 4
+  %2 = load <3 x float>, <3 x float> addrspace(3)* getelementptr inbounds ([2 x <3 x float>], [2 x <3 x float>] addrspace(3)* @"arrayofVecData", i32 0, i32 1), align 4
+  %3 = fadd <3 x float> %1, %2
+  ret <3 x float> %3
 }
 
 ; CHECK-LABEL: load_vec_test
@@ -36,8 +36,14 @@ define <4 x i32> @load_vec_test() #0 {
 
 ; CHECK-LABEL: load_static_array_of_vec_test
 define <4 x i32> @load_static_array_of_vec_test(i32 %index) #0 {
-  ; CHECK: getelementptr [3 x [4 x i32]], ptr @staticArrayOfVecData.scalarized, i32 0, i32 %index
-  ; CHECK-COUNT-4: load i32, ptr {{.*}}, align 4
+  ; CHECK: getelementptr inbounds [3 x [4 x i32]], ptr @staticArrayOfVecData.scalarized, i32 0, i32 %index
+  ; CHECK: load i32, ptr {{.*}}, align 4
+  ; CHECK: getelementptr i32, ptr {{.*}}, i32 1
+  ; CHECK: load i32, ptr {{.*}}, align 4
+  ; CHECK: getelementptr i32, ptr {{.*}}, i32 2
+  ; CHECK: load i32, ptr {{.*}}, align 4
+  ; CHECK: getelementptr i32, ptr {{.*}}, i32 3
+  ; CHECK: load i32, ptr {{.*}}, align 4
   %3 = getelementptr inbounds [3 x <4 x i32>], [3 x <4 x i32>]* @staticArrayOfVecData, i32 0, i32 %index
   %4 = load <4 x i32>, <4 x i32>* %3, align 4
   ret <4 x i32> %4



More information about the llvm-commits mailing list