[llvm] b441707 - [FuncSpec] Constant propagate multiple arguments for recursive functions.

Alexandros Lamprineas via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 31 05:08:36 PDT 2022


Author: Alexandros Lamprineas
Date: 2022-03-31T13:00:08+01:00
New Revision: b4417075dc1cbfac0a3f777850ba77c031d7db3c

URL: https://github.com/llvm/llvm-project/commit/b4417075dc1cbfac0a3f777850ba77c031d7db3c
DIFF: https://github.com/llvm/llvm-project/commit/b4417075dc1cbfac0a3f777850ba77c031d7db3c.diff

LOG: [FuncSpec] Constant propagate multiple arguments for recursive functions.

This fixes a TODO in constantArgPropagation() to make it feature complete.
However, I do find myself in agreement with the review comments in
https://reviews.llvm.org/D106426. I don't think we should pursue
specializing such recursive functions as the code size increase becomes
linear to 'max-iters'. Compiling the modified test just with -O3 (no
function specialization) generates the same code.

Differential Revision: https://reviews.llvm.org/D122755

Added: 
    

Modified: 
    llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
    llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index c9775e097a45d..fe8b788c43306 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -19,7 +19,6 @@
 // Current limitations:
 // - It does not yet handle integer ranges. We do support "literal constants",
 //   but that's off by default under an option.
-// - Only 1 argument per function is specialised,
 // - The cost-model could be further looked into (it mainly focuses on inlining
 //   benefits),
 // - We are not yet caching analysis results, but profiling and checking where
@@ -210,35 +209,39 @@ static void constantArgPropagation(FuncList &WorkList, Module &M,
   // are any new constant values for the call instruction via
   // stack variables.
   for (auto *F : WorkList) {
-    // TODO: Generalize for any read only arguments.
-    if (F->arg_size() != 1)
-      continue;
-
-    auto &Arg = *F->arg_begin();
-    if (!Arg.onlyReadsMemory() || !Arg.getType()->isPointerTy())
-      continue;
 
     for (auto *User : F->users()) {
+
       auto *Call = dyn_cast<CallInst>(User);
       if (!Call)
-        break;
-      auto *ArgOp = Call->getArgOperand(0);
-      auto *ArgOpType = ArgOp->getType();
-      auto *ConstVal = getConstantStackValue(Call, ArgOp, Solver);
-      if (!ConstVal)
-        break;
+        continue;
 
-      Value *GV = new GlobalVariable(M, ConstVal->getType(), true,
-                                     GlobalValue::InternalLinkage, ConstVal,
-                                     "funcspec.arg");
+      bool Changed = false;
+      for (const Use &U : Call->args()) {
+        unsigned Idx = Call->getArgOperandNo(&U);
+        Value *ArgOp = Call->getArgOperand(Idx);
+        Type *ArgOpType = ArgOp->getType();
 
-      if (ArgOpType != ConstVal->getType())
-        GV = ConstantExpr::getBitCast(cast<Constant>(GV), ArgOp->getType());
+        if (!Call->onlyReadsMemory(Idx) || !ArgOpType->isPointerTy())
+          continue;
+
+        auto *ConstVal = getConstantStackValue(Call, ArgOp, Solver);
+        if (!ConstVal)
+          continue;
 
-      Call->setArgOperand(0, GV);
+        Value *GV = new GlobalVariable(M, ConstVal->getType(), true,
+                                       GlobalValue::InternalLinkage, ConstVal,
+                                       "funcspec.arg");
+        if (ArgOpType != ConstVal->getType())
+          GV = ConstantExpr::getBitCast(cast<Constant>(GV), ArgOpType);
+
+        Call->setArgOperand(Idx, GV);
+        Changed = true;
+      }
 
       // Add the changed CallInst to Solver Worklist
-      Solver.visitCall(*Call);
+      if (Changed)
+        Solver.visitCall(*Call);
     }
   }
 }

diff  --git a/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive.ll b/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive.ll
index 0ad3586a98025..82020edec442f 100644
--- a/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive.ll
+++ b/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive.ll
@@ -2,50 +2,58 @@
 ; RUN: opt -function-specialization -force-function-specialization -func-specialization-max-iters=3 -inline -instcombine -S < %s | FileCheck %s --check-prefix=ITERS3
 ; RUN: opt -function-specialization -force-function-specialization -func-specialization-max-iters=4 -inline -instcombine -S < %s | FileCheck %s --check-prefix=ITERS4
 
- at Global = internal constant i32 1, align 4
+ at low = internal constant i32 0, align 4
+ at high = internal constant i32 6, align 4
 
-define internal void @recursiveFunc(i32* nocapture readonly %arg) {
-  %temp = alloca i32, align 4
-  %arg.load = load i32, i32* %arg, align 4
-  %arg.cmp = icmp slt i32 %arg.load, 4
-  br i1 %arg.cmp, label %block6, label %ret.block
+define internal void @recursiveFunc(i32* nocapture readonly %lo, i32 %step, i32* nocapture readonly %hi) {
+  %lo.temp = alloca i32, align 4
+  %hi.temp = alloca i32, align 4
+  %lo.load = load i32, i32* %lo, align 4
+  %hi.load = load i32, i32* %hi, align 4
+  %cmp = icmp ne i32 %lo.load, %hi.load
+  br i1 %cmp, label %block6, label %ret.block
 
 block6:
-  call void @print_val(i32 %arg.load)
-  %arg.add = add nsw i32 %arg.load, 1
-  store i32 %arg.add, i32* %temp, align 4
-  call void @recursiveFunc(i32* nonnull %temp)
+  call void @print_val(i32 %lo.load, i32 %hi.load)
+  %add = add nsw i32 %lo.load, %step
+  %sub = sub nsw i32 %hi.load, %step
+  store i32 %add, i32* %lo.temp, align 4
+  store i32 %sub, i32* %hi.temp, align 4
+  call void @recursiveFunc(i32* nonnull %lo.temp, i32 %step, i32* nonnull %hi.temp)
   br label %ret.block
 
 ret.block:
   ret void
 }
 
-; ITERS2:  @funcspec.arg.3 = internal constant i32 3
-; ITERS3:  @funcspec.arg.5 = internal constant i32 4
+; ITERS2:  @funcspec.arg.4 = internal constant i32 2
+; ITERS2:  @funcspec.arg.5 = internal constant i32 4
+
+; ITERS3:  @funcspec.arg.7 = internal constant i32 3
+; ITERS3:  @funcspec.arg.8 = internal constant i32 3
 
 define i32 @main() {
 ; ITERS2-LABEL: @main(
-; ITERS2-NEXT:    call void @print_val(i32 1)
-; ITERS2-NEXT:    call void @print_val(i32 2)
-; ITERS2-NEXT:    call void @recursiveFunc(i32* nonnull @funcspec.arg.3)
+; ITERS2-NEXT:    call void @print_val(i32 0, i32 6)
+; ITERS2-NEXT:    call void @print_val(i32 1, i32 5)
+; ITERS2-NEXT:    call void @recursiveFunc(i32* nonnull @funcspec.arg.4, i32 1, i32* nonnull @funcspec.arg.5)
 ; ITERS2-NEXT:    ret i32 0
 ;
 ; ITERS3-LABEL: @main(
-; ITERS3-NEXT:    call void @print_val(i32 1)
-; ITERS3-NEXT:    call void @print_val(i32 2)
-; ITERS3-NEXT:    call void @print_val(i32 3)
-; ITERS3-NEXT:    call void @recursiveFunc(i32* nonnull @funcspec.arg.5)
+; ITERS3-NEXT:    call void @print_val(i32 0, i32 6)
+; ITERS3-NEXT:    call void @print_val(i32 1, i32 5)
+; ITERS3-NEXT:    call void @print_val(i32 2, i32 4)
+; ITERS3-NEXT:    call void @recursiveFunc(i32* nonnull @funcspec.arg.7, i32 1, i32* nonnull @funcspec.arg.8)
 ; ITERS3-NEXT:    ret i32 0
 ;
 ; ITERS4-LABEL: @main(
-; ITERS4-NEXT:    call void @print_val(i32 1)
-; ITERS4-NEXT:    call void @print_val(i32 2)
-; ITERS4-NEXT:    call void @print_val(i32 3)
+; ITERS4-NEXT:    call void @print_val(i32 0, i32 6)
+; ITERS4-NEXT:    call void @print_val(i32 1, i32 5)
+; ITERS4-NEXT:    call void @print_val(i32 2, i32 4)
 ; ITERS4-NEXT:    ret i32 0
 ;
-  call void @recursiveFunc(i32* nonnull @Global)
+  call void @recursiveFunc(i32* nonnull @low, i32 1, i32* nonnull @high)
   ret i32 0
 }
 
-declare dso_local void @print_val(i32)
+declare dso_local void @print_val(i32, i32)


        


More information about the llvm-commits mailing list