[llvm] [MergeFunc] Fix crash caused by bitcasting ArrayType (PR #133259)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Mar 27 07:43:25 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Tobias Stadler (tobias-stadler)
<details>
<summary>Changes</summary>
createCast in MergeFunctions did not consider ArrayTypes, which results in the creation of a bitcast between ArrayTypes in the thunk function, leading to an assertion failure in the provided test case.
The version of createCast in GlobalMergeFunctions does handle ArrayTypes, so this common code has been factored out into the IRBuilder.
---
Full diff: https://github.com/llvm/llvm-project/pull/133259.diff
5 Files Affected:
- (modified) llvm/include/llvm/IR/IRBuilder.h (+7)
- (modified) llvm/lib/CodeGen/GlobalMergeFunctions.cpp (+5-42)
- (modified) llvm/lib/IR/IRBuilder.cpp (+35)
- (modified) llvm/lib/Transforms/IPO/MergeFunctions.cpp (+2-29)
- (added) llvm/test/Transforms/MergeFunc/crash-cast-arrays.ll (+38)
``````````diff
diff --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h
index 750a99cc50dd7..a20fdec3f201d 100644
--- a/llvm/include/llvm/IR/IRBuilder.h
+++ b/llvm/include/llvm/IR/IRBuilder.h
@@ -2291,6 +2291,13 @@ class IRBuilderBase {
// isSigned parameter.
Value *CreateIntCast(Value *, Type *, const char *) = delete;
+ /// Cast between aggregate types that must have identical structure but may
+ /// differ in their leaf types. The leaf values are recursively extracted,
+ /// casted, and then reinserted into a value of type DestTy. The leaf types
+ /// must be castable using a bitcast or ptrcast, because signedness is
+ /// not specified.
+ Value *CreateAggregateCast(Value *V, Type *DestTy);
+
//===--------------------------------------------------------------------===//
// Instruction creation methods: Compare Instructions
//===--------------------------------------------------------------------===//
diff --git a/llvm/lib/CodeGen/GlobalMergeFunctions.cpp b/llvm/lib/CodeGen/GlobalMergeFunctions.cpp
index e920b1be6822c..d4c53e79ed2e1 100644
--- a/llvm/lib/CodeGen/GlobalMergeFunctions.cpp
+++ b/llvm/lib/CodeGen/GlobalMergeFunctions.cpp
@@ -140,44 +140,6 @@ static bool ignoreOp(const Instruction *I, unsigned OpIdx) {
return true;
}
-static Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) {
- Type *SrcTy = V->getType();
- if (SrcTy->isStructTy()) {
- assert(DestTy->isStructTy());
- assert(SrcTy->getStructNumElements() == DestTy->getStructNumElements());
- Value *Result = PoisonValue::get(DestTy);
- for (unsigned int I = 0, E = SrcTy->getStructNumElements(); I < E; ++I) {
- Value *Element =
- createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)),
- DestTy->getStructElementType(I));
-
- Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I));
- }
- return Result;
- }
- assert(!DestTy->isStructTy());
- if (auto *SrcAT = dyn_cast<ArrayType>(SrcTy)) {
- auto *DestAT = dyn_cast<ArrayType>(DestTy);
- assert(DestAT);
- assert(SrcAT->getNumElements() == DestAT->getNumElements());
- Value *Result = PoisonValue::get(DestTy);
- for (unsigned int I = 0, E = SrcAT->getNumElements(); I < E; ++I) {
- Value *Element =
- createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)),
- DestAT->getElementType());
-
- Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I));
- }
- return Result;
- }
- assert(!DestTy->isArrayTy());
- if (SrcTy->isIntegerTy() && DestTy->isPointerTy())
- return Builder.CreateIntToPtr(V, DestTy);
- if (SrcTy->isPointerTy() && DestTy->isIntegerTy())
- return Builder.CreatePtrToInt(V, DestTy);
- return Builder.CreateBitCast(V, DestTy);
-}
-
void GlobalMergeFunc::analyze(Module &M) {
++NumAnalyzedModues;
for (Function &Func : M) {
@@ -268,7 +230,7 @@ static Function *createMergedFunction(FuncMergeInfo &FI,
if (OrigC->getType() != NewArg->getType()) {
IRBuilder<> Builder(Inst->getParent(), Inst->getIterator());
Inst->setOperand(OpndIndex,
- createCast(Builder, NewArg, OrigC->getType()));
+ Builder.CreateAggregateCast(NewArg, OrigC->getType()));
} else {
Inst->setOperand(OpndIndex, NewArg);
}
@@ -297,7 +259,8 @@ static void createThunk(FuncMergeInfo &FI, ArrayRef<Constant *> Params,
// Add arguments which are passed through Thunk.
for (Argument &AI : Thunk->args()) {
- Args.push_back(createCast(Builder, &AI, ToFuncTy->getParamType(ParamIdx)));
+ Args.push_back(
+ Builder.CreateAggregateCast(&AI, ToFuncTy->getParamType(ParamIdx)));
++ParamIdx;
}
@@ -305,7 +268,7 @@ static void createThunk(FuncMergeInfo &FI, ArrayRef<Constant *> Params,
for (auto *Param : Params) {
assert(ParamIdx < ToFuncTy->getNumParams());
Args.push_back(
- createCast(Builder, Param, ToFuncTy->getParamType(ParamIdx)));
+ Builder.CreateAggregateCast(Param, ToFuncTy->getParamType(ParamIdx)));
++ParamIdx;
}
@@ -319,7 +282,7 @@ static void createThunk(FuncMergeInfo &FI, ArrayRef<Constant *> Params,
if (Thunk->getReturnType()->isVoidTy())
Builder.CreateRetVoid();
else
- Builder.CreateRet(createCast(Builder, CI, Thunk->getReturnType()));
+ Builder.CreateRet(Builder.CreateAggregateCast(CI, Thunk->getReturnType()));
}
// Check if the old merged/optimized IndexOperandHashMap is compatible with
diff --git a/llvm/lib/IR/IRBuilder.cpp b/llvm/lib/IR/IRBuilder.cpp
index 421b617a5fb7e..58a65ec646557 100644
--- a/llvm/lib/IR/IRBuilder.cpp
+++ b/llvm/lib/IR/IRBuilder.cpp
@@ -76,6 +76,41 @@ void IRBuilderBase::SetInstDebugLocation(Instruction *I) const {
}
}
+Value *IRBuilderBase::CreateAggregateCast(Value *V, Type *DestTy) {
+ Type *SrcTy = V->getType();
+ if (SrcTy == DestTy)
+ return V;
+ if (auto *SrcST = dyn_cast<StructType>(SrcTy)) {
+ assert(DestTy->isStructTy() && "Expected StructType");
+ auto *DestST = cast<StructType>(DestTy);
+ assert(SrcST->getNumElements() == DestST->getNumElements());
+ Value *Result = PoisonValue::get(DestTy);
+ for (unsigned int I = 0, E = SrcST->getNumElements(); I < E; ++I) {
+ Value *Element = CreateAggregateCast(CreateExtractValue(V, ArrayRef(I)),
+ DestST->getElementType(I));
+
+ Result = CreateInsertValue(Result, Element, ArrayRef(I));
+ }
+ return Result;
+ }
+ if (auto *SrcAT = dyn_cast<ArrayType>(SrcTy)) {
+ assert(DestTy->isArrayTy() && "Expected ArrayType");
+ auto *DestAT = cast<ArrayType>(DestTy);
+ assert(SrcAT->getNumElements() == DestAT->getNumElements());
+ Value *Result = PoisonValue::get(DestTy);
+ for (unsigned int I = 0, E = SrcAT->getNumElements(); I < E; ++I) {
+ Value *Element = CreateAggregateCast(CreateExtractValue(V, ArrayRef(I)),
+ DestAT->getElementType());
+
+ Result = CreateInsertValue(Result, Element, ArrayRef(I));
+ }
+ return Result;
+ }
+
+ assert(!DestTy->isAggregateType());
+ return CreateBitOrPointerCast(V, DestTy);
+}
+
CallInst *
IRBuilderBase::createCallHelper(Function *Callee, ArrayRef<Value *> Ops,
const Twine &Name, FMFSource FMFSource,
diff --git a/llvm/lib/Transforms/IPO/MergeFunctions.cpp b/llvm/lib/Transforms/IPO/MergeFunctions.cpp
index 924db314674d5..c58c0f40c1b23 100644
--- a/llvm/lib/Transforms/IPO/MergeFunctions.cpp
+++ b/llvm/lib/Transforms/IPO/MergeFunctions.cpp
@@ -511,33 +511,6 @@ void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) {
}
}
-// Helper for writeThunk,
-// Selects proper bitcast operation,
-// but a bit simpler then CastInst::getCastOpcode.
-static Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) {
- Type *SrcTy = V->getType();
- if (SrcTy->isStructTy()) {
- assert(DestTy->isStructTy());
- assert(SrcTy->getStructNumElements() == DestTy->getStructNumElements());
- Value *Result = PoisonValue::get(DestTy);
- for (unsigned int I = 0, E = SrcTy->getStructNumElements(); I < E; ++I) {
- Value *Element =
- createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)),
- DestTy->getStructElementType(I));
-
- Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I));
- }
- return Result;
- }
- assert(!DestTy->isStructTy());
- if (SrcTy->isIntegerTy() && DestTy->isPointerTy())
- return Builder.CreateIntToPtr(V, DestTy);
- else if (SrcTy->isPointerTy() && DestTy->isIntegerTy())
- return Builder.CreatePtrToInt(V, DestTy);
- else
- return Builder.CreateBitCast(V, DestTy);
-}
-
// Erase the instructions in PDIUnrelatedWL as they are unrelated to the
// parameter debug info, from the entry block.
void MergeFunctions::eraseInstsUnrelatedToPDI(
@@ -789,7 +762,7 @@ void MergeFunctions::writeThunk(Function *F, Function *G) {
unsigned i = 0;
FunctionType *FFTy = F->getFunctionType();
for (Argument &AI : H->args()) {
- Args.push_back(createCast(Builder, &AI, FFTy->getParamType(i)));
+ Args.push_back(Builder.CreateAggregateCast(&AI, FFTy->getParamType(i)));
++i;
}
@@ -804,7 +777,7 @@ void MergeFunctions::writeThunk(Function *F, Function *G) {
if (H->getReturnType()->isVoidTy()) {
RI = Builder.CreateRetVoid();
} else {
- RI = Builder.CreateRet(createCast(Builder, CI, H->getReturnType()));
+ RI = Builder.CreateRet(Builder.CreateAggregateCast(CI, H->getReturnType()));
}
if (MergeFunctionsPDI) {
diff --git a/llvm/test/Transforms/MergeFunc/crash-cast-arrays.ll b/llvm/test/Transforms/MergeFunc/crash-cast-arrays.ll
new file mode 100644
index 0000000000000..fcbb06400a618
--- /dev/null
+++ b/llvm/test/Transforms/MergeFunc/crash-cast-arrays.ll
@@ -0,0 +1,38 @@
+; RUN: opt -S -passes=mergefunc < %s | FileCheck %s
+
+%A = type { double }
+; the intermediary struct causes A_arr and B_arr to be different types
+%A_struct = type { %A }
+%A_arr = type { [1 x %A_struct] }
+
+%B = type { double }
+%B_struct = type { %B }
+%B_arr = type { [1 x %B_struct] }
+
+declare void @noop()
+
+define %A_arr @a() {
+; CHECK-LABEL: define %A_arr @a() {
+; CHECK-NEXT: call void @noop()
+; CHECK-NEXT: ret %A_arr zeroinitializer
+;
+ call void @noop()
+ ret %A_arr zeroinitializer
+}
+
+define %B_arr @b() {
+; CHECK-LABEL: define %B_arr @b() {
+; CHECK-NEXT: [[TMP1:%.*]] = tail call %A_arr @a
+; CHECK-NEXT: [[TMP2:%.*]] = extractvalue %A_arr [[TMP1]], 0
+; CHECK-NEXT: [[TMP3:%.*]] = extractvalue [1 x %A_struct] [[TMP2]], 0
+; CHECK-NEXT: [[TMP4:%.*]] = extractvalue %A_struct [[TMP3]], 0
+; CHECK-NEXT: [[TMP5:%.*]] = extractvalue %A [[TMP4]], 0
+; CHECK-NEXT: [[TMP6:%.*]] = insertvalue %B poison, double [[TMP5]], 0
+; CHECK-NEXT: [[TMP7:%.*]] = insertvalue %B_struct poison, %B [[TMP6]], 0
+; CHECK-NEXT: [[TMP8:%.*]] = insertvalue [1 x %B_struct] poison, %B_struct [[TMP7]], 0
+; CHECK-NEXT: [[TMP9:%.*]] = insertvalue %B_arr poison, [1 x %B_struct] [[TMP8]], 0
+; CHECK-NEXT: ret %B_arr [[TMP9]]
+;
+ call void @noop()
+ ret %B_arr zeroinitializer
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/133259
More information about the llvm-commits
mailing list