[llvm] 13a6342 - [DirectX] Fix the writing of ConstantExpr GEPs to DXIL bitcode (#154446)

via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 26 10:02:02 PDT 2025


Author: Deric C.
Date: 2025-08-26T10:01:59-07:00
New Revision: 13a634281fb080ab0641998eb336ee591631beff

URL: https://github.com/llvm/llvm-project/commit/13a634281fb080ab0641998eb336ee591631beff
DIFF: https://github.com/llvm/llvm-project/commit/13a634281fb080ab0641998eb336ee591631beff.diff

LOG: [DirectX] Fix the writing of ConstantExpr GEPs to DXIL bitcode (#154446)

Fixes #153304

Changes:
- When writing `ConstantExpr` GEPs to DXIL bitcode, the bitcode writer
will use the old Constant Code `CST_CODE_CE_GEP_OLD = 12` instead of the
newer `CST_CODE_CE_GEP = 32` which is interpreted as an undef in DXIL.
Additional context: [CST_CODE_CE_GEP = 12 in
DXC](https://github.com/microsoft/DirectXShaderCompiler/blob/0c9e75e7e91bb18fab101abc81d399a0296f499e/include/llvm/Bitcode/LLVMBitCodes.h#L187)
while the same constant code is labeled [CST_CODE_CE_GEP_OLD in
LLVM](https://github.com/llvm/llvm-project/blob/65de318d186c815f43b892aa20b98c50f22ab6fe/llvm/include/llvm/Bitcode/LLVMBitCodes.h#L411)
- Modifies the `PointerTypeAnalysis` to be able to analyze pointer-typed
constants that appear in the operands of instructions so that the
correct type of the `ConstantExpr` GEP is determined and written into
the DXIL bitcode.
- Adds a `PointerTypeAnalysis` test and dxil-dis test to ensure that the
pointer type of `ConstantExpr` GEPs are resolved and `ConstantExpr` GEPs
are written to DXIL bitcode correctly

In addition, this PR also adds a missing call to
`GV.removeDeadConstantUsers()` in the DXILFinalizeLinkage pass, and
removes an unnecessary manual removal of a ConstantExpr in the
DXILFlattenArrays pass.

Added: 
    llvm/test/tools/dxil-dis/constantexpr-gep.ll

Modified: 
    llvm/lib/Target/DirectX/DXILDataScalarization.cpp
    llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp
    llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp
    llvm/lib/Target/DirectX/DirectXIRPasses/PointerTypeAnalysis.cpp
    llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
index feecfc0880e25..d507d71b99fc9 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -343,9 +343,7 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
 
   GOp->replaceAllUsesWith(NewGEP);
 
-  if (auto *CE = dyn_cast<ConstantExpr>(GOp))
-    CE->destroyConstant();
-  else if (auto *OldGEPI = dyn_cast<GetElementPtrInst>(GOp))
+  if (auto *OldGEPI = dyn_cast<GetElementPtrInst>(GOp))
     OldGEPI->eraseFromParent();
 
   return true;

diff  --git a/llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp b/llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp
index 13e3408815bba..aa16e795dc768 100644
--- a/llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp
+++ b/llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp
@@ -22,11 +22,13 @@ static bool finalizeLinkage(Module &M) {
 
   // Convert private globals and external globals with no usage to internal
   // linkage.
-  for (GlobalVariable &GV : M.globals())
+  for (GlobalVariable &GV : M.globals()) {
+    GV.removeDeadConstantUsers();
     if (GV.hasPrivateLinkage() || (GV.hasExternalLinkage() && GV.use_empty())) {
       GV.setLinkage(GlobalValue::InternalLinkage);
       MadeChange = true;
     }
+  }
 
   SmallVector<Function *> Funcs;
 

diff  --git a/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp b/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp
index 1d79c3018439e..bc1a3a7995bda 100644
--- a/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp
+++ b/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp
@@ -2113,7 +2113,7 @@ void DXILBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
         }
         break;
       case Instruction::GetElementPtr: {
-        Code = bitc::CST_CODE_CE_GEP;
+        Code = bitc::CST_CODE_CE_GEP_OLD;
         const auto *GO = cast<GEPOperator>(C);
         if (GO->isInBounds())
           Code = bitc::CST_CODE_CE_INBOUNDS_GEP;

diff  --git a/llvm/lib/Target/DirectX/DirectXIRPasses/PointerTypeAnalysis.cpp b/llvm/lib/Target/DirectX/DirectXIRPasses/PointerTypeAnalysis.cpp
index f99bb4f4eaee1..c2e139edc6bd1 100644
--- a/llvm/lib/Target/DirectX/DirectXIRPasses/PointerTypeAnalysis.cpp
+++ b/llvm/lib/Target/DirectX/DirectXIRPasses/PointerTypeAnalysis.cpp
@@ -15,25 +15,39 @@
 #include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Module.h"
+#include "llvm/IR/Operator.h"
 
 using namespace llvm;
 using namespace llvm::dxil;
 
 namespace {
 
+Type *classifyFunctionType(const Function &F, PointerTypeMap &Map);
+
 // Classifies the type of the value passed in by walking the value's users to
 // find a typed instruction to materialize a type from.
 Type *classifyPointerType(const Value *V, PointerTypeMap &Map) {
   assert(V->getType()->isPointerTy() &&
          "classifyPointerType called with non-pointer");
+
+  // A CallInst will trigger this case, and we want to classify its Function
+  // operand as a Function rather than a generic Value.
+  if (const Function *F = dyn_cast<Function>(V))
+    return classifyFunctionType(*F, Map);
+
+  // There can potentially be dead constants hanging off of the globals we do
+  // not want to deal with. So we remove them here.
+  if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
+    GV->removeDeadConstantUsers();
+
   auto It = Map.find(V);
   if (It != Map.end())
     return It->second;
 
   Type *PointeeTy = nullptr;
-  if (auto *Inst = dyn_cast<GetElementPtrInst>(V)) {
-    if (!Inst->getResultElementType()->isPointerTy())
-      PointeeTy = Inst->getResultElementType();
+  if (auto *GEP = dyn_cast<GEPOperator>(V)) {
+    if (!GEP->getResultElementType()->isPointerTy())
+      PointeeTy = GEP->getResultElementType();
   } else if (auto *Inst = dyn_cast<AllocaInst>(V)) {
     PointeeTy = Inst->getAllocatedType();
   } else if (auto *GV = dyn_cast<GlobalVariable>(V)) {
@@ -49,8 +63,8 @@ Type *classifyPointerType(const Value *V, PointerTypeMap &Map) {
       // When store value is ptr type, cannot get more type info.
       if (NewPointeeTy->isPointerTy())
         continue;
-    } else if (const auto *Inst = dyn_cast<GetElementPtrInst>(User)) {
-      NewPointeeTy = Inst->getSourceElementType();
+    } else if (const auto *GEP = dyn_cast<GEPOperator>(User)) {
+      NewPointeeTy = GEP->getSourceElementType();
     }
     if (NewPointeeTy) {
       // HLSL doesn't support pointers, so it is unlikely to get more than one
@@ -204,6 +218,9 @@ PointerTypeMap PointerTypeAnalysis::run(const Module &M) {
       for (const auto &I : B) {
         if (I.getType()->isPointerTy())
           classifyPointerType(&I, Map);
+        for (const auto &O : I.operands())
+          if (O.get()->getType()->isPointerTy())
+            classifyPointerType(O.get(), Map);
       }
     }
   }

diff  --git a/llvm/test/tools/dxil-dis/constantexpr-gep.ll b/llvm/test/tools/dxil-dis/constantexpr-gep.ll
new file mode 100644
index 0000000000000..59251474f1a4b
--- /dev/null
+++ b/llvm/test/tools/dxil-dis/constantexpr-gep.ll
@@ -0,0 +1,35 @@
+; RUN: llc --filetype=obj %s -o - | dxil-dis -o - | FileCheck %s
+target triple = "dxil-unknown-shadermodel6.7-library"
+
+; CHECK: [[GLOBAL:@.*]] = unnamed_addr addrspace(3) global [10 x i32] zeroinitializer, align 4
+ at g = local_unnamed_addr addrspace(3) global [10 x i32] zeroinitializer, align 4
+
+define i32 @fn() #0 {
+; CHECK-LABEL:  define i32 @fn()
+; CHECK-NEXT:   [[LOAD:%.*]] = load i32, i32 addrspace(3)* getelementptr inbounds ([10 x i32], [10 x i32] addrspace(3)* [[GLOBAL]], i32 0, i32 1), align 4
+; CHECK-NEXT:   ret i32 [[LOAD]]
+;
+  %gep = getelementptr [10 x i32], ptr addrspace(3) @g, i32 0, i32 1
+  %ld = load i32, ptr addrspace(3) %gep, align 4
+  ret i32 %ld
+}
+
+define i32 @fn2() #0 {
+; CHECK-LABEL:  define i32 @fn2()
+; CHECK-NEXT:   [[LOAD:%.*]] = load i32, i32 addrspace(3)* getelementptr inbounds ([10 x i32], [10 x i32] addrspace(3)* [[GLOBAL]], i32 0, i32 2), align 4
+; CHECK-NEXT:   ret i32 [[LOAD]]
+;
+  %ld = load i32, ptr addrspace(3) getelementptr ([10 x i32], ptr addrspace(3) @g, i32 0, i32 2), align 4
+  ret i32 %ld
+}
+
+define i32 @fn3() #0 {
+; CHECK-LABEL:  define i32 @fn3()
+; CHECK-NEXT:   [[LOAD:%.*]] = load i32, i32 addrspace(3)* getelementptr inbounds ([10 x i32], [10 x i32] addrspace(3)* [[GLOBAL]], i32 0, i32 3), align 4
+; CHECK-NEXT:   ret i32 [[LOAD]]
+;
+  %ld = load i32, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @g, i32 12), align 4
+  ret i32 %ld
+}
+
+attributes #0 = { "hlsl.export" }

diff  --git a/llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp b/llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp
index 9d41e94bb0bae..6ae139e076281 100644
--- a/llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp
+++ b/llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp
@@ -8,6 +8,7 @@
 
 #include "DirectXIRPasses/PointerTypeAnalysis.h"
 #include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Constants.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
@@ -123,6 +124,33 @@ TEST(PointerTypeAnalysis, DiscoverGEP) {
   EXPECT_THAT(Map, Contains(Pair(IsA<GetElementPtrInst>(), I64Ptr)));
 }
 
+TEST(PointerTypeAnalysis, DiscoverConstantExprGEP) {
+  StringRef Assembly = R"(
+    @g = internal global [10 x i32] zeroinitializer
+    define i32 @test() {
+      %i = load i32, ptr getelementptr ([10 x i32], ptr @g, i64 0, i64 1)
+      ret i32 %i
+    }
+  )";
+
+  LLVMContext Context;
+  SMDiagnostic Error;
+  auto M = parseAssemblyString(Assembly, Error, Context);
+  ASSERT_TRUE(M) << "Bad assembly?";
+
+  PointerTypeMap Map = PointerTypeAnalysis::run(*M);
+  ASSERT_EQ(Map.size(), 3u);
+  Type *I32Ty = Type::getInt32Ty(Context);
+  Type *I32Ptr = TypedPointerType::get(I32Ty, 0);
+  Type *I32ArrPtr = TypedPointerType::get(ArrayType::get(I32Ty, 10), 0);
+  Type *FnTy = FunctionType::get(I32Ty, {}, false);
+
+  EXPECT_THAT(Map, Contains(Pair(IsA<GlobalVariable>(), I32ArrPtr)));
+  EXPECT_THAT(Map,
+              Contains(Pair(IsA<Function>(), TypedPointerType::get(FnTy, 0))));
+  EXPECT_THAT(Map, Contains(Pair(IsA<ConstantExpr>(), I32Ptr)));
+}
+
 TEST(PointerTypeAnalysis, TraceIndirect) {
   StringRef Assembly = R"(
     define i64 @test(ptr %p) {


        


More information about the llvm-commits mailing list