[llvm] [DirectX] Don't limit visitGetElementPtrInst to global ptrs (PR #144959)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 19 14:33:45 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-directx

Author: Farzon Lotfi (farzonl)

<details>
<summary>Changes</summary>

fixes #<!-- -->144608
- there is a getPointerOperandIndex function so we don't need to iterate the operands trying to find the pointer. This resulted in a small cleanup to visitStoreInst and visitLoadInst.

- The meat of this change was in visitGetElementPtrInst to account for allocas and not bail when we don't find a global.

---
Full diff: https://github.com/llvm/llvm-project/pull/144959.diff


2 Files Affected:

- (modified) llvm/lib/Target/DirectX/DXILDataScalarization.cpp (+55-49) 
- (modified) llvm/test/CodeGen/DirectX/scalarize-alloca.ll (+17-2) 


``````````diff
diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
index 06708cec00cec..61c5301ed5051 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -14,11 +14,13 @@
 #include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/InstVisitor.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/IR/ReplaceConstant.h"
 #include "llvm/IR/Type.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Transforms/Utils/Cloning.h"
 #include "llvm/Transforms/Utils/Local.h"
 
@@ -127,71 +129,75 @@ bool DataScalarizerVisitor::visitAllocaInst(AllocaInst &AI) {
 }
 
 bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
-  unsigned NumOperands = LI.getNumOperands();
-  for (unsigned I = 0; I < NumOperands; ++I) {
-    Value *CurrOpperand = LI.getOperand(I);
-    ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
-    if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
-      GetElementPtrInst *OldGEP =
-          cast<GetElementPtrInst>(CE->getAsInstruction());
-      OldGEP->insertBefore(LI.getIterator());
-      IRBuilder<> Builder(&LI);
-      LoadInst *NewLoad =
-          Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
-      NewLoad->setAlignment(LI.getAlign());
-      LI.replaceAllUsesWith(NewLoad);
-      LI.eraseFromParent();
-      visitGetElementPtrInst(*OldGEP);
-      return true;
-    }
-    if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
-      LI.setOperand(I, NewGlobal);
+  Value *PtrOperand = LI.getPointerOperand();
+  ConstantExpr *CE = dyn_cast<ConstantExpr>(PtrOperand);
+  if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
+    GetElementPtrInst *OldGEP = cast<GetElementPtrInst>(CE->getAsInstruction());
+    OldGEP->insertBefore(LI.getIterator());
+    IRBuilder<> Builder(&LI);
+    LoadInst *NewLoad = Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
+    NewLoad->setAlignment(LI.getAlign());
+    LI.replaceAllUsesWith(NewLoad);
+    LI.eraseFromParent();
+    visitGetElementPtrInst(*OldGEP);
+    return true;
   }
+  if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand))
+    LI.setOperand(LI.getPointerOperandIndex(), NewGlobal);
   return false;
 }
 
 bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
-  unsigned NumOperands = SI.getNumOperands();
-  for (unsigned I = 0; I < NumOperands; ++I) {
-    Value *CurrOpperand = SI.getOperand(I);
-    ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
-    if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
-      GetElementPtrInst *OldGEP =
-          cast<GetElementPtrInst>(CE->getAsInstruction());
-      OldGEP->insertBefore(SI.getIterator());
-      IRBuilder<> Builder(&SI);
-      StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
-      NewStore->setAlignment(SI.getAlign());
-      SI.replaceAllUsesWith(NewStore);
-      SI.eraseFromParent();
-      visitGetElementPtrInst(*OldGEP);
-      return true;
-    }
-    if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
-      SI.setOperand(I, NewGlobal);
+
+  Value *PtrOperand = SI.getPointerOperand();
+  ConstantExpr *CE = dyn_cast<ConstantExpr>(PtrOperand);
+  if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
+    GetElementPtrInst *OldGEP = cast<GetElementPtrInst>(CE->getAsInstruction());
+    OldGEP->insertBefore(SI.getIterator());
+    IRBuilder<> Builder(&SI);
+    StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
+    NewStore->setAlignment(SI.getAlign());
+    SI.replaceAllUsesWith(NewStore);
+    SI.eraseFromParent();
+    visitGetElementPtrInst(*OldGEP);
+    return true;
   }
+  if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand))
+    SI.setOperand(SI.getPointerOperandIndex(), NewGlobal);
+
   return false;
 }
 
 bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
-
-  unsigned NumOperands = GEPI.getNumOperands();
-  GlobalVariable *NewGlobal = nullptr;
-  for (unsigned I = 0; I < NumOperands; ++I) {
-    Value *CurrOpperand = GEPI.getOperand(I);
-    NewGlobal = lookupReplacementGlobal(CurrOpperand);
-    if (NewGlobal)
-      break;
+  Value *PtrOperand = GEPI.getPointerOperand();
+  Type *OrigGEPType = GEPI.getPointerOperandType();
+  Type *NewGEPType = OrigGEPType;
+  bool NeedsTransform = false;
+
+  if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand)) {
+    NewGEPType = NewGlobal->getValueType();
+    PtrOperand = NewGlobal;
+    NeedsTransform = true;
+  } else if (AllocaInst *Alloca = dyn_cast<AllocaInst>(PtrOperand)) {
+    Type *AllocatedType = Alloca->getAllocatedType();
+    // OrigGEPType might just be a pointer lets make sure
+    // to add the allocated type so we have a size
+    if (AllocatedType != OrigGEPType) {
+      NewGEPType = AllocatedType;
+      NeedsTransform = true;
+    }
   }
-  if (!NewGlobal)
+
+  // Note: We bail if this isn't a gep touched via alloca or global
+  // transformations
+  if (!NeedsTransform)
     return false;
 
   IRBuilder<> Builder(&GEPI);
   SmallVector<Value *, MaxVecSize> Indices(GEPI.indices());
 
-  Value *NewGEP =
-      Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices,
-                        GEPI.getName(), GEPI.getNoWrapFlags());
+  Value *NewGEP = Builder.CreateGEP(NewGEPType, PtrOperand, Indices,
+                                    GEPI.getName(), GEPI.getNoWrapFlags());
   GEPI.replaceAllUsesWith(NewGEP);
   GEPI.eraseFromParent();
   return true;
diff --git a/llvm/test/CodeGen/DirectX/scalarize-alloca.ll b/llvm/test/CodeGen/DirectX/scalarize-alloca.ll
index 4829f3a31791f..b589136d6965c 100644
--- a/llvm/test/CodeGen/DirectX/scalarize-alloca.ll
+++ b/llvm/test/CodeGen/DirectX/scalarize-alloca.ll
@@ -1,10 +1,25 @@
-; RUN: opt -S -passes='dxil-data-scalarization' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefix=SCHECK
-; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefix=FCHECK
+; RUN: opt -S -passes='dxil-data-scalarization' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=SCHECK,CHECK
+; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=FCHECK,CHECK
 
 ; CHECK-LABEL: alloca_2d__vec_test
 define void @alloca_2d__vec_test() local_unnamed_addr #2 {
   ; SCHECK:  alloca [2 x [4 x i32]], align 16
   ; FCHECK:  alloca [8 x i32], align 16
+  ; CHECK: ret void
   %1 = alloca [2 x <4 x i32>], align 16
   ret void
 }
+
+; CHECK-LABEL: alloca_2d_gep_test
+define void @alloca_2d_gep_test() {
+  ; SCHECK:  [[alloca_val:%.*]] = alloca [2 x [2 x i32]], align 16
+  ; FCHECK:  [[alloca_val:%.*]] = alloca [4 x i32], align 16
+  ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
+  ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [2 x [2 x i32]], ptr [[alloca_val]], i32 0, i32 [[tid]]
+  ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [4 x i32], ptr [[alloca_val]], i32 0, i32 [[tid]]
+  ; CHECK: ret void
+  %1 = alloca [2 x <2 x i32>], align 16
+  %2 = tail call i32 @llvm.dx.thread.id(i32 0)
+  %3 = getelementptr inbounds nuw [2 x <2 x i32>], ptr %1, i32 0, i32 %2
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/144959


More information about the llvm-commits mailing list