[llvm] [DirectX] Let data scalarizer pass account for sub-types when updating GEP type (PR #166200)
Finn Plummer via llvm-commits
llvm-commits at lists.llvm.org
Wed Nov 5 09:19:25 PST 2025
https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/166200
>From ab1327f9a7e779ed1dfee9d23b8e1a3e5626098c Mon Sep 17 00:00:00 2001
From: Finn Plummer <mail at inbelic.dev>
Date: Fri, 31 Oct 2025 10:21:00 -0700
Subject: [PATCH 1/3] drive by clean up
---
llvm/lib/Target/DirectX/DXILDataScalarization.cpp | 12 +++++-------
1 file changed, 5 insertions(+), 7 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
index d507d71b99fc9..9fa3591159e7b 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -307,13 +307,11 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
bool NeedsTransform = false;
// Unwrap GEP ConstantExprs to find the base operand and element type
- while (auto *CE = dyn_cast<ConstantExpr>(PtrOperand)) {
- if (auto *GEPCE = dyn_cast<GEPOperator>(CE)) {
- GOp = GEPCE;
- PtrOperand = GEPCE->getPointerOperand();
- NewGEPType = GEPCE->getSourceElementType();
- } else
- break;
+ while (auto *GEPCE = dyn_cast_or_null<GEPOperator>(
+ dyn_cast<ConstantExpr>(PtrOperand))) {
+ GOp = GEPCE;
+ PtrOperand = GEPCE->getPointerOperand();
+ NewGEPType = GEPCE->getSourceElementType();
}
if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand)) {
>From 993b51a6eed82aba326ed350f8e2c67d22c4ff4d Mon Sep 17 00:00:00 2001
From: Finn Plummer <mail at inbelic.dev>
Date: Fri, 31 Oct 2025 14:56:12 -0700
Subject: [PATCH 2/3] [DirectX] Make data scalarization pass account for GEP as
a sub-type
---
.../Target/DirectX/DXILDataScalarization.cpp | 56 ++++++++++++---
llvm/test/CodeGen/DirectX/scalarize-alloca.ll | 65 +++++++++++++++++
llvm/test/CodeGen/DirectX/scalarize-global.ll | 70 +++++++++++++++++++
3 files changed, 182 insertions(+), 9 deletions(-)
create mode 100644 llvm/test/CodeGen/DirectX/scalarize-global.ll
diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
index 9fa3591159e7b..88a9b5084cba2 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -304,7 +304,6 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
GEPOperator *GOp = cast<GEPOperator>(&GEPI);
Value *PtrOperand = GOp->getPointerOperand();
Type *NewGEPType = GOp->getSourceElementType();
- bool NeedsTransform = false;
// Unwrap GEP ConstantExprs to find the base operand and element type
while (auto *GEPCE = dyn_cast_or_null<GEPOperator>(
@@ -314,28 +313,67 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
NewGEPType = GEPCE->getSourceElementType();
}
+ Type *const OrigGEPType = NewGEPType;
+ Value *const OrigOperand = PtrOperand;
+
if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand)) {
NewGEPType = NewGlobal->getValueType();
PtrOperand = NewGlobal;
- NeedsTransform = true;
} else if (AllocaInst *Alloca = dyn_cast<AllocaInst>(PtrOperand)) {
Type *AllocatedType = Alloca->getAllocatedType();
if (isa<ArrayType>(AllocatedType) &&
- AllocatedType != GOp->getResultElementType()) {
+ AllocatedType != GOp->getResultElementType())
NewGEPType = AllocatedType;
- NeedsTransform = true;
+ } else
+ return false; // Only GEPs into an alloca or global variable are considered
+
+ // Defer changing i8 GEP types until dxil-flatten-arrays
+ if (OrigGEPType->isIntegerTy(8))
+ NewGEPType = OrigGEPType;
+
+ // If the original type is a "sub-type" of the new type, then ensure the gep
+ // correctly zero-indexes the extra dimensions to keep the offset calculation
+ // correct.
+ // Eg:
+ // i32, [4 x i32] and [8 x [ 4 x i32]] are sub-types of [8 x [4 x i32]], etc.
+ //
+ // So then:
+ // gep [4 x i32] %idx
+ // -> gep [8 x [4 x i32]], i32 0, i32 %idx
+ // gep i32 %idx
+ // -> gep [8 x [4 x i32]], i32 0, i32 0, i32 %idx
+ uint32_t MissingDims = 0;
+ Type *SubType = NewGEPType;
+
+ // The new type will be in it's array version so match accordingly
+ Type *const GEPArrType = equivalentArrayTypeFromVector(OrigGEPType);
+
+ while (SubType != GEPArrType) {
+ MissingDims++;
+
+ ArrayType *ArrType = dyn_cast<ArrayType>(SubType);
+ if (!ArrType) {
+ assert(SubType == GEPArrType && "GEP uses a strange sub-type of alloca/global variable");
+ break;
}
+
+ SubType = ArrType->getElementType();
}
+
+ bool NeedsTransform = OrigOperand != PtrOperand ||
+ OrigGEPType != NewGEPType || MissingDims != 0;
+
if (!NeedsTransform)
return false;
- // Keep scalar GEPs scalar; dxil-flatten-arrays will do flattening later
- if (!isa<ArrayType>(GOp->getSourceElementType()))
- NewGEPType = GOp->getSourceElementType();
-
IRBuilder<> Builder(&GEPI);
- SmallVector<Value *, MaxVecSize> Indices(GOp->indices());
+ SmallVector<Value *, MaxVecSize> Indices;
+
+ for (uint32_t I = 0; I < MissingDims; I++)
+ Indices.push_back(Builder.getInt32(0));
+ llvm::append_range(Indices, GOp->indices());
+
Value *NewGEP = Builder.CreateGEP(NewGEPType, PtrOperand, Indices,
GOp->getName(), GOp->getNoWrapFlags());
diff --git a/llvm/test/CodeGen/DirectX/scalarize-alloca.ll b/llvm/test/CodeGen/DirectX/scalarize-alloca.ll
index a8557e47b0ea6..475935d2eb135 100644
--- a/llvm/test/CodeGen/DirectX/scalarize-alloca.ll
+++ b/llvm/test/CodeGen/DirectX/scalarize-alloca.ll
@@ -42,3 +42,68 @@ define void @alloca_2d_gep_test() {
%3 = getelementptr inbounds nuw [2 x <2 x i32>], ptr %1, i32 0, i32 %2
ret void
}
+
+; CHECK-LABEL: subtype_array_test
+define void @subtype_array_test() {
+ ; SCHECK: [[alloca_val:%.*]] = alloca [8 x [4 x i32]], align 4
+ ; FCHECK: [[alloca_val:%.*]] = alloca [32 x i32], align 4
+ ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
+ ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [8 x [4 x i32]], ptr [[alloca_val]], i32 0, i32 [[tid]]
+ ; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 4
+ ; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_mul]]
+ ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr [[alloca_val]], i32 0, i32 [[flatidx]]
+ ; CHECK: ret void
+ %arr = alloca [8 x [4 x i32]], align 4
+ %i = tail call i32 @llvm.dx.thread.id(i32 0)
+ %gep = getelementptr inbounds nuw [4 x i32], ptr %arr, i32 %i
+ ret void
+}
+
+; CHECK-LABEL: subtype_vector_test
+define void @subtype_vector_test() {
+ ; SCHECK: [[alloca_val:%.*]] = alloca [8 x [4 x i32]], align 4
+ ; FCHECK: [[alloca_val:%.*]] = alloca [32 x i32], align 4
+ ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
+ ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [8 x [4 x i32]], ptr [[alloca_val]], i32 0, i32 [[tid]]
+ ; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 4
+ ; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_mul]]
+ ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr [[alloca_val]], i32 0, i32 [[flatidx]]
+ ; CHECK: ret void
+ %arr = alloca [8 x <4 x i32>], align 4
+ %i = tail call i32 @llvm.dx.thread.id(i32 0)
+ %gep = getelementptr inbounds nuw <4 x i32>, ptr %arr, i32 %i
+ ret void
+}
+
+; CHECK-LABEL: subtype_scalar_test
+define void @subtype_scalar_test() {
+ ; SCHECK: [[alloca_val:%.*]] = alloca [8 x [4 x i32]], align 4
+ ; FCHECK: [[alloca_val:%.*]] = alloca [32 x i32], align 4
+ ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
+ ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [8 x [4 x i32]], ptr [[alloca_val]], i32 0, i32 0, i32 [[tid]]
+ ; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 1
+ ; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_mul]]
+ ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr [[alloca_val]], i32 0, i32 [[flatidx]]
+ ; CHECK: ret void
+ %arr = alloca [8 x [4 x i32]], align 4
+ %i = tail call i32 @llvm.dx.thread.id(i32 0)
+ %gep = getelementptr inbounds nuw i32, ptr %arr, i32 %i
+ ret void
+}
+
+; CHECK-LABEL: subtype_i8_test
+define void @subtype_i8_test() {
+ ; SCHECK: [[alloca_val:%.*]] = alloca [8 x [4 x i32]], align 4
+ ; FCHECK: [[alloca_val:%.*]] = alloca [32 x i32], align 4
+ ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
+ ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw i8, ptr [[alloca_val]], i32 [[tid]]
+ ; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 1
+ ; FCHECK: [[flatidx_lshr:%.*]] = lshr i32 [[flatidx_mul]], 2
+ ; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_lshr]]
+ ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr [[alloca_val]], i32 0, i32 [[flatidx]]
+ ; CHECK: ret void
+ %arr = alloca [8 x [4 x i32]], align 4
+ %i = tail call i32 @llvm.dx.thread.id(i32 0)
+ %gep = getelementptr inbounds nuw i8, ptr %arr, i32 %i
+ ret void
+}
diff --git a/llvm/test/CodeGen/DirectX/scalarize-global.ll b/llvm/test/CodeGen/DirectX/scalarize-global.ll
new file mode 100644
index 0000000000000..ca10f6ece5a85
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/scalarize-global.ll
@@ -0,0 +1,70 @@
+; RUN: opt -S -passes='dxil-data-scalarization' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=SCHECK,CHECK
+; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=FCHECK,CHECK
+
+@"arrayofVecData" = local_unnamed_addr addrspace(3) global [8 x <4 x i32>] zeroinitializer, align 16
+@"vecData" = external addrspace(3) global <4 x i32>, align 4
+
+; SCHECK: [[arrayofVecData:@arrayofVecData.*]] = local_unnamed_addr addrspace(3) global [8 x [4 x i32]] zeroinitializer, align 16
+; FCHECK: [[arrayofVecData:@arrayofVecData.*]] = local_unnamed_addr addrspace(3) global [32 x i32] zeroinitializer, align 16
+; CHECK: [[vecData:@vecData.*]] = external addrspace(3) global [4 x i32], align 4
+
+; CHECK-LABEL: subtype_array_test
+define <4 x i32> @subtype_array_test() {
+ ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
+ ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [8 x [4 x i32]], ptr addrspace(3) [[arrayofVecData]], i32 0, i32 [[tid]]
+ ; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 4
+ ; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_mul]]
+ ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr addrspace(3) [[arrayofVecData]], i32 0, i32 [[flatidx]]
+ ; CHECK: [[x:%.*]] = load <4 x i32>, ptr addrspace(3) [[gep]], align 4
+ ; CHECK: ret <4 x i32> [[x]]
+ %i = tail call i32 @llvm.dx.thread.id(i32 0)
+ %gep = getelementptr inbounds nuw [4 x i32], ptr addrspace(3) @"arrayofVecData", i32 %i
+ %x = load <4 x i32>, ptr addrspace(3) %gep, align 4
+ ret <4 x i32> %x
+}
+
+; CHECK-LABEL: subtype_vector_test
+define <4 x i32> @subtype_vector_test() {
+ ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
+ ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [8 x [4 x i32]], ptr addrspace(3) [[arrayofVecData]], i32 0, i32 [[tid]]
+ ; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 4
+ ; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_mul]]
+ ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr addrspace(3) [[arrayofVecData]], i32 0, i32 [[flatidx]]
+ ; CHECK: [[x:%.*]] = load <4 x i32>, ptr addrspace(3) [[gep]], align 4
+ ; CHECK: ret <4 x i32> [[x]]
+ %i = tail call i32 @llvm.dx.thread.id(i32 0)
+ %gep = getelementptr inbounds nuw <4 x i32>, ptr addrspace(3) @"arrayofVecData", i32 %i
+ %x = load <4 x i32>, ptr addrspace(3) %gep, align 4
+ ret <4 x i32> %x
+}
+
+; CHECK-LABEL: subtype_scalar_test
+define <4 x i32> @subtype_scalar_test() {
+ ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
+ ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [8 x [4 x i32]], ptr addrspace(3) [[arrayofVecData]], i32 0, i32 0, i32 [[tid]]
+ ; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 1
+ ; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_mul]]
+ ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr addrspace(3) [[arrayofVecData]], i32 0, i32 [[flatidx]]
+ ; CHECK: [[x:%.*]] = load <4 x i32>, ptr addrspace(3) [[gep]], align 4
+ ; CHECK: ret <4 x i32> [[x]]
+ %i = tail call i32 @llvm.dx.thread.id(i32 0)
+ %gep = getelementptr inbounds nuw i32, ptr addrspace(3) @"arrayofVecData", i32 %i
+ %x = load <4 x i32>, ptr addrspace(3) %gep, align 4
+ ret <4 x i32> %x
+}
+
+; CHECK-LABEL: subtype_i8_test
+define <4 x i32> @subtype_i8_test() {
+ ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
+ ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw i8, ptr addrspace(3) [[arrayofVecData]], i32 [[tid]]
+ ; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 1
+ ; FCHECK: [[flatidx_lshr:%.*]] = lshr i32 [[flatidx_mul]], 2
+ ; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_lshr]]
+ ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr addrspace(3) [[arrayofVecData]], i32 0, i32 [[flatidx]]
+ ; CHECK: [[x:%.*]] = load <4 x i32>, ptr addrspace(3) [[gep]], align 4
+ ; CHECK: ret <4 x i32> [[x]]
+ %i = tail call i32 @llvm.dx.thread.id(i32 0)
+ %gep = getelementptr inbounds nuw i8, ptr addrspace(3) @"arrayofVecData", i32 %i
+ %x = load <4 x i32>, ptr addrspace(3) %gep, align 4
+ ret <4 x i32> %x
+}
>From c11205bd75d8331169d7e0645819c0e1794b5df1 Mon Sep 17 00:00:00 2001
From: Finn Plummer <mail at inbelic.dev>
Date: Mon, 3 Nov 2025 09:39:30 -0800
Subject: [PATCH 3/3] touch up
---
llvm/lib/Target/DirectX/DXILDataScalarization.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
index 88a9b5084cba2..1bfbee928bbf1 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -353,14 +353,14 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
ArrayType *ArrType = dyn_cast<ArrayType>(SubType);
if (!ArrType) {
- assert(SubType == GEPArrType && "GEP uses a strange sub-type of alloca/global variable");
+ assert(SubType == GEPArrType &&
+ "GEP uses an DXIL invalid sub-type of alloca/global variable");
break;
}
SubType = ArrType->getElementType();
}
-
bool NeedsTransform = OrigOperand != PtrOperand ||
OrigGEPType != NewGEPType || MissingDims != 0;
More information about the llvm-commits
mailing list