[llvm] [AArch64][SME] Allow memory operations lowering to custom SME functions. (PR #79263)

Dinar Temirbulatov via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 24 14:00:23 PST 2024


https://github.com/dtemirbulatov updated https://github.com/llvm/llvm-project/pull/79263

>From ff372f2ed6bf6aa56044bc4886b502f6ca933180 Mon Sep 17 00:00:00 2001
From: Dinar Temirbulatov <Dinar.Temirbulatov at arm.com>
Date: Wed, 24 Jan 2024 08:14:07 +0000
Subject: [PATCH] [AArch64][SME] Enable memory operations lowering to custom
 SME functions.

This change allows to lower memcpy, memset, memmove to custom SME version
provided by LibRT.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 10 ++-
 .../AArch64/AArch64SelectionDAGInfo.cpp       | 72 +++++++++++++++++++
 .../Target/AArch64/AArch64SelectionDAGInfo.h  |  4 ++
 .../AArch64/Utils/AArch64SMEAttributes.cpp    |  3 +
 llvm/test/CodeGen/AArch64/sme2-mops.ll        | 67 +++++++++++++++++
 5 files changed, 154 insertions(+), 2 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/sme2-mops.ll

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 332fb37655288ce..64936b9c86ac1b9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -7659,8 +7659,14 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
   SMEAttrs CalleeAttrs, CallerAttrs(MF.getFunction());
   if (CLI.CB)
     CalleeAttrs = SMEAttrs(*CLI.CB);
-  else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
-    CalleeAttrs = SMEAttrs(ES->getSymbol());
+  else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee)) {
+    if (StringRef(ES->getSymbol()) == StringRef("__arm_sc_memcpy")) {
+      auto Attrs = AttributeList().addFnAttribute(
+          *DAG.getContext(), "aarch64_pstate_sm_compatible");
+      CalleeAttrs = SMEAttrs(Attrs);
+    } else
+      CalleeAttrs = SMEAttrs(ES->getSymbol());
+  }
 
   auto DescribeCallsite =
       [&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
diff --git a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
index 9e43f206efcf786..fff4e2333194e32 100644
--- a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
@@ -76,12 +76,74 @@ SDValue AArch64SelectionDAGInfo::EmitMOPS(AArch64ISD::NodeType SDOpcode,
   }
 }
 
+SDValue AArch64SelectionDAGInfo::EmitSpecializedLibcall(
+    SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, SDValue Dst, SDValue Src,
+    SDValue Size, RTLIB::Libcall LC) const {
+  const AArch64Subtarget &STI =
+      DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
+  const AArch64TargetLowering *TLI = STI.getTargetLowering();
+  TargetLowering::ArgListTy Args;
+  TargetLowering::ArgListEntry Entry;
+  Entry.Ty = DAG.getDataLayout().getIntPtrType(*DAG.getContext());
+  Entry.Node = Dst;
+  Args.push_back(Entry);
+
+  enum { SME_MEMCPY = 0, SME_MEMMOVE, SME_MEMSET } SMELibcall;
+  switch (LC) {
+  case RTLIB::MEMCPY:
+    SMELibcall = SME_MEMCPY;
+    Entry.Node = Src;
+    Args.push_back(Entry);
+    break;
+  case RTLIB::MEMMOVE:
+    SMELibcall = SME_MEMMOVE;
+    Entry.Node = Src;
+    Args.push_back(Entry);
+    break;
+  case RTLIB::MEMSET:
+    SMELibcall = SME_MEMSET;
+    if (Src.getValueType().bitsGT(MVT::i32))
+      Src = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Src);
+    else if (Src.getValueType().bitsLT(MVT::i32))
+      Src = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, Src);
+    Entry.Node = Src;
+    Entry.Ty = Type::getInt32Ty(*DAG.getContext());
+    Entry.IsSExt = false;
+    Args.push_back(Entry);
+    break;
+  default:
+    return SDValue();
+  }
+  Entry.Node = Size;
+  Args.push_back(Entry);
+  char const *FunctionNames[3] = {"__arm_sc_memcpy", "__arm_sc_memmove",
+                                  "__arm_sc_memset"};
+
+  TargetLowering::CallLoweringInfo CLI(DAG);
+  CLI.setDebugLoc(DL)
+      .setChain(Chain)
+      .setLibCallee(
+          TLI->getLibcallCallingConv(RTLIB::MEMCPY),
+          Type::getVoidTy(*DAG.getContext()),
+          DAG.getExternalSymbol(FunctionNames[SMELibcall],
+                                TLI->getPointerTy(DAG.getDataLayout())),
+          std::move(Args))
+      .setDiscardResult();
+  std::pair<SDValue, SDValue> CallResult = TLI->LowerCallTo(CLI);
+  return CallResult.second;
+}
+
 SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemcpy(
     SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, SDValue Dst, SDValue Src,
     SDValue Size, Align Alignment, bool isVolatile, bool AlwaysInline,
     MachinePointerInfo DstPtrInfo, MachinePointerInfo SrcPtrInfo) const {
   const AArch64Subtarget &STI =
       DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
+
+  SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
+  if (Attrs.hasStreamingBody() || Attrs.hasStreamingCompatibleInterface())
+    return EmitSpecializedLibcall(DAG, DL, Chain, Dst, Src, Size,
+                                  RTLIB::MEMCPY);
   if (STI.hasMOPS())
     return EmitMOPS(AArch64ISD::MOPS_MEMCOPY, DAG, DL, Chain, Dst, Src, Size,
                     Alignment, isVolatile, DstPtrInfo, SrcPtrInfo);
@@ -95,6 +157,11 @@ SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemset(
   const AArch64Subtarget &STI =
       DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
 
+  SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
+  if (Attrs.hasStreamingBody() || Attrs.hasStreamingCompatibleInterface())
+    return EmitSpecializedLibcall(DAG, dl, Chain, Dst, Src, Size,
+                                  RTLIB::MEMSET);
+
   if (STI.hasMOPS()) {
     return EmitMOPS(AArch64ISD::MOPS_MEMSET, DAG, dl, Chain, Dst, Src, Size,
                     Alignment, isVolatile, DstPtrInfo, MachinePointerInfo{});
@@ -108,6 +175,11 @@ SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemmove(
     MachinePointerInfo DstPtrInfo, MachinePointerInfo SrcPtrInfo) const {
   const AArch64Subtarget &STI =
       DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
+
+  SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
+  if (Attrs.hasStreamingBody() || Attrs.hasStreamingCompatibleInterface())
+    return EmitSpecializedLibcall(DAG, dl, Chain, Dst, Src, Size,
+                                  RTLIB::MEMMOVE);
   if (STI.hasMOPS()) {
     return EmitMOPS(AArch64ISD::MOPS_MEMMOVE, DAG, dl, Chain, Dst, Src, Size,
                     Alignment, isVolatile, DstPtrInfo, SrcPtrInfo);
diff --git a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h
index 73f93724d6fc736..9c55c21f3c3202e 100644
--- a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h
@@ -47,6 +47,10 @@ class AArch64SelectionDAGInfo : public SelectionDAGTargetInfo {
                                   SDValue Chain, SDValue Op1, SDValue Op2,
                                   MachinePointerInfo DstPtrInfo,
                                   bool ZeroData) const override;
+
+  SDValue EmitSpecializedLibcall(SelectionDAG &DAG, const SDLoc &DL,
+                                 SDValue Chain, SDValue Dst, SDValue Src,
+                                 SDValue Size, RTLIB::Libcall LC) const;
 };
 }
 
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
index 3ee54e5df0a13df..5080e4a0b4f9a2e 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
@@ -51,6 +51,9 @@ SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
   if (FuncName == "__arm_tpidr2_restore")
     Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::ZA_Shared |
                 SMEAttrs::SME_ABI_Routine);
+  if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
+      FuncName == "__arm_sc_memmove")
+    Bitmask |= SMEAttrs::SM_Compatible;
 }
 
 SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
diff --git a/llvm/test/CodeGen/AArch64/sme2-mops.ll b/llvm/test/CodeGen/AArch64/sme2-mops.ll
new file mode 100644
index 000000000000000..0ded6e965ecb9c5
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme2-mops.ll
@@ -0,0 +1,67 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2 -verify-machineinstrs < %s | FileCheck %s
+
+ at dst = global [512 x i8] zeroinitializer, align 1
+ at src = global [512 x i8] zeroinitializer, align 1
+
+
+define void @sc_memcpy(i64 noundef %n) "aarch64_pstate_sm_compatible" {
+; CHECK-LABEL: sc_memcpy:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    .cfi_offset w30, -16
+; CHECK-NEXT:    mov x2, x0
+; CHECK-NEXT:    adrp x0, :got:dst
+; CHECK-NEXT:    adrp x1, :got:src
+; CHECK-NEXT:    ldr x0, [x0, :got_lo12:dst]
+; CHECK-NEXT:    ldr x1, [x1, :got_lo12:src]
+; CHECK-NEXT:    bl __arm_sc_memcpy
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+entry:
+  tail call void @llvm.memcpy.p0.p0.i64(ptr align 1 @dst, ptr nonnull align 1 @src, i64 %n, i1 false)
+  ret void
+}
+
+define void @sc_memset(i64 noundef %n) "aarch64_pstate_sm_compatible" {
+; CHECK-LABEL: sc_memset:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    .cfi_offset w30, -16
+; CHECK-NEXT:    mov x2, x0
+; CHECK-NEXT:    adrp x0, :got:dst
+; CHECK-NEXT:    mov w1, #2 // =0x2
+; CHECK-NEXT:    ldr x0, [x0, :got_lo12:dst]
+; CHECK-NEXT:    // kill: def $w2 killed $w2 killed $x2
+; CHECK-NEXT:    bl __arm_sc_memset
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+entry:
+  tail call void @llvm.memset.p0.i64(ptr align 1 @dst, i8 2, i64 %n, i1 false)
+  ret void
+}
+
+define void @sc_memmove(i64 noundef %n) "aarch64_pstate_sm_compatible" {
+; CHECK-LABEL: sc_memmove:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    .cfi_offset w30, -16
+; CHECK-NEXT:    mov x2, x0
+; CHECK-NEXT:    adrp x0, :got:dst
+; CHECK-NEXT:    adrp x1, :got:src
+; CHECK-NEXT:    ldr x0, [x0, :got_lo12:dst]
+; CHECK-NEXT:    ldr x1, [x1, :got_lo12:src]
+; CHECK-NEXT:    bl __arm_sc_memmove
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+entry:
+  tail call void @llvm.memmove.p0.p0.i64(ptr align 1 @dst, ptr nonnull align 1 @src, i64 %n, i1 false)
+  ret void
+}
+
+declare void @llvm.memset.p0.i64(ptr nocapture writeonly, i8, i64, i1 immarg)
+declare void @llvm.memcpy.p0.p0.i64(ptr nocapture writeonly, ptr nocapture readonly, i64, i1 immarg)
+declare void @llvm.memmove.p0.p0.i64(ptr nocapture writeonly, ptr nocapture readonly, i64, i1 immarg)



More information about the llvm-commits mailing list