[PATCH] D146267: [llvm] Handle duplicate call bases when applying branch funneling

Leonard Chan via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 16 16:50:04 PDT 2023


leonardchan created this revision.
leonardchan added reviewers: tejohnson, pcc.
leonardchan added a project: LLVM.
Herald added subscribers: ormris, hiraditya, Prazek.
Herald added a project: All.
leonardchan requested review of this revision.

It's possible to segfault in `DevirtModule::applyICallBranchFunnel` when attempting to call `getCaller` on a call base that was erased in a prior iteration. This can occur when attempting to find devirtualizable calls via `findDevirtualizableCallsForTypeTest` if the vtable passed to llvm.type.test is a global and not a local. The function works by taking the first argument of the llvm.type.test call (which is a vtable), iterating through all uses of it, and adding any relevant all uses that are calls associated with that intrinsic call to a vector. For most cases where the vtable is actually a *local*, this wouldn't be an issue. Take for example:

  define i32 @fn(ptr %obj) #0 {                                                                                                                                                                                                                                                            
    %vtable = load ptr, ptr %obj                                                                                                                                                                                                                                                            
    %p = call i1 @llvm.type.test(ptr %vtable, metadata !"typeid2")                                                                                                                                                                                                                          
    call void @llvm.assume(i1 %p)                                                                                                                                                                                                                                                           
    %fptr = load ptr, ptr %vtable                                                                                                                                                                                                                                                           
    %result = call i32 %fptr(ptr %obj, i32 1)                                                                                                                                                                                                                                               
    ret i32 %result                                                                                                                                                                                                                                                                         
  }    

`findDevirtualizableCallsForTypeTest` will check the call base ` %result = call i32 %fptr(ptr %obj, i32 1)`, find that it is associated with a virtualizable call from `%vtable`, find all loads for `%vtable`, and add any instances those load results are called into a vector. Now consider the case where instead `%vtable` was the global itself rather than a local:

  define i32 @fn(ptr %obj) #0 {                                                                                                                                                                                                                                                            
    %p = call i1 @llvm.type.test(ptr @vtable, metadata !"typeid2")                                                                                                                                                                                                                          
    call void @llvm.assume(i1 %p)                                                                                                                                                                                                                                                           
    %fptr = load ptr, ptr @vtable                                                                                                                                                                                                                                                           
    %result = call i32 %fptr(ptr %obj, i32 1)                                                                                                                                                                                                                                               
    ret i32 %result                                                                                                                                                                                                                                                                         
  }    

`findDevirtualizableCallsForTypeTest` should work normally and add one unique call instance to a vector. However, if there are multiple instances where this same global is used for llvm.type.test, like with:

  define i32 @fn(ptr %obj) #0 {                                                                                                                                                                                                                                                            
    %p = call i1 @llvm.type.test(ptr @vtable, metadata !"typeid2")                                                                                                                                                                                                                          
    call void @llvm.assume(i1 %p)                                                                                                                                                                                                                                                           
    %fptr = load ptr, ptr @vtable                                                                                                                                                                                                                                                           
    %result = call i32 %fptr(ptr %obj, i32 1)                                                                                                                                                                                                                                               
    ret i32 %result
  }
  
  define i32 @fn2(ptr %obj) #0 {                                                                                                                                                                                                                                                            
    %p = call i1 @llvm.type.test(ptr @vtable, metadata !"typeid2")                                                                                                                                                                                                                          
    call void @llvm.assume(i1 %p)                                                                                                                                                                                                                                                           
    %fptr = load ptr, ptr @vtable                                                                                                                                                                                                                                                           
    %result = call i32 %fptr(ptr %obj, i32 1)                                                                                                                                                                                                                                               
    ret i32 %result
  }

Then each call base `%result = call i32 %fptr(ptr %obj, i32 1)` will be added to the vector twice. This is because for either call base `%result = call i32 %fptr(ptr %obj, i32 1) `, we determine it is associated with a virtualizable call from `@vtable`, and then we iterate through all the uses of `@vtable`, which is used across multiple functions. So when scanning the first `%result = call i32 %fptr(ptr %obj, i32 1)`, then both call bases will be added to the vector, but when scanning the second one, both call bases are added again, resulting in duplicate call bases in the CSInfo.CallSites vector.

Note this is actually accounted for in every other instance WPD iterates over CallSites. What everything else does is actually add the call base to the `OptimizedCalls` set and just check if it's already in the set. We can't reuse that particular set since it serves a different purpose marking which calls where devirtualized which `applyICallBranchFunnel` explicitly says it doesn't. The easiest fix here seems to just check if the call base was removed in a previous iteration. An alternative approach could be to just change `CSInfo.CallSites` from a vector to a set.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D146267

Files:
  llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
  llvm/test/Transforms/WholeProgramDevirt/branch-funnel.ll


Index: llvm/test/Transforms/WholeProgramDevirt/branch-funnel.ll
===================================================================
--- llvm/test/Transforms/WholeProgramDevirt/branch-funnel.ll
+++ llvm/test/Transforms/WholeProgramDevirt/branch-funnel.ll
@@ -233,6 +233,54 @@
   ret i32 %result
 }
 
+; CHECK-LABEL: define i32 @fn4
+; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
+define i32 @fn4(ptr %obj) #0 {
+  %p = call i1 @llvm.type.test(ptr @vt1_1, metadata !"typeid1")
+  call void @llvm.assume(i1 %p)
+  %fptr = load ptr, ptr @vt1_1
+  ; RETP: call i32 @__typeid_typeid1_0_branch_funnel(ptr nest @vt1_1, ptr %obj, i32 1)
+  %result = call i32 %fptr(ptr %obj, i32 1)
+  ; NORETP: call i32 %
+  ret i32 %result
+}
+
+; CHECK-LABEL: define i32 @fn4_cpy
+; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
+define i32 @fn4_cpy(ptr %obj) #0 {
+  %p = call i1 @llvm.type.test(ptr @vt1_1, metadata !"typeid1")
+  call void @llvm.assume(i1 %p)
+  %fptr = load ptr, ptr @vt1_1
+  ; RETP: call i32 @__typeid_typeid1_0_branch_funnel(ptr nest @vt1_1, ptr %obj, i32 1)
+  %result = call i32 %fptr(ptr %obj, i32 1)
+  ; NORETP: call i32 %
+  ret i32 %result
+}
+
+; CHECK-LABEL: define i32 @fn4_rv
+; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
+define i32 @fn4_rv(ptr %obj) #0 {
+  %p = call i1 @llvm.type.test(ptr @vt1_1_rv, metadata !"typeid1")
+  call void @llvm.assume(i1 %p)
+  %fptr = call ptr @llvm.load.relative.i32(ptr @vt1_1_rv, i32 0)
+  ; RETP: call i32 @__typeid_typeid1_0_branch_funnel(ptr nest @vt1_1_rv, ptr %obj, i32 1)
+  %result = call i32 %fptr(ptr %obj, i32 1)
+  ; NORETP: call i32 %
+  ret i32 %result
+}
+
+; CHECK-LABEL: define i32 @fn4_rv_cpy
+; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
+define i32 @fn4_rv_cpy(ptr %obj) #0 {
+  %p = call i1 @llvm.type.test(ptr @vt1_1_rv, metadata !"typeid1")
+  call void @llvm.assume(i1 %p)
+  %fptr = call ptr @llvm.load.relative.i32(ptr @vt1_1_rv, i32 0)
+  ; RETP: call i32 @__typeid_typeid1_0_branch_funnel(ptr nest @vt1_1_rv, ptr %obj, i32 1)
+  %result = call i32 %fptr(ptr %obj, i32 1)
+  ; NORETP: call i32 %
+  ret i32 %result
+}
+
 ; CHECK-LABEL: define hidden void @__typeid_typeid1_0_branch_funnel(ptr nest %0, ...)
 ; CHECK-NEXT: musttail call void (...) @llvm.icall.branch.funnel(ptr %0, ptr {{(nonnull )?}}@vt1_1, ptr {{(nonnull )?}}@vf1_1, ptr {{(nonnull )?}}@vt1_2, ptr {{(nonnull )?}}@vf1_2, ...)
 
Index: llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
===================================================================
--- llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -1394,6 +1394,22 @@
     for (auto &&VCallSite : CSInfo.CallSites) {
       CallBase &CB = VCallSite.CB;
 
+      if (!CB.getParent()) {
+        // When finding devirtualizable calls, it's possible to find the same
+        // vtable passed to multiple llvm.type.test or llvm.type.checked.load
+        // calls, which can cause duplicate call sites to be recorded in
+        // [Const]CallSites. If we've already replaced and erased one of these
+        // call instances, just ignore it.
+        //
+        // All other areas in WPD which iterate over the CallSites account for
+        // duplicate call bases add them to a set to mark them as devirtualized
+        // and skip over them if they're already in the set. This doesn't
+        // necessarily mark them as devirtualized so we can't use the same set.
+        // It's also much easier to account for duplicates to see if they've
+        // been erased rather than tracking them in a set.
+        continue;
+      }
+
       // Jump tables are only profitable if the retpoline mitigation is enabled.
       Attribute FSAttr = CB.getCaller()->getFnAttribute("target-features");
       if (!FSAttr.isValid() ||


-------------- next part --------------
A non-text attachment was scrubbed...
Name: D146267.505948.patch
Type: text/x-patch
Size: 3815 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20230316/e4e7bf8b/attachment.bin>


More information about the llvm-commits mailing list