[llvm] Pack out arguments into a struct (PR #119267)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Dec 9 12:42:36 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: None (elhewaty)
<details>
<summary>Changes</summary>
- **[Attributor] Add pre-commit tests**
- **[Attributor] Pack out arguments into a struct**
---
Full diff: https://github.com/llvm/llvm-project/pull/119267.diff
2 Files Affected:
- (modified) llvm/lib/Transforms/IPO/Attributor.cpp (+113)
- (added) llvm/test/Transforms/Attributor/remove_out_args.ll (+16)
``````````diff
diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp
index 116f419129a239..90eecf45892ee6 100644
--- a/llvm/lib/Transforms/IPO/Attributor.cpp
+++ b/llvm/lib/Transforms/IPO/Attributor.cpp
@@ -2963,6 +2963,119 @@ bool Attributor::shouldSeedAttribute(AbstractAttribute &AA) {
return Result;
}
+// For now: argument can be put in the struct if it's write only and
+// has no aliases.
+static bool canBeComapctedInAStruct(const Argument &Arg, Attributor &A, const AbstractAttribute &QueryingAA) {
+ IRPosition ArgPosition = IRPosition::argument(Arg);
+ // Check if Arg has no alias.
+ auto *AAliasInfo = A.getAAFor<AANoAlias>(QueryingAA, ArgPosition, DepClassTy::NONE);
+ if (!AAliasInfo || !AAliasInfo->isKnownNoAlias())
+ return false;
+
+ // Check if Arg is write-only.
+ const auto *MemBehaviorAA =
+ A.getAAFor<AAMemoryBehavior>(QueryingAA, ArgPosition, DepClassTy::NONE);
+ if (!MemBehaviorAA || !MemBehaviorAA->isKnownWriteOnly())
+ return false;
+
+ return true;
+}
+
+static void replaceArgRetWithStructRetCalls(Function &OldFunction, Function &NewFunction) {
+ for (auto UseItr = OldFunction.use_begin(); UseItr != OldFunction.use_end(); ++UseItr) {
+ CallBase *Call = dyn_cast<CallBase>(UseItr->getUser());
+ if (!Call)
+ continue;
+
+ IRBuilder<> Builder(Call);
+ SmallVector<Value *, 8> NewArgs;
+ for (unsigned ArgIdx = 0; ArgIdx < Call->arg_size(); ++ArgIdx)
+ if (std::find_if(OldFunction.arg_begin(), OldFunction.arg_end(),
+ [&](Argument &Arg) { return &Arg == Call->getArgOperand(ArgIdx); }) == OldFunction.arg_end())
+ NewArgs.push_back(Call->getArgOperand(ArgIdx));
+
+ CallInst *NewCall = Builder.CreateCall(&NewFunction, NewArgs);
+ Call->replaceAllUsesWith(NewCall);
+ Call->eraseFromParent();
+ }
+}
+
+static bool convertOutArgsToRetStruct(Function &F, Attributor &A, AbstractAttribute &QueryingAA) {
+ // Get valid ptr args.
+ DenseMap<Argument *, Type *> PtrToType;
+ for (unsigned ArgIdx = 0; ArgIdx < F.arg_size(); ++ArgIdx) {
+ Argument *Arg = F.getArg(ArgIdx);
+ if (Arg->getType()->isPointerTy() && canBeComapctedInAStruct(*Arg, A, QueryingAA)) {
+ // Get the the type of the pointer through its users
+ for (auto UseItr = Arg->use_begin(); UseItr != Arg->use_end(); ++UseItr) {
+ auto *Store = dyn_cast<StoreInst>(UseItr->getUser());
+ if (Store)
+ PtrToType[Arg] = Store->getValueOperand()->getType();
+ }
+ }
+ }
+
+ // If there is no valid candidates then return false.
+ if (PtrToType.empty())
+ return false;
+
+ // Create the new struct return type.
+ SmallVector<Type *, 4> OutStructElements;
+ if (auto *OriginalFuncTy = F.getReturnType(); !OriginalFuncTy->isVoidTy())
+ OutStructElements.push_back(OriginalFuncTy);
+
+ for (const auto &[Arg, Type] : PtrToType)
+ OutStructElements.push_back(Type);
+
+ auto *ReturnStructType = StructType::create(F.getContext(), OutStructElements, (F.getName() + "Out").str());
+
+ // Get the new Args.
+ SmallVector<Type *, 4> NewParamTypes;
+ for (unsigned ArgIdx = 0; ArgIdx < F.arg_size(); ++ArgIdx)
+ if (!PtrToType.count(F.getArg(ArgIdx)))
+ NewParamTypes.push_back(F.getArg(ArgIdx)->getType());
+
+ auto *NewFunctionType = FunctionType::get(ReturnStructType, NewParamTypes, F.isVarArg());
+ auto *NewFunction = Function::Create(NewFunctionType, F.getLinkage(), F.getAddressSpace(), F.getName());
+
+ // Map old args to new args.
+ ValueToValueMapTy VMap;
+ auto *NewArgIt = NewFunction->arg_begin();
+ for (Argument &OldArg : F.args())
+ if (!PtrToType.count(F.getArg(OldArg.getArgNo())))
+ VMap[&OldArg] = &(*NewArgIt++);
+
+ // Clone the old function into the new one.
+ SmallVector<ReturnInst *, 8> Returns;
+ CloneFunctionInto(NewFunction, &F, VMap, CloneFunctionChangeType::LocalChangesOnly, Returns);
+
+ // Update the return values (make it struct).
+ for (ReturnInst *Ret : Returns) {
+ IRBuilder<> Builder(Ret);
+ SmallVector<Value *, 4> StructValues;
+ // Include original return type, if any
+ if (auto *OriginalFuncTy = F.getReturnType(); !OriginalFuncTy->isVoidTy())
+ StructValues.push_back(Ret->getReturnValue());
+
+ // Create a load instruction to fill the struct element.
+ for (const auto &[Arg, Ty] : PtrToType) {
+ Value *OutVal = Builder.CreateLoad(Ty, VMap[Arg]);
+ StructValues.push_back(OutVal);
+ }
+
+ // Build the return struct incrementally.
+ Value *StructRetVal = UndefValue::get(ReturnStructType);
+ for (unsigned i = 0; i < StructValues.size(); ++i)
+ StructRetVal = Builder.CreateInsertValue(StructRetVal, StructValues[i], i);
+
+ Builder.CreateRet(StructRetVal);
+ Ret->eraseFromParent();
+ }
+
+ replaceArgRetWithStructRetCalls(F, *NewFunction);
+ F.eraseFromParent();
+}
+
ChangeStatus Attributor::rewriteFunctionSignatures(
SmallSetVector<Function *, 8> &ModifiedFns) {
ChangeStatus Changed = ChangeStatus::UNCHANGED;
diff --git a/llvm/test/Transforms/Attributor/remove_out_args.ll b/llvm/test/Transforms/Attributor/remove_out_args.ll
new file mode 100644
index 00000000000000..bd52bf5d80656c
--- /dev/null
+++ b/llvm/test/Transforms/Attributor/remove_out_args.ll
@@ -0,0 +1,16 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes=attributor < %s | FileCheck %s
+
+
+
+define i1 @foo(ptr %dst) {
+; CHECK-LABEL: define noundef i1 @foo(
+; CHECK-SAME: ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[DST:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: store i32 42, ptr [[DST]], align 4
+; CHECK-NEXT: ret i1 true
+;
+entry:
+ store i32 42, ptr %dst
+ ret i1 1
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/119267
More information about the llvm-commits
mailing list