[llvm] Pack out arguments into a struct (PR #119267)

via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 9 12:46:22 PST 2024


https://github.com/elhewaty updated https://github.com/llvm/llvm-project/pull/119267

>From dd5735feee88c633895c3f39640fef3962c45d5b Mon Sep 17 00:00:00 2001
From: Mohamed Atef <mohamedatef1698 at gmail.com>
Date: Mon, 9 Dec 2024 22:24:29 +0200
Subject: [PATCH 1/2] [Attributor] Add pre-commit tests

---
 .../Transforms/Attributor/remove_out_args.ll     | 16 ++++++++++++++++
 1 file changed, 16 insertions(+)
 create mode 100644 llvm/test/Transforms/Attributor/remove_out_args.ll

diff --git a/llvm/test/Transforms/Attributor/remove_out_args.ll b/llvm/test/Transforms/Attributor/remove_out_args.ll
new file mode 100644
index 00000000000000..bd52bf5d80656c
--- /dev/null
+++ b/llvm/test/Transforms/Attributor/remove_out_args.ll
@@ -0,0 +1,16 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes=attributor < %s | FileCheck %s
+
+
+
+define i1 @foo(ptr %dst) {
+; CHECK-LABEL: define noundef i1 @foo(
+; CHECK-SAME: ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[DST:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    store i32 42, ptr [[DST]], align 4
+; CHECK-NEXT:    ret i1 true
+;
+entry:
+  store i32 42, ptr %dst
+  ret i1 1
+}

>From 8e887118f95fa6048e5ebe663de53b4cbc7204fb Mon Sep 17 00:00:00 2001
From: Mohamed Atef <mohamedatef1698 at gmail.com>
Date: Mon, 9 Dec 2024 22:40:02 +0200
Subject: [PATCH 2/2] [Attributor] Pack out arguments into a struct

---
 llvm/lib/Transforms/IPO/Attributor.cpp | 126 +++++++++++++++++++++++++
 1 file changed, 126 insertions(+)

diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp
index 116f419129a239..1757d65080fa1b 100644
--- a/llvm/lib/Transforms/IPO/Attributor.cpp
+++ b/llvm/lib/Transforms/IPO/Attributor.cpp
@@ -2963,6 +2963,132 @@ bool Attributor::shouldSeedAttribute(AbstractAttribute &AA) {
   return Result;
 }
 
+// For now: argument can be put in the struct if it's write only and
+// has no aliases.
+static bool canBeComapctedInAStruct(const Argument &Arg, Attributor &A,
+                                    const AbstractAttribute &QueryingAA) {
+  IRPosition ArgPosition = IRPosition::argument(Arg);
+  // Check if Arg has no alias.
+  auto *AAliasInfo =
+      A.getAAFor<AANoAlias>(QueryingAA, ArgPosition, DepClassTy::NONE);
+  if (!AAliasInfo || !AAliasInfo->isKnownNoAlias())
+    return false;
+
+  // Check if Arg is write-only.
+  const auto *MemBehaviorAA =
+      A.getAAFor<AAMemoryBehavior>(QueryingAA, ArgPosition, DepClassTy::NONE);
+  if (!MemBehaviorAA || !MemBehaviorAA->isKnownWriteOnly())
+    return false;
+
+  return true;
+}
+
+static void replaceArgRetWithStructRetCalls(Function &OldFunction,
+                                            Function &NewFunction) {
+  for (auto UseItr = OldFunction.use_begin(); UseItr != OldFunction.use_end();
+       ++UseItr) {
+    CallBase *Call = dyn_cast<CallBase>(UseItr->getUser());
+    if (!Call)
+      continue;
+
+    IRBuilder<> Builder(Call);
+    SmallVector<Value *, 8> NewArgs;
+    for (unsigned ArgIdx = 0; ArgIdx < Call->arg_size(); ++ArgIdx)
+      if (std::find_if(OldFunction.arg_begin(), OldFunction.arg_end(),
+                       [&](Argument &Arg) {
+                         return &Arg == Call->getArgOperand(ArgIdx);
+                       }) == OldFunction.arg_end())
+        NewArgs.push_back(Call->getArgOperand(ArgIdx));
+
+    CallInst *NewCall = Builder.CreateCall(&NewFunction, NewArgs);
+    Call->replaceAllUsesWith(NewCall);
+    Call->eraseFromParent();
+  }
+}
+
+static bool convertOutArgsToRetStruct(Function &F, Attributor &A,
+                                      AbstractAttribute &QueryingAA) {
+  // Get valid ptr args.
+  DenseMap<Argument *, Type *> PtrToType;
+  for (unsigned ArgIdx = 0; ArgIdx < F.arg_size(); ++ArgIdx) {
+    Argument *Arg = F.getArg(ArgIdx);
+    if (Arg->getType()->isPointerTy() &&
+        canBeComapctedInAStruct(*Arg, A, QueryingAA)) {
+      // Get the the type of the pointer through its users
+      for (auto UseItr = Arg->use_begin(); UseItr != Arg->use_end(); ++UseItr) {
+        auto *Store = dyn_cast<StoreInst>(UseItr->getUser());
+        if (Store)
+          PtrToType[Arg] = Store->getValueOperand()->getType();
+      }
+    }
+  }
+
+  // If there is no valid candidates then return false.
+  if (PtrToType.empty())
+    return false;
+
+  // Create the new struct return type.
+  SmallVector<Type *, 4> OutStructElements;
+  if (auto *OriginalFuncTy = F.getReturnType(); !OriginalFuncTy->isVoidTy())
+    OutStructElements.push_back(OriginalFuncTy);
+
+  for (const auto &[Arg, Type] : PtrToType)
+    OutStructElements.push_back(Type);
+
+  auto *ReturnStructType = StructType::create(F.getContext(), OutStructElements,
+                                              (F.getName() + "Out").str());
+
+  // Get the new Args.
+  SmallVector<Type *, 4> NewParamTypes;
+  for (unsigned ArgIdx = 0; ArgIdx < F.arg_size(); ++ArgIdx)
+    if (!PtrToType.count(F.getArg(ArgIdx)))
+      NewParamTypes.push_back(F.getArg(ArgIdx)->getType());
+
+  auto *NewFunctionType =
+      FunctionType::get(ReturnStructType, NewParamTypes, F.isVarArg());
+  auto *NewFunction = Function::Create(NewFunctionType, F.getLinkage(),
+                                       F.getAddressSpace(), F.getName());
+
+  // Map old args to new args.
+  ValueToValueMapTy VMap;
+  auto *NewArgIt = NewFunction->arg_begin();
+  for (Argument &OldArg : F.args())
+    if (!PtrToType.count(F.getArg(OldArg.getArgNo())))
+      VMap[&OldArg] = &(*NewArgIt++);
+
+  // Clone the old function into the new one.
+  SmallVector<ReturnInst *, 8> Returns;
+  CloneFunctionInto(NewFunction, &F, VMap,
+                    CloneFunctionChangeType::LocalChangesOnly, Returns);
+
+  // Update the return values (make it struct).
+  for (ReturnInst *Ret : Returns) {
+    IRBuilder<> Builder(Ret);
+    SmallVector<Value *, 4> StructValues;
+    // Include original return type, if any
+    if (auto *OriginalFuncTy = F.getReturnType(); !OriginalFuncTy->isVoidTy())
+      StructValues.push_back(Ret->getReturnValue());
+
+    // Create a load instruction to fill the struct element.
+    for (const auto &[Arg, Ty] : PtrToType) {
+      Value *OutVal = Builder.CreateLoad(Ty, VMap[Arg]);
+      StructValues.push_back(OutVal);
+    }
+
+    // Build the return struct incrementally.
+    Value *StructRetVal = UndefValue::get(ReturnStructType);
+    for (unsigned i = 0; i < StructValues.size(); ++i)
+      StructRetVal =
+          Builder.CreateInsertValue(StructRetVal, StructValues[i], i);
+
+    Builder.CreateRet(StructRetVal);
+    Ret->eraseFromParent();
+  }
+
+  replaceArgRetWithStructRetCalls(F, *NewFunction);
+  F.eraseFromParent();
+}
+
 ChangeStatus Attributor::rewriteFunctionSignatures(
     SmallSetVector<Function *, 8> &ModifiedFns) {
   ChangeStatus Changed = ChangeStatus::UNCHANGED;



More information about the llvm-commits mailing list