[llvm] 75c09b7 - [DirectX] Let data scalarizer pass account for sub-types when updating GEP type (#166200)

via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 6 09:14:55 PST 2025


Author: Finn Plummer
Date: 2025-11-06T09:14:50-08:00
New Revision: 75c09b792433fffc442e0ea53b45ee8e330f8acc

URL: https://github.com/llvm/llvm-project/commit/75c09b792433fffc442e0ea53b45ee8e330f8acc
DIFF: https://github.com/llvm/llvm-project/commit/75c09b792433fffc442e0ea53b45ee8e330f8acc.diff

LOG: [DirectX] Let data scalarizer pass account for sub-types when updating GEP type (#166200)

This pr lets the `dxil-data-scalarization` account for a GEP with a
source type that is a sub-type of the pointer operand type.

The pass is updated so that the replaced GEP introduces zero indices
such that the result type remains the same (with the vector -> array
transform).

Please see resolved issue for an annotated example.

Resolves: https://github.com/llvm/llvm-project/issues/165473

Added: 
    llvm/test/CodeGen/DirectX/scalarize-global.ll

Modified: 
    llvm/lib/Target/DirectX/DXILDataScalarization.cpp
    llvm/test/CodeGen/DirectX/scalarize-alloca.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
index d507d71b99fc9..9f1616f6960fe 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -304,40 +304,76 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
   GEPOperator *GOp = cast<GEPOperator>(&GEPI);
   Value *PtrOperand = GOp->getPointerOperand();
   Type *NewGEPType = GOp->getSourceElementType();
-  bool NeedsTransform = false;
 
   // Unwrap GEP ConstantExprs to find the base operand and element type
-  while (auto *CE = dyn_cast<ConstantExpr>(PtrOperand)) {
-    if (auto *GEPCE = dyn_cast<GEPOperator>(CE)) {
-      GOp = GEPCE;
-      PtrOperand = GEPCE->getPointerOperand();
-      NewGEPType = GEPCE->getSourceElementType();
-    } else
-      break;
+  while (auto *GEPCE = dyn_cast_or_null<GEPOperator>(
+             dyn_cast<ConstantExpr>(PtrOperand))) {
+    GOp = GEPCE;
+    PtrOperand = GEPCE->getPointerOperand();
+    NewGEPType = GEPCE->getSourceElementType();
   }
 
+  Type *const OrigGEPType = NewGEPType;
+  Value *const OrigOperand = PtrOperand;
+
   if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand)) {
     NewGEPType = NewGlobal->getValueType();
     PtrOperand = NewGlobal;
-    NeedsTransform = true;
   } else if (AllocaInst *Alloca = dyn_cast<AllocaInst>(PtrOperand)) {
     Type *AllocatedType = Alloca->getAllocatedType();
     if (isa<ArrayType>(AllocatedType) &&
-        AllocatedType != GOp->getResultElementType()) {
+        AllocatedType != GOp->getResultElementType())
       NewGEPType = AllocatedType;
-      NeedsTransform = true;
+  } else
+    return false; // Only GEPs into an alloca or global variable are considered
+
+  // Defer changing i8 GEP types until dxil-flatten-arrays
+  if (OrigGEPType->isIntegerTy(8))
+    NewGEPType = OrigGEPType;
+
+  // If the original type is a "sub-type" of the new type, then ensure the gep
+  // correctly zero-indexes the extra dimensions to keep the offset calculation
+  // correct.
+  // Eg:
+  //  i32, [4 x i32] and [8 x [4 x i32]] are sub-types of [8 x [4 x i32]], etc.
+  //
+  // So then:
+  //   gep [4 x i32] %idx
+  //     -> gep [8 x [4 x i32]], i32 0, i32 %idx
+  //   gep i32 %idx
+  //     -> gep [8 x [4 x i32]], i32 0, i32 0, i32 %idx
+  uint32_t MissingDims = 0;
+  Type *SubType = NewGEPType;
+
+  // The new type will be in its array version; so match accordingly.
+  Type *const GEPArrType = equivalentArrayTypeFromVector(OrigGEPType);
+
+  while (SubType != GEPArrType) {
+    MissingDims++;
+
+    ArrayType *ArrType = dyn_cast<ArrayType>(SubType);
+    if (!ArrType) {
+      assert(SubType == GEPArrType &&
+             "GEP uses an DXIL invalid sub-type of alloca/global variable");
+      break;
     }
+
+    SubType = ArrType->getElementType();
   }
 
+  bool NeedsTransform = OrigOperand != PtrOperand ||
+                        OrigGEPType != NewGEPType || MissingDims != 0;
+
   if (!NeedsTransform)
     return false;
 
-  // Keep scalar GEPs scalar; dxil-flatten-arrays will do flattening later
-  if (!isa<ArrayType>(GOp->getSourceElementType()))
-    NewGEPType = GOp->getSourceElementType();
-
   IRBuilder<> Builder(&GEPI);
-  SmallVector<Value *, MaxVecSize> Indices(GOp->indices());
+  SmallVector<Value *, MaxVecSize> Indices;
+
+  for (uint32_t I = 0; I < MissingDims; I++)
+    Indices.push_back(Builder.getInt32(0));
+  llvm::append_range(Indices, GOp->indices());
+
   Value *NewGEP = Builder.CreateGEP(NewGEPType, PtrOperand, Indices,
                                     GOp->getName(), GOp->getNoWrapFlags());
 

diff  --git a/llvm/test/CodeGen/DirectX/scalarize-alloca.ll b/llvm/test/CodeGen/DirectX/scalarize-alloca.ll
index a8557e47b0ea6..475935d2eb135 100644
--- a/llvm/test/CodeGen/DirectX/scalarize-alloca.ll
+++ b/llvm/test/CodeGen/DirectX/scalarize-alloca.ll
@@ -42,3 +42,68 @@ define void @alloca_2d_gep_test() {
   %3 = getelementptr inbounds nuw [2 x <2 x i32>], ptr %1, i32 0, i32 %2
   ret void
 }
+
+; CHECK-LABEL: subtype_array_test
+define void @subtype_array_test() {
+  ; SCHECK:  [[alloca_val:%.*]] = alloca [8 x [4 x i32]], align 4
+  ; FCHECK:  [[alloca_val:%.*]] = alloca [32 x i32], align 4
+  ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
+  ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [8 x [4 x i32]], ptr [[alloca_val]], i32 0, i32 [[tid]]
+  ; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 4
+  ; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_mul]]
+  ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr [[alloca_val]], i32 0, i32 [[flatidx]]
+  ; CHECK: ret void
+  %arr = alloca [8 x [4 x i32]], align 4
+  %i = tail call i32 @llvm.dx.thread.id(i32 0)
+  %gep = getelementptr inbounds nuw [4 x i32], ptr %arr, i32 %i
+  ret void
+}
+
+; CHECK-LABEL: subtype_vector_test
+define void @subtype_vector_test() {
+  ; SCHECK:  [[alloca_val:%.*]] = alloca [8 x [4 x i32]], align 4
+  ; FCHECK:  [[alloca_val:%.*]] = alloca [32 x i32], align 4
+  ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
+  ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [8 x [4 x i32]], ptr [[alloca_val]], i32 0, i32 [[tid]]
+  ; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 4
+  ; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_mul]]
+  ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr [[alloca_val]], i32 0, i32 [[flatidx]]
+  ; CHECK: ret void
+  %arr = alloca [8 x <4 x i32>], align 4
+  %i = tail call i32 @llvm.dx.thread.id(i32 0)
+  %gep = getelementptr inbounds nuw <4 x i32>, ptr %arr, i32 %i
+  ret void
+}
+
+; CHECK-LABEL: subtype_scalar_test
+define void @subtype_scalar_test() {
+  ; SCHECK:  [[alloca_val:%.*]] = alloca [8 x [4 x i32]], align 4
+  ; FCHECK:  [[alloca_val:%.*]] = alloca [32 x i32], align 4
+  ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
+  ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [8 x [4 x i32]], ptr [[alloca_val]], i32 0, i32 0, i32 [[tid]]
+  ; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 1
+  ; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_mul]]
+  ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr [[alloca_val]], i32 0, i32 [[flatidx]]
+  ; CHECK: ret void
+  %arr = alloca [8 x [4 x i32]], align 4
+  %i = tail call i32 @llvm.dx.thread.id(i32 0)
+  %gep = getelementptr inbounds nuw i32, ptr %arr, i32 %i
+  ret void
+}
+
+; CHECK-LABEL: subtype_i8_test
+define void @subtype_i8_test() {
+  ; SCHECK:  [[alloca_val:%.*]] = alloca [8 x [4 x i32]], align 4
+  ; FCHECK:  [[alloca_val:%.*]] = alloca [32 x i32], align 4
+  ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
+  ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw i8, ptr [[alloca_val]], i32 [[tid]]
+  ; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 1
+  ; FCHECK: [[flatidx_lshr:%.*]] = lshr i32 [[flatidx_mul]], 2
+  ; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_lshr]]
+  ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr [[alloca_val]], i32 0, i32 [[flatidx]]
+  ; CHECK: ret void
+  %arr = alloca [8 x [4 x i32]], align 4
+  %i = tail call i32 @llvm.dx.thread.id(i32 0)
+  %gep = getelementptr inbounds nuw i8, ptr %arr, i32 %i
+  ret void
+}

diff  --git a/llvm/test/CodeGen/DirectX/scalarize-global.ll b/llvm/test/CodeGen/DirectX/scalarize-global.ll
new file mode 100644
index 0000000000000..ca10f6ece5a85
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/scalarize-global.ll
@@ -0,0 +1,70 @@
+; RUN: opt -S -passes='dxil-data-scalarization' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=SCHECK,CHECK
+; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=FCHECK,CHECK
+
+@"arrayofVecData" = local_unnamed_addr addrspace(3) global [8 x <4 x i32>] zeroinitializer, align 16
+@"vecData" = external addrspace(3) global <4 x i32>, align 4
+
+; SCHECK: [[arrayofVecData:@arrayofVecData.*]] = local_unnamed_addr addrspace(3) global [8 x [4 x i32]] zeroinitializer, align 16
+; FCHECK: [[arrayofVecData:@arrayofVecData.*]] = local_unnamed_addr addrspace(3) global [32 x i32] zeroinitializer, align 16
+; CHECK: [[vecData:@vecData.*]] = external addrspace(3) global [4 x i32], align 4
+
+; CHECK-LABEL: subtype_array_test
+define <4 x i32> @subtype_array_test() {
+  ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
+  ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [8 x [4 x i32]], ptr addrspace(3) [[arrayofVecData]], i32 0, i32 [[tid]]
+  ; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 4
+  ; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_mul]]
+  ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr addrspace(3) [[arrayofVecData]], i32 0, i32 [[flatidx]]
+  ; CHECK: [[x:%.*]] = load <4 x i32>, ptr addrspace(3) [[gep]], align 4
+  ; CHECK: ret <4 x i32> [[x]]
+  %i = tail call i32 @llvm.dx.thread.id(i32 0)
+  %gep = getelementptr inbounds nuw [4 x i32], ptr addrspace(3) @"arrayofVecData", i32 %i
+  %x = load <4 x i32>, ptr addrspace(3) %gep, align 4
+  ret <4 x i32> %x
+}
+
+; CHECK-LABEL: subtype_vector_test
+define <4 x i32> @subtype_vector_test() {
+  ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
+  ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [8 x [4 x i32]], ptr addrspace(3) [[arrayofVecData]], i32 0, i32 [[tid]]
+  ; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 4
+  ; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_mul]]
+  ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr addrspace(3) [[arrayofVecData]], i32 0, i32 [[flatidx]]
+  ; CHECK: [[x:%.*]] = load <4 x i32>, ptr addrspace(3) [[gep]], align 4
+  ; CHECK: ret <4 x i32> [[x]]
+  %i = tail call i32 @llvm.dx.thread.id(i32 0)
+  %gep = getelementptr inbounds nuw <4 x i32>, ptr addrspace(3) @"arrayofVecData", i32 %i
+  %x = load <4 x i32>, ptr addrspace(3) %gep, align 4
+  ret <4 x i32> %x
+}
+
+; CHECK-LABEL: subtype_scalar_test
+define <4 x i32> @subtype_scalar_test() {
+  ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
+  ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [8 x [4 x i32]], ptr addrspace(3) [[arrayofVecData]], i32 0, i32 0, i32 [[tid]]
+  ; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 1
+  ; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_mul]]
+  ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr addrspace(3) [[arrayofVecData]], i32 0, i32 [[flatidx]]
+  ; CHECK: [[x:%.*]] = load <4 x i32>, ptr addrspace(3) [[gep]], align 4
+  ; CHECK: ret <4 x i32> [[x]]
+  %i = tail call i32 @llvm.dx.thread.id(i32 0)
+  %gep = getelementptr inbounds nuw i32, ptr addrspace(3) @"arrayofVecData", i32 %i
+  %x = load <4 x i32>, ptr addrspace(3) %gep, align 4
+  ret <4 x i32> %x
+}
+
+; CHECK-LABEL: subtype_i8_test
+define <4 x i32> @subtype_i8_test() {
+  ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
+  ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw i8, ptr addrspace(3) [[arrayofVecData]], i32 [[tid]]
+  ; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 1
+  ; FCHECK: [[flatidx_lshr:%.*]] = lshr i32 [[flatidx_mul]], 2
+  ; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_lshr]]
+  ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr addrspace(3) [[arrayofVecData]], i32 0, i32 [[flatidx]]
+  ; CHECK: [[x:%.*]] = load <4 x i32>, ptr addrspace(3) [[gep]], align 4
+  ; CHECK: ret <4 x i32> [[x]]
+  %i = tail call i32 @llvm.dx.thread.id(i32 0)
+  %gep = getelementptr inbounds nuw i8, ptr addrspace(3) @"arrayofVecData", i32 %i
+  %x = load <4 x i32>, ptr addrspace(3) %gep, align 4
+  ret <4 x i32> %x
+}


        


More information about the llvm-commits mailing list