[llvm] a620697 - [IR] Check callee param attributes as well in CallBase::getParamAttr() (#91394)
via llvm-commits
llvm-commits at lists.llvm.org
Wed May 8 10:14:55 PDT 2024
Author: Arthur Eubanks
Date: 2024-05-08T10:14:51-07:00
New Revision: a620697340671aea2b0c65449fcddf3c2e4d1917
URL: https://github.com/llvm/llvm-project/commit/a620697340671aea2b0c65449fcddf3c2e4d1917
DIFF: https://github.com/llvm/llvm-project/commit/a620697340671aea2b0c65449fcddf3c2e4d1917.diff
LOG: [IR] Check callee param attributes as well in CallBase::getParamAttr() (#91394)
These methods aren't used yet, but may be in the future. This keeps them
in line with other methods like getFnAttr().
Added:
Modified:
llvm/include/llvm/IR/InstrTypes.h
llvm/lib/IR/Instructions.cpp
llvm/unittests/IR/AttributesTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index eaade9ce4755f..9dd1bb455a718 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -1997,13 +1997,19 @@ class CallBase : public Instruction {
/// Get the attribute of a given kind from a given arg
Attribute getParamAttr(unsigned ArgNo, Attribute::AttrKind Kind) const {
assert(ArgNo < arg_size() && "Out of bounds");
- return getAttributes().getParamAttr(ArgNo, Kind);
+ Attribute A = getAttributes().getParamAttr(ArgNo, Kind);
+ if (A.isValid())
+ return A;
+ return getParamAttrOnCalledFunction(ArgNo, Kind);
}
/// Get the attribute of a given kind from a given arg
Attribute getParamAttr(unsigned ArgNo, StringRef Kind) const {
assert(ArgNo < arg_size() && "Out of bounds");
- return getAttributes().getParamAttr(ArgNo, Kind);
+ Attribute A = getAttributes().getParamAttr(ArgNo, Kind);
+ if (A.isValid())
+ return A;
+ return getParamAttrOnCalledFunction(ArgNo, Kind);
}
/// Return true if the data operand at index \p i has the attribute \p
@@ -2652,6 +2658,8 @@ class CallBase : public Instruction {
return hasFnAttrOnCalledFunction(Kind);
}
template <typename AK> Attribute getFnAttrOnCalledFunction(AK Kind) const;
+ template <typename AK>
+ Attribute getParamAttrOnCalledFunction(unsigned ArgNo, AK Kind) const;
/// Determine whether the return value has the given attribute. Supports
/// Attribute::AttrKind and StringRef as \p AttrKind types.
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 7ad1ad4cddb70..32af58a43b68e 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -500,6 +500,22 @@ template Attribute
CallBase::getFnAttrOnCalledFunction(Attribute::AttrKind Kind) const;
template Attribute CallBase::getFnAttrOnCalledFunction(StringRef Kind) const;
+template <typename AK>
+Attribute CallBase::getParamAttrOnCalledFunction(unsigned ArgNo,
+ AK Kind) const {
+ Value *V = getCalledOperand();
+
+ if (auto *F = dyn_cast<Function>(V))
+ return F->getAttributes().getParamAttr(ArgNo, Kind);
+
+ return Attribute();
+}
+template Attribute
+CallBase::getParamAttrOnCalledFunction(unsigned ArgNo,
+ Attribute::AttrKind Kind) const;
+template Attribute CallBase::getParamAttrOnCalledFunction(unsigned ArgNo,
+ StringRef Kind) const;
+
void CallBase::getOperandBundlesAsDefs(
SmallVectorImpl<OperandBundleDef> &Defs) const {
for (unsigned i = 0, e = getNumOperandBundles(); i != e; ++i)
diff --git a/llvm/unittests/IR/AttributesTest.cpp b/llvm/unittests/IR/AttributesTest.cpp
index a7967593c2f96..da72fa14510cb 100644
--- a/llvm/unittests/IR/AttributesTest.cpp
+++ b/llvm/unittests/IR/AttributesTest.cpp
@@ -340,4 +340,50 @@ TEST(Attributes, ConstantRangeAttributeCAPI) {
}
}
+TEST(Attributes, CalleeAttributes) {
+ const char *IRString = R"IR(
+ declare void @f1(i32 %i)
+ declare void @f2(i32 range(i32 1, 2) %i)
+
+ define void @g1(i32 %i) {
+ call void @f1(i32 %i)
+ ret void
+ }
+ define void @g2(i32 %i) {
+ call void @f2(i32 %i)
+ ret void
+ }
+ define void @g3(i32 %i) {
+ call void @f1(i32 range(i32 3, 4) %i)
+ ret void
+ }
+ define void @g4(i32 %i) {
+ call void @f2(i32 range(i32 3, 4) %i)
+ ret void
+ }
+ )IR";
+
+ SMDiagnostic Err;
+ LLVMContext Context;
+ std::unique_ptr<Module> M = parseAssemblyString(IRString, Err, Context);
+ ASSERT_TRUE(M);
+
+ {
+ auto *I = cast<CallBase>(&M->getFunction("g1")->getEntryBlock().front());
+ ASSERT_FALSE(I->getParamAttr(0, Attribute::Range).isValid());
+ }
+ {
+ auto *I = cast<CallBase>(&M->getFunction("g2")->getEntryBlock().front());
+ ASSERT_TRUE(I->getParamAttr(0, Attribute::Range).isValid());
+ }
+ {
+ auto *I = cast<CallBase>(&M->getFunction("g3")->getEntryBlock().front());
+ ASSERT_TRUE(I->getParamAttr(0, Attribute::Range).isValid());
+ }
+ {
+ auto *I = cast<CallBase>(&M->getFunction("g4")->getEntryBlock().front());
+ ASSERT_TRUE(I->getParamAttr(0, Attribute::Range).isValid());
+ }
+}
+
} // end anonymous namespace
More information about the llvm-commits
mailing list