[llvm] [AArch64][SME] Extend Inliner cost-model with custom penalty for calls. (PR #68416)

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 6 06:07:12 PDT 2023


https://github.com/sdesmalen-arm created https://github.com/llvm/llvm-project/pull/68416

This is a stacked PR following on from #68415 

This patch has two purposes:
(1) It tries to make inlining more likely when it can avoid a streaming-mode change.
(2) It avoids inlining when inlining causes more streaming-mode changes.

An example of (1) is:
```
  void streaming_compatible_bar(void);

  void foo(void) __arm_streaming {
    /* other code */
    streaming_compatible_bar();
    /* other code */
  }

  void f(void) {
    foo();            // expensive streaming mode change
  }

  ->

  void f(void) {
    /* other code */
    streaming_compatible_bar();
    /* other code */
  }
```
where it wouldn't have inlined the function when foo would be a non-streaming function.

An example of (2) is:
```
  void streaming_bar(void) __arm_streaming;

  void foo(void) __arm_streaming {
    streaming_bar();
    streaming_bar();
  }

  void f(void) {
    foo();            // expensive streaming mode change
  }

  -> (do not inline into)

  void f(void) {
    streaming_bar();  // these are now two expensive streaming mode changes
    streaming_bar();
  }```

>From 2ec46c7d50dde0c0dddd39c3936c57310bb09d4e Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Thu, 7 Sep 2023 17:30:32 +0100
Subject: [PATCH 1/2] [AArch64][SME] Allow inlining when streaming-mode
 attributes dont match up.

The use-case here is to support things like:

  int foo(int x, int y) __arm_streaming { return std::max<int>(x, y); }

where the call to non-streaming `std::max<int>(x, y)` can be safely inlined
into the streaming function.

This is a first step and will need further work to allow more cases (e.g.
more finegrained analysis of the function calls to ensure they don't
result in any incompatible instructions for the requested mode).
---
 .../AArch64/AArch64TargetTransformInfo.cpp    |  41 +++++-
 .../Inline/AArch64/sme-pstatesm-attrs.ll      | 138 ++++++++++++++----
 .../Inline/AArch64/sme-pstateza-attrs.ll      |  71 ++++++++-
 3 files changed, 214 insertions(+), 36 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index cded28054f59259..d053350c08bf9ab 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -190,16 +190,49 @@ static cl::opt<bool> EnableFixedwidthAutovecInStreamingMode(
 static cl::opt<bool> EnableScalableAutovecInStreamingMode(
     "enable-scalable-autovec-in-streaming-mode", cl::init(false), cl::Hidden);
 
+static bool isSMEABIRoutineCall(const CallInst &CI) {
+  const auto *F = CI.getCalledFunction();
+  return F && StringSwitch<bool>(F->getName())
+                  .Case("__arm_sme_state", true)
+                  .Case("__arm_tpidr2_save", true)
+                  .Case("__arm_tpidr2_restore", true)
+                  .Case("__arm_za_disable", true)
+                  .Default(false);
+}
+
+/// Returns true if the function has explicit operations that can only be lowered
+/// using incompatible instructions for the selected mode.
+/// This also returns true if the function F may use or modify ZA state.
+static bool hasPossibleIncompatibleOps(const Function *F) {
+  for (const BasicBlock &BB : *F) {
+    for (const Instruction &I : BB) {
+      // Be conservative for now and assume that any call to inline asm or to
+      // intrinsics could could result in non-streaming ops (e.g. calls to
+      // @llvm.aarch64.* or @llvm.gather/scatter intrinsics). We can assume that
+      // all native LLVM instructions can be lowered to compatible instructions.
+      if (isa<CallInst>(I) && !I.isDebugOrPseudoInst() &&
+          (cast<CallInst>(I).isInlineAsm() || isa<IntrinsicInst>(I) ||
+           isSMEABIRoutineCall(cast<CallInst>(I))))
+        return true;
+    }
+  }
+  return false;
+}
+
 bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
                                          const Function *Callee) const {
   SMEAttrs CallerAttrs(*Caller);
   SMEAttrs CalleeAttrs(*Callee);
-  if (CallerAttrs.requiresSMChange(CalleeAttrs,
-                                   /*BodyOverridesInterface=*/true) ||
-      CallerAttrs.requiresLazySave(CalleeAttrs) ||
-      CalleeAttrs.hasNewZABody())
+  if (CalleeAttrs.hasNewZABody())
     return false;
 
+  if (CallerAttrs.requiresLazySave(CalleeAttrs) ||
+      CallerAttrs.requiresSMChange(CalleeAttrs,
+                                   /*BodyOverridesInterface=*/true)) {
+    if (hasPossibleIncompatibleOps(Callee))
+      return false;
+  }
+
   const TargetMachine &TM = getTLI()->getTargetMachine();
 
   const FeatureBitset &CallerBits =
diff --git a/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs.ll b/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs.ll
index 3df5400875ae288..f2f5768dbe9c6e9 100644
--- a/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs.ll
+++ b/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs.ll
@@ -102,11 +102,11 @@ entry:
 ; [ ] N  -> SC
 ; [ ] N  -> N + B
 ; [ ] N  -> SC + B
-define void @normal_caller_streaming_callee_dont_inline() {
-; CHECK-LABEL: define void @normal_caller_streaming_callee_dont_inline
+define void @normal_caller_streaming_callee_inline() {
+; CHECK-LABEL: define void @normal_caller_streaming_callee_inline
 ; CHECK-SAME: () #[[ATTR1]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @streaming_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -136,11 +136,11 @@ entry:
 ; [ ] N  -> SC
 ; [x] N  -> N + B
 ; [ ] N  -> SC + B
-define void @normal_caller_locally_streaming_callee_dont_inline() {
-; CHECK-LABEL: define void @normal_caller_locally_streaming_callee_dont_inline
+define void @normal_caller_locally_streaming_callee_inline() {
+; CHECK-LABEL: define void @normal_caller_locally_streaming_callee_inline
 ; CHECK-SAME: () #[[ATTR1]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @locally_streaming_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -153,11 +153,11 @@ entry:
 ; [ ] N  -> SC
 ; [ ] N  -> N + B
 ; [x] N  -> SC + B
-define void @normal_caller_streaming_compatible_locally_streaming_callee_dont_inline() {
-; CHECK-LABEL: define void @normal_caller_streaming_compatible_locally_streaming_callee_dont_inline
+define void @normal_caller_streaming_compatible_locally_streaming_callee_inline() {
+; CHECK-LABEL: define void @normal_caller_streaming_compatible_locally_streaming_callee_inline
 ; CHECK-SAME: () #[[ATTR1]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @streaming_compatible_locally_streaming_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -170,11 +170,11 @@ entry:
 ; [ ] S  -> SC
 ; [ ] S  -> N + B
 ; [ ] S  -> SC + B
-define void @streaming_caller_normal_callee_dont_inline() "aarch64_pstate_sm_enabled" {
-; CHECK-LABEL: define void @streaming_caller_normal_callee_dont_inline
+define void @streaming_caller_normal_callee_inline() "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: define void @streaming_caller_normal_callee_inline
 ; CHECK-SAME: () #[[ATTR2]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @normal_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -255,11 +255,11 @@ entry:
 ; [ ] N + B -> SC
 ; [ ] N + B -> N + B
 ; [ ] N + B -> SC + B
-define void @locally_streaming_caller_normal_callee_dont_inline() "aarch64_pstate_sm_body" {
-; CHECK-LABEL: define void @locally_streaming_caller_normal_callee_dont_inline
+define void @locally_streaming_caller_normal_callee_inline() "aarch64_pstate_sm_body" {
+; CHECK-LABEL: define void @locally_streaming_caller_normal_callee_inline
 ; CHECK-SAME: () #[[ATTR3]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @normal_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -340,11 +340,11 @@ entry:
 ; [ ] SC -> SC
 ; [ ] SC -> N + B
 ; [ ] SC -> SC + B
-define void @streaming_compatible_caller_normal_callee_dont_inline() "aarch64_pstate_sm_compatible" {
-; CHECK-LABEL: define void @streaming_compatible_caller_normal_callee_dont_inline
+define void @streaming_compatible_caller_normal_callee_inline() "aarch64_pstate_sm_compatible" {
+; CHECK-LABEL: define void @streaming_compatible_caller_normal_callee_inline
 ; CHECK-SAME: () #[[ATTR0]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @normal_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -357,11 +357,11 @@ entry:
 ; [ ] SC -> SC
 ; [ ] SC -> N + B
 ; [ ] SC -> SC + B
-define void @streaming_compatible_caller_streaming_callee_dont_inline() "aarch64_pstate_sm_compatible" {
-; CHECK-LABEL: define void @streaming_compatible_caller_streaming_callee_dont_inline
+define void @streaming_compatible_caller_streaming_callee_inline() "aarch64_pstate_sm_compatible" {
+; CHECK-LABEL: define void @streaming_compatible_caller_streaming_callee_inline
 ; CHECK-SAME: () #[[ATTR0]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @streaming_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -391,11 +391,11 @@ entry:
 ; [ ] SC -> SC
 ; [x] SC -> N + B
 ; [ ] SC -> SC + B
-define void @streaming_compatible_caller_locally_streaming_callee_dont_inline() "aarch64_pstate_sm_compatible" {
-; CHECK-LABEL: define void @streaming_compatible_caller_locally_streaming_callee_dont_inline
+define void @streaming_compatible_caller_locally_streaming_callee_inline() "aarch64_pstate_sm_compatible" {
+; CHECK-LABEL: define void @streaming_compatible_caller_locally_streaming_callee_inline
 ; CHECK-SAME: () #[[ATTR0]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @locally_streaming_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -408,11 +408,11 @@ entry:
 ; [ ] SC -> SC
 ; [ ] SC -> N + B
 ; [x] SC -> SC + B
-define void @streaming_compatible_caller_streaming_compatible_locally_streaming_callee_dont_inline() "aarch64_pstate_sm_compatible" {
-; CHECK-LABEL: define void @streaming_compatible_caller_streaming_compatible_locally_streaming_callee_dont_inline
+define void @streaming_compatible_caller_streaming_compatible_locally_streaming_callee_inline() "aarch64_pstate_sm_compatible" {
+; CHECK-LABEL: define void @streaming_compatible_caller_streaming_compatible_locally_streaming_callee_inline
 ; CHECK-SAME: () #[[ATTR0]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @streaming_compatible_locally_streaming_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -424,11 +424,11 @@ entry:
 ; [ ] SC + B -> SC
 ; [ ] SC + B -> N + B
 ; [ ] SC + B -> SC + B
-define void @streaming_compatible_locally_streaming_caller_normal_callee_dont_inline() "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" {
-; CHECK-LABEL: define void @streaming_compatible_locally_streaming_caller_normal_callee_dont_inline
+define void @streaming_compatible_locally_streaming_caller_normal_callee_inline() "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" {
+; CHECK-LABEL: define void @streaming_compatible_locally_streaming_caller_normal_callee_inline
 ; CHECK-SAME: () #[[ATTR4]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @normal_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -503,3 +503,81 @@ entry:
   call void @streaming_compatible_locally_streaming_callee()
   ret void
 }
+
+define void @normal_callee_with_inlineasm() {
+; CHECK-LABEL: define void @normal_callee_with_inlineasm
+; CHECK-SAME: () #[[ATTR1]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    call void asm sideeffect "
+; CHECK-NEXT:    ret void
+;
+entry:
+  call void asm sideeffect "; inlineasm", ""()
+  ret void
+}
+
+define void @streaming_caller_normal_callee_with_inlineasm_dont_inline() "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: define void @streaming_caller_normal_callee_with_inlineasm_dont_inline
+; CHECK-SAME: () #[[ATTR2]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    call void @normal_callee_with_inlineasm()
+; CHECK-NEXT:    ret void
+;
+entry:
+  call void @normal_callee_with_inlineasm()
+  ret void
+}
+
+define i64 @normal_callee_with_intrinsic_call() {
+; CHECK-LABEL: define i64 @normal_callee_with_intrinsic_call
+; CHECK-SAME: () #[[ATTR1]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[RES:%.*]] = call i64 @llvm.aarch64.sve.cntb(i32 4)
+; CHECK-NEXT:    ret i64 [[RES]]
+;
+entry:
+  %res = call i64 @llvm.aarch64.sve.cntb(i32 4)
+  ret i64 %res
+}
+
+define i64 @streaming_caller_normal_callee_with_intrinsic_call_dont_inline() "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: define i64 @streaming_caller_normal_callee_with_intrinsic_call_dont_inline
+; CHECK-SAME: () #[[ATTR2]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[RES:%.*]] = call i64 @normal_callee_with_intrinsic_call()
+; CHECK-NEXT:    ret i64 [[RES]]
+;
+entry:
+  %res = call i64 @normal_callee_with_intrinsic_call()
+  ret i64 %res
+}
+
+declare i64 @llvm.aarch64.sve.cntb(i32)
+
+define i64 @normal_callee_call_sme_state() {
+; CHECK-LABEL: define i64 @normal_callee_call_sme_state
+; CHECK-SAME: () #[[ATTR1]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[RES:%.*]] = call { i64, i64 } @__arm_sme_state()
+; CHECK-NEXT:    [[RES_0:%.*]] = extractvalue { i64, i64 } [[RES]], 0
+; CHECK-NEXT:    ret i64 [[RES_0]]
+;
+entry:
+  %res = call {i64, i64} @__arm_sme_state()
+  %res.0 = extractvalue {i64, i64} %res, 0
+  ret i64 %res.0
+}
+
+declare {i64, i64} @__arm_sme_state()
+
+define i64 @streaming_caller_normal_callee_call_sme_state_dont_inline() "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: define i64 @streaming_caller_normal_callee_call_sme_state_dont_inline
+; CHECK-SAME: () #[[ATTR2]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[RES:%.*]] = call i64 @normal_callee_call_sme_state()
+; CHECK-NEXT:    ret i64 [[RES]]
+;
+entry:
+  %res = call i64 @normal_callee_call_sme_state()
+  ret i64 %res
+}
diff --git a/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll b/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll
index a833e7a911ac03f..7b104977cff5a7b 100644
--- a/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll
+++ b/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll
@@ -3,10 +3,12 @@
 
 declare void @inlined_body()
 
+;
 ; Define some functions that will be called by the functions below.
 ; These just call a '...body()' function. If we see the call to one of
 ; these functions being replaced by '...body()', then we know it has been
 ; inlined.
+;
 
 define void @nonza_callee() {
 ; CHECK-LABEL: define void @nonza_callee
@@ -42,6 +44,7 @@ define void @new_za_callee() "aarch64_pstate_za_new" {
   ret void
 }
 
+;
 ; Now test that inlining only happens when no lazy-save is needed.
 ; Test for a number of combinations, where:
 ; N   Not using ZA.
@@ -85,7 +88,7 @@ define void @new_za_caller_nonza_callee_dont_inline() "aarch64_pstate_za_new" {
 ; CHECK-LABEL: define void @new_za_caller_nonza_callee_dont_inline
 ; CHECK-SAME: () #[[ATTR2]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @nonza_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -130,7 +133,7 @@ define void @shared_za_caller_nonza_callee_dont_inline() "aarch64_pstate_za_shar
 ; CHECK-LABEL: define void @shared_za_caller_nonza_callee_dont_inline
 ; CHECK-SAME: () #[[ATTR1]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @nonza_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -167,3 +170,67 @@ entry:
   call void @shared_za_callee()
   ret void
 }
+
+define void @private_za_callee_call_za_disable() {
+; CHECK-LABEL: define void @private_za_callee_call_za_disable
+; CHECK-SAME: () #[[ATTR0]] {
+; CHECK-NEXT:    call void @__arm_za_disable()
+; CHECK-NEXT:    ret void
+;
+  call void @__arm_za_disable()
+  ret void
+}
+
+define void @shared_za_caller_private_za_callee_call_za_disable() "aarch64_pstate_za_shared" {
+; CHECK-LABEL: define void @shared_za_caller_private_za_callee_call_za_disable
+; CHECK-SAME: () #[[ATTR1]] {
+; CHECK-NEXT:    call void @private_za_callee_call_za_disable()
+; CHECK-NEXT:    ret void
+;
+  call void @private_za_callee_call_za_disable()
+  ret void
+}
+
+define void @private_za_callee_call_tpidr2_save() {
+; CHECK-LABEL: define void @private_za_callee_call_tpidr2_save
+; CHECK-SAME: () #[[ATTR0]] {
+; CHECK-NEXT:    call void @__arm_tpidr2_save()
+; CHECK-NEXT:    ret void
+;
+  call void @__arm_tpidr2_save()
+  ret void
+}
+
+define void @shared_za_caller_private_za_callee_call_tpidr2_save_dont_inline() "aarch64_pstate_za_shared" {
+; CHECK-LABEL: define void @shared_za_caller_private_za_callee_call_tpidr2_save_dont_inline
+; CHECK-SAME: () #[[ATTR1]] {
+; CHECK-NEXT:    call void @private_za_callee_call_tpidr2_save()
+; CHECK-NEXT:    ret void
+;
+  call void @private_za_callee_call_tpidr2_save()
+  ret void
+}
+
+define void @private_za_callee_call_tpidr2_restore(ptr %ptr) {
+; CHECK-LABEL: define void @private_za_callee_call_tpidr2_restore
+; CHECK-SAME: (ptr [[PTR:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:    call void @__arm_tpidr2_restore(ptr [[PTR]])
+; CHECK-NEXT:    ret void
+;
+  call void @__arm_tpidr2_restore(ptr %ptr)
+  ret void
+}
+
+define void @shared_za_caller_private_za_callee_call_tpidr2_restore_dont_inline(ptr %ptr) "aarch64_pstate_za_shared" {
+; CHECK-LABEL: define void @shared_za_caller_private_za_callee_call_tpidr2_restore_dont_inline
+; CHECK-SAME: (ptr [[PTR:%.*]]) #[[ATTR1]] {
+; CHECK-NEXT:    call void @private_za_callee_call_tpidr2_restore(ptr [[PTR]])
+; CHECK-NEXT:    ret void
+;
+  call void @private_za_callee_call_tpidr2_restore(ptr %ptr)
+  ret void
+}
+
+declare void @__arm_za_disable()
+declare void @__arm_tpidr2_save()
+declare void @__arm_tpidr2_restore(ptr)

>From e98d96e643aa52828f54b0a1f3e3d113896185fa Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Mon, 11 Sep 2023 15:06:30 +0100
Subject: [PATCH 2/2] [AArch64][SME] Extend Inliner cost-model with custom
 penalty for calls.

This patch has two purposes:
(1) It tries to make inlining more likely when it can avoid a streaming-mode change.
(2) It avoids inlining when inlining causes more streaming-mode changes.

An example of (1) is:

  void streaming_compatible_bar(void);

  void foo(void) __arm_streaming {
    /* other code */
    streaming_compatible_bar();
    /* other code */
  }

  void f(void) {
    foo();            // expensive streaming mode change
  }

  ->

  void f(void) {
    /* other code */
    streaming_compatible_bar();
    /* other code */
  }

  where it wouldn't have inlined the function when foo would be a non-streaming function.

An example of (2) is:

  void streaming_bar(void) __arm_streaming;

  void foo(void) __arm_streaming {
    streaming_bar();
    streaming_bar();
  }

  void f(void) {
    foo();            // expensive streaming mode change
  }

  -> (do not inline into)

  void f(void) {
    streaming_bar();  // these are now two expensive streaming mode changes
    streaming_bar();
  }
---
 llvm/include/llvm/Analysis/InlineCost.h       |  3 +-
 .../llvm/Analysis/TargetTransformInfo.h       | 16 ++++
 .../llvm/Analysis/TargetTransformInfoImpl.h   |  5 +
 llvm/lib/Analysis/InlineCost.cpp              | 15 +--
 llvm/lib/Analysis/TargetTransformInfo.cpp     |  6 ++
 .../AArch64/AArch64TargetTransformInfo.cpp    | 34 +++++++
 .../AArch64/AArch64TargetTransformInfo.h      |  3 +
 llvm/lib/Transforms/IPO/PartialInlining.cpp   |  6 +-
 .../sme-pstatesm-attrs-low-threshold.ll       | 43 +++++++++
 .../Inline/AArch64/sme-pstatesm-attrs.ll      | 95 +++++++++++++++++++
 10 files changed, 216 insertions(+), 10 deletions(-)
 create mode 100644 llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs-low-threshold.ll

diff --git a/llvm/include/llvm/Analysis/InlineCost.h b/llvm/include/llvm/Analysis/InlineCost.h
index 57f452853d2d6d6..3f0bb879e021fd7 100644
--- a/llvm/include/llvm/Analysis/InlineCost.h
+++ b/llvm/include/llvm/Analysis/InlineCost.h
@@ -259,7 +259,8 @@ InlineParams getInlineParams(unsigned OptLevel, unsigned SizeOptLevel);
 
 /// Return the cost associated with a callsite, including parameter passing
 /// and the call/return instruction.
-int getCallsiteCost(const CallBase &Call, const DataLayout &DL);
+int getCallsiteCost(const TargetTransformInfo &TTI, const CallBase &Call,
+                    const DataLayout &DL);
 
 /// Get an InlineCost object representing the cost of inlining this
 /// callsite.
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 5234ef8788d9e96..7a85c03d659232e 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1506,6 +1506,15 @@ class TargetTransformInfo {
   bool areInlineCompatible(const Function *Caller,
                            const Function *Callee) const;
 
+  /// Returns a penalty for invoking call \p Call in \p F.
+  /// For example, if a function F calls a function G, which in turn calls
+  /// function H, then getInlineCallPenalty(F, H()) would return the
+  /// penalty of calling H from F, e.g. after inlining G into F.
+  /// \p DefaultCallPenalty is passed to give a default penalty that
+  /// the target can amend or override.
+  unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
+                                unsigned DefaultCallPenalty) const;
+
   /// \returns True if the caller and callee agree on how \p Types will be
   /// passed to or returned from the callee.
   /// to the callee.
@@ -2001,6 +2010,8 @@ class TargetTransformInfo::Concept {
       std::optional<uint32_t> AtomicCpySize) const = 0;
   virtual bool areInlineCompatible(const Function *Caller,
                                    const Function *Callee) const = 0;
+  virtual unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
+                                        unsigned DefaultCallPenalty) const = 0;
   virtual bool areTypesABICompatible(const Function *Caller,
                                      const Function *Callee,
                                      const ArrayRef<Type *> &Types) const = 0;
@@ -2662,6 +2673,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
                            const Function *Callee) const override {
     return Impl.areInlineCompatible(Caller, Callee);
   }
+  unsigned
+  getInlineCallPenalty(const Function *F, const CallBase &Call,
+                       unsigned DefaultCallPenalty) const override {
+    return Impl.getInlineCallPenalty(F, Call, DefaultCallPenalty);
+  }
   bool areTypesABICompatible(const Function *Caller, const Function *Callee,
                              const ArrayRef<Type *> &Types) const override {
     return Impl.areTypesABICompatible(Caller, Callee, Types);
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index c1ff314ae51c98b..e6fc178365626ba 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -802,6 +802,11 @@ class TargetTransformInfoImplBase {
             Callee->getFnAttribute("target-features"));
   }
 
+  unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
+                                unsigned DefaultCallPenalty) const {
+    return DefaultCallPenalty;
+  }
+
   bool areTypesABICompatible(const Function *Caller, const Function *Callee,
                              const ArrayRef<Type *> &Types) const {
     return (Caller->getFnAttribute("target-cpu") ==
diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index fa0c30637633df3..7096e06d925adef 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -695,7 +695,8 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
       }
     } else
       // Otherwise simply add the cost for merely making the call.
-      addCost(CallPenalty);
+      addCost(TTI.getInlineCallPenalty(CandidateCall.getCaller(), Call,
+                                       CallPenalty));
   }
 
   void onFinalizeSwitch(unsigned JumpTableSize,
@@ -918,7 +919,7 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
     // Compute the total savings for the call site.
     auto *CallerBB = CandidateCall.getParent();
     BlockFrequencyInfo *CallerBFI = &(GetBFI(*(CallerBB->getParent())));
-    CycleSavings += getCallsiteCost(this->CandidateCall, DL);
+    CycleSavings += getCallsiteCost(TTI, this->CandidateCall, DL);
     CycleSavings *= *CallerBFI->getBlockProfileCount(CallerBB);
 
     // Remove the cost of the cold basic blocks to model the runtime cost more
@@ -1076,7 +1077,7 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
 
     // Give out bonuses for the callsite, as the instructions setting them up
     // will be gone after inlining.
-    addCost(-getCallsiteCost(this->CandidateCall, DL));
+    addCost(-getCallsiteCost(TTI, this->CandidateCall, DL));
 
     // If this function uses the coldcc calling convention, prefer not to inline
     // it.
@@ -1315,7 +1316,7 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
 
   InlineResult onAnalysisStart() override {
     increment(InlineCostFeatureIndex::callsite_cost,
-              -1 * getCallsiteCost(this->CandidateCall, DL));
+              -1 * getCallsiteCost(TTI, this->CandidateCall, DL));
 
     set(InlineCostFeatureIndex::cold_cc_penalty,
         (F.getCallingConv() == CallingConv::Cold));
@@ -2887,7 +2888,8 @@ static bool functionsHaveCompatibleAttributes(
          AttributeFuncs::areInlineCompatible(*Caller, *Callee);
 }
 
-int llvm::getCallsiteCost(const CallBase &Call, const DataLayout &DL) {
+int llvm::getCallsiteCost(const TargetTransformInfo &TTI, const CallBase &Call,
+                          const DataLayout &DL) {
   int64_t Cost = 0;
   for (unsigned I = 0, E = Call.arg_size(); I != E; ++I) {
     if (Call.isByValArgument(I)) {
@@ -2917,7 +2919,8 @@ int llvm::getCallsiteCost(const CallBase &Call, const DataLayout &DL) {
   }
   // The call instruction also disappears after inlining.
   Cost += InstrCost;
-  Cost += CallPenalty;
+  Cost += TTI.getInlineCallPenalty(Call.getCaller(), Call, CallPenalty);
+
   return std::min<int64_t>(Cost, INT_MAX);
 }
 
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index aad14f21d114619..10ed3a4437dae5f 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1133,6 +1133,12 @@ bool TargetTransformInfo::areInlineCompatible(const Function *Caller,
   return TTIImpl->areInlineCompatible(Caller, Callee);
 }
 
+unsigned TargetTransformInfo::getInlineCallPenalty(
+    const Function *F, const CallBase &Call,
+    unsigned DefaultCallPenalty) const {
+  return TTIImpl->getInlineCallPenalty(F, Call, DefaultCallPenalty);
+}
+
 bool TargetTransformInfo::areTypesABICompatible(
     const Function *Caller, const Function *Callee,
     const ArrayRef<Type *> &Types) const {
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index d053350c08bf9ab..e107c7d8540cce5 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -245,6 +245,40 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
   return (CallerBits & CalleeBits) == CalleeBits;
 }
 
+unsigned
+AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
+                                     unsigned DefaultCallPenalty) const {
+  // This function calculates a penalty for executing Call in F.
+  //
+  // There are two ways this function can be called:
+  // (1)  F:
+  //       call from F -> G (the call here is Call)
+  //
+  // For (1), Call.getCaller() == F, so it will always return a high cost if
+  // a streaming-mode change is required (thus promoting the need to inline the
+  // function)
+  //
+  // (2)  F:
+  //       call from F -> G (the call here is not Call)
+  //      G:
+  //       call from G -> H (the call here is Call)
+  //
+  // For (2), if after inlining the body of G into F the call to H requires a
+  // streaming-mode change, and the call to G from F would also require a
+  // streaming-mode change, then there is benefit to do the streaming-mode
+  // change only once and avoid inlining of G into F.
+  SMEAttrs FAttrs(*F);
+  SMEAttrs CalleeAttrs(Call);
+  if (FAttrs.requiresSMChange(CalleeAttrs)) {
+    if (F == Call.getCaller())                                // (1)
+      return 5 * DefaultCallPenalty;
+    if (FAttrs.requiresSMChange(SMEAttrs(*Call.getCaller()))) // (2)
+      return 10 * DefaultCallPenalty;
+  }
+
+  return DefaultCallPenalty;
+}
+
 bool AArch64TTIImpl::shouldMaximizeVectorBandwidth(
     TargetTransformInfo::RegisterKind K) const {
   assert(K != TargetTransformInfo::RGK_Scalar);
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index a6baade412c77d2..cccce44fe35ffc0 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -77,6 +77,9 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
   bool areInlineCompatible(const Function *Caller,
                            const Function *Callee) const;
 
+  unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
+                                unsigned DefaultCallPenalty) const;
+
   /// \name Scalar TTI Implementations
   /// @{
 
diff --git a/llvm/lib/Transforms/IPO/PartialInlining.cpp b/llvm/lib/Transforms/IPO/PartialInlining.cpp
index 25da06add24f031..aa4f205ec5bdf1e 100644
--- a/llvm/lib/Transforms/IPO/PartialInlining.cpp
+++ b/llvm/lib/Transforms/IPO/PartialInlining.cpp
@@ -767,7 +767,7 @@ bool PartialInlinerImpl::shouldPartialInline(
   const DataLayout &DL = Caller->getParent()->getDataLayout();
 
   // The savings of eliminating the call:
-  int NonWeightedSavings = getCallsiteCost(CB, DL);
+  int NonWeightedSavings = getCallsiteCost(CalleeTTI, CB, DL);
   BlockFrequency NormWeightedSavings(NonWeightedSavings);
 
   // Weighted saving is smaller than weighted cost, return false
@@ -842,12 +842,12 @@ PartialInlinerImpl::computeBBInlineCost(BasicBlock *BB,
     }
 
     if (CallInst *CI = dyn_cast<CallInst>(&I)) {
-      InlineCost += getCallsiteCost(*CI, DL);
+      InlineCost += getCallsiteCost(*TTI, *CI, DL);
       continue;
     }
 
     if (InvokeInst *II = dyn_cast<InvokeInst>(&I)) {
-      InlineCost += getCallsiteCost(*II, DL);
+      InlineCost += getCallsiteCost(*TTI, *II, DL);
       continue;
     }
 
diff --git a/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs-low-threshold.ll b/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs-low-threshold.ll
new file mode 100644
index 000000000000000..f207efaeaad36b8
--- /dev/null
+++ b/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs-low-threshold.ll
@@ -0,0 +1,43 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2
+; RUN: opt < %s -mtriple=aarch64-unknown-linux-gnu -mattr=+sme -S -passes=inline  -inlinedefault-threshold=1 | FileCheck %s
+
+; This test sets the inline-threshold to 1 such that by default the call to @streaming_callee is not inlined.
+; However, if the call to @streaming_callee requires a streaming-mode change, it should always inline the call because the streaming-mode change is more expensive.
+target triple = "aarch64"
+
+declare void @streaming_compatible_f() "aarch64_pstate_sm_compatible"
+
+define void @streaming_callee() "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: define void @streaming_callee
+; CHECK-SAME: () #[[ATTR1:[0-9]+]] {
+; CHECK-NEXT:    call void @streaming_compatible_f()
+; CHECK-NEXT:    call void @streaming_compatible_f()
+; CHECK-NEXT:    ret void
+;
+  call void @streaming_compatible_f()
+  call void @streaming_compatible_f()
+  ret void
+}
+
+; Inline call to @streaming_callee to remove a streaming mode change.
+define void @non_streaming_caller_inline() {
+; CHECK-LABEL: define void @non_streaming_caller_inline
+; CHECK-SAME: () #[[ATTR2:[0-9]+]] {
+; CHECK-NEXT:    call void @streaming_compatible_f()
+; CHECK-NEXT:    call void @streaming_compatible_f()
+; CHECK-NEXT:    ret void
+;
+  call void @streaming_callee()
+  ret void
+}
+
+; Don't inline call to @streaming_callee when the inline-threshold is set to 1, because it does not eliminate a streaming-mode change.
+define void @streaming_caller_dont_inline() "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: define void @streaming_caller_dont_inline
+; CHECK-SAME: () #[[ATTR1]] {
+; CHECK-NEXT:    call void @streaming_callee()
+; CHECK-NEXT:    ret void
+;
+  call void @streaming_callee()
+  ret void
+}
diff --git a/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs.ll b/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs.ll
index f2f5768dbe9c6e9..d6b1f3ef45e7655 100644
--- a/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs.ll
+++ b/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs.ll
@@ -581,3 +581,98 @@ entry:
   %res = call i64 @normal_callee_call_sme_state()
   ret i64 %res
 }
+
+
+
+declare void @streaming_body() "aarch64_pstate_sm_enabled"
+
+define void @streaming_caller_single_streaming_callee() "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: define void @streaming_caller_single_streaming_callee
+; CHECK-SAME: () #[[ATTR2]] {
+; CHECK-NEXT:    call void @streaming_body()
+; CHECK-NEXT:    ret void
+;
+  call void @streaming_body()
+  ret void
+}
+
+define void @streaming_caller_multiple_streaming_callees() "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: define void @streaming_caller_multiple_streaming_callees
+; CHECK-SAME: () #[[ATTR2]] {
+; CHECK-NEXT:    call void @streaming_body()
+; CHECK-NEXT:    call void @streaming_body()
+; CHECK-NEXT:    ret void
+;
+  call void @streaming_body()
+  call void @streaming_body()
+  ret void
+}
+
+; Allow inlining, as inline it would not increase the number of streaming-mode changes.
+define void @streaming_caller_single_streaming_callee_inline() {
+; CHECK-LABEL: define void @streaming_caller_single_streaming_callee_inline
+; CHECK-SAME: () #[[ATTR1]] {
+; CHECK-NEXT:    call void @streaming_body()
+; CHECK-NEXT:    ret void
+;
+  call void @streaming_caller_single_streaming_callee()
+  ret void
+}
+
+; Prevent inlining, as inline it would lead to multiple streaming-mode changes.
+define void @streaming_caller_multiple_streaming_callees_dont_inline() {
+; CHECK-LABEL: define void @streaming_caller_multiple_streaming_callees_dont_inline
+; CHECK-SAME: () #[[ATTR1]] {
+; CHECK-NEXT:    call void @streaming_caller_multiple_streaming_callees()
+; CHECK-NEXT:    ret void
+;
+  call void @streaming_caller_multiple_streaming_callees()
+  ret void
+}
+
+declare void @streaming_compatible_body() "aarch64_pstate_sm_compatible"
+
+define void @streaming_caller_single_streaming_compatible_callee() "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: define void @streaming_caller_single_streaming_compatible_callee
+; CHECK-SAME: () #[[ATTR2]] {
+; CHECK-NEXT:    call void @streaming_compatible_body()
+; CHECK-NEXT:    ret void
+;
+  call void @streaming_compatible_body()
+  ret void
+}
+
+define void @streaming_caller_multiple_streaming_compatible_callees() "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: define void @streaming_caller_multiple_streaming_compatible_callees
+; CHECK-SAME: () #[[ATTR2]] {
+; CHECK-NEXT:    call void @streaming_compatible_body()
+; CHECK-NEXT:    call void @streaming_compatible_body()
+; CHECK-NEXT:    ret void
+;
+  call void @streaming_compatible_body()
+  call void @streaming_compatible_body()
+  ret void
+}
+
+; Allow inlining, as inline would remove a streaming-mode change.
+define void @streaming_caller_single_streaming_compatible_callee_inline() {
+; CHECK-LABEL: define void @streaming_caller_single_streaming_compatible_callee_inline
+; CHECK-SAME: () #[[ATTR1]] {
+; CHECK-NEXT:    call void @streaming_compatible_body()
+; CHECK-NEXT:    ret void
+;
+  call void @streaming_caller_single_streaming_compatible_callee()
+  ret void
+}
+
+; Allow inlining, as inline would remove several stremaing-mode changes.
+define void @streaming_caller_multiple_streaming_compatible_callees_inline() {
+; CHECK-LABEL: define void @streaming_caller_multiple_streaming_compatible_callees_inline
+; CHECK-SAME: () #[[ATTR1]] {
+; CHECK-NEXT:    call void @streaming_compatible_body()
+; CHECK-NEXT:    call void @streaming_compatible_body()
+; CHECK-NEXT:    ret void
+;
+  call void @streaming_caller_multiple_streaming_compatible_callees()
+  ret void
+}



More information about the llvm-commits mailing list