[llvm] [WIP][AMDGPU][Attributor] Infer `inreg` attribute in `AMDGPUAttributor` (PR #101609)

Shilei Tian via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 15 09:11:49 PDT 2024


https://github.com/shiltian updated https://github.com/llvm/llvm-project/pull/101609

>From f4b9eff66b39528fca6d255ce1f25b6cc1549223 Mon Sep 17 00:00:00 2001
From: Shilei Tian <i at tianshilei.me>
Date: Thu, 1 Aug 2024 23:39:05 -0400
Subject: [PATCH] [WIP][AMDGPU][Attributor] Infer `inreg` attribute in
 `AMDGPUAttributor`

---
 llvm/lib/Target/AMDGPU/AMDGPUAttributor.cpp | 107 ++++++++++++++++++++
 1 file changed, 107 insertions(+)

diff --git a/llvm/lib/Target/AMDGPU/AMDGPUAttributor.cpp b/llvm/lib/Target/AMDGPU/AMDGPUAttributor.cpp
index d65e0ae92308e6..bcb36c2adbc57d 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUAttributor.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUAttributor.cpp
@@ -1009,6 +1009,113 @@ struct AAAMDGPUNoAGPR
 
 const char AAAMDGPUNoAGPR::ID = 0;
 
+struct AAAMDGPUInreg
+    : public IRAttribute<Attribute::InReg,
+                         StateWrapper<BooleanState, AbstractAttribute>,
+                         AAAMDGPUInreg> {
+  AAAMDGPUInreg(const IRPosition &IRP, Attributor &A) : IRAttribute(IRP) {}
+
+  /// Create an abstract attribute view for the position \p IRP.
+  static AAAMDGPUInreg &createForPosition(const IRPosition &IRP, Attributor &A);
+
+  /// See AbstractAttribute::getName()
+  const std::string getName() const override { return "AAAMDGPUInreg"; }
+
+  /// See AbstractAttribute::getIdAddr()
+  const char *getIdAddr() const override { return &ID; }
+
+  /// This function should return true if the type of the \p AA is AAAMDGPUInreg
+  static bool classof(const AbstractAttribute *AA) {
+    return (AA->getIdAddr() == &ID);
+  }
+
+  /// Unique ID (due to the unique address)
+  static const char ID;
+};
+
+const char AAAMDGPUInreg::ID = 0;
+
+namespace {
+
+struct AAAMDGPUInregCallSiteArgument : public AAAMDGPUInreg {
+  AAAMDGPUInregCallSiteArgument(const IRPosition &IRP, Attributor &A)
+      : AAAMDGPUInreg(IRP, A) {}
+
+  void initialize(Attributor &A) override {}
+
+  const std::string getAsStr(Attributor *A) const override {
+    return getAssumed() ? "inreg" : "non-inreg";
+  }
+
+  void trackStatistics() const override {}
+
+  ChangeStatus updateImpl(Attributor &A) override {
+    return ChangeStatus::UNCHANGED;
+  }
+};
+
+struct AAAMDGPUInregArgument : public AAAMDGPUInreg {
+  AAAMDGPUInregArgument(const IRPosition &IRP, Attributor &A)
+      : AAAMDGPUInreg(IRP, A) {}
+
+  void initialize(Attributor &A) override {
+    if (getAssociatedArgument()->hasAttribute(Attribute::InReg))
+      indicateOptimisticFixpoint();
+  }
+
+  const std::string getAsStr(Attributor *A) const override {
+    return getAssumed() ? "inreg" : "non-inreg";
+  }
+
+  void trackStatistics() const override {}
+
+  ChangeStatus updateImpl(Attributor &A) override {
+    unsigned ArgNo = getAssociatedArgument()->getArgNo();
+    unsigned NumCallSites = 0;
+    unsigned NumInregCallSites = 0;
+
+    auto Pred = [&](AbstractCallSite ACS) -> bool {
+      auto *AA = A.getOrCreateAAFor<AAAMDGPUInreg>(
+          IRPosition::callsite_argument(*ACS.getInstruction(), ArgNo), this,
+          DepClassTy::REQUIRED);
+      if (!AA || !AA->isValidState())
+        return false;
+      ++NumCallSites;
+      if (AA->isAssumed())
+        ++NumInregCallSites;
+      return true;
+    };
+
+    bool UsedAssumedInformation = false;
+    if (!A.checkForAllCallSites(Pred, *this, /*RequireAllCallSites=*/true,
+                                UsedAssumedInformation))
+      return indicatePessimisticFixpoint();
+
+    if (NumCallSites == NumInregCallSites)
+      return ChangeStatus::UNCHANGED;
+
+    if (getAssumed() && NumInregCallSites == 0) {
+      setAssumed(false);
+      return ChangeStatus::CHANGED;
+    }
+
+    return ChangeStatus::UNCHANGED;
+  }
+};
+}
+
+AAAMDGPUInreg &AAAMDGPUInreg::createForPosition(const IRPosition &IRP,
+                                                Attributor &A) {
+  switch (IRP.getPositionKind()) {
+  case IRPosition::IRP_CALL_SITE_ARGUMENT:
+    return *new (A.Allocator) AAAMDGPUInregCallSiteArgument(IRP, A);
+  case IRPosition::IRP_ARGUMENT:
+    return *new (A.Allocator) AAAMDGPUInregArgument(IRP, A);
+  default:
+    llvm_unreachable("not a valid position for AAAMDGPUInreg");
+  }
+}
+
 static void addPreloadKernArgHint(Function &F, TargetMachine &TM) {
   const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
   for (unsigned I = 0;



More information about the llvm-commits mailing list