[llvm] 00a8314 - [AArch64][SME] Extend Inliner cost-model with custom penalty for calls. (#68416)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 31 03:28:44 PDT 2023
Author: Sander de Smalen
Date: 2023-10-31T10:28:40Z
New Revision: 00a831421fdd94aec65221bdb37042c1aacfe8e0
URL: https://github.com/llvm/llvm-project/commit/00a831421fdd94aec65221bdb37042c1aacfe8e0
DIFF: https://github.com/llvm/llvm-project/commit/00a831421fdd94aec65221bdb37042c1aacfe8e0.diff
LOG: [AArch64][SME] Extend Inliner cost-model with custom penalty for calls. (#68416)
This is a stacked PR following on from #68415
This patch has two purposes:
(1) It tries to make inlining more likely when it can avoid a
streaming-mode change.
(2) It avoids inlining when inlining causes more streaming-mode changes.
An example of (1) is:
```
void streaming_compatible_bar(void);
void foo(void) __arm_streaming {
/* other code */
streaming_compatible_bar();
/* other code */
}
void f(void) {
foo(); // expensive streaming mode change
}
->
void f(void) {
/* other code */
streaming_compatible_bar();
/* other code */
}
```
where it wouldn't have inlined the function when foo would be a
non-streaming function.
An example of (2) is:
```
void streaming_bar(void) __arm_streaming;
void foo(void) __arm_streaming {
streaming_bar();
streaming_bar();
}
void f(void) {
foo(); // expensive streaming mode change
}
-> (do not inline into)
void f(void) {
streaming_bar(); // these are now two expensive streaming mode changes
streaming_bar();
}```
Added:
llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs-low-threshold.ll
Modified:
llvm/include/llvm/Analysis/InlineCost.h
llvm/include/llvm/Analysis/TargetTransformInfo.h
llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
llvm/lib/Analysis/InlineCost.cpp
llvm/lib/Analysis/TargetTransformInfo.cpp
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
llvm/lib/Transforms/IPO/PartialInlining.cpp
llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs.ll
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/InlineCost.h b/llvm/include/llvm/Analysis/InlineCost.h
index 57f452853d2d6d6..3f0bb879e021fd7 100644
--- a/llvm/include/llvm/Analysis/InlineCost.h
+++ b/llvm/include/llvm/Analysis/InlineCost.h
@@ -259,7 +259,8 @@ InlineParams getInlineParams(unsigned OptLevel, unsigned SizeOptLevel);
/// Return the cost associated with a callsite, including parameter passing
/// and the call/return instruction.
-int getCallsiteCost(const CallBase &Call, const DataLayout &DL);
+int getCallsiteCost(const TargetTransformInfo &TTI, const CallBase &Call,
+ const DataLayout &DL);
/// Get an InlineCost object representing the cost of inlining this
/// callsite.
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 3ec80d99b392b2e..c18e0acdb759a8d 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1517,6 +1517,15 @@ class TargetTransformInfo {
bool areInlineCompatible(const Function *Caller,
const Function *Callee) const;
+ /// Returns a penalty for invoking call \p Call in \p F.
+ /// For example, if a function F calls a function G, which in turn calls
+ /// function H, then getInlineCallPenalty(F, H()) would return the
+ /// penalty of calling H from F, e.g. after inlining G into F.
+ /// \p DefaultCallPenalty is passed to give a default penalty that
+ /// the target can amend or override.
+ unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
+ unsigned DefaultCallPenalty) const;
+
/// \returns True if the caller and callee agree on how \p Types will be
/// passed to or returned from the callee.
/// to the callee.
@@ -2012,6 +2021,8 @@ class TargetTransformInfo::Concept {
std::optional<uint32_t> AtomicCpySize) const = 0;
virtual bool areInlineCompatible(const Function *Caller,
const Function *Callee) const = 0;
+ virtual unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
+ unsigned DefaultCallPenalty) const = 0;
virtual bool areTypesABICompatible(const Function *Caller,
const Function *Callee,
const ArrayRef<Type *> &Types) const = 0;
@@ -2673,6 +2684,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
const Function *Callee) const override {
return Impl.areInlineCompatible(Caller, Callee);
}
+ unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
+ unsigned DefaultCallPenalty) const override {
+ return Impl.getInlineCallPenalty(F, Call, DefaultCallPenalty);
+ }
bool areTypesABICompatible(const Function *Caller, const Function *Callee,
const ArrayRef<Type *> &Types) const override {
return Impl.areTypesABICompatible(Caller, Callee, Types);
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index e14915443513990..2ccf57c22234f9a 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -802,6 +802,11 @@ class TargetTransformInfoImplBase {
Callee->getFnAttribute("target-features"));
}
+ unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
+ unsigned DefaultCallPenalty) const {
+ return DefaultCallPenalty;
+ }
+
bool areTypesABICompatible(const Function *Caller, const Function *Callee,
const ArrayRef<Type *> &Types) const {
return (Caller->getFnAttribute("target-cpu") ==
diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index fa0c30637633df3..7096e06d925adef 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -695,7 +695,8 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
}
} else
// Otherwise simply add the cost for merely making the call.
- addCost(CallPenalty);
+ addCost(TTI.getInlineCallPenalty(CandidateCall.getCaller(), Call,
+ CallPenalty));
}
void onFinalizeSwitch(unsigned JumpTableSize,
@@ -918,7 +919,7 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
// Compute the total savings for the call site.
auto *CallerBB = CandidateCall.getParent();
BlockFrequencyInfo *CallerBFI = &(GetBFI(*(CallerBB->getParent())));
- CycleSavings += getCallsiteCost(this->CandidateCall, DL);
+ CycleSavings += getCallsiteCost(TTI, this->CandidateCall, DL);
CycleSavings *= *CallerBFI->getBlockProfileCount(CallerBB);
// Remove the cost of the cold basic blocks to model the runtime cost more
@@ -1076,7 +1077,7 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
// Give out bonuses for the callsite, as the instructions setting them up
// will be gone after inlining.
- addCost(-getCallsiteCost(this->CandidateCall, DL));
+ addCost(-getCallsiteCost(TTI, this->CandidateCall, DL));
// If this function uses the coldcc calling convention, prefer not to inline
// it.
@@ -1315,7 +1316,7 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
InlineResult onAnalysisStart() override {
increment(InlineCostFeatureIndex::callsite_cost,
- -1 * getCallsiteCost(this->CandidateCall, DL));
+ -1 * getCallsiteCost(TTI, this->CandidateCall, DL));
set(InlineCostFeatureIndex::cold_cc_penalty,
(F.getCallingConv() == CallingConv::Cold));
@@ -2887,7 +2888,8 @@ static bool functionsHaveCompatibleAttributes(
AttributeFuncs::areInlineCompatible(*Caller, *Callee);
}
-int llvm::getCallsiteCost(const CallBase &Call, const DataLayout &DL) {
+int llvm::getCallsiteCost(const TargetTransformInfo &TTI, const CallBase &Call,
+ const DataLayout &DL) {
int64_t Cost = 0;
for (unsigned I = 0, E = Call.arg_size(); I != E; ++I) {
if (Call.isByValArgument(I)) {
@@ -2917,7 +2919,8 @@ int llvm::getCallsiteCost(const CallBase &Call, const DataLayout &DL) {
}
// The call instruction also disappears after inlining.
Cost += InstrCost;
- Cost += CallPenalty;
+ Cost += TTI.getInlineCallPenalty(Call.getCaller(), Call, CallPenalty);
+
return std::min<int64_t>(Cost, INT_MAX);
}
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index aad14f21d114619..caa9b17ae695e49 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1133,6 +1133,13 @@ bool TargetTransformInfo::areInlineCompatible(const Function *Caller,
return TTIImpl->areInlineCompatible(Caller, Callee);
}
+unsigned
+TargetTransformInfo::getInlineCallPenalty(const Function *F,
+ const CallBase &Call,
+ unsigned DefaultCallPenalty) const {
+ return TTIImpl->getInlineCallPenalty(F, Call, DefaultCallPenalty);
+}
+
bool TargetTransformInfo::areTypesABICompatible(
const Function *Caller, const Function *Callee,
const ArrayRef<Type *> &Types) const {
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 776619c90393c03..0eaa3e817c0b62d 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -46,6 +46,15 @@ static cl::opt<unsigned>
NeonNonConstStrideOverhead("neon-nonconst-stride-overhead", cl::init(10),
cl::Hidden);
+static cl::opt<unsigned> CallPenaltyChangeSM(
+ "call-penalty-sm-change", cl::init(5), cl::Hidden,
+ cl::desc(
+ "Penalty of calling a function that requires a change to PSTATE.SM"));
+
+static cl::opt<unsigned> InlineCallPenaltyChangeSM(
+ "inline-call-penalty-sm-change", cl::init(10), cl::Hidden,
+ cl::desc("Penalty of inlining a call that requires a change to PSTATE.SM"));
+
namespace {
class TailFoldingOption {
// These bitfields will only ever be set to something non-zero in operator=,
@@ -269,6 +278,40 @@ bool AArch64TTIImpl::areTypesABICompatible(
return true;
}
+unsigned
+AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
+ unsigned DefaultCallPenalty) const {
+ // This function calculates a penalty for executing Call in F.
+ //
+ // There are two ways this function can be called:
+ // (1) F:
+ // call from F -> G (the call here is Call)
+ //
+ // For (1), Call.getCaller() == F, so it will always return a high cost if
+ // a streaming-mode change is required (thus promoting the need to inline the
+ // function)
+ //
+ // (2) F:
+ // call from F -> G (the call here is not Call)
+ // G:
+ // call from G -> H (the call here is Call)
+ //
+ // For (2), if after inlining the body of G into F the call to H requires a
+ // streaming-mode change, and the call to G from F would also require a
+ // streaming-mode change, then there is benefit to do the streaming-mode
+ // change only once and avoid inlining of G into F.
+ SMEAttrs FAttrs(*F);
+ SMEAttrs CalleeAttrs(Call);
+ if (FAttrs.requiresSMChange(CalleeAttrs)) {
+ if (F == Call.getCaller()) // (1)
+ return CallPenaltyChangeSM * DefaultCallPenalty;
+ if (FAttrs.requiresSMChange(SMEAttrs(*Call.getCaller()))) // (2)
+ return InlineCallPenaltyChangeSM * DefaultCallPenalty;
+ }
+
+ return DefaultCallPenalty;
+}
+
bool AArch64TTIImpl::shouldMaximizeVectorBandwidth(
TargetTransformInfo::RegisterKind K) const {
assert(K != TargetTransformInfo::RGK_Scalar);
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index c08004ad299fd68..fa4c93d5f77a196 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -80,6 +80,9 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
bool areTypesABICompatible(const Function *Caller, const Function *Callee,
const ArrayRef<Type *> &Types) const;
+ unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
+ unsigned DefaultCallPenalty) const;
+
/// \name Scalar TTI Implementations
/// @{
diff --git a/llvm/lib/Transforms/IPO/PartialInlining.cpp b/llvm/lib/Transforms/IPO/PartialInlining.cpp
index 25da06add24f031..aa4f205ec5bdf1e 100644
--- a/llvm/lib/Transforms/IPO/PartialInlining.cpp
+++ b/llvm/lib/Transforms/IPO/PartialInlining.cpp
@@ -767,7 +767,7 @@ bool PartialInlinerImpl::shouldPartialInline(
const DataLayout &DL = Caller->getParent()->getDataLayout();
// The savings of eliminating the call:
- int NonWeightedSavings = getCallsiteCost(CB, DL);
+ int NonWeightedSavings = getCallsiteCost(CalleeTTI, CB, DL);
BlockFrequency NormWeightedSavings(NonWeightedSavings);
// Weighted saving is smaller than weighted cost, return false
@@ -842,12 +842,12 @@ PartialInlinerImpl::computeBBInlineCost(BasicBlock *BB,
}
if (CallInst *CI = dyn_cast<CallInst>(&I)) {
- InlineCost += getCallsiteCost(*CI, DL);
+ InlineCost += getCallsiteCost(*TTI, *CI, DL);
continue;
}
if (InvokeInst *II = dyn_cast<InvokeInst>(&I)) {
- InlineCost += getCallsiteCost(*II, DL);
+ InlineCost += getCallsiteCost(*TTI, *II, DL);
continue;
}
diff --git a/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs-low-threshold.ll b/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs-low-threshold.ll
new file mode 100644
index 000000000000000..72003d2fee4bac6
--- /dev/null
+++ b/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs-low-threshold.ll
@@ -0,0 +1,45 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2
+; RUN: opt < %s -mtriple=aarch64-unknown-linux-gnu -mattr=+sme -S -passes=inline -inlinedefault-threshold=1 | FileCheck %s
+
+; This test sets the inline-threshold to 1 such that by default the call to @streaming_callee is not inlined.
+; However, if the call to @streaming_callee requires a streaming-mode change, it should always inline the call because the streaming-mode change is more expensive.
+target triple = "aarch64"
+
+declare void @streaming_compatible_f() "aarch64_pstate_sm_compatible"
+
+; Function @streaming_callee doesn't contain any operations that may use ZA
+; state and therefore can be legally inlined into a normal function.
+define void @streaming_callee() "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: define void @streaming_callee
+; CHECK-SAME: () #[[ATTR1:[0-9]+]] {
+; CHECK-NEXT: call void @streaming_compatible_f()
+; CHECK-NEXT: call void @streaming_compatible_f()
+; CHECK-NEXT: ret void
+;
+ call void @streaming_compatible_f()
+ call void @streaming_compatible_f()
+ ret void
+}
+
+; Inline call to @streaming_callee to remove a streaming mode change.
+define void @non_streaming_caller_inline() {
+; CHECK-LABEL: define void @non_streaming_caller_inline
+; CHECK-SAME: () #[[ATTR2:[0-9]+]] {
+; CHECK-NEXT: call void @streaming_compatible_f()
+; CHECK-NEXT: call void @streaming_compatible_f()
+; CHECK-NEXT: ret void
+;
+ call void @streaming_callee()
+ ret void
+}
+
+; Don't inline call to @streaming_callee when the inline-threshold is set to 1, because it does not eliminate a streaming-mode change.
+define void @streaming_caller_dont_inline() "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: define void @streaming_caller_dont_inline
+; CHECK-SAME: () #[[ATTR1]] {
+; CHECK-NEXT: call void @streaming_callee()
+; CHECK-NEXT: ret void
+;
+ call void @streaming_callee()
+ ret void
+}
diff --git a/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs.ll b/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs.ll
index f2f5768dbe9c6e9..d6b1f3ef45e7655 100644
--- a/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs.ll
+++ b/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs.ll
@@ -581,3 +581,98 @@ entry:
%res = call i64 @normal_callee_call_sme_state()
ret i64 %res
}
+
+
+
+declare void @streaming_body() "aarch64_pstate_sm_enabled"
+
+define void @streaming_caller_single_streaming_callee() "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: define void @streaming_caller_single_streaming_callee
+; CHECK-SAME: () #[[ATTR2]] {
+; CHECK-NEXT: call void @streaming_body()
+; CHECK-NEXT: ret void
+;
+ call void @streaming_body()
+ ret void
+}
+
+define void @streaming_caller_multiple_streaming_callees() "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: define void @streaming_caller_multiple_streaming_callees
+; CHECK-SAME: () #[[ATTR2]] {
+; CHECK-NEXT: call void @streaming_body()
+; CHECK-NEXT: call void @streaming_body()
+; CHECK-NEXT: ret void
+;
+ call void @streaming_body()
+ call void @streaming_body()
+ ret void
+}
+
+; Allow inlining, as inline it would not increase the number of streaming-mode changes.
+define void @streaming_caller_single_streaming_callee_inline() {
+; CHECK-LABEL: define void @streaming_caller_single_streaming_callee_inline
+; CHECK-SAME: () #[[ATTR1]] {
+; CHECK-NEXT: call void @streaming_body()
+; CHECK-NEXT: ret void
+;
+ call void @streaming_caller_single_streaming_callee()
+ ret void
+}
+
+; Prevent inlining, as inline it would lead to multiple streaming-mode changes.
+define void @streaming_caller_multiple_streaming_callees_dont_inline() {
+; CHECK-LABEL: define void @streaming_caller_multiple_streaming_callees_dont_inline
+; CHECK-SAME: () #[[ATTR1]] {
+; CHECK-NEXT: call void @streaming_caller_multiple_streaming_callees()
+; CHECK-NEXT: ret void
+;
+ call void @streaming_caller_multiple_streaming_callees()
+ ret void
+}
+
+declare void @streaming_compatible_body() "aarch64_pstate_sm_compatible"
+
+define void @streaming_caller_single_streaming_compatible_callee() "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: define void @streaming_caller_single_streaming_compatible_callee
+; CHECK-SAME: () #[[ATTR2]] {
+; CHECK-NEXT: call void @streaming_compatible_body()
+; CHECK-NEXT: ret void
+;
+ call void @streaming_compatible_body()
+ ret void
+}
+
+define void @streaming_caller_multiple_streaming_compatible_callees() "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: define void @streaming_caller_multiple_streaming_compatible_callees
+; CHECK-SAME: () #[[ATTR2]] {
+; CHECK-NEXT: call void @streaming_compatible_body()
+; CHECK-NEXT: call void @streaming_compatible_body()
+; CHECK-NEXT: ret void
+;
+ call void @streaming_compatible_body()
+ call void @streaming_compatible_body()
+ ret void
+}
+
+; Allow inlining, as inline would remove a streaming-mode change.
+define void @streaming_caller_single_streaming_compatible_callee_inline() {
+; CHECK-LABEL: define void @streaming_caller_single_streaming_compatible_callee_inline
+; CHECK-SAME: () #[[ATTR1]] {
+; CHECK-NEXT: call void @streaming_compatible_body()
+; CHECK-NEXT: ret void
+;
+ call void @streaming_caller_single_streaming_compatible_callee()
+ ret void
+}
+
+; Allow inlining, as inline would remove several stremaing-mode changes.
+define void @streaming_caller_multiple_streaming_compatible_callees_inline() {
+; CHECK-LABEL: define void @streaming_caller_multiple_streaming_compatible_callees_inline
+; CHECK-SAME: () #[[ATTR1]] {
+; CHECK-NEXT: call void @streaming_compatible_body()
+; CHECK-NEXT: call void @streaming_compatible_body()
+; CHECK-NEXT: ret void
+;
+ call void @streaming_caller_multiple_streaming_compatible_callees()
+ ret void
+}
More information about the llvm-commits
mailing list