[llvm] [Attributor] Pack out arguments into a struct (PR #119267)

via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 23 02:09:04 PST 2024


https://github.com/elhewaty updated https://github.com/llvm/llvm-project/pull/119267

>From 208b57616c67197830b6e5408a90356932436674 Mon Sep 17 00:00:00 2001
From: Mohamed Atef <mohamedatef1698 at gmail.com>
Date: Mon, 9 Dec 2024 22:24:29 +0200
Subject: [PATCH 1/2] [Attributor] Add pre-commit tests

---
 .../Transforms/Attributor/remove_out_args.ll     | 16 ++++++++++++++++
 1 file changed, 16 insertions(+)
 create mode 100644 llvm/test/Transforms/Attributor/remove_out_args.ll

diff --git a/llvm/test/Transforms/Attributor/remove_out_args.ll b/llvm/test/Transforms/Attributor/remove_out_args.ll
new file mode 100644
index 00000000000000..bd52bf5d80656c
--- /dev/null
+++ b/llvm/test/Transforms/Attributor/remove_out_args.ll
@@ -0,0 +1,16 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes=attributor < %s | FileCheck %s
+
+
+
+define i1 @foo(ptr %dst) {
+; CHECK-LABEL: define noundef i1 @foo(
+; CHECK-SAME: ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[DST:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    store i32 42, ptr [[DST]], align 4
+; CHECK-NEXT:    ret i1 true
+;
+entry:
+  store i32 42, ptr %dst
+  ret i1 1
+}

>From 4faf50acbeedf24a41e105a26496521e700cb120 Mon Sep 17 00:00:00 2001
From: Mohamed Atef <mohamedatef1698 at gmail.com>
Date: Sat, 21 Dec 2024 23:40:41 +0200
Subject: [PATCH 2/2] [Attributor] Add AAConvertOutArgument class to the
 attributor framework

---
 llvm/include/llvm/Transforms/IPO/Attributor.h |  48 +++++
 llvm/lib/Transforms/IPO/Attributor.cpp        |   3 +
 .../Transforms/IPO/AttributorAttributes.cpp   | 177 ++++++++++++++++++
 3 files changed, 228 insertions(+)

diff --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h
index 8915969f75466c..fd831a314e8090 100644
--- a/llvm/include/llvm/Transforms/IPO/Attributor.h
+++ b/llvm/include/llvm/Transforms/IPO/Attributor.h
@@ -6469,6 +6469,54 @@ struct AADenormalFPMath
   static const char ID;
 };
 
+/// An abstract attribute for converting out arguments into struct elements.
+struct AAConvertOutArgument
+    : public StateWrapper<BooleanState, AbstractAttribute> {
+  using Base = StateWrapper<BooleanState, AbstractAttribute>;
+
+  AAConvertOutArgument(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
+
+  /// Create an abstract attribute view for the position \p IRP.
+  static AAConvertOutArgument &createForPosition(const IRPosition &IRP,
+                                                 Attributor &A);
+
+  /// See AbstractAttribute::getName()
+  const std::string getName() const override { return "AAConvertOutArgument"; }
+
+  /// Return true if convertible is assumed.
+  bool isAssumedConvertible() const { return getAssumed(); }
+
+  /// Return true if convertible is known.
+  bool isKnownConvertible() const { return getKnown(); }
+
+  /// See AbstractAttribute::getIdAddr()
+  const char *getIdAddr() const override { return &ID; }
+
+  /// This function should return true if the type of the \p AA is
+  /// AADenormalFPMath.
+  static bool classof(const AbstractAttribute *AA) {
+    return (AA->getIdAddr() == &ID);
+  }
+
+  /// Unique ID (due to the unique address)
+  static const char ID;
+
+protected:
+  static bool isEligibleArgumentType(Type *Ty) { return Ty->isPointerTy(); }
+
+  static bool isEligibleArgument(const Argument &Arg, Attributor &A,
+                                 const AbstractAttribute &AA) {
+    if (!isEligibleArgumentType(Arg.getType()))
+      return false;
+
+    const IRPosition &ArgPos = IRPosition::argument(Arg);
+    auto *AAMem = A.getAAFor<AAMemoryBehavior>(AA, ArgPos, DepClassTy::NONE);
+
+    return Arg.hasNoAliasAttr() && AAMem && AAMem->isKnownWriteOnly() &&
+           !Arg.hasPointeeInMemoryValueAttr();
+  }
+};
+
 raw_ostream &operator<<(raw_ostream &, const AAPointerInfo::Access &);
 
 /// Run options, used by the pass manager.
diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp
index 116f419129a239..dd2dfda97bf2e6 100644
--- a/llvm/lib/Transforms/IPO/Attributor.cpp
+++ b/llvm/lib/Transforms/IPO/Attributor.cpp
@@ -3412,6 +3412,9 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) {
     // Every function can track active assumptions.
     getOrCreateAAFor<AAAssumptionInfo>(FPos);
 
+    // Every function can have out arguments.
+    getOrCreateAAFor<AAConvertOutArgument>(FPos);
+
     // If we're not using a dynamic mode for float, there's nothing worthwhile
     // to infer. This misses the edge case denormal-fp-math="dynamic" and
     // denormal-fp-math-f32=something, but that likely has no real world use.
diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
index eb45df34771d32..ec0d3961389fc7 100644
--- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
+++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
@@ -68,6 +68,7 @@
 #include "llvm/Support/raw_ostream.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/CallPromotionUtils.h"
+#include "llvm/Transforms/Utils/Cloning.h"
 #include "llvm/Transforms/Utils/Local.h"
 #include "llvm/Transforms/Utils/ValueMapper.h"
 #include <cassert>
@@ -197,6 +198,7 @@ PIPE_OPERATOR(AAAllocationInfo)
 PIPE_OPERATOR(AAIndirectCallInfo)
 PIPE_OPERATOR(AAGlobalValueInfo)
 PIPE_OPERATOR(AADenormalFPMath)
+PIPE_OPERATOR(AAConvertOutArgument)
 
 #undef PIPE_OPERATOR
 
@@ -12989,6 +12991,179 @@ struct AAAllocationInfoCallSiteArgument : AAAllocationInfoImpl {
 };
 } // namespace
 
+/// ----------- AAConvertOutArgument ----------
+namespace {
+struct AAConvertOutArgumentFunction final : AAConvertOutArgument {
+  AAConvertOutArgumentFunction(const IRPosition &IRP, Attributor &A)
+      : AAConvertOutArgument(IRP, A) {}
+
+  /// See AbstractAttribute::updateImpl(...).
+  ChangeStatus updateImpl(Attributor &A) override {
+    const Function *F = getAssociatedFunction();
+    if (!F || F->isDeclaration())
+      return indicatePessimisticFixpoint();
+
+    bool hasCandidateArg = false;
+    for (const Argument &Arg : F->args())
+      if (Arg.getType()->isPointerTy() && isEligibleArgument(Arg, A, *this))
+        hasCandidateArg = true;
+
+    return hasCandidateArg ? indicateOptimisticFixpoint()
+                           : indicatePessimisticFixpoint();
+  }
+
+  /// See AbstractAttribute::manifest(...).
+  ChangeStatus manifest(Attributor &A) override {
+    const Function &F = *getAssociatedFunction();
+    DenseMap<Argument*, Type*> PtrToType;
+    SmallVector<Argument *, 4> CandidateArgs;
+    for (unsigned argIdx = 0; argIdx < F.arg_size(); ++argIdx) {
+      Argument *Arg = F.getArg(argIdx);
+      if (isEligibleArgument(*Arg, A, *this)) {
+        CandidateArgs.push_back(Arg);
+        for (auto UseItr = Arg->use_begin(); UseItr != Arg->use_end(); ++UseItr) {
+          auto *Store = dyn_cast<StoreInst>(UseItr->getUser());
+          if (Store)
+            PtrToType[Arg] = Store->getValueOperand()->getType();
+        }
+      }
+    }
+
+    // If there is no valid candidates then return false.
+    if (PtrToType.empty())
+      return ChangeStatus::UNCHANGED;
+
+    // Create the new struct return type.
+    SmallVector<Type *, 4> OutStructElementsTypes;
+    if (auto *OriginalFuncTy = F.getReturnType(); !OriginalFuncTy->isVoidTy())
+      OutStructElementsTypes.push_back(OriginalFuncTy);
+
+    for (auto *Arg : CandidateArgs)
+      OutStructElementsTypes.push_back(PtrToType[Arg]);
+
+    auto *ReturnStructType = StructType::create(F.getContext(), OutStructElementsTypes, (F.getName() + "Out").str());
+
+    // Get the new Args.
+    SmallVector<Type *, 4> NewParamTypes;
+    for (unsigned ArgIdx = 0; ArgIdx < F.arg_size(); ++ArgIdx)
+      if (!PtrToType.count(F.getArg(ArgIdx)))
+        NewParamTypes.push_back(F.getArg(ArgIdx)->getType());
+
+    auto *NewFunctionType = FunctionType::get(ReturnStructType, NewParamTypes, F.isVarArg());
+    auto *NewFunction = Function::Create(NewFunctionType, F.getLinkage(), F.getAddressSpace(), F.getName() + ".converted");
+
+    // Map old args to new args.
+    ValueToValueMapTy VMap;
+    auto *NewArgIt = NewFunction->arg_begin();
+    for (const Argument &OldArg : F.args())
+      if (!PtrToType.count(F.getArg(OldArg.getArgNo())))
+        VMap[&OldArg] = &(*NewArgIt++);
+
+
+    // Clone the old function into the new one.
+    SmallVector<ReturnInst *, 8> Returns;
+    CloneFunctionInto(NewFunction, &F, VMap, CloneFunctionChangeType::LocalChangesOnly, Returns);
+
+    // Update the return values (make it struct).
+    for (ReturnInst *Ret : Returns) {
+      IRBuilder<> Builder(Ret);
+      SmallVector<Value *, 4> StructValues;
+      // Include original return type, if any
+      if (auto *OriginalFuncTy = F.getReturnType(); !OriginalFuncTy->isVoidTy())
+        StructValues.push_back(Ret->getReturnValue());
+
+      // Create a load instruction to fill the struct element.
+      for (auto *Arg: CandidateArgs) {
+        Value *OutVal = Builder.CreateLoad(PtrToType[Arg], VMap[Arg]);
+        StructValues.push_back(OutVal);
+      }
+
+      // Build the return struct incrementally.
+      Value *StructRetVal = UndefValue::get(ReturnStructType);
+      for (unsigned i = 0; i < StructValues.size(); ++i)
+        StructRetVal = Builder.CreateInsertValue(StructRetVal, StructValues[i], i);
+
+      Builder.CreateRet(StructRetVal);
+      A.deleteAfterManifest(*Ret);
+    }
+  }
+
+  /// See AbstractAttribute::getAsStr(...).
+  const std::string getAsStr(Attributor *A) const override {
+    return "AAConvertOutArgumentFunction";
+  }
+
+  /// See AbstractAttribute::trackStatistics()
+  void trackStatistics() const override {}
+};
+
+struct AAConvertOutArgumentCallSite final : AAConvertOutArgument {
+  AAConvertOutArgumentCallSite(const IRPosition &IRP, Attributor &A)
+      : AAConvertOutArgument(IRP, A) {}
+
+  /// See AbstractAttribute::updateImpl(...).
+  ChangeStatus updateImpl(Attributor &A) override {
+    CallBase *CB = cast<CallBase>(getCtxI());
+    Function *F = CB->getCalledFunction();
+    if (!F)
+      return indicatePessimisticFixpoint();
+
+    // Get convert attribute.
+    auto *ConvertAA = A.getAAFor<AAConvertOutArgument>(
+        *this, IRPosition::function(*F), DepClassTy::REQUIRED);
+
+    // If function will be transformed, mark this call site for update
+    if (!ConvertAA || ConvertAA->isAssumedConvertible())
+      return ChangeStatus::CHANGED;
+
+    return ChangeStatus::UNCHANGED;
+  }
+
+  /// See AbstractAttribute::manifest(...).
+  ChangeStatus manifest(Attributor &A) override {
+    CallBase *CB = cast<CallBase>(getCtxI());
+    Function *F = CB->getCalledFunction();
+    if (!F)
+      return ChangeStatus::UNCHANGED;
+
+    IRBuilder<> Builder(CB);
+    // Create args for new call.
+    SmallVector<Value *, 4> NewArgs;
+    for (unsigned ArgIdx = 0; ArgIdx < CB->arg_size(); ++ArgIdx) {
+      Value *Arg = CB->getArgOperand(ArgIdx);
+      Argument *ParamArg = F->getArg(ArgIdx); 
+      if (!isEligibleArgument(*ParamArg, A, *this))
+        NewArgs.push_back(Arg);
+    }
+
+    Module *M = F->getParent();
+    auto *NewF = M->getFunction((F->getName() + ".converted").str());
+    if (!NewF)
+      return ChangeStatus::UNCHANGED;
+
+    FunctionCallee NewCallee(NewF->getFunctionType(), NewF);
+    Instruction *NewCall = CallInst::Create(NewCallee, NewArgs, CB->getName() + ".converted", CB);
+    IRPosition ReturnPos = IRPosition::callsite_returned(*CB);
+    A.changeAfterManifest(ReturnPos, *NewCall);
+
+    // Redirect all uses of the old call to the new call.
+    for (auto &Use : CB->uses())
+      Use.set(NewCall);
+
+    A.deleteAfterManifest(*CB);
+    return ChangeStatus::CHANGED;
+  }
+
+  /// See AbstractAttribute::getAsStr(...).
+  const std::string getAsStr(Attributor *A) const override {
+    return "AAConvertOutArgumentCallSite";
+  }
+
+  /// See AbstractAttribute::trackStatistics()
+  void trackStatistics() const override {}
+};
+} // namespace
+
 const char AANoUnwind::ID = 0;
 const char AANoSync::ID = 0;
 const char AANoFree::ID = 0;
@@ -13026,6 +13201,7 @@ const char AAAllocationInfo::ID = 0;
 const char AAIndirectCallInfo::ID = 0;
 const char AAGlobalValueInfo::ID = 0;
 const char AADenormalFPMath::ID = 0;
+const char AAConvertOutArgument::ID = 0;
 
 // Macro magic to create the static generator function for attributes that
 // follow the naming scheme.
@@ -13141,6 +13317,7 @@ CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMemoryLocation)
 CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AACallEdges)
 CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAAssumptionInfo)
 CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMustProgress)
+CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAConvertOutArgument)
 
 CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANonNull)
 CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoAlias)



More information about the llvm-commits mailing list