[llvm] [IR] Check callee param attributes as well in CallBase::getParamAttr() (PR #91394)

Arthur Eubanks via llvm-commits llvm-commits at lists.llvm.org
Tue May 7 13:37:14 PDT 2024


https://github.com/aeubanks created https://github.com/llvm/llvm-project/pull/91394

These methods aren't used yet, but may be in the future. This keeps them in line with other methods like getFnAttr().

>From 971626754d49a1da3f515c4b89a77967f3685836 Mon Sep 17 00:00:00 2001
From: Arthur Eubanks <aeubanks at google.com>
Date: Tue, 7 May 2024 20:35:39 +0000
Subject: [PATCH] [IR] Check callee param attributes as well in
 CallBase::getParamAttr()

These methods aren't used yet, but may be in the future. This keeps them in line with other methods like getFnAttr().
---
 llvm/include/llvm/IR/InstrTypes.h    | 12 ++++++--
 llvm/lib/IR/Instructions.cpp         | 16 ++++++++++
 llvm/unittests/IR/AttributesTest.cpp | 46 ++++++++++++++++++++++++++++
 3 files changed, 72 insertions(+), 2 deletions(-)

diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index b9af3a6ca42c0..9e492f9d4948e 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -1997,13 +1997,19 @@ class CallBase : public Instruction {
   /// Get the attribute of a given kind from a given arg
   Attribute getParamAttr(unsigned ArgNo, Attribute::AttrKind Kind) const {
     assert(ArgNo < arg_size() && "Out of bounds");
-    return getAttributes().getParamAttr(ArgNo, Kind);
+    Attribute A = getAttributes().getParamAttr(ArgNo, Kind);
+    if (A.isValid())
+      return A;
+    return getParamAttrOnCalledFunction(ArgNo, Kind);
   }
 
   /// Get the attribute of a given kind from a given arg
   Attribute getParamAttr(unsigned ArgNo, StringRef Kind) const {
     assert(ArgNo < arg_size() && "Out of bounds");
-    return getAttributes().getParamAttr(ArgNo, Kind);
+    Attribute A = getAttributes().getParamAttr(ArgNo, Kind);
+    if (A.isValid())
+      return A;
+    return getParamAttrOnCalledFunction(ArgNo, Kind);
   }
 
   /// Return true if the data operand at index \p i has the attribute \p
@@ -2647,6 +2653,8 @@ class CallBase : public Instruction {
     return hasFnAttrOnCalledFunction(Kind);
   }
   template <typename AK> Attribute getFnAttrOnCalledFunction(AK Kind) const;
+  template <typename AK>
+  Attribute getParamAttrOnCalledFunction(unsigned ArgNo, AK Kind) const;
 
   /// Determine whether the return value has the given attribute. Supports
   /// Attribute::AttrKind and StringRef as \p AttrKind types.
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 7ad1ad4cddb70..32af58a43b68e 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -500,6 +500,22 @@ template Attribute
 CallBase::getFnAttrOnCalledFunction(Attribute::AttrKind Kind) const;
 template Attribute CallBase::getFnAttrOnCalledFunction(StringRef Kind) const;
 
+template <typename AK>
+Attribute CallBase::getParamAttrOnCalledFunction(unsigned ArgNo,
+                                                 AK Kind) const {
+  Value *V = getCalledOperand();
+
+  if (auto *F = dyn_cast<Function>(V))
+    return F->getAttributes().getParamAttr(ArgNo, Kind);
+
+  return Attribute();
+}
+template Attribute
+CallBase::getParamAttrOnCalledFunction(unsigned ArgNo,
+                                       Attribute::AttrKind Kind) const;
+template Attribute CallBase::getParamAttrOnCalledFunction(unsigned ArgNo,
+                                                          StringRef Kind) const;
+
 void CallBase::getOperandBundlesAsDefs(
     SmallVectorImpl<OperandBundleDef> &Defs) const {
   for (unsigned i = 0, e = getNumOperandBundles(); i != e; ++i)
diff --git a/llvm/unittests/IR/AttributesTest.cpp b/llvm/unittests/IR/AttributesTest.cpp
index a7967593c2f96..da72fa14510cb 100644
--- a/llvm/unittests/IR/AttributesTest.cpp
+++ b/llvm/unittests/IR/AttributesTest.cpp
@@ -340,4 +340,50 @@ TEST(Attributes, ConstantRangeAttributeCAPI) {
   }
 }
 
+TEST(Attributes, CalleeAttributes) {
+  const char *IRString = R"IR(
+    declare void @f1(i32 %i)
+    declare void @f2(i32 range(i32 1, 2) %i)
+
+    define void @g1(i32 %i) {
+      call void @f1(i32 %i)
+      ret void
+    }
+    define void @g2(i32 %i) {
+      call void @f2(i32 %i)
+      ret void
+    }
+    define void @g3(i32 %i) {
+      call void @f1(i32 range(i32 3, 4) %i)
+      ret void
+    }
+    define void @g4(i32 %i) {
+      call void @f2(i32 range(i32 3, 4) %i)
+      ret void
+    }
+  )IR";
+
+  SMDiagnostic Err;
+  LLVMContext Context;
+  std::unique_ptr<Module> M = parseAssemblyString(IRString, Err, Context);
+  ASSERT_TRUE(M);
+
+  {
+    auto *I = cast<CallBase>(&M->getFunction("g1")->getEntryBlock().front());
+    ASSERT_FALSE(I->getParamAttr(0, Attribute::Range).isValid());
+  }
+  {
+    auto *I = cast<CallBase>(&M->getFunction("g2")->getEntryBlock().front());
+    ASSERT_TRUE(I->getParamAttr(0, Attribute::Range).isValid());
+  }
+  {
+    auto *I = cast<CallBase>(&M->getFunction("g3")->getEntryBlock().front());
+    ASSERT_TRUE(I->getParamAttr(0, Attribute::Range).isValid());
+  }
+  {
+    auto *I = cast<CallBase>(&M->getFunction("g4")->getEntryBlock().front());
+    ASSERT_TRUE(I->getParamAttr(0, Attribute::Range).isValid());
+  }
+}
+
 } // end anonymous namespace



More information about the llvm-commits mailing list