[llvm] [AArch64][SME] Allow inlining when streaming-mode attributes dont match up. (PR #68415)

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 6 06:03:44 PDT 2023


https://github.com/sdesmalen-arm created https://github.com/llvm/llvm-project/pull/68415

The use-case here is to support things like:

  int foo(int x, int y) __arm_streaming { return std::max<int>(x, y); }

where the call to non-streaming `std::max<int>(x, y)` can be safely inlined into the streaming function.

This is a first step and will need further work to allow more cases (e.g. more finegrained analysis of the function calls to ensure they don't result in any incompatible instructions for the requested mode).

>From 2ec46c7d50dde0c0dddd39c3936c57310bb09d4e Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Thu, 7 Sep 2023 17:30:32 +0100
Subject: [PATCH] [AArch64][SME] Allow inlining when streaming-mode attributes
 dont match up.

The use-case here is to support things like:

  int foo(int x, int y) __arm_streaming { return std::max<int>(x, y); }

where the call to non-streaming `std::max<int>(x, y)` can be safely inlined
into the streaming function.

This is a first step and will need further work to allow more cases (e.g.
more finegrained analysis of the function calls to ensure they don't
result in any incompatible instructions for the requested mode).
---
 .../AArch64/AArch64TargetTransformInfo.cpp    |  41 +++++-
 .../Inline/AArch64/sme-pstatesm-attrs.ll      | 138 ++++++++++++++----
 .../Inline/AArch64/sme-pstateza-attrs.ll      |  71 ++++++++-
 3 files changed, 214 insertions(+), 36 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index cded28054f59259..d053350c08bf9ab 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -190,16 +190,49 @@ static cl::opt<bool> EnableFixedwidthAutovecInStreamingMode(
 static cl::opt<bool> EnableScalableAutovecInStreamingMode(
     "enable-scalable-autovec-in-streaming-mode", cl::init(false), cl::Hidden);
 
+static bool isSMEABIRoutineCall(const CallInst &CI) {
+  const auto *F = CI.getCalledFunction();
+  return F && StringSwitch<bool>(F->getName())
+                  .Case("__arm_sme_state", true)
+                  .Case("__arm_tpidr2_save", true)
+                  .Case("__arm_tpidr2_restore", true)
+                  .Case("__arm_za_disable", true)
+                  .Default(false);
+}
+
+/// Returns true if the function has explicit operations that can only be lowered
+/// using incompatible instructions for the selected mode.
+/// This also returns true if the function F may use or modify ZA state.
+static bool hasPossibleIncompatibleOps(const Function *F) {
+  for (const BasicBlock &BB : *F) {
+    for (const Instruction &I : BB) {
+      // Be conservative for now and assume that any call to inline asm or to
+      // intrinsics could could result in non-streaming ops (e.g. calls to
+      // @llvm.aarch64.* or @llvm.gather/scatter intrinsics). We can assume that
+      // all native LLVM instructions can be lowered to compatible instructions.
+      if (isa<CallInst>(I) && !I.isDebugOrPseudoInst() &&
+          (cast<CallInst>(I).isInlineAsm() || isa<IntrinsicInst>(I) ||
+           isSMEABIRoutineCall(cast<CallInst>(I))))
+        return true;
+    }
+  }
+  return false;
+}
+
 bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
                                          const Function *Callee) const {
   SMEAttrs CallerAttrs(*Caller);
   SMEAttrs CalleeAttrs(*Callee);
-  if (CallerAttrs.requiresSMChange(CalleeAttrs,
-                                   /*BodyOverridesInterface=*/true) ||
-      CallerAttrs.requiresLazySave(CalleeAttrs) ||
-      CalleeAttrs.hasNewZABody())
+  if (CalleeAttrs.hasNewZABody())
     return false;
 
+  if (CallerAttrs.requiresLazySave(CalleeAttrs) ||
+      CallerAttrs.requiresSMChange(CalleeAttrs,
+                                   /*BodyOverridesInterface=*/true)) {
+    if (hasPossibleIncompatibleOps(Callee))
+      return false;
+  }
+
   const TargetMachine &TM = getTLI()->getTargetMachine();
 
   const FeatureBitset &CallerBits =
diff --git a/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs.ll b/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs.ll
index 3df5400875ae288..f2f5768dbe9c6e9 100644
--- a/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs.ll
+++ b/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs.ll
@@ -102,11 +102,11 @@ entry:
 ; [ ] N  -> SC
 ; [ ] N  -> N + B
 ; [ ] N  -> SC + B
-define void @normal_caller_streaming_callee_dont_inline() {
-; CHECK-LABEL: define void @normal_caller_streaming_callee_dont_inline
+define void @normal_caller_streaming_callee_inline() {
+; CHECK-LABEL: define void @normal_caller_streaming_callee_inline
 ; CHECK-SAME: () #[[ATTR1]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @streaming_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -136,11 +136,11 @@ entry:
 ; [ ] N  -> SC
 ; [x] N  -> N + B
 ; [ ] N  -> SC + B
-define void @normal_caller_locally_streaming_callee_dont_inline() {
-; CHECK-LABEL: define void @normal_caller_locally_streaming_callee_dont_inline
+define void @normal_caller_locally_streaming_callee_inline() {
+; CHECK-LABEL: define void @normal_caller_locally_streaming_callee_inline
 ; CHECK-SAME: () #[[ATTR1]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @locally_streaming_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -153,11 +153,11 @@ entry:
 ; [ ] N  -> SC
 ; [ ] N  -> N + B
 ; [x] N  -> SC + B
-define void @normal_caller_streaming_compatible_locally_streaming_callee_dont_inline() {
-; CHECK-LABEL: define void @normal_caller_streaming_compatible_locally_streaming_callee_dont_inline
+define void @normal_caller_streaming_compatible_locally_streaming_callee_inline() {
+; CHECK-LABEL: define void @normal_caller_streaming_compatible_locally_streaming_callee_inline
 ; CHECK-SAME: () #[[ATTR1]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @streaming_compatible_locally_streaming_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -170,11 +170,11 @@ entry:
 ; [ ] S  -> SC
 ; [ ] S  -> N + B
 ; [ ] S  -> SC + B
-define void @streaming_caller_normal_callee_dont_inline() "aarch64_pstate_sm_enabled" {
-; CHECK-LABEL: define void @streaming_caller_normal_callee_dont_inline
+define void @streaming_caller_normal_callee_inline() "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: define void @streaming_caller_normal_callee_inline
 ; CHECK-SAME: () #[[ATTR2]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @normal_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -255,11 +255,11 @@ entry:
 ; [ ] N + B -> SC
 ; [ ] N + B -> N + B
 ; [ ] N + B -> SC + B
-define void @locally_streaming_caller_normal_callee_dont_inline() "aarch64_pstate_sm_body" {
-; CHECK-LABEL: define void @locally_streaming_caller_normal_callee_dont_inline
+define void @locally_streaming_caller_normal_callee_inline() "aarch64_pstate_sm_body" {
+; CHECK-LABEL: define void @locally_streaming_caller_normal_callee_inline
 ; CHECK-SAME: () #[[ATTR3]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @normal_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -340,11 +340,11 @@ entry:
 ; [ ] SC -> SC
 ; [ ] SC -> N + B
 ; [ ] SC -> SC + B
-define void @streaming_compatible_caller_normal_callee_dont_inline() "aarch64_pstate_sm_compatible" {
-; CHECK-LABEL: define void @streaming_compatible_caller_normal_callee_dont_inline
+define void @streaming_compatible_caller_normal_callee_inline() "aarch64_pstate_sm_compatible" {
+; CHECK-LABEL: define void @streaming_compatible_caller_normal_callee_inline
 ; CHECK-SAME: () #[[ATTR0]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @normal_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -357,11 +357,11 @@ entry:
 ; [ ] SC -> SC
 ; [ ] SC -> N + B
 ; [ ] SC -> SC + B
-define void @streaming_compatible_caller_streaming_callee_dont_inline() "aarch64_pstate_sm_compatible" {
-; CHECK-LABEL: define void @streaming_compatible_caller_streaming_callee_dont_inline
+define void @streaming_compatible_caller_streaming_callee_inline() "aarch64_pstate_sm_compatible" {
+; CHECK-LABEL: define void @streaming_compatible_caller_streaming_callee_inline
 ; CHECK-SAME: () #[[ATTR0]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @streaming_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -391,11 +391,11 @@ entry:
 ; [ ] SC -> SC
 ; [x] SC -> N + B
 ; [ ] SC -> SC + B
-define void @streaming_compatible_caller_locally_streaming_callee_dont_inline() "aarch64_pstate_sm_compatible" {
-; CHECK-LABEL: define void @streaming_compatible_caller_locally_streaming_callee_dont_inline
+define void @streaming_compatible_caller_locally_streaming_callee_inline() "aarch64_pstate_sm_compatible" {
+; CHECK-LABEL: define void @streaming_compatible_caller_locally_streaming_callee_inline
 ; CHECK-SAME: () #[[ATTR0]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @locally_streaming_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -408,11 +408,11 @@ entry:
 ; [ ] SC -> SC
 ; [ ] SC -> N + B
 ; [x] SC -> SC + B
-define void @streaming_compatible_caller_streaming_compatible_locally_streaming_callee_dont_inline() "aarch64_pstate_sm_compatible" {
-; CHECK-LABEL: define void @streaming_compatible_caller_streaming_compatible_locally_streaming_callee_dont_inline
+define void @streaming_compatible_caller_streaming_compatible_locally_streaming_callee_inline() "aarch64_pstate_sm_compatible" {
+; CHECK-LABEL: define void @streaming_compatible_caller_streaming_compatible_locally_streaming_callee_inline
 ; CHECK-SAME: () #[[ATTR0]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @streaming_compatible_locally_streaming_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -424,11 +424,11 @@ entry:
 ; [ ] SC + B -> SC
 ; [ ] SC + B -> N + B
 ; [ ] SC + B -> SC + B
-define void @streaming_compatible_locally_streaming_caller_normal_callee_dont_inline() "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" {
-; CHECK-LABEL: define void @streaming_compatible_locally_streaming_caller_normal_callee_dont_inline
+define void @streaming_compatible_locally_streaming_caller_normal_callee_inline() "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" {
+; CHECK-LABEL: define void @streaming_compatible_locally_streaming_caller_normal_callee_inline
 ; CHECK-SAME: () #[[ATTR4]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @normal_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -503,3 +503,81 @@ entry:
   call void @streaming_compatible_locally_streaming_callee()
   ret void
 }
+
+define void @normal_callee_with_inlineasm() {
+; CHECK-LABEL: define void @normal_callee_with_inlineasm
+; CHECK-SAME: () #[[ATTR1]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    call void asm sideeffect "
+; CHECK-NEXT:    ret void
+;
+entry:
+  call void asm sideeffect "; inlineasm", ""()
+  ret void
+}
+
+define void @streaming_caller_normal_callee_with_inlineasm_dont_inline() "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: define void @streaming_caller_normal_callee_with_inlineasm_dont_inline
+; CHECK-SAME: () #[[ATTR2]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    call void @normal_callee_with_inlineasm()
+; CHECK-NEXT:    ret void
+;
+entry:
+  call void @normal_callee_with_inlineasm()
+  ret void
+}
+
+define i64 @normal_callee_with_intrinsic_call() {
+; CHECK-LABEL: define i64 @normal_callee_with_intrinsic_call
+; CHECK-SAME: () #[[ATTR1]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[RES:%.*]] = call i64 @llvm.aarch64.sve.cntb(i32 4)
+; CHECK-NEXT:    ret i64 [[RES]]
+;
+entry:
+  %res = call i64 @llvm.aarch64.sve.cntb(i32 4)
+  ret i64 %res
+}
+
+define i64 @streaming_caller_normal_callee_with_intrinsic_call_dont_inline() "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: define i64 @streaming_caller_normal_callee_with_intrinsic_call_dont_inline
+; CHECK-SAME: () #[[ATTR2]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[RES:%.*]] = call i64 @normal_callee_with_intrinsic_call()
+; CHECK-NEXT:    ret i64 [[RES]]
+;
+entry:
+  %res = call i64 @normal_callee_with_intrinsic_call()
+  ret i64 %res
+}
+
+declare i64 @llvm.aarch64.sve.cntb(i32)
+
+define i64 @normal_callee_call_sme_state() {
+; CHECK-LABEL: define i64 @normal_callee_call_sme_state
+; CHECK-SAME: () #[[ATTR1]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[RES:%.*]] = call { i64, i64 } @__arm_sme_state()
+; CHECK-NEXT:    [[RES_0:%.*]] = extractvalue { i64, i64 } [[RES]], 0
+; CHECK-NEXT:    ret i64 [[RES_0]]
+;
+entry:
+  %res = call {i64, i64} @__arm_sme_state()
+  %res.0 = extractvalue {i64, i64} %res, 0
+  ret i64 %res.0
+}
+
+declare {i64, i64} @__arm_sme_state()
+
+define i64 @streaming_caller_normal_callee_call_sme_state_dont_inline() "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: define i64 @streaming_caller_normal_callee_call_sme_state_dont_inline
+; CHECK-SAME: () #[[ATTR2]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[RES:%.*]] = call i64 @normal_callee_call_sme_state()
+; CHECK-NEXT:    ret i64 [[RES]]
+;
+entry:
+  %res = call i64 @normal_callee_call_sme_state()
+  ret i64 %res
+}
diff --git a/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll b/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll
index a833e7a911ac03f..7b104977cff5a7b 100644
--- a/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll
+++ b/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll
@@ -3,10 +3,12 @@
 
 declare void @inlined_body()
 
+;
 ; Define some functions that will be called by the functions below.
 ; These just call a '...body()' function. If we see the call to one of
 ; these functions being replaced by '...body()', then we know it has been
 ; inlined.
+;
 
 define void @nonza_callee() {
 ; CHECK-LABEL: define void @nonza_callee
@@ -42,6 +44,7 @@ define void @new_za_callee() "aarch64_pstate_za_new" {
   ret void
 }
 
+;
 ; Now test that inlining only happens when no lazy-save is needed.
 ; Test for a number of combinations, where:
 ; N   Not using ZA.
@@ -85,7 +88,7 @@ define void @new_za_caller_nonza_callee_dont_inline() "aarch64_pstate_za_new" {
 ; CHECK-LABEL: define void @new_za_caller_nonza_callee_dont_inline
 ; CHECK-SAME: () #[[ATTR2]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @nonza_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -130,7 +133,7 @@ define void @shared_za_caller_nonza_callee_dont_inline() "aarch64_pstate_za_shar
 ; CHECK-LABEL: define void @shared_za_caller_nonza_callee_dont_inline
 ; CHECK-SAME: () #[[ATTR1]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @nonza_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -167,3 +170,67 @@ entry:
   call void @shared_za_callee()
   ret void
 }
+
+define void @private_za_callee_call_za_disable() {
+; CHECK-LABEL: define void @private_za_callee_call_za_disable
+; CHECK-SAME: () #[[ATTR0]] {
+; CHECK-NEXT:    call void @__arm_za_disable()
+; CHECK-NEXT:    ret void
+;
+  call void @__arm_za_disable()
+  ret void
+}
+
+define void @shared_za_caller_private_za_callee_call_za_disable() "aarch64_pstate_za_shared" {
+; CHECK-LABEL: define void @shared_za_caller_private_za_callee_call_za_disable
+; CHECK-SAME: () #[[ATTR1]] {
+; CHECK-NEXT:    call void @private_za_callee_call_za_disable()
+; CHECK-NEXT:    ret void
+;
+  call void @private_za_callee_call_za_disable()
+  ret void
+}
+
+define void @private_za_callee_call_tpidr2_save() {
+; CHECK-LABEL: define void @private_za_callee_call_tpidr2_save
+; CHECK-SAME: () #[[ATTR0]] {
+; CHECK-NEXT:    call void @__arm_tpidr2_save()
+; CHECK-NEXT:    ret void
+;
+  call void @__arm_tpidr2_save()
+  ret void
+}
+
+define void @shared_za_caller_private_za_callee_call_tpidr2_save_dont_inline() "aarch64_pstate_za_shared" {
+; CHECK-LABEL: define void @shared_za_caller_private_za_callee_call_tpidr2_save_dont_inline
+; CHECK-SAME: () #[[ATTR1]] {
+; CHECK-NEXT:    call void @private_za_callee_call_tpidr2_save()
+; CHECK-NEXT:    ret void
+;
+  call void @private_za_callee_call_tpidr2_save()
+  ret void
+}
+
+define void @private_za_callee_call_tpidr2_restore(ptr %ptr) {
+; CHECK-LABEL: define void @private_za_callee_call_tpidr2_restore
+; CHECK-SAME: (ptr [[PTR:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:    call void @__arm_tpidr2_restore(ptr [[PTR]])
+; CHECK-NEXT:    ret void
+;
+  call void @__arm_tpidr2_restore(ptr %ptr)
+  ret void
+}
+
+define void @shared_za_caller_private_za_callee_call_tpidr2_restore_dont_inline(ptr %ptr) "aarch64_pstate_za_shared" {
+; CHECK-LABEL: define void @shared_za_caller_private_za_callee_call_tpidr2_restore_dont_inline
+; CHECK-SAME: (ptr [[PTR:%.*]]) #[[ATTR1]] {
+; CHECK-NEXT:    call void @private_za_callee_call_tpidr2_restore(ptr [[PTR]])
+; CHECK-NEXT:    ret void
+;
+  call void @private_za_callee_call_tpidr2_restore(ptr %ptr)
+  ret void
+}
+
+declare void @__arm_za_disable()
+declare void @__arm_tpidr2_save()
+declare void @__arm_tpidr2_restore(ptr)



More information about the llvm-commits mailing list