[llvm] 1302610 - [MergeFunc] Fix crash caused by bitcasting ArrayType (#133259)

via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 4 02:16:43 PDT 2025


Author: Tobias Stadler
Date: 2025-04-04T10:16:40+01:00
New Revision: 1302610f03a1f10c2eea4c66445ccba4c52887b6

URL: https://github.com/llvm/llvm-project/commit/1302610f03a1f10c2eea4c66445ccba4c52887b6
DIFF: https://github.com/llvm/llvm-project/commit/1302610f03a1f10c2eea4c66445ccba4c52887b6.diff

LOG: [MergeFunc] Fix crash caused by bitcasting ArrayType (#133259)

createCast in MergeFunctions did not consider ArrayTypes, which results
in the creation of a bitcast between ArrayTypes in the thunk function,
leading to an assertion failure in the provided test case.

The version of createCast in GlobalMergeFunctions does handle
ArrayTypes, so this common code has been factored out into the
IRBuilder.

Added: 
    llvm/test/Transforms/MergeFunc/crash-cast-arrays.ll

Modified: 
    llvm/include/llvm/IR/IRBuilder.h
    llvm/lib/CodeGen/GlobalMergeFunctions.cpp
    llvm/lib/IR/IRBuilder.cpp
    llvm/lib/Transforms/IPO/MergeFunctions.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h
index 07660e93253da..0e68ffadc6939 100644
--- a/llvm/include/llvm/IR/IRBuilder.h
+++ b/llvm/include/llvm/IR/IRBuilder.h
@@ -2299,6 +2299,13 @@ class IRBuilderBase {
   // isSigned parameter.
   Value *CreateIntCast(Value *, Type *, const char *) = delete;
 
+  /// Cast between aggregate types that must have identical structure but may
+  /// 
diff er in their leaf types. The leaf values are recursively extracted,
+  /// casted, and then reinserted into a value of type DestTy. The leaf types
+  /// must be castable using a bitcast or ptrcast, because signedness is
+  /// not specified.
+  Value *CreateAggregateCast(Value *V, Type *DestTy);
+
   //===--------------------------------------------------------------------===//
   // Instruction creation methods: Compare Instructions
   //===--------------------------------------------------------------------===//

diff  --git a/llvm/lib/CodeGen/GlobalMergeFunctions.cpp b/llvm/lib/CodeGen/GlobalMergeFunctions.cpp
index e920b1be6822c..d4c53e79ed2e1 100644
--- a/llvm/lib/CodeGen/GlobalMergeFunctions.cpp
+++ b/llvm/lib/CodeGen/GlobalMergeFunctions.cpp
@@ -140,44 +140,6 @@ static bool ignoreOp(const Instruction *I, unsigned OpIdx) {
   return true;
 }
 
-static Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) {
-  Type *SrcTy = V->getType();
-  if (SrcTy->isStructTy()) {
-    assert(DestTy->isStructTy());
-    assert(SrcTy->getStructNumElements() == DestTy->getStructNumElements());
-    Value *Result = PoisonValue::get(DestTy);
-    for (unsigned int I = 0, E = SrcTy->getStructNumElements(); I < E; ++I) {
-      Value *Element =
-          createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)),
-                     DestTy->getStructElementType(I));
-
-      Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I));
-    }
-    return Result;
-  }
-  assert(!DestTy->isStructTy());
-  if (auto *SrcAT = dyn_cast<ArrayType>(SrcTy)) {
-    auto *DestAT = dyn_cast<ArrayType>(DestTy);
-    assert(DestAT);
-    assert(SrcAT->getNumElements() == DestAT->getNumElements());
-    Value *Result = PoisonValue::get(DestTy);
-    for (unsigned int I = 0, E = SrcAT->getNumElements(); I < E; ++I) {
-      Value *Element =
-          createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)),
-                     DestAT->getElementType());
-
-      Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I));
-    }
-    return Result;
-  }
-  assert(!DestTy->isArrayTy());
-  if (SrcTy->isIntegerTy() && DestTy->isPointerTy())
-    return Builder.CreateIntToPtr(V, DestTy);
-  if (SrcTy->isPointerTy() && DestTy->isIntegerTy())
-    return Builder.CreatePtrToInt(V, DestTy);
-  return Builder.CreateBitCast(V, DestTy);
-}
-
 void GlobalMergeFunc::analyze(Module &M) {
   ++NumAnalyzedModues;
   for (Function &Func : M) {
@@ -268,7 +230,7 @@ static Function *createMergedFunction(FuncMergeInfo &FI,
       if (OrigC->getType() != NewArg->getType()) {
         IRBuilder<> Builder(Inst->getParent(), Inst->getIterator());
         Inst->setOperand(OpndIndex,
-                         createCast(Builder, NewArg, OrigC->getType()));
+                         Builder.CreateAggregateCast(NewArg, OrigC->getType()));
       } else {
         Inst->setOperand(OpndIndex, NewArg);
       }
@@ -297,7 +259,8 @@ static void createThunk(FuncMergeInfo &FI, ArrayRef<Constant *> Params,
 
   // Add arguments which are passed through Thunk.
   for (Argument &AI : Thunk->args()) {
-    Args.push_back(createCast(Builder, &AI, ToFuncTy->getParamType(ParamIdx)));
+    Args.push_back(
+        Builder.CreateAggregateCast(&AI, ToFuncTy->getParamType(ParamIdx)));
     ++ParamIdx;
   }
 
@@ -305,7 +268,7 @@ static void createThunk(FuncMergeInfo &FI, ArrayRef<Constant *> Params,
   for (auto *Param : Params) {
     assert(ParamIdx < ToFuncTy->getNumParams());
     Args.push_back(
-        createCast(Builder, Param, ToFuncTy->getParamType(ParamIdx)));
+        Builder.CreateAggregateCast(Param, ToFuncTy->getParamType(ParamIdx)));
     ++ParamIdx;
   }
 
@@ -319,7 +282,7 @@ static void createThunk(FuncMergeInfo &FI, ArrayRef<Constant *> Params,
   if (Thunk->getReturnType()->isVoidTy())
     Builder.CreateRetVoid();
   else
-    Builder.CreateRet(createCast(Builder, CI, Thunk->getReturnType()));
+    Builder.CreateRet(Builder.CreateAggregateCast(CI, Thunk->getReturnType()));
 }
 
 // Check if the old merged/optimized IndexOperandHashMap is compatible with

diff  --git a/llvm/lib/IR/IRBuilder.cpp b/llvm/lib/IR/IRBuilder.cpp
index 421b617a5fb7e..e5a2f08c393c9 100644
--- a/llvm/lib/IR/IRBuilder.cpp
+++ b/llvm/lib/IR/IRBuilder.cpp
@@ -76,6 +76,40 @@ void IRBuilderBase::SetInstDebugLocation(Instruction *I) const {
     }
 }
 
+Value *IRBuilderBase::CreateAggregateCast(Value *V, Type *DestTy) {
+  Type *SrcTy = V->getType();
+  if (SrcTy == DestTy)
+    return V;
+
+  if (SrcTy->isAggregateType()) {
+    unsigned NumElements;
+    if (SrcTy->isStructTy()) {
+      assert(DestTy->isStructTy() && "Expected StructType");
+      assert(SrcTy->getStructNumElements() == DestTy->getStructNumElements() &&
+             "Expected StructTypes with equal number of elements");
+      NumElements = SrcTy->getStructNumElements();
+    } else {
+      assert(SrcTy->isArrayTy() && DestTy->isArrayTy() && "Expected ArrayType");
+      assert(SrcTy->getArrayNumElements() == DestTy->getArrayNumElements() &&
+             "Expected ArrayTypes with equal number of elements");
+      NumElements = SrcTy->getArrayNumElements();
+    }
+
+    Value *Result = PoisonValue::get(DestTy);
+    for (unsigned I = 0; I < NumElements; ++I) {
+      Type *ElementTy = SrcTy->isStructTy() ? DestTy->getStructElementType(I)
+                                            : DestTy->getArrayElementType();
+      Value *Element =
+          CreateAggregateCast(CreateExtractValue(V, ArrayRef(I)), ElementTy);
+
+      Result = CreateInsertValue(Result, Element, ArrayRef(I));
+    }
+    return Result;
+  }
+
+  return CreateBitOrPointerCast(V, DestTy);
+}
+
 CallInst *
 IRBuilderBase::createCallHelper(Function *Callee, ArrayRef<Value *> Ops,
                                 const Twine &Name, FMFSource FMFSource,

diff  --git a/llvm/lib/Transforms/IPO/MergeFunctions.cpp b/llvm/lib/Transforms/IPO/MergeFunctions.cpp
index 924db314674d5..c58c0f40c1b23 100644
--- a/llvm/lib/Transforms/IPO/MergeFunctions.cpp
+++ b/llvm/lib/Transforms/IPO/MergeFunctions.cpp
@@ -511,33 +511,6 @@ void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) {
   }
 }
 
-// Helper for writeThunk,
-// Selects proper bitcast operation,
-// but a bit simpler then CastInst::getCastOpcode.
-static Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) {
-  Type *SrcTy = V->getType();
-  if (SrcTy->isStructTy()) {
-    assert(DestTy->isStructTy());
-    assert(SrcTy->getStructNumElements() == DestTy->getStructNumElements());
-    Value *Result = PoisonValue::get(DestTy);
-    for (unsigned int I = 0, E = SrcTy->getStructNumElements(); I < E; ++I) {
-      Value *Element =
-          createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)),
-                     DestTy->getStructElementType(I));
-
-      Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I));
-    }
-    return Result;
-  }
-  assert(!DestTy->isStructTy());
-  if (SrcTy->isIntegerTy() && DestTy->isPointerTy())
-    return Builder.CreateIntToPtr(V, DestTy);
-  else if (SrcTy->isPointerTy() && DestTy->isIntegerTy())
-    return Builder.CreatePtrToInt(V, DestTy);
-  else
-    return Builder.CreateBitCast(V, DestTy);
-}
-
 // Erase the instructions in PDIUnrelatedWL as they are unrelated to the
 // parameter debug info, from the entry block.
 void MergeFunctions::eraseInstsUnrelatedToPDI(
@@ -789,7 +762,7 @@ void MergeFunctions::writeThunk(Function *F, Function *G) {
   unsigned i = 0;
   FunctionType *FFTy = F->getFunctionType();
   for (Argument &AI : H->args()) {
-    Args.push_back(createCast(Builder, &AI, FFTy->getParamType(i)));
+    Args.push_back(Builder.CreateAggregateCast(&AI, FFTy->getParamType(i)));
     ++i;
   }
 
@@ -804,7 +777,7 @@ void MergeFunctions::writeThunk(Function *F, Function *G) {
   if (H->getReturnType()->isVoidTy()) {
     RI = Builder.CreateRetVoid();
   } else {
-    RI = Builder.CreateRet(createCast(Builder, CI, H->getReturnType()));
+    RI = Builder.CreateRet(Builder.CreateAggregateCast(CI, H->getReturnType()));
   }
 
   if (MergeFunctionsPDI) {

diff  --git a/llvm/test/Transforms/MergeFunc/crash-cast-arrays.ll b/llvm/test/Transforms/MergeFunc/crash-cast-arrays.ll
new file mode 100644
index 0000000000000..6a18feba1263a
--- /dev/null
+++ b/llvm/test/Transforms/MergeFunc/crash-cast-arrays.ll
@@ -0,0 +1,76 @@
+; RUN: opt -S -passes=mergefunc < %s | FileCheck %s
+
+target datalayout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-n32:64-S128-Fn32"
+
+%A = type { double }
+; the intermediary struct causes A_arr and B_arr to be 
diff erent types
+%A_struct = type { %A }
+%A_arr = type { [1 x %A_struct] }
+
+%B = type { double }
+%B_struct = type { %B }
+%B_arr = type { [1 x %B_struct] }
+
+; conversion between C_arr and D_arr is possible, but requires ptrcast
+%C = type { i64 }
+%C_struct = type { %C }
+%C_arr = type { [1 x %C_struct] }
+
+%D = type { ptr }
+%D_struct = type { %D }
+%D_arr = type { [1 x %D_struct] }
+
+declare void @noop()
+
+define %A_arr @a() {
+; CHECK-LABEL: define %A_arr @a() {
+; CHECK-NEXT:    call void @noop()
+; CHECK-NEXT:    ret %A_arr zeroinitializer
+;
+  call void @noop()
+  ret %A_arr zeroinitializer
+}
+
+define %C_arr @c() {
+; CHECK-LABEL: define %C_arr @c() {
+; CHECK-NEXT:    call void @noop()
+; CHECK-NEXT:    ret %C_arr zeroinitializer
+;
+  call void @noop()
+  ret %C_arr zeroinitializer
+}
+
+define %B_arr @b() {
+; CHECK-LABEL: define %B_arr @b() {
+; CHECK-NEXT:    [[TMP1:%.*]] = tail call %A_arr @a
+; CHECK-NEXT:    [[TMP2:%.*]] = extractvalue %A_arr [[TMP1]], 0
+; CHECK-NEXT:    [[TMP3:%.*]] = extractvalue [1 x %A_struct] [[TMP2]], 0
+; CHECK-NEXT:    [[TMP4:%.*]] = extractvalue %A_struct [[TMP3]], 0
+; CHECK-NEXT:    [[TMP5:%.*]] = extractvalue %A [[TMP4]], 0
+; CHECK-NEXT:    [[TMP6:%.*]] = insertvalue %B poison, double [[TMP5]], 0
+; CHECK-NEXT:    [[TMP7:%.*]] = insertvalue %B_struct poison, %B [[TMP6]], 0
+; CHECK-NEXT:    [[TMP8:%.*]] = insertvalue [1 x %B_struct] poison, %B_struct [[TMP7]], 0
+; CHECK-NEXT:    [[TMP9:%.*]] = insertvalue %B_arr poison, [1 x %B_struct] [[TMP8]], 0
+; CHECK-NEXT:    ret %B_arr [[TMP9]]
+;
+  call void @noop()
+  ret %B_arr zeroinitializer
+}
+
+define %D_arr @d() {
+; CHECK-LABEL: define %D_arr @d() {
+; CHECK-NEXT:    [[TMP1:%.*]] = tail call %C_arr @c
+; CHECK-NEXT:    [[TMP2:%.*]] = extractvalue %C_arr [[TMP1]], 0
+; CHECK-NEXT:    [[TMP3:%.*]] = extractvalue [1 x %C_struct] [[TMP2]], 0
+; CHECK-NEXT:    [[TMP4:%.*]] = extractvalue %C_struct [[TMP3]], 0
+; CHECK-NEXT:    [[TMP5:%.*]] = extractvalue %C [[TMP4]], 0
+; CHECK-NEXT:    [[TMP10:%.*]] = inttoptr i64 [[TMP5]] to ptr
+; CHECK-NEXT:    [[TMP6:%.*]] = insertvalue %D poison, ptr [[TMP10]], 0
+; CHECK-NEXT:    [[TMP7:%.*]] = insertvalue %D_struct poison, %D [[TMP6]], 0
+; CHECK-NEXT:    [[TMP8:%.*]] = insertvalue [1 x %D_struct] poison, %D_struct [[TMP7]], 0
+; CHECK-NEXT:    [[TMP9:%.*]] = insertvalue %D_arr poison, [1 x %D_struct] [[TMP8]], 0
+; CHECK-NEXT:    ret %D_arr [[TMP9]]
+;
+  call void @noop()
+  ret %D_arr zeroinitializer
+}


        


More information about the llvm-commits mailing list