[llvm] [Attributor] Pack out arguments into a struct (PR #119267)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 26 05:02:04 PST 2024


================
@@ -12989,6 +12991,179 @@ struct AAAllocationInfoCallSiteArgument : AAAllocationInfoImpl {
 };
 } // namespace
 
+/// ----------- AAConvertOutArgument ----------
+namespace {
+struct AAConvertOutArgumentFunction final : AAConvertOutArgument {
+  AAConvertOutArgumentFunction(const IRPosition &IRP, Attributor &A)
+      : AAConvertOutArgument(IRP, A) {}
+
+  /// See AbstractAttribute::updateImpl(...).
+  ChangeStatus updateImpl(Attributor &A) override {
+    const Function *F = getAssociatedFunction();
+    if (!F || F->isDeclaration())
+      return indicatePessimisticFixpoint();
+
+    bool hasCandidateArg = false;
+    for (const Argument &Arg : F->args())
+      if (Arg.getType()->isPointerTy() && isEligibleArgument(Arg, A, *this))
+        hasCandidateArg = true;
+
+    return hasCandidateArg ? indicateOptimisticFixpoint()
+                           : indicatePessimisticFixpoint();
+  }
+
+  /// See AbstractAttribute::manifest(...).
+  ChangeStatus manifest(Attributor &A) override {
+    const Function &F = *getAssociatedFunction();
+    DenseMap<Argument*, Type*> PtrToType;
+    SmallVector<Argument *, 4> CandidateArgs;
+    for (unsigned argIdx = 0; argIdx < F.arg_size(); ++argIdx) {
+      Argument *Arg = F.getArg(argIdx);
+      if (isEligibleArgument(*Arg, A, *this)) {
+        CandidateArgs.push_back(Arg);
+        for (auto UseItr = Arg->use_begin(); UseItr != Arg->use_end(); ++UseItr) {
+          auto *Store = dyn_cast<StoreInst>(UseItr->getUser());
+          if (Store)
+            PtrToType[Arg] = Store->getValueOperand()->getType();
+        }
+      }
+    }
+
+    // If there is no valid candidates then return false.
+    if (PtrToType.empty())
+      return ChangeStatus::UNCHANGED;
+
+    // Create the new struct return type.
+    SmallVector<Type *, 4> OutStructElementsTypes;
+    if (auto *OriginalFuncTy = F.getReturnType(); !OriginalFuncTy->isVoidTy())
+      OutStructElementsTypes.push_back(OriginalFuncTy);
+
+    for (auto *Arg : CandidateArgs)
+      OutStructElementsTypes.push_back(PtrToType[Arg]);
+
+    auto *ReturnStructType = StructType::create(F.getContext(), OutStructElementsTypes, (F.getName() + "Out").str());
+
+    // Get the new Args.
+    SmallVector<Type *, 4> NewParamTypes;
+    for (unsigned ArgIdx = 0; ArgIdx < F.arg_size(); ++ArgIdx)
+      if (!PtrToType.count(F.getArg(ArgIdx)))
+        NewParamTypes.push_back(F.getArg(ArgIdx)->getType());
+
+    auto *NewFunctionType = FunctionType::get(ReturnStructType, NewParamTypes, F.isVarArg());
+    auto *NewFunction = Function::Create(NewFunctionType, F.getLinkage(), F.getAddressSpace(), F.getName() + ".converted");
+
+    // Map old args to new args.
+    ValueToValueMapTy VMap;
+    auto *NewArgIt = NewFunction->arg_begin();
+    for (const Argument &OldArg : F.args())
+      if (!PtrToType.count(F.getArg(OldArg.getArgNo())))
+        VMap[&OldArg] = &(*NewArgIt++);
+
+
+    // Clone the old function into the new one.
+    SmallVector<ReturnInst *, 8> Returns;
+    CloneFunctionInto(NewFunction, &F, VMap, CloneFunctionChangeType::LocalChangesOnly, Returns);
+
+    // Update the return values (make it struct).
+    for (ReturnInst *Ret : Returns) {
+      IRBuilder<> Builder(Ret);
+      SmallVector<Value *, 4> StructValues;
+      // Include original return type, if any
+      if (auto *OriginalFuncTy = F.getReturnType(); !OriginalFuncTy->isVoidTy())
+        StructValues.push_back(Ret->getReturnValue());
+
+      // Create a load instruction to fill the struct element.
+      for (auto *Arg: CandidateArgs) {
+        Value *OutVal = Builder.CreateLoad(PtrToType[Arg], VMap[Arg]);
+        StructValues.push_back(OutVal);
+      }
+
+      // Build the return struct incrementally.
+      Value *StructRetVal = UndefValue::get(ReturnStructType);
+      for (unsigned i = 0; i < StructValues.size(); ++i)
+        StructRetVal = Builder.CreateInsertValue(StructRetVal, StructValues[i], i);
+
+      Builder.CreateRet(StructRetVal);
+      A.deleteAfterManifest(*Ret);
+    }
+  }
+
+  /// See AbstractAttribute::getAsStr(...).
+  const std::string getAsStr(Attributor *A) const override {
+    return "AAConvertOutArgumentFunction";
+  }
+
+  /// See AbstractAttribute::trackStatistics()
+  void trackStatistics() const override {}
+};
+
+struct AAConvertOutArgumentCallSite final : AAConvertOutArgument {
+  AAConvertOutArgumentCallSite(const IRPosition &IRP, Attributor &A)
+      : AAConvertOutArgument(IRP, A) {}
+
+  /// See AbstractAttribute::updateImpl(...).
+  ChangeStatus updateImpl(Attributor &A) override {
+    CallBase *CB = cast<CallBase>(getCtxI());
+    Function *F = CB->getCalledFunction();
+    if (!F)
+      return indicatePessimisticFixpoint();
+
+    // Get convert attribute.
+    auto *ConvertAA = A.getAAFor<AAConvertOutArgument>(
+        *this, IRPosition::function(*F), DepClassTy::REQUIRED);
+
+    // If function will be transformed, mark this call site for update
+    if (!ConvertAA || ConvertAA->isAssumedConvertible())
+      return ChangeStatus::CHANGED;
+
+    return ChangeStatus::UNCHANGED;
+  }
+
+  /// See AbstractAttribute::manifest(...).
+  ChangeStatus manifest(Attributor &A) override {
+    CallBase *CB = cast<CallBase>(getCtxI());
+    Function *F = CB->getCalledFunction();
+    if (!F)
+      return ChangeStatus::UNCHANGED;
+
+    IRBuilder<> Builder(CB);
+    // Create args for new call.
+    SmallVector<Value *, 4> NewArgs;
+    for (unsigned ArgIdx = 0; ArgIdx < CB->arg_size(); ++ArgIdx) {
+      Value *Arg = CB->getArgOperand(ArgIdx);
+      Argument *ParamArg = F->getArg(ArgIdx); 
+      if (!isEligibleArgument(*ParamArg, A, *this))
+        NewArgs.push_back(Arg);
+    }
+
+    Module *M = F->getParent();
+    auto *NewF = M->getFunction((F->getName() + ".converted").str());
+    if (!NewF)
+      return ChangeStatus::UNCHANGED;
+
+    FunctionCallee NewCallee(NewF->getFunctionType(), NewF);
+    Instruction *NewCall = CallInst::Create(NewCallee, NewArgs, CB->getName() + ".converted", CB);
+    IRPosition ReturnPos = IRPosition::callsite_returned(*CB);
+    A.changeAfterManifest(ReturnPos, *NewCall);
+
+    // Redirect all uses of the old call to the new call.
+    for (auto &Use : CB->uses())
+      Use.set(NewCall);
----------------
arsenm wrote:

Should assert that this is the callee use, and not a data operand or other type of non-call user 

https://github.com/llvm/llvm-project/pull/119267


More information about the llvm-commits mailing list