[llvm] [MergeFunc] Fix crash caused by bitcasting ArrayType (PR #133259)
Tobias Stadler via llvm-commits
llvm-commits at lists.llvm.org
Thu Apr 3 08:30:44 PDT 2025
https://github.com/tobias-stadler updated https://github.com/llvm/llvm-project/pull/133259
>From 51d43552b2678739990f6fa4c7dbce09b7426254 Mon Sep 17 00:00:00 2001
From: Tobias Stadler <mail at stadler-tobias.de>
Date: Wed, 26 Mar 2025 16:32:44 +0000
Subject: [PATCH 1/3] [MergeFunc] Fix crash caused by bitcasting ArrayType
createCast in MergeFunctions did not consider ArrayTypes, which results
in the creation of a bitcast between ArrayTypes in the thunk function,
leading to an assertion failure in the provided test case.
The version of createCast in GlobalMergeFunctions does handle
ArrayTypes, so this common code has been factored out into the
IRBuilder.
---
llvm/include/llvm/IR/IRBuilder.h | 7 +++
llvm/lib/CodeGen/GlobalMergeFunctions.cpp | 47 ++-----------------
llvm/lib/IR/IRBuilder.cpp | 35 ++++++++++++++
llvm/lib/Transforms/IPO/MergeFunctions.cpp | 31 +-----------
.../Transforms/MergeFunc/crash-cast-arrays.ll | 38 +++++++++++++++
5 files changed, 87 insertions(+), 71 deletions(-)
create mode 100644 llvm/test/Transforms/MergeFunc/crash-cast-arrays.ll
diff --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h
index 750a99cc50dd7..a20fdec3f201d 100644
--- a/llvm/include/llvm/IR/IRBuilder.h
+++ b/llvm/include/llvm/IR/IRBuilder.h
@@ -2291,6 +2291,13 @@ class IRBuilderBase {
// isSigned parameter.
Value *CreateIntCast(Value *, Type *, const char *) = delete;
+ /// Cast between aggregate types that must have identical structure but may
+ /// differ in their leaf types. The leaf values are recursively extracted,
+ /// casted, and then reinserted into a value of type DestTy. The leaf types
+ /// must be castable using a bitcast or ptrcast, because signedness is
+ /// not specified.
+ Value *CreateAggregateCast(Value *V, Type *DestTy);
+
//===--------------------------------------------------------------------===//
// Instruction creation methods: Compare Instructions
//===--------------------------------------------------------------------===//
diff --git a/llvm/lib/CodeGen/GlobalMergeFunctions.cpp b/llvm/lib/CodeGen/GlobalMergeFunctions.cpp
index e920b1be6822c..d4c53e79ed2e1 100644
--- a/llvm/lib/CodeGen/GlobalMergeFunctions.cpp
+++ b/llvm/lib/CodeGen/GlobalMergeFunctions.cpp
@@ -140,44 +140,6 @@ static bool ignoreOp(const Instruction *I, unsigned OpIdx) {
return true;
}
-static Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) {
- Type *SrcTy = V->getType();
- if (SrcTy->isStructTy()) {
- assert(DestTy->isStructTy());
- assert(SrcTy->getStructNumElements() == DestTy->getStructNumElements());
- Value *Result = PoisonValue::get(DestTy);
- for (unsigned int I = 0, E = SrcTy->getStructNumElements(); I < E; ++I) {
- Value *Element =
- createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)),
- DestTy->getStructElementType(I));
-
- Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I));
- }
- return Result;
- }
- assert(!DestTy->isStructTy());
- if (auto *SrcAT = dyn_cast<ArrayType>(SrcTy)) {
- auto *DestAT = dyn_cast<ArrayType>(DestTy);
- assert(DestAT);
- assert(SrcAT->getNumElements() == DestAT->getNumElements());
- Value *Result = PoisonValue::get(DestTy);
- for (unsigned int I = 0, E = SrcAT->getNumElements(); I < E; ++I) {
- Value *Element =
- createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)),
- DestAT->getElementType());
-
- Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I));
- }
- return Result;
- }
- assert(!DestTy->isArrayTy());
- if (SrcTy->isIntegerTy() && DestTy->isPointerTy())
- return Builder.CreateIntToPtr(V, DestTy);
- if (SrcTy->isPointerTy() && DestTy->isIntegerTy())
- return Builder.CreatePtrToInt(V, DestTy);
- return Builder.CreateBitCast(V, DestTy);
-}
-
void GlobalMergeFunc::analyze(Module &M) {
++NumAnalyzedModues;
for (Function &Func : M) {
@@ -268,7 +230,7 @@ static Function *createMergedFunction(FuncMergeInfo &FI,
if (OrigC->getType() != NewArg->getType()) {
IRBuilder<> Builder(Inst->getParent(), Inst->getIterator());
Inst->setOperand(OpndIndex,
- createCast(Builder, NewArg, OrigC->getType()));
+ Builder.CreateAggregateCast(NewArg, OrigC->getType()));
} else {
Inst->setOperand(OpndIndex, NewArg);
}
@@ -297,7 +259,8 @@ static void createThunk(FuncMergeInfo &FI, ArrayRef<Constant *> Params,
// Add arguments which are passed through Thunk.
for (Argument &AI : Thunk->args()) {
- Args.push_back(createCast(Builder, &AI, ToFuncTy->getParamType(ParamIdx)));
+ Args.push_back(
+ Builder.CreateAggregateCast(&AI, ToFuncTy->getParamType(ParamIdx)));
++ParamIdx;
}
@@ -305,7 +268,7 @@ static void createThunk(FuncMergeInfo &FI, ArrayRef<Constant *> Params,
for (auto *Param : Params) {
assert(ParamIdx < ToFuncTy->getNumParams());
Args.push_back(
- createCast(Builder, Param, ToFuncTy->getParamType(ParamIdx)));
+ Builder.CreateAggregateCast(Param, ToFuncTy->getParamType(ParamIdx)));
++ParamIdx;
}
@@ -319,7 +282,7 @@ static void createThunk(FuncMergeInfo &FI, ArrayRef<Constant *> Params,
if (Thunk->getReturnType()->isVoidTy())
Builder.CreateRetVoid();
else
- Builder.CreateRet(createCast(Builder, CI, Thunk->getReturnType()));
+ Builder.CreateRet(Builder.CreateAggregateCast(CI, Thunk->getReturnType()));
}
// Check if the old merged/optimized IndexOperandHashMap is compatible with
diff --git a/llvm/lib/IR/IRBuilder.cpp b/llvm/lib/IR/IRBuilder.cpp
index 421b617a5fb7e..58a65ec646557 100644
--- a/llvm/lib/IR/IRBuilder.cpp
+++ b/llvm/lib/IR/IRBuilder.cpp
@@ -76,6 +76,41 @@ void IRBuilderBase::SetInstDebugLocation(Instruction *I) const {
}
}
+Value *IRBuilderBase::CreateAggregateCast(Value *V, Type *DestTy) {
+ Type *SrcTy = V->getType();
+ if (SrcTy == DestTy)
+ return V;
+ if (auto *SrcST = dyn_cast<StructType>(SrcTy)) {
+ assert(DestTy->isStructTy() && "Expected StructType");
+ auto *DestST = cast<StructType>(DestTy);
+ assert(SrcST->getNumElements() == DestST->getNumElements());
+ Value *Result = PoisonValue::get(DestTy);
+ for (unsigned int I = 0, E = SrcST->getNumElements(); I < E; ++I) {
+ Value *Element = CreateAggregateCast(CreateExtractValue(V, ArrayRef(I)),
+ DestST->getElementType(I));
+
+ Result = CreateInsertValue(Result, Element, ArrayRef(I));
+ }
+ return Result;
+ }
+ if (auto *SrcAT = dyn_cast<ArrayType>(SrcTy)) {
+ assert(DestTy->isArrayTy() && "Expected ArrayType");
+ auto *DestAT = cast<ArrayType>(DestTy);
+ assert(SrcAT->getNumElements() == DestAT->getNumElements());
+ Value *Result = PoisonValue::get(DestTy);
+ for (unsigned int I = 0, E = SrcAT->getNumElements(); I < E; ++I) {
+ Value *Element = CreateAggregateCast(CreateExtractValue(V, ArrayRef(I)),
+ DestAT->getElementType());
+
+ Result = CreateInsertValue(Result, Element, ArrayRef(I));
+ }
+ return Result;
+ }
+
+ assert(!DestTy->isAggregateType());
+ return CreateBitOrPointerCast(V, DestTy);
+}
+
CallInst *
IRBuilderBase::createCallHelper(Function *Callee, ArrayRef<Value *> Ops,
const Twine &Name, FMFSource FMFSource,
diff --git a/llvm/lib/Transforms/IPO/MergeFunctions.cpp b/llvm/lib/Transforms/IPO/MergeFunctions.cpp
index 924db314674d5..c58c0f40c1b23 100644
--- a/llvm/lib/Transforms/IPO/MergeFunctions.cpp
+++ b/llvm/lib/Transforms/IPO/MergeFunctions.cpp
@@ -511,33 +511,6 @@ void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) {
}
}
-// Helper for writeThunk,
-// Selects proper bitcast operation,
-// but a bit simpler then CastInst::getCastOpcode.
-static Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) {
- Type *SrcTy = V->getType();
- if (SrcTy->isStructTy()) {
- assert(DestTy->isStructTy());
- assert(SrcTy->getStructNumElements() == DestTy->getStructNumElements());
- Value *Result = PoisonValue::get(DestTy);
- for (unsigned int I = 0, E = SrcTy->getStructNumElements(); I < E; ++I) {
- Value *Element =
- createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)),
- DestTy->getStructElementType(I));
-
- Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I));
- }
- return Result;
- }
- assert(!DestTy->isStructTy());
- if (SrcTy->isIntegerTy() && DestTy->isPointerTy())
- return Builder.CreateIntToPtr(V, DestTy);
- else if (SrcTy->isPointerTy() && DestTy->isIntegerTy())
- return Builder.CreatePtrToInt(V, DestTy);
- else
- return Builder.CreateBitCast(V, DestTy);
-}
-
// Erase the instructions in PDIUnrelatedWL as they are unrelated to the
// parameter debug info, from the entry block.
void MergeFunctions::eraseInstsUnrelatedToPDI(
@@ -789,7 +762,7 @@ void MergeFunctions::writeThunk(Function *F, Function *G) {
unsigned i = 0;
FunctionType *FFTy = F->getFunctionType();
for (Argument &AI : H->args()) {
- Args.push_back(createCast(Builder, &AI, FFTy->getParamType(i)));
+ Args.push_back(Builder.CreateAggregateCast(&AI, FFTy->getParamType(i)));
++i;
}
@@ -804,7 +777,7 @@ void MergeFunctions::writeThunk(Function *F, Function *G) {
if (H->getReturnType()->isVoidTy()) {
RI = Builder.CreateRetVoid();
} else {
- RI = Builder.CreateRet(createCast(Builder, CI, H->getReturnType()));
+ RI = Builder.CreateRet(Builder.CreateAggregateCast(CI, H->getReturnType()));
}
if (MergeFunctionsPDI) {
diff --git a/llvm/test/Transforms/MergeFunc/crash-cast-arrays.ll b/llvm/test/Transforms/MergeFunc/crash-cast-arrays.ll
new file mode 100644
index 0000000000000..fcbb06400a618
--- /dev/null
+++ b/llvm/test/Transforms/MergeFunc/crash-cast-arrays.ll
@@ -0,0 +1,38 @@
+; RUN: opt -S -passes=mergefunc < %s | FileCheck %s
+
+%A = type { double }
+; the intermediary struct causes A_arr and B_arr to be different types
+%A_struct = type { %A }
+%A_arr = type { [1 x %A_struct] }
+
+%B = type { double }
+%B_struct = type { %B }
+%B_arr = type { [1 x %B_struct] }
+
+declare void @noop()
+
+define %A_arr @a() {
+; CHECK-LABEL: define %A_arr @a() {
+; CHECK-NEXT: call void @noop()
+; CHECK-NEXT: ret %A_arr zeroinitializer
+;
+ call void @noop()
+ ret %A_arr zeroinitializer
+}
+
+define %B_arr @b() {
+; CHECK-LABEL: define %B_arr @b() {
+; CHECK-NEXT: [[TMP1:%.*]] = tail call %A_arr @a
+; CHECK-NEXT: [[TMP2:%.*]] = extractvalue %A_arr [[TMP1]], 0
+; CHECK-NEXT: [[TMP3:%.*]] = extractvalue [1 x %A_struct] [[TMP2]], 0
+; CHECK-NEXT: [[TMP4:%.*]] = extractvalue %A_struct [[TMP3]], 0
+; CHECK-NEXT: [[TMP5:%.*]] = extractvalue %A [[TMP4]], 0
+; CHECK-NEXT: [[TMP6:%.*]] = insertvalue %B poison, double [[TMP5]], 0
+; CHECK-NEXT: [[TMP7:%.*]] = insertvalue %B_struct poison, %B [[TMP6]], 0
+; CHECK-NEXT: [[TMP8:%.*]] = insertvalue [1 x %B_struct] poison, %B_struct [[TMP7]], 0
+; CHECK-NEXT: [[TMP9:%.*]] = insertvalue %B_arr poison, [1 x %B_struct] [[TMP8]], 0
+; CHECK-NEXT: ret %B_arr [[TMP9]]
+;
+ call void @noop()
+ ret %B_arr zeroinitializer
+}
>From efd9a26d988e7bf01523fe20c4d5b40724dd3927 Mon Sep 17 00:00:00 2001
From: Tobias Stadler <mail at stadler-tobias.de>
Date: Wed, 2 Apr 2025 16:32:27 +0100
Subject: [PATCH 2/3] Unify Struct and Array handling; Add ptrcast test
---
llvm/lib/IR/IRBuilder.cpp | 38 +++++++++----------
.../Transforms/MergeFunc/crash-cast-arrays.ll | 38 +++++++++++++++++++
2 files changed, 57 insertions(+), 19 deletions(-)
diff --git a/llvm/lib/IR/IRBuilder.cpp b/llvm/lib/IR/IRBuilder.cpp
index 58a65ec646557..c373af06a9a82 100644
--- a/llvm/lib/IR/IRBuilder.cpp
+++ b/llvm/lib/IR/IRBuilder.cpp
@@ -80,34 +80,34 @@ Value *IRBuilderBase::CreateAggregateCast(Value *V, Type *DestTy) {
Type *SrcTy = V->getType();
if (SrcTy == DestTy)
return V;
- if (auto *SrcST = dyn_cast<StructType>(SrcTy)) {
- assert(DestTy->isStructTy() && "Expected StructType");
- auto *DestST = cast<StructType>(DestTy);
- assert(SrcST->getNumElements() == DestST->getNumElements());
- Value *Result = PoisonValue::get(DestTy);
- for (unsigned int I = 0, E = SrcST->getNumElements(); I < E; ++I) {
- Value *Element = CreateAggregateCast(CreateExtractValue(V, ArrayRef(I)),
- DestST->getElementType(I));
- Result = CreateInsertValue(Result, Element, ArrayRef(I));
+ if (SrcTy->isAggregateType()) {
+ unsigned NumElements;
+ if (SrcTy->isStructTy()) {
+ assert(DestTy->isStructTy() && "Expected StructType");
+ assert(SrcTy->getStructNumElements() == DestTy->getStructNumElements() &&
+ "Expected StructTypes with equal number of elements");
+ NumElements = SrcTy->getStructNumElements();
+ } else {
+ assert(SrcTy->isArrayTy());
+ assert(DestTy->isArrayTy() && "Expected ArrayType");
+ assert(SrcTy->getArrayNumElements() == DestTy->getArrayNumElements() &&
+ "Expected ArrayTypes with equal number of elements");
+ NumElements = SrcTy->getArrayNumElements();
}
- return Result;
- }
- if (auto *SrcAT = dyn_cast<ArrayType>(SrcTy)) {
- assert(DestTy->isArrayTy() && "Expected ArrayType");
- auto *DestAT = cast<ArrayType>(DestTy);
- assert(SrcAT->getNumElements() == DestAT->getNumElements());
+
Value *Result = PoisonValue::get(DestTy);
- for (unsigned int I = 0, E = SrcAT->getNumElements(); I < E; ++I) {
- Value *Element = CreateAggregateCast(CreateExtractValue(V, ArrayRef(I)),
- DestAT->getElementType());
+ for (unsigned I = 0; I < NumElements; ++I) {
+ Type *ElementTy = SrcTy->isStructTy() ? DestTy->getStructElementType(I)
+ : DestTy->getArrayElementType();
+ Value *Element =
+ CreateAggregateCast(CreateExtractValue(V, ArrayRef(I)), ElementTy);
Result = CreateInsertValue(Result, Element, ArrayRef(I));
}
return Result;
}
- assert(!DestTy->isAggregateType());
return CreateBitOrPointerCast(V, DestTy);
}
diff --git a/llvm/test/Transforms/MergeFunc/crash-cast-arrays.ll b/llvm/test/Transforms/MergeFunc/crash-cast-arrays.ll
index fcbb06400a618..6a18feba1263a 100644
--- a/llvm/test/Transforms/MergeFunc/crash-cast-arrays.ll
+++ b/llvm/test/Transforms/MergeFunc/crash-cast-arrays.ll
@@ -1,5 +1,7 @@
; RUN: opt -S -passes=mergefunc < %s | FileCheck %s
+target datalayout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-n32:64-S128-Fn32"
+
%A = type { double }
; the intermediary struct causes A_arr and B_arr to be different types
%A_struct = type { %A }
@@ -9,6 +11,15 @@
%B_struct = type { %B }
%B_arr = type { [1 x %B_struct] }
+; conversion between C_arr and D_arr is possible, but requires ptrcast
+%C = type { i64 }
+%C_struct = type { %C }
+%C_arr = type { [1 x %C_struct] }
+
+%D = type { ptr }
+%D_struct = type { %D }
+%D_arr = type { [1 x %D_struct] }
+
declare void @noop()
define %A_arr @a() {
@@ -20,6 +31,15 @@ define %A_arr @a() {
ret %A_arr zeroinitializer
}
+define %C_arr @c() {
+; CHECK-LABEL: define %C_arr @c() {
+; CHECK-NEXT: call void @noop()
+; CHECK-NEXT: ret %C_arr zeroinitializer
+;
+ call void @noop()
+ ret %C_arr zeroinitializer
+}
+
define %B_arr @b() {
; CHECK-LABEL: define %B_arr @b() {
; CHECK-NEXT: [[TMP1:%.*]] = tail call %A_arr @a
@@ -36,3 +56,21 @@ define %B_arr @b() {
call void @noop()
ret %B_arr zeroinitializer
}
+
+define %D_arr @d() {
+; CHECK-LABEL: define %D_arr @d() {
+; CHECK-NEXT: [[TMP1:%.*]] = tail call %C_arr @c
+; CHECK-NEXT: [[TMP2:%.*]] = extractvalue %C_arr [[TMP1]], 0
+; CHECK-NEXT: [[TMP3:%.*]] = extractvalue [1 x %C_struct] [[TMP2]], 0
+; CHECK-NEXT: [[TMP4:%.*]] = extractvalue %C_struct [[TMP3]], 0
+; CHECK-NEXT: [[TMP5:%.*]] = extractvalue %C [[TMP4]], 0
+; CHECK-NEXT: [[TMP10:%.*]] = inttoptr i64 [[TMP5]] to ptr
+; CHECK-NEXT: [[TMP6:%.*]] = insertvalue %D poison, ptr [[TMP10]], 0
+; CHECK-NEXT: [[TMP7:%.*]] = insertvalue %D_struct poison, %D [[TMP6]], 0
+; CHECK-NEXT: [[TMP8:%.*]] = insertvalue [1 x %D_struct] poison, %D_struct [[TMP7]], 0
+; CHECK-NEXT: [[TMP9:%.*]] = insertvalue %D_arr poison, [1 x %D_struct] [[TMP8]], 0
+; CHECK-NEXT: ret %D_arr [[TMP9]]
+;
+ call void @noop()
+ ret %D_arr zeroinitializer
+}
>From a5775538015dcb589d1a44a257871adbda18cb98 Mon Sep 17 00:00:00 2001
From: Tobias Stadler <mail at stadler-tobias.de>
Date: Thu, 3 Apr 2025 16:30:36 +0100
Subject: [PATCH 3/3] Update llvm/lib/IR/IRBuilder.cpp
Co-authored-by: Florian Hahn <flo at fhahn.com>
---
llvm/lib/IR/IRBuilder.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/llvm/lib/IR/IRBuilder.cpp b/llvm/lib/IR/IRBuilder.cpp
index c373af06a9a82..e5a2f08c393c9 100644
--- a/llvm/lib/IR/IRBuilder.cpp
+++ b/llvm/lib/IR/IRBuilder.cpp
@@ -89,8 +89,7 @@ Value *IRBuilderBase::CreateAggregateCast(Value *V, Type *DestTy) {
"Expected StructTypes with equal number of elements");
NumElements = SrcTy->getStructNumElements();
} else {
- assert(SrcTy->isArrayTy());
- assert(DestTy->isArrayTy() && "Expected ArrayType");
+ assert(SrcTy->isArrayTy() && DestTy->isArrayTy() && "Expected ArrayType");
assert(SrcTy->getArrayNumElements() == DestTy->getArrayNumElements() &&
"Expected ArrayTypes with equal number of elements");
NumElements = SrcTy->getArrayNumElements();
More information about the llvm-commits
mailing list