[llvm] [MergeFunc] Fix crash caused by bitcasting ArrayType (PR #133259)

Tobias Stadler via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 27 07:42:50 PDT 2025


https://github.com/tobias-stadler created https://github.com/llvm/llvm-project/pull/133259

createCast in MergeFunctions did not consider ArrayTypes, which results in the creation of a bitcast between ArrayTypes in the thunk function, leading to an assertion failure in the provided test case.

The version of createCast in GlobalMergeFunctions does handle ArrayTypes, so this common code has been factored out into the IRBuilder.

>From 51d43552b2678739990f6fa4c7dbce09b7426254 Mon Sep 17 00:00:00 2001
From: Tobias Stadler <mail at stadler-tobias.de>
Date: Wed, 26 Mar 2025 16:32:44 +0000
Subject: [PATCH] [MergeFunc] Fix crash caused by bitcasting ArrayType

createCast in MergeFunctions did not consider ArrayTypes, which results
in the creation of a bitcast between ArrayTypes in the thunk function,
leading to an assertion failure in the provided test case.

The version of createCast in GlobalMergeFunctions does handle
ArrayTypes, so this common code has been factored out into the
IRBuilder.
---
 llvm/include/llvm/IR/IRBuilder.h              |  7 +++
 llvm/lib/CodeGen/GlobalMergeFunctions.cpp     | 47 ++-----------------
 llvm/lib/IR/IRBuilder.cpp                     | 35 ++++++++++++++
 llvm/lib/Transforms/IPO/MergeFunctions.cpp    | 31 +-----------
 .../Transforms/MergeFunc/crash-cast-arrays.ll | 38 +++++++++++++++
 5 files changed, 87 insertions(+), 71 deletions(-)
 create mode 100644 llvm/test/Transforms/MergeFunc/crash-cast-arrays.ll

diff --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h
index 750a99cc50dd7..a20fdec3f201d 100644
--- a/llvm/include/llvm/IR/IRBuilder.h
+++ b/llvm/include/llvm/IR/IRBuilder.h
@@ -2291,6 +2291,13 @@ class IRBuilderBase {
   // isSigned parameter.
   Value *CreateIntCast(Value *, Type *, const char *) = delete;
 
+  /// Cast between aggregate types that must have identical structure but may
+  /// differ in their leaf types. The leaf values are recursively extracted,
+  /// casted, and then reinserted into a value of type DestTy. The leaf types
+  /// must be castable using a bitcast or ptrcast, because signedness is
+  /// not specified.
+  Value *CreateAggregateCast(Value *V, Type *DestTy);
+
   //===--------------------------------------------------------------------===//
   // Instruction creation methods: Compare Instructions
   //===--------------------------------------------------------------------===//
diff --git a/llvm/lib/CodeGen/GlobalMergeFunctions.cpp b/llvm/lib/CodeGen/GlobalMergeFunctions.cpp
index e920b1be6822c..d4c53e79ed2e1 100644
--- a/llvm/lib/CodeGen/GlobalMergeFunctions.cpp
+++ b/llvm/lib/CodeGen/GlobalMergeFunctions.cpp
@@ -140,44 +140,6 @@ static bool ignoreOp(const Instruction *I, unsigned OpIdx) {
   return true;
 }
 
-static Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) {
-  Type *SrcTy = V->getType();
-  if (SrcTy->isStructTy()) {
-    assert(DestTy->isStructTy());
-    assert(SrcTy->getStructNumElements() == DestTy->getStructNumElements());
-    Value *Result = PoisonValue::get(DestTy);
-    for (unsigned int I = 0, E = SrcTy->getStructNumElements(); I < E; ++I) {
-      Value *Element =
-          createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)),
-                     DestTy->getStructElementType(I));
-
-      Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I));
-    }
-    return Result;
-  }
-  assert(!DestTy->isStructTy());
-  if (auto *SrcAT = dyn_cast<ArrayType>(SrcTy)) {
-    auto *DestAT = dyn_cast<ArrayType>(DestTy);
-    assert(DestAT);
-    assert(SrcAT->getNumElements() == DestAT->getNumElements());
-    Value *Result = PoisonValue::get(DestTy);
-    for (unsigned int I = 0, E = SrcAT->getNumElements(); I < E; ++I) {
-      Value *Element =
-          createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)),
-                     DestAT->getElementType());
-
-      Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I));
-    }
-    return Result;
-  }
-  assert(!DestTy->isArrayTy());
-  if (SrcTy->isIntegerTy() && DestTy->isPointerTy())
-    return Builder.CreateIntToPtr(V, DestTy);
-  if (SrcTy->isPointerTy() && DestTy->isIntegerTy())
-    return Builder.CreatePtrToInt(V, DestTy);
-  return Builder.CreateBitCast(V, DestTy);
-}
-
 void GlobalMergeFunc::analyze(Module &M) {
   ++NumAnalyzedModues;
   for (Function &Func : M) {
@@ -268,7 +230,7 @@ static Function *createMergedFunction(FuncMergeInfo &FI,
       if (OrigC->getType() != NewArg->getType()) {
         IRBuilder<> Builder(Inst->getParent(), Inst->getIterator());
         Inst->setOperand(OpndIndex,
-                         createCast(Builder, NewArg, OrigC->getType()));
+                         Builder.CreateAggregateCast(NewArg, OrigC->getType()));
       } else {
         Inst->setOperand(OpndIndex, NewArg);
       }
@@ -297,7 +259,8 @@ static void createThunk(FuncMergeInfo &FI, ArrayRef<Constant *> Params,
 
   // Add arguments which are passed through Thunk.
   for (Argument &AI : Thunk->args()) {
-    Args.push_back(createCast(Builder, &AI, ToFuncTy->getParamType(ParamIdx)));
+    Args.push_back(
+        Builder.CreateAggregateCast(&AI, ToFuncTy->getParamType(ParamIdx)));
     ++ParamIdx;
   }
 
@@ -305,7 +268,7 @@ static void createThunk(FuncMergeInfo &FI, ArrayRef<Constant *> Params,
   for (auto *Param : Params) {
     assert(ParamIdx < ToFuncTy->getNumParams());
     Args.push_back(
-        createCast(Builder, Param, ToFuncTy->getParamType(ParamIdx)));
+        Builder.CreateAggregateCast(Param, ToFuncTy->getParamType(ParamIdx)));
     ++ParamIdx;
   }
 
@@ -319,7 +282,7 @@ static void createThunk(FuncMergeInfo &FI, ArrayRef<Constant *> Params,
   if (Thunk->getReturnType()->isVoidTy())
     Builder.CreateRetVoid();
   else
-    Builder.CreateRet(createCast(Builder, CI, Thunk->getReturnType()));
+    Builder.CreateRet(Builder.CreateAggregateCast(CI, Thunk->getReturnType()));
 }
 
 // Check if the old merged/optimized IndexOperandHashMap is compatible with
diff --git a/llvm/lib/IR/IRBuilder.cpp b/llvm/lib/IR/IRBuilder.cpp
index 421b617a5fb7e..58a65ec646557 100644
--- a/llvm/lib/IR/IRBuilder.cpp
+++ b/llvm/lib/IR/IRBuilder.cpp
@@ -76,6 +76,41 @@ void IRBuilderBase::SetInstDebugLocation(Instruction *I) const {
     }
 }
 
+Value *IRBuilderBase::CreateAggregateCast(Value *V, Type *DestTy) {
+  Type *SrcTy = V->getType();
+  if (SrcTy == DestTy)
+    return V;
+  if (auto *SrcST = dyn_cast<StructType>(SrcTy)) {
+    assert(DestTy->isStructTy() && "Expected StructType");
+    auto *DestST = cast<StructType>(DestTy);
+    assert(SrcST->getNumElements() == DestST->getNumElements());
+    Value *Result = PoisonValue::get(DestTy);
+    for (unsigned int I = 0, E = SrcST->getNumElements(); I < E; ++I) {
+      Value *Element = CreateAggregateCast(CreateExtractValue(V, ArrayRef(I)),
+                                           DestST->getElementType(I));
+
+      Result = CreateInsertValue(Result, Element, ArrayRef(I));
+    }
+    return Result;
+  }
+  if (auto *SrcAT = dyn_cast<ArrayType>(SrcTy)) {
+    assert(DestTy->isArrayTy() && "Expected ArrayType");
+    auto *DestAT = cast<ArrayType>(DestTy);
+    assert(SrcAT->getNumElements() == DestAT->getNumElements());
+    Value *Result = PoisonValue::get(DestTy);
+    for (unsigned int I = 0, E = SrcAT->getNumElements(); I < E; ++I) {
+      Value *Element = CreateAggregateCast(CreateExtractValue(V, ArrayRef(I)),
+                                           DestAT->getElementType());
+
+      Result = CreateInsertValue(Result, Element, ArrayRef(I));
+    }
+    return Result;
+  }
+
+  assert(!DestTy->isAggregateType());
+  return CreateBitOrPointerCast(V, DestTy);
+}
+
 CallInst *
 IRBuilderBase::createCallHelper(Function *Callee, ArrayRef<Value *> Ops,
                                 const Twine &Name, FMFSource FMFSource,
diff --git a/llvm/lib/Transforms/IPO/MergeFunctions.cpp b/llvm/lib/Transforms/IPO/MergeFunctions.cpp
index 924db314674d5..c58c0f40c1b23 100644
--- a/llvm/lib/Transforms/IPO/MergeFunctions.cpp
+++ b/llvm/lib/Transforms/IPO/MergeFunctions.cpp
@@ -511,33 +511,6 @@ void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) {
   }
 }
 
-// Helper for writeThunk,
-// Selects proper bitcast operation,
-// but a bit simpler then CastInst::getCastOpcode.
-static Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) {
-  Type *SrcTy = V->getType();
-  if (SrcTy->isStructTy()) {
-    assert(DestTy->isStructTy());
-    assert(SrcTy->getStructNumElements() == DestTy->getStructNumElements());
-    Value *Result = PoisonValue::get(DestTy);
-    for (unsigned int I = 0, E = SrcTy->getStructNumElements(); I < E; ++I) {
-      Value *Element =
-          createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)),
-                     DestTy->getStructElementType(I));
-
-      Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I));
-    }
-    return Result;
-  }
-  assert(!DestTy->isStructTy());
-  if (SrcTy->isIntegerTy() && DestTy->isPointerTy())
-    return Builder.CreateIntToPtr(V, DestTy);
-  else if (SrcTy->isPointerTy() && DestTy->isIntegerTy())
-    return Builder.CreatePtrToInt(V, DestTy);
-  else
-    return Builder.CreateBitCast(V, DestTy);
-}
-
 // Erase the instructions in PDIUnrelatedWL as they are unrelated to the
 // parameter debug info, from the entry block.
 void MergeFunctions::eraseInstsUnrelatedToPDI(
@@ -789,7 +762,7 @@ void MergeFunctions::writeThunk(Function *F, Function *G) {
   unsigned i = 0;
   FunctionType *FFTy = F->getFunctionType();
   for (Argument &AI : H->args()) {
-    Args.push_back(createCast(Builder, &AI, FFTy->getParamType(i)));
+    Args.push_back(Builder.CreateAggregateCast(&AI, FFTy->getParamType(i)));
     ++i;
   }
 
@@ -804,7 +777,7 @@ void MergeFunctions::writeThunk(Function *F, Function *G) {
   if (H->getReturnType()->isVoidTy()) {
     RI = Builder.CreateRetVoid();
   } else {
-    RI = Builder.CreateRet(createCast(Builder, CI, H->getReturnType()));
+    RI = Builder.CreateRet(Builder.CreateAggregateCast(CI, H->getReturnType()));
   }
 
   if (MergeFunctionsPDI) {
diff --git a/llvm/test/Transforms/MergeFunc/crash-cast-arrays.ll b/llvm/test/Transforms/MergeFunc/crash-cast-arrays.ll
new file mode 100644
index 0000000000000..fcbb06400a618
--- /dev/null
+++ b/llvm/test/Transforms/MergeFunc/crash-cast-arrays.ll
@@ -0,0 +1,38 @@
+; RUN: opt -S -passes=mergefunc < %s | FileCheck %s
+
+%A = type { double }
+; the intermediary struct causes A_arr and B_arr to be different types
+%A_struct = type { %A }
+%A_arr = type { [1 x %A_struct] }
+
+%B = type { double }
+%B_struct = type { %B }
+%B_arr = type { [1 x %B_struct] }
+
+declare void @noop()
+
+define %A_arr @a() {
+; CHECK-LABEL: define %A_arr @a() {
+; CHECK-NEXT:    call void @noop()
+; CHECK-NEXT:    ret %A_arr zeroinitializer
+;
+  call void @noop()
+  ret %A_arr zeroinitializer
+}
+
+define %B_arr @b() {
+; CHECK-LABEL: define %B_arr @b() {
+; CHECK-NEXT:    [[TMP1:%.*]] = tail call %A_arr @a
+; CHECK-NEXT:    [[TMP2:%.*]] = extractvalue %A_arr [[TMP1]], 0
+; CHECK-NEXT:    [[TMP3:%.*]] = extractvalue [1 x %A_struct] [[TMP2]], 0
+; CHECK-NEXT:    [[TMP4:%.*]] = extractvalue %A_struct [[TMP3]], 0
+; CHECK-NEXT:    [[TMP5:%.*]] = extractvalue %A [[TMP4]], 0
+; CHECK-NEXT:    [[TMP6:%.*]] = insertvalue %B poison, double [[TMP5]], 0
+; CHECK-NEXT:    [[TMP7:%.*]] = insertvalue %B_struct poison, %B [[TMP6]], 0
+; CHECK-NEXT:    [[TMP8:%.*]] = insertvalue [1 x %B_struct] poison, %B_struct [[TMP7]], 0
+; CHECK-NEXT:    [[TMP9:%.*]] = insertvalue %B_arr poison, [1 x %B_struct] [[TMP8]], 0
+; CHECK-NEXT:    ret %B_arr [[TMP9]]
+;
+  call void @noop()
+  ret %B_arr zeroinitializer
+}



More information about the llvm-commits mailing list