[llvm] 53a9175 - [llvm] Handle duplicate call bases when applying branch funneling
Leonard Chan via llvm-commits
llvm-commits at lists.llvm.org
Thu Mar 23 14:48:28 PDT 2023
Author: Leonard Chan
Date: 2023-03-23T21:44:59Z
New Revision: 53a917595186d711026505dbc42b95aca5a67825
URL: https://github.com/llvm/llvm-project/commit/53a917595186d711026505dbc42b95aca5a67825
DIFF: https://github.com/llvm/llvm-project/commit/53a917595186d711026505dbc42b95aca5a67825.diff
LOG: [llvm] Handle duplicate call bases when applying branch funneling
It's possible to segfault in `DevirtModule::applyICallBranchFunnel` when
attempting to call `getCaller` on a call base that was erased in a prior
iteration. This can occur when attempting to find devirtualizable calls
via `findDevirtualizableCallsForTypeTest` if the vtable passed to
llvm.type.test is a global and not a local. The function works by taking
the first argument of the llvm.type.test call (which is a vtable),
iterating through all uses of it, and adding any relevant all uses that
are calls associated with that intrinsic call to a vector. For most
cases where the vtable is actually a *local*, this wouldn't be an issue.
Take for example:
```
define i32 @fn(ptr %obj) #0 {
%vtable = load ptr, ptr %obj
%p = call i1 @llvm.type.test(ptr %vtable, metadata !"typeid2")
call void @llvm.assume(i1 %p)
%fptr = load ptr, ptr %vtable
%result = call i32 %fptr(ptr %obj, i32 1)
ret i32 %result
}
```
`findDevirtualizableCallsForTypeTest` will check the call base ` %result
= call i32 %fptr(ptr %obj, i32 1)`, find that it is associated with a
virtualizable call from `%vtable`, find all loads for `%vtable`, and add
any instances those load results are called into a vector. Now consider
the case where instead `%vtable` was the global itself rather than a
local:
```
define i32 @fn(ptr %obj) #0 {
%p = call i1 @llvm.type.test(ptr @vtable, metadata !"typeid2")
call void @llvm.assume(i1 %p)
%fptr = load ptr, ptr @vtable
%result = call i32 %fptr(ptr %obj, i32 1)
ret i32 %result
}
```
`findDevirtualizableCallsForTypeTest` should work normally and add one
unique call instance to a vector. However, if there are multiple
instances where this same global is used for llvm.type.test, like with:
```
define i32 @fn(ptr %obj) #0 {
%p = call i1 @llvm.type.test(ptr @vtable, metadata !"typeid2")
call void @llvm.assume(i1 %p)
%fptr = load ptr, ptr @vtable
%result = call i32 %fptr(ptr %obj, i32 1)
ret i32 %result
}
define i32 @fn2(ptr %obj) #0 {
%p = call i1 @llvm.type.test(ptr @vtable, metadata !"typeid2")
call void @llvm.assume(i1 %p)
%fptr = load ptr, ptr @vtable
%result = call i32 %fptr(ptr %obj, i32 1)
ret i32 %result
}
```
Then each call base `%result = call i32 %fptr(ptr %obj, i32 1)` will be
added to the vector twice. This is because for either call base `%result
= call i32 %fptr(ptr %obj, i32 1) `, we determine it is associated with
a virtualizable call from `@vtable`, and then we iterate through all the
uses of `@vtable`, which is used across multiple functions. So when
scanning the first `%result = call i32 %fptr(ptr %obj, i32 1)`, then
both call bases will be added to the vector, but when scanning the
second one, both call bases are added again, resulting in duplicate call
bases in the CSInfo.CallSites vector.
Note this is actually accounted for in every other instance WPD iterates
over CallSites. What everything else does is actually add the call base
to the `OptimizedCalls` set and just check if it's already in the set.
We can't reuse that particular set since it serves a different purpose
marking which calls where devirtualized which `applyICallBranchFunnel`
explicitly says it doesn't. For this fix, we can just account for
duplicates with a map and do the actual replacements afterwards by
iterating over the map.
Differential Revision: https://reviews.llvm.org/D146267
Added:
Modified:
llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
llvm/test/Transforms/WholeProgramDevirt/branch-funnel.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
index e380b47c735fe..8224de30d6986 100644
--- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -1391,9 +1391,20 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
IsExported = true;
if (CSInfo.AllCallSitesDevirted)
return;
+
+ std::map<CallBase *, CallBase *> CallBases;
for (auto &&VCallSite : CSInfo.CallSites) {
CallBase &CB = VCallSite.CB;
+ if (CallBases.find(&CB) != CallBases.end()) {
+ // When finding devirtualizable calls, it's possible to find the same
+ // vtable passed to multiple llvm.type.test or llvm.type.checked.load
+ // calls, which can cause duplicate call sites to be recorded in
+ // [Const]CallSites. If we've already found one of these
+ // call instances, just ignore it. It will be replaced later.
+ continue;
+ }
+
// Jump tables are only profitable if the retpoline mitigation is enabled.
Attribute FSAttr = CB.getCaller()->getFnAttribute("target-features");
if (!FSAttr.isValid() ||
@@ -1440,8 +1451,7 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
AttributeList::get(M.getContext(), Attrs.getFnAttrs(),
Attrs.getRetAttrs(), NewArgAttrs));
- CB.replaceAllUsesWith(NewCS);
- CB.eraseFromParent();
+ CallBases[&CB] = NewCS;
// This use is no longer unsafe.
if (VCallSite.NumUnsafeUses)
@@ -1451,6 +1461,11 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
// retpoline mitigation, which would mean that they are lowered to
// llvm.type.test and therefore require an llvm.type.test resolution for the
// type identifier.
+
+ std::for_each(CallBases.begin(), CallBases.end(), [](auto &CBs) {
+ CBs.first->replaceAllUsesWith(CBs.second);
+ CBs.first->eraseFromParent();
+ });
};
Apply(SlotInfo.CSInfo);
for (auto &P : SlotInfo.ConstCSInfo)
diff --git a/llvm/test/Transforms/WholeProgramDevirt/branch-funnel.ll b/llvm/test/Transforms/WholeProgramDevirt/branch-funnel.ll
index 4a6e3634a5d16..0b1023eee2732 100644
--- a/llvm/test/Transforms/WholeProgramDevirt/branch-funnel.ll
+++ b/llvm/test/Transforms/WholeProgramDevirt/branch-funnel.ll
@@ -233,6 +233,54 @@ define i32 @fn3_rv(ptr %obj) #0 {
ret i32 %result
}
+; CHECK-LABEL: define i32 @fn4
+; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
+define i32 @fn4(ptr %obj) #0 {
+ %p = call i1 @llvm.type.test(ptr @vt1_1, metadata !"typeid1")
+ call void @llvm.assume(i1 %p)
+ %fptr = load ptr, ptr @vt1_1
+ ; RETP: call i32 @__typeid_typeid1_0_branch_funnel(ptr nest @vt1_1, ptr %obj, i32 1)
+ %result = call i32 %fptr(ptr %obj, i32 1)
+ ; NORETP: call i32 %
+ ret i32 %result
+}
+
+; CHECK-LABEL: define i32 @fn4_cpy
+; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
+define i32 @fn4_cpy(ptr %obj) #0 {
+ %p = call i1 @llvm.type.test(ptr @vt1_1, metadata !"typeid1")
+ call void @llvm.assume(i1 %p)
+ %fptr = load ptr, ptr @vt1_1
+ ; RETP: call i32 @__typeid_typeid1_0_branch_funnel(ptr nest @vt1_1, ptr %obj, i32 1)
+ %result = call i32 %fptr(ptr %obj, i32 1)
+ ; NORETP: call i32 %
+ ret i32 %result
+}
+
+; CHECK-LABEL: define i32 @fn4_rv
+; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
+define i32 @fn4_rv(ptr %obj) #0 {
+ %p = call i1 @llvm.type.test(ptr @vt1_1_rv, metadata !"typeid1_rv")
+ call void @llvm.assume(i1 %p)
+ %fptr = call ptr @llvm.load.relative.i32(ptr @vt1_1_rv, i32 0)
+ ; RETP: call i32 @__typeid_typeid1_rv_0_branch_funnel(ptr nest @vt1_1_rv, ptr %obj, i32 1)
+ %result = call i32 %fptr(ptr %obj, i32 1)
+ ; NORETP: call i32 %
+ ret i32 %result
+}
+
+; CHECK-LABEL: define i32 @fn4_rv_cpy
+; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
+define i32 @fn4_rv_cpy(ptr %obj) #0 {
+ %p = call i1 @llvm.type.test(ptr @vt1_1_rv, metadata !"typeid1_rv")
+ call void @llvm.assume(i1 %p)
+ %fptr = call ptr @llvm.load.relative.i32(ptr @vt1_1_rv, i32 0)
+ ; RETP: call i32 @__typeid_typeid1_rv_0_branch_funnel(ptr nest @vt1_1_rv, ptr %obj, i32 1)
+ %result = call i32 %fptr(ptr %obj, i32 1)
+ ; NORETP: call i32 %
+ ret i32 %result
+}
+
; CHECK-LABEL: define hidden void @__typeid_typeid1_0_branch_funnel(ptr nest %0, ...)
; CHECK-NEXT: musttail call void (...) @llvm.icall.branch.funnel(ptr %0, ptr {{(nonnull )?}}@vt1_1, ptr {{(nonnull )?}}@vf1_1, ptr {{(nonnull )?}}@vt1_2, ptr {{(nonnull )?}}@vf1_2, ...)
More information about the llvm-commits
mailing list