[llvm] [Attributor] Pack out arguments into a struct (PR #119267)

via llvm-commits llvm-commits at lists.llvm.org
Sun Jan 26 10:05:41 PST 2025


https://github.com/elhewaty updated https://github.com/llvm/llvm-project/pull/119267

>From 6b545c2ffe586082156fe53a8a6c5c010f80cfc1 Mon Sep 17 00:00:00 2001
From: Mohamed Atef <mohamedatef1698 at gmail.com>
Date: Mon, 9 Dec 2024 22:24:29 +0200
Subject: [PATCH 1/2] [Attributor] Add pre-commit tests

---
 .../Transforms/Attributor/remove_out_args.ll  | 20 +++++++++++++++++++
 1 file changed, 20 insertions(+)
 create mode 100644 llvm/test/Transforms/Attributor/remove_out_args.ll

diff --git a/llvm/test/Transforms/Attributor/remove_out_args.ll b/llvm/test/Transforms/Attributor/remove_out_args.ll
new file mode 100644
index 00000000000000..40c39ea41ff67b
--- /dev/null
+++ b/llvm/test/Transforms/Attributor/remove_out_args.ll
@@ -0,0 +1,20 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes=attributor < %s | FileCheck %s
+
+
+
+define internal i1 @foo(ptr %dst) {
+entry:
+  store i32 42, ptr %dst
+  ret i1 true
+}
+
+
+define i1 @fee(i32 %x, i32 %y) {
+  %ptr = alloca i32
+  %a = call i1 @foo(ptr %ptr, i32 %y)
+  %b = load i32, ptr %ptr
+  %c = icmp sle i32 %b, %x
+  %xor = xor i1 %a, %c
+  ret i1 %xor
+}

>From e4be7f00ced90b5f85630b97a1b9ae475c2381ca Mon Sep 17 00:00:00 2001
From: Mohamed Atef <mohamedatef1698 at gmail.com>
Date: Wed, 1 Jan 2025 18:38:56 +0200
Subject: [PATCH 2/2] [Attributor] Convert out arguments into a struct return

---
 llvm/include/llvm/Transforms/IPO/Attributor.h |  33 +++
 llvm/lib/Transforms/IPO/Attributor.cpp        |   7 +
 .../Transforms/IPO/AttributorAttributes.cpp   | 223 ++++++++++++++++++
 3 files changed, 263 insertions(+)

diff --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h
index 85893146997499..19a319a57d3268 100644
--- a/llvm/include/llvm/Transforms/IPO/Attributor.h
+++ b/llvm/include/llvm/Transforms/IPO/Attributor.h
@@ -6477,6 +6477,39 @@ struct AADenormalFPMath
   static const char ID;
 };
 
+/// An abstract attribute for converting out arguments into struct elements.
+struct AAConvertOutArgument
+    : public StateWrapper<BooleanState, AbstractAttribute> {
+  using Base = StateWrapper<BooleanState, AbstractAttribute>;
+
+  AAConvertOutArgument(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
+
+  /// Create an abstract attribute view for the position \p IRP.
+  static AAConvertOutArgument &createForPosition(const IRPosition &IRP,
+                                                 Attributor &A);
+
+  /// See AbstractAttribute::getName()
+  const std::string getName() const override { return "AAConvertOutArgument"; }
+
+  /// Return true if convertible is assumed.
+  bool isAssumedConvertible() const { return getAssumed(); }
+
+  /// Return true if convertible is known.
+  bool isKnownConvertible() const { return getKnown(); }
+
+  /// See AbstractAttribute::getIdAddr()
+  const char *getIdAddr() const override { return &ID; }
+
+  /// This function should return true if the type of the \p AA is
+  /// AADenormalFPMath.
+  static bool classof(const AbstractAttribute *AA) {
+    return (AA->getIdAddr() == &ID);
+  }
+
+  /// Unique ID (due to the unique address)
+  static const char ID;
+};
+
 raw_ostream &operator<<(raw_ostream &, const AAPointerInfo::Access &);
 
 /// Run options, used by the pass manager.
diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp
index a93284926d684f..373ce71afe3582 100644
--- a/llvm/lib/Transforms/IPO/Attributor.cpp
+++ b/llvm/lib/Transforms/IPO/Attributor.cpp
@@ -3459,6 +3459,7 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) {
     }
   }
 
+  bool markedAsAAConvertArgument = false;
   for (Argument &Arg : F.args()) {
     IRPosition ArgPos = IRPosition::argument(Arg);
     auto ArgNo = Arg.getArgNo();
@@ -3510,6 +3511,12 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) {
       // Every argument with pointer type might be privatizable (or
       // promotable)
       getOrCreateAAFor<AAPrivatizablePtr>(ArgPos);
+
+      // Every function with pointer argument type can have out arguments.
+      if (!markedAsAAConvertArgument) {
+        getOrCreateAAFor<AAConvertOutArgument>(FPos);
+        markedAsAAConvertArgument = true;
+      }
     } else if (AttributeFuncs::isNoFPClassCompatibleType(Arg.getType())) {
       getOrCreateAAFor<AANoFPClass>(ArgPos);
     }
diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
index 58b8f1f779f729..ccce5ac7c418a4 100644
--- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
+++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
@@ -68,6 +68,7 @@
 #include "llvm/Support/raw_ostream.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/CallPromotionUtils.h"
+#include "llvm/Transforms/Utils/Cloning.h"
 #include "llvm/Transforms/Utils/Local.h"
 #include "llvm/Transforms/Utils/ValueMapper.h"
 #include <cassert>
@@ -197,6 +198,7 @@ PIPE_OPERATOR(AAAllocationInfo)
 PIPE_OPERATOR(AAIndirectCallInfo)
 PIPE_OPERATOR(AAGlobalValueInfo)
 PIPE_OPERATOR(AADenormalFPMath)
+PIPE_OPERATOR(AAConvertOutArgument)
 
 #undef PIPE_OPERATOR
 
@@ -12987,6 +12989,225 @@ struct AAAllocationInfoCallSiteArgument : AAAllocationInfoImpl {
 };
 } // namespace
 
+/// ----------- AAConvertOutArgument ----------
+namespace {
+static bool isEligibleArgument(const Argument &Arg, Attributor &A,
+                               const AbstractAttribute &AA) {
+  if (!Arg.getType()->isPointerTy())
+    return false;
+
+  const IRPosition &ArgPos = IRPosition::argument(Arg);
+  auto *AAMem = A.getAAFor<AAMemoryBehavior>(AA, ArgPos, DepClassTy::OPTIONAL);
+  auto *NoAlias = A.getAAFor<AANoAlias>(AA, ArgPos, DepClassTy::OPTIONAL);
+
+  return AAMem && NoAlias && AAMem->isAssumedWriteOnly() &&
+         NoAlias->isAssumedNoAlias() && !Arg.hasPointeeInMemoryValueAttr();
+}
+
+struct AAConvertOutArgumentFunction final : AAConvertOutArgument {
+  AAConvertOutArgumentFunction(const IRPosition &IRP, Attributor &A)
+      : AAConvertOutArgument(IRP, A) {}
+
+  SmallVector<bool> ArgumentsStates;
+
+  /// See AbstractAttribute::updateImpl(...).
+  void initialize(Attributor &A) override {
+    const Function *F = getAssociatedFunction();
+    if (!F || F->isDeclaration())
+      return;
+
+    // Assume that all  args are convertable at the begining.
+    ArgumentsStates.resize(F->arg_size(), true);
+  }
+
+  /// See AbstractAttribute::updateImpl(...).
+  ChangeStatus updateImpl(Attributor &A) override {
+    const Function *F = getAssociatedFunction();
+    if (!F || F->isDeclaration())
+      return indicatePessimisticFixpoint();
+
+    auto NewStates = ArgumentsStates;
+    for (unsigned ArgIdx = 0; ArgIdx < F->arg_size(); ++ArgIdx)
+      if (!isEligibleArgument(*F->getArg(ArgIdx), A, *this))
+        NewStates[ArgIdx] = false;
+
+    bool Changed = NewStates == ArgumentsStates;
+    ArgumentsStates = NewStates;
+    return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
+  }
+
+  /// See AbstractAttribute::manifest(...).
+  ChangeStatus manifest(Attributor &A) override {
+    Function &F = *getAssociatedFunction();
+    DenseMap<Argument *, Type *> PtrToType;
+    SmallVector<Argument *, 4> CandidateArgs;
+
+    for (unsigned ArgIdx = 0; ArgIdx < F.arg_size(); ++ArgIdx) {
+      Argument *Arg = F.getArg(ArgIdx);
+      if (!isEligibleArgument(*Arg, A, *this))
+        continue;
+
+      CandidateArgs.push_back(Arg);
+      // AAPointerInfo on args
+      for (auto &Use : Arg->uses())
+        if (auto *Store = dyn_cast<StoreInst>(Use.getUser()))
+          PtrToType[Arg] = Store->getValueOperand()->getType();
+    }
+
+    // If there is no valid candidates then return false.
+    if (PtrToType.empty())
+      return indicatePessimisticFixpoint();
+
+    // Create the new struct return type.
+    SmallVector<Type *, 4> OutStructElementsTypes;
+    if (auto *OriginalFuncTy = F.getReturnType(); !OriginalFuncTy->isVoidTy())
+      OutStructElementsTypes.push_back(OriginalFuncTy);
+
+    for (auto *Arg : CandidateArgs)
+      OutStructElementsTypes.push_back(PtrToType[Arg]);
+
+    auto *ReturnStructType = StructType::create(
+        F.getContext(), OutStructElementsTypes, (F.getName() + "_out").str());
+
+    // Get the new Args.
+    SmallVector<Type *, 4> NewParamTypes;
+    for (auto &Arg : F.args())
+      if (!PtrToType.count(&Arg))
+        NewParamTypes.push_back(Arg.getType());
+
+    auto *NewFunctionType =
+        FunctionType::get(ReturnStructType, NewParamTypes, F.isVarArg());
+    auto *NewFunction =
+        Function::Create(NewFunctionType, F.getLinkage(), F.getAddressSpace(),
+                         F.getName() + ".converted");
+
+    // Map old arguments to new ones, And also map the old arguments to struct
+    // elements.
+    ValueToValueMapTy VMap;
+    auto NewArgIt = NewFunction->arg_begin();
+    BasicBlock *EntryBlock =
+        BasicBlock::Create(NewFunction->getContext(), "entry", NewFunction);
+
+    IRBuilder<> EntryBuilder(EntryBlock);
+    for (auto &OldArg : F.args()) {
+      if (PtrToType.count(&OldArg)) {
+        dbgs() << "OldArg: " << OldArg
+               << " ======> Type: " << *PtrToType[&OldArg] << "\n";
+        AllocaInst *Alloca = EntryBuilder.CreateAlloca(
+            PtrToType[&OldArg], nullptr, OldArg.getName() + "_");
+        VMap[&OldArg] = Alloca;
+      } else
+        VMap[&OldArg] = &(*NewArgIt++);
+    }
+
+    // Clone the old function into the new one.
+    SmallVector<ReturnInst *, 8> Returns;
+    CloneFunctionInto(NewFunction, &F, VMap,
+                      CloneFunctionChangeType::LocalChangesOnly, Returns);
+
+    // Update the return values (make it struct).
+    for (ReturnInst *Ret : Returns) {
+      IRBuilder<> Builder(Ret);
+      SmallVector<Value *, 4> StructValues;
+      // Include original return type, if any
+      if (auto *OriginalFuncTy = F.getReturnType(); !OriginalFuncTy->isVoidTy())
+        StructValues.push_back(Ret->getReturnValue());
+
+      // Create a load instruction to fill the struct element.
+      for (auto *Arg : CandidateArgs) {
+        Value *OutVal = Builder.CreateLoad(PtrToType[Arg], VMap[Arg]);
+        StructValues.push_back(OutVal);
+      }
+
+      // Build the return struct incrementally.
+      Value *StructRetVal = UndefValue::get(ReturnStructType);
+      for (unsigned i = 0; i < StructValues.size(); ++i)
+        StructRetVal =
+            Builder.CreateInsertValue(StructRetVal, StructValues[i], i);
+
+      Builder.CreateRet(StructRetVal);
+      A.deleteAfterManifest(*Ret);
+    }
+    return ChangeStatus::CHANGED;
+  }
+
+  /// See AbstractAttribute::getAsStr(...).
+  const std::string getAsStr(Attributor *A) const override {
+    return "AAConvertOutArgumentFunction";
+  }
+
+  /// See AbstractAttribute::trackStatistics()
+  void trackStatistics() const override {}
+};
+
+struct AAConvertOutArgumentCallSite final : AAConvertOutArgument {
+  AAConvertOutArgumentCallSite(const IRPosition &IRP, Attributor &A)
+      : AAConvertOutArgument(IRP, A) {}
+
+  /// See AbstractAttribute::updateImpl(...).
+  ChangeStatus updateImpl(Attributor &A) override {
+    CallBase *CB = cast<CallBase>(getCtxI());
+    Function *F = CB->getCalledFunction();
+    if (!F)
+      return indicatePessimisticFixpoint();
+
+    // Get convert attribute.
+    auto *ConvertAA = A.getAAFor<AAConvertOutArgument>(
+        *this, IRPosition::function(*F), DepClassTy::REQUIRED);
+
+    // If function will be transformed, mark this call site for update
+    if (!ConvertAA || ConvertAA->isAssumedConvertible())
+      return ChangeStatus::CHANGED;
+
+    return ChangeStatus::UNCHANGED;
+  }
+
+  /// See AbstractAttribute::manifest(...).
+  ChangeStatus manifest(Attributor &A) override {
+    CallBase *CB = cast<CallBase>(getCtxI());
+    Function *F = CB->getCalledFunction();
+    if (!F)
+      return ChangeStatus::UNCHANGED;
+
+    IRBuilder<> Builder(CB);
+    // Create args for new call.
+    SmallVector<Value *, 4> NewArgs;
+    for (unsigned ArgIdx = 0; ArgIdx < CB->arg_size(); ++ArgIdx) {
+      Value *Arg = CB->getArgOperand(ArgIdx);
+      Argument *ParamArg = F->getArg(ArgIdx);
+      if (!isEligibleArgument(*ParamArg, A, *this))
+        NewArgs.push_back(Arg);
+    }
+
+    Module *M = F->getParent();
+    auto *NewF = M->getFunction((F->getName() + ".converted").str());
+    if (!NewF)
+      return ChangeStatus::UNCHANGED;
+
+    FunctionCallee NewCallee(NewF->getFunctionType(), NewF);
+    Instruction *NewCall =
+        CallInst::Create(NewCallee, NewArgs, CB->getName() + ".converted", CB);
+    IRPosition ReturnPos = IRPosition::callsite_returned(*CB);
+    A.changeAfterManifest(ReturnPos, *NewCall);
+
+    // Redirect all uses of the old call to the new call.
+    for (auto &Use : CB->uses())
+      Use.set(NewCall);
+
+    A.deleteAfterManifest(*CB);
+    return ChangeStatus::CHANGED;
+  }
+
+  /// See AbstractAttribute::getAsStr(...).
+  const std::string getAsStr(Attributor *A) const override {
+    return "AAConvertOutArgumentCallSite";
+  }
+
+  /// See AbstractAttribute::trackStatistics()
+  void trackStatistics() const override {}
+};
+} // namespace
+
 const char AANoUnwind::ID = 0;
 const char AANoSync::ID = 0;
 const char AANoFree::ID = 0;
@@ -13024,6 +13245,7 @@ const char AAAllocationInfo::ID = 0;
 const char AAIndirectCallInfo::ID = 0;
 const char AAGlobalValueInfo::ID = 0;
 const char AADenormalFPMath::ID = 0;
+const char AAConvertOutArgument::ID = 0;
 
 // Macro magic to create the static generator function for attributes that
 // follow the naming scheme.
@@ -13139,6 +13361,7 @@ CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMemoryLocation)
 CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AACallEdges)
 CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAAssumptionInfo)
 CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMustProgress)
+CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAConvertOutArgument)
 
 CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANonNull)
 CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoAlias)



More information about the llvm-commits mailing list