[llvm] 7ea597e - [FuncSpec] Consider constant struct arguments when specializing.

Alexandros Lamprineas via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 17 10:40:10 PDT 2023


Author: Alexandros Lamprineas
Date: 2023-04-17T18:39:21+01:00
New Revision: 7ea597ead9fc9773ee5a5f5595fa4a61a68cc90f

URL: https://github.com/llvm/llvm-project/commit/7ea597ead9fc9773ee5a5f5595fa4a61a68cc90f
DIFF: https://github.com/llvm/llvm-project/commit/7ea597ead9fc9773ee5a5f5595fa4a61a68cc90f.diff

LOG: [FuncSpec] Consider constant struct arguments when specializing.

Optionally enabled just like integer and floating point arguments.

Differential Revision: https://reviews.llvm.org/D145374

Added: 
    llvm/test/Transforms/FunctionSpecialization/constant-struct.ll

Modified: 
    llvm/include/llvm/Transforms/Utils/SCCPSolver.h
    llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
    llvm/lib/Transforms/Utils/SCCPSolver.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h
index 633052f1c15e7..3512f0a867b91 100644
--- a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h
+++ b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h
@@ -168,17 +168,18 @@ class SCCPSolver {
   /// range with a single element.
   Constant *getConstant(const ValueLatticeElement &LV) const;
 
+  /// Return either a Constant or nullptr for a given Value.
+  Constant *getConstantOrNull(Value *V) const;
+
   /// Return a reference to the set of argument tracked functions.
   SmallPtrSetImpl<Function *> &getArgumentTrackedFunctions();
 
-  /// Mark the constant arguments of a new function specialization. \p F points
-  /// to the cloned function and \p Args contains a list of constant arguments
-  /// represented as pairs of {formal,actual} values (the formal argument is
-  /// associated with the original function definition). All other arguments of
-  /// the specialization inherit the lattice state of their corresponding values
-  /// in the original function.
-  void markArgInFuncSpecialization(Function *F,
-                                   const SmallVectorImpl<ArgInfo> &Args);
+  /// Set the Lattice Value for the arguments of a specialization \p F.
+  /// If an argument is Constant then its lattice value is marked with the
+  /// corresponding actual argument in \p Args. Otherwise, its lattice value
+  /// is inherited (copied) from the corresponding formal argument in \p Args.
+  void setLatticeValueForSpecializationArguments(Function *F,
+                                       const SmallVectorImpl<ArgInfo> &Args);
 
   /// Mark all of the blocks in function \p F non-executable. Clients can used
   /// this method to erase a function from the module (e.g., if it has been

diff  --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index 8f5878feab60f..4fc609e4405e4 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -536,7 +536,7 @@ Function *FunctionSpecializer::createSpecialization(Function *F, const SpecSig &
   // Initialize the lattice state of the arguments of the function clone,
   // marking the argument on which we specialized the function constant
   // with the given value.
-  Solver.markArgInFuncSpecialization(Clone, S.Args);
+  Solver.setLatticeValueForSpecializationArguments(Clone, S.Args);
 
   Solver.addArgumentTrackedFunction(Clone);
   Solver.markBlockExecutable(&Clone->front());
@@ -666,16 +666,9 @@ bool FunctionSpecializer::isArgumentInteresting(Argument *A) {
   if (A->user_empty())
     return false;
 
-  // For now, don't attempt to specialize functions based on the values of
-  // composite types.
-  Type *ArgTy = A->getType();
-  if (!ArgTy->isSingleValueType())
-    return false;
-
-  // Specialization of integer and floating point types needs to be explicitly
-  // enabled.
-  if (!SpecializeLiteralConstant &&
-      (ArgTy->isIntegerTy() || ArgTy->isFloatingPointTy()))
+  Type *Ty = A->getType();
+  if (!Ty->isPointerTy() && (!SpecializeLiteralConstant ||
+      (!Ty->isIntegerTy() && !Ty->isFloatingPointTy() && !Ty->isStructTy())))
     return false;
 
   // SCCP solver does not record an argument that will be constructed on
@@ -686,21 +679,22 @@ bool FunctionSpecializer::isArgumentInteresting(Argument *A) {
   // Check the lattice value and decide if we should attemt to specialize,
   // based on this argument. No point in specialization, if the lattice value
   // is already a constant.
-  const ValueLatticeElement &LV = Solver.getLatticeValueFor(A);
-  if (LV.isUnknownOrUndef() || LV.isConstant() ||
-      (LV.isConstantRange() && LV.getConstantRange().isSingleElement())) {
-    LLVM_DEBUG(dbgs() << "FnSpecialization: Nothing to do, parameter "
-                      << A->getNameOrAsOperand() << " is already constant\n");
-    return false;
-  }
-
-  LLVM_DEBUG(dbgs() << "FnSpecialization: Found interesting parameter "
-                    << A->getNameOrAsOperand() << "\n");
+  bool IsOverdefined = Ty->isStructTy()
+    ? any_of(Solver.getStructLatticeValueFor(A), SCCPSolver::isOverdefined)
+    : SCCPSolver::isOverdefined(Solver.getLatticeValueFor(A));
 
-  return true;
+  LLVM_DEBUG(
+    if (IsOverdefined)
+      dbgs() << "FnSpecialization: Found interesting parameter "
+             << A->getNameOrAsOperand() << "\n";
+    else
+      dbgs() << "FnSpecialization: Nothing to do, parameter "
+             << A->getNameOrAsOperand() << " is already constant\n";
+  );
+  return IsOverdefined;
 }
 
-/// Check if the valuy \p V  (an actual argument) is a constant or can only
+/// Check if the value \p V  (an actual argument) is a constant or can only
 /// have a constant value. Return that constant.
 Constant *FunctionSpecializer::getCandidateConstant(Value *V) {
   if (isa<PoisonValue>(V))
@@ -720,18 +714,8 @@ Constant *FunctionSpecializer::getCandidateConstant(Value *V) {
   // Select for possible specialisation values that are constants or
   // are deduced to be constants or constant ranges with a single element.
   Constant *C = dyn_cast<Constant>(V);
-  if (!C) {
-    const ValueLatticeElement &LV = Solver.getLatticeValueFor(V);
-    if (LV.isConstant())
-      C = LV.getConstant();
-    else if (LV.isConstantRange() && LV.getConstantRange().isSingleElement()) {
-      assert(V->getType()->isIntegerTy() && "Non-integral constant range");
-      C = Constant::getIntegerValue(V->getType(),
-                                    *LV.getConstantRange().getSingleElement());
-    } else
-      return nullptr;
-  }
-
+  if (!C)
+    C = Solver.getConstantOrNull(V);
   return C;
 }
 

diff  --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp
index 3dd155ba3c9b2..cc86ab26471a6 100644
--- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp
+++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp
@@ -73,30 +73,9 @@ static bool canRemoveInstruction(Instruction *I) {
 }
 
 bool SCCPSolver::tryToReplaceWithConstant(Value *V) {
-  Constant *Const = nullptr;
-  if (V->getType()->isStructTy()) {
-    std::vector<ValueLatticeElement> IVs = getStructLatticeValueFor(V);
-    if (llvm::any_of(IVs, isOverdefined))
-      return false;
-    std::vector<Constant *> ConstVals;
-    auto *ST = cast<StructType>(V->getType());
-    for (unsigned i = 0, e = ST->getNumElements(); i != e; ++i) {
-      ValueLatticeElement V = IVs[i];
-      ConstVals.push_back(SCCPSolver::isConstant(V)
-                              ? getConstant(V)
-                              : UndefValue::get(ST->getElementType(i)));
-    }
-    Const = ConstantStruct::get(ST, ConstVals);
-  } else {
-    const ValueLatticeElement &IV = getLatticeValueFor(V);
-    if (isOverdefined(IV))
-      return false;
-
-    Const = SCCPSolver::isConstant(IV) ? getConstant(IV)
-                                       : UndefValue::get(V->getType());
-  }
-  assert(Const && "Constant is nullptr here!");
-
+  Constant *Const = getConstantOrNull(V);
+  if (!Const)
+    return false;
   // Replacing `musttail` instructions with constant breaks `musttail` invariant
   // unless the call itself can be removed.
   // Calls with "clang.arc.attachedcall" implicitly use the return value and
@@ -734,12 +713,14 @@ class SCCPInstVisitor : public InstVisitor<SCCPInstVisitor> {
 
   Constant *getConstant(const ValueLatticeElement &LV) const;
 
+  Constant *getConstantOrNull(Value *V) const;
+
   SmallPtrSetImpl<Function *> &getArgumentTrackedFunctions() {
     return TrackingIncomingArguments;
   }
 
-  void markArgInFuncSpecialization(Function *F,
-                                   const SmallVectorImpl<ArgInfo> &Args);
+  void setLatticeValueForSpecializationArguments(Function *F,
+                                       const SmallVectorImpl<ArgInfo> &Args);
 
   void markFunctionUnreachable(Function *F) {
     for (auto &BB : *F)
@@ -833,36 +814,68 @@ Constant *SCCPInstVisitor::getConstant(const ValueLatticeElement &LV) const {
   return nullptr;
 }
 
-void SCCPInstVisitor::markArgInFuncSpecialization(
-    Function *F, const SmallVectorImpl<ArgInfo> &Args) {
+Constant *SCCPInstVisitor::getConstantOrNull(Value *V) const {
+  Constant *Const = nullptr;
+  if (V->getType()->isStructTy()) {
+    std::vector<ValueLatticeElement> LVs = getStructLatticeValueFor(V);
+    if (any_of(LVs, SCCPSolver::isOverdefined))
+      return nullptr;
+    std::vector<Constant *> ConstVals;
+    auto *ST = cast<StructType>(V->getType());
+    for (unsigned I = 0, E = ST->getNumElements(); I != E; ++I) {
+      ValueLatticeElement LV = LVs[I];
+      ConstVals.push_back(SCCPSolver::isConstant(LV)
+                              ? getConstant(LV)
+                              : UndefValue::get(ST->getElementType(I)));
+    }
+    Const = ConstantStruct::get(ST, ConstVals);
+  } else {
+    const ValueLatticeElement &LV = getLatticeValueFor(V);
+    if (SCCPSolver::isOverdefined(LV))
+      return nullptr;
+    Const = SCCPSolver::isConstant(LV) ? getConstant(LV)
+                                       : UndefValue::get(V->getType());
+  }
+  assert(Const && "Constant is nullptr here!");
+  return Const;
+}
+
+void SCCPInstVisitor::setLatticeValueForSpecializationArguments(Function *F,
+                                        const SmallVectorImpl<ArgInfo> &Args) {
   assert(!Args.empty() && "Specialization without arguments");
   assert(F->arg_size() == Args[0].Formal->getParent()->arg_size() &&
          "Functions should have the same number of arguments");
 
   auto Iter = Args.begin();
-  Argument *NewArg = F->arg_begin();
-  Argument *OldArg = Args[0].Formal->getParent()->arg_begin();
+  Function::arg_iterator NewArg = F->arg_begin();
+  Function::arg_iterator OldArg = Args[0].Formal->getParent()->arg_begin();
   for (auto End = F->arg_end(); NewArg != End; ++NewArg, ++OldArg) {
 
     LLVM_DEBUG(dbgs() << "SCCP: Marking argument "
                       << NewArg->getNameOrAsOperand() << "\n");
 
-    if (Iter != Args.end() && OldArg == Iter->Formal) {
-      // Mark the argument constants in the new function.
-      markConstant(NewArg, Iter->Actual);
+    // Mark the argument constants in the new function
+    // or copy the lattice state over from the old function.
+    if (Iter != Args.end() && Iter->Formal == &*OldArg) {
+      if (auto *STy = dyn_cast<StructType>(NewArg->getType())) {
+        for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I) {
+          ValueLatticeElement &NewValue = StructValueState[{&*NewArg, I}];
+          NewValue.markConstant(Iter->Actual->getAggregateElement(I));
+        }
+      } else {
+        ValueState[&*NewArg].markConstant(Iter->Actual);
+      }
       ++Iter;
-    } else if (ValueState.count(OldArg)) {
-      // For the remaining arguments in the new function, copy the lattice state
-      // over from the old function.
-      //
-      // Note: This previously looked like this:
-      // ValueState[NewArg] = ValueState[OldArg];
-      // This is incorrect because the DenseMap class may resize the underlying
-      // memory when inserting `NewArg`, which will invalidate the reference to
-      // `OldArg`. Instead, we make sure `NewArg` exists before setting it.
-      auto &NewValue = ValueState[NewArg];
-      NewValue = ValueState[OldArg];
-      pushToWorkList(NewValue, NewArg);
+    } else {
+      if (auto *STy = dyn_cast<StructType>(NewArg->getType())) {
+        for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I) {
+          ValueLatticeElement &NewValue = StructValueState[{&*NewArg, I}];
+          NewValue = StructValueState[{&*OldArg, I}];
+        }
+      } else {
+        ValueLatticeElement &NewValue = ValueState[&*NewArg];
+        NewValue = ValueState[&*OldArg];
+      }
     }
   }
 }
@@ -1945,13 +1958,17 @@ Constant *SCCPSolver::getConstant(const ValueLatticeElement &LV) const {
   return Visitor->getConstant(LV);
 }
 
+Constant *SCCPSolver::getConstantOrNull(Value *V) const {
+  return Visitor->getConstantOrNull(V);
+}
+
 SmallPtrSetImpl<Function *> &SCCPSolver::getArgumentTrackedFunctions() {
   return Visitor->getArgumentTrackedFunctions();
 }
 
-void SCCPSolver::markArgInFuncSpecialization(
-    Function *F, const SmallVectorImpl<ArgInfo> &Args) {
-  Visitor->markArgInFuncSpecialization(F, Args);
+void SCCPSolver::setLatticeValueForSpecializationArguments(Function *F,
+                                   const SmallVectorImpl<ArgInfo> &Args) {
+  Visitor->setLatticeValueForSpecializationArguments(F, Args);
 }
 
 void SCCPSolver::markFunctionUnreachable(Function *F) {

diff  --git a/llvm/test/Transforms/FunctionSpecialization/constant-struct.ll b/llvm/test/Transforms/FunctionSpecialization/constant-struct.ll
new file mode 100644
index 0000000000000..6c3bfaef49b0a
--- /dev/null
+++ b/llvm/test/Transforms/FunctionSpecialization/constant-struct.ll
@@ -0,0 +1,46 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+
+; RUN: opt -passes="ipsccp<func-spec>" -force-specialization \
+; RUN:     -funcspec-for-literal-constant -S < %s | FileCheck %s
+
+define i32 @foo(i32 %y0, i32 %y1) {
+; CHECK-LABEL: @foo(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[Y:%.*]] = insertvalue { i32, i32 } undef, i32 [[Y0:%.*]], 0
+; CHECK-NEXT:    [[YY:%.*]] = insertvalue { i32, i32 } [[Y]], i32 [[Y1:%.*]], 1
+; CHECK-NEXT:    [[CALL:%.*]] = tail call i32 @add.1({ i32, i32 } { i32 2, i32 3 }, { i32, i32 } [[YY]])
+; CHECK-NEXT:    ret i32 [[CALL]]
+;
+entry:
+  %y = insertvalue { i32, i32 } undef, i32 %y0, 0
+  %yy = insertvalue { i32, i32 } %y, i32 %y1, 1
+  %call = tail call i32 @add({i32, i32} {i32 2, i32 3}, {i32, i32} %yy)
+  ret i32 %call
+}
+
+define i32 @bar(i32 %x0, i32 %x1) {
+; CHECK-LABEL: @bar(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[X:%.*]] = insertvalue { i32, i32 } undef, i32 [[X0:%.*]], 0
+; CHECK-NEXT:    [[XX:%.*]] = insertvalue { i32, i32 } [[X]], i32 [[X1:%.*]], 1
+; CHECK-NEXT:    [[CALL:%.*]] = tail call i32 @add.2({ i32, i32 } [[XX]], { i32, i32 } { i32 3, i32 2 })
+; CHECK-NEXT:    ret i32 [[CALL]]
+;
+entry:
+  %x = insertvalue { i32, i32 } undef, i32 %x0, 0
+  %xx = insertvalue { i32, i32 } %x, i32 %x1, 1
+  %call = tail call i32 @add({i32, i32} %xx, {i32, i32} {i32 3, i32 2})
+  ret i32 %call
+}
+
+define internal i32 @add({i32, i32} %x, {i32, i32} %y) {
+entry:
+  %x0 = extractvalue {i32, i32} %x, 0
+  %y0 = extractvalue {i32, i32} %y, 0
+  %add0 = add nsw i32 %x0, %y0
+  %x1 = extractvalue {i32, i32} %x, 1
+  %y1 = extractvalue {i32, i32} %y, 1
+  %add1 = add nsw i32 %x1, %y1
+  %mul = mul i32 %add0, %add1
+  ret i32 %mul
+}


        


More information about the llvm-commits mailing list