[llvm] [AArch64][SME] Allow memory operations lowering to custom SME functions. (PR #79263)
Dinar Temirbulatov via llvm-commits
llvm-commits at lists.llvm.org
Tue Jan 30 01:11:26 PST 2024
https://github.com/dtemirbulatov updated https://github.com/llvm/llvm-project/pull/79263
>From 2ef16b03a5611701c215b574a38688d1febf42b7 Mon Sep 17 00:00:00 2001
From: Dinar Temirbulatov <Dinar.Temirbulatov at arm.com>
Date: Wed, 24 Jan 2024 08:14:07 +0000
Subject: [PATCH 1/2] [AArch64][SME] Enable memory operations lowering to
custom SME functions.
This change allows to lower memcpy, memset, memmove to custom SME version
provided by LibRT.
---
.../Target/AArch64/AArch64ISelLowering.cpp | 10 ++-
.../AArch64/AArch64SelectionDAGInfo.cpp | 72 +++++++++++++++++++
.../Target/AArch64/AArch64SelectionDAGInfo.h | 4 ++
.../AArch64/Utils/AArch64SMEAttributes.cpp | 3 +
llvm/test/CodeGen/AArch64/sme2-mops.ll | 67 +++++++++++++++++
5 files changed, 154 insertions(+), 2 deletions(-)
create mode 100644 llvm/test/CodeGen/AArch64/sme2-mops.ll
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index f572772d3c980..6385b341bcf63 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -7659,8 +7659,14 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
SMEAttrs CalleeAttrs, CallerAttrs(MF.getFunction());
if (CLI.CB)
CalleeAttrs = SMEAttrs(*CLI.CB);
- else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
- CalleeAttrs = SMEAttrs(ES->getSymbol());
+ else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee)) {
+ if (StringRef(ES->getSymbol()) == StringRef("__arm_sc_memcpy")) {
+ auto Attrs = AttributeList().addFnAttribute(
+ *DAG.getContext(), "aarch64_pstate_sm_compatible");
+ CalleeAttrs = SMEAttrs(Attrs);
+ } else
+ CalleeAttrs = SMEAttrs(ES->getSymbol());
+ }
auto DescribeCallsite =
[&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
diff --git a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
index 9e43f206efcf7..fff4e2333194e 100644
--- a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
@@ -76,12 +76,74 @@ SDValue AArch64SelectionDAGInfo::EmitMOPS(AArch64ISD::NodeType SDOpcode,
}
}
+SDValue AArch64SelectionDAGInfo::EmitSpecializedLibcall(
+ SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, SDValue Dst, SDValue Src,
+ SDValue Size, RTLIB::Libcall LC) const {
+ const AArch64Subtarget &STI =
+ DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
+ const AArch64TargetLowering *TLI = STI.getTargetLowering();
+ TargetLowering::ArgListTy Args;
+ TargetLowering::ArgListEntry Entry;
+ Entry.Ty = DAG.getDataLayout().getIntPtrType(*DAG.getContext());
+ Entry.Node = Dst;
+ Args.push_back(Entry);
+
+ enum { SME_MEMCPY = 0, SME_MEMMOVE, SME_MEMSET } SMELibcall;
+ switch (LC) {
+ case RTLIB::MEMCPY:
+ SMELibcall = SME_MEMCPY;
+ Entry.Node = Src;
+ Args.push_back(Entry);
+ break;
+ case RTLIB::MEMMOVE:
+ SMELibcall = SME_MEMMOVE;
+ Entry.Node = Src;
+ Args.push_back(Entry);
+ break;
+ case RTLIB::MEMSET:
+ SMELibcall = SME_MEMSET;
+ if (Src.getValueType().bitsGT(MVT::i32))
+ Src = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Src);
+ else if (Src.getValueType().bitsLT(MVT::i32))
+ Src = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, Src);
+ Entry.Node = Src;
+ Entry.Ty = Type::getInt32Ty(*DAG.getContext());
+ Entry.IsSExt = false;
+ Args.push_back(Entry);
+ break;
+ default:
+ return SDValue();
+ }
+ Entry.Node = Size;
+ Args.push_back(Entry);
+ char const *FunctionNames[3] = {"__arm_sc_memcpy", "__arm_sc_memmove",
+ "__arm_sc_memset"};
+
+ TargetLowering::CallLoweringInfo CLI(DAG);
+ CLI.setDebugLoc(DL)
+ .setChain(Chain)
+ .setLibCallee(
+ TLI->getLibcallCallingConv(RTLIB::MEMCPY),
+ Type::getVoidTy(*DAG.getContext()),
+ DAG.getExternalSymbol(FunctionNames[SMELibcall],
+ TLI->getPointerTy(DAG.getDataLayout())),
+ std::move(Args))
+ .setDiscardResult();
+ std::pair<SDValue, SDValue> CallResult = TLI->LowerCallTo(CLI);
+ return CallResult.second;
+}
+
SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemcpy(
SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, SDValue Dst, SDValue Src,
SDValue Size, Align Alignment, bool isVolatile, bool AlwaysInline,
MachinePointerInfo DstPtrInfo, MachinePointerInfo SrcPtrInfo) const {
const AArch64Subtarget &STI =
DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
+
+ SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
+ if (Attrs.hasStreamingBody() || Attrs.hasStreamingCompatibleInterface())
+ return EmitSpecializedLibcall(DAG, DL, Chain, Dst, Src, Size,
+ RTLIB::MEMCPY);
if (STI.hasMOPS())
return EmitMOPS(AArch64ISD::MOPS_MEMCOPY, DAG, DL, Chain, Dst, Src, Size,
Alignment, isVolatile, DstPtrInfo, SrcPtrInfo);
@@ -95,6 +157,11 @@ SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemset(
const AArch64Subtarget &STI =
DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
+ SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
+ if (Attrs.hasStreamingBody() || Attrs.hasStreamingCompatibleInterface())
+ return EmitSpecializedLibcall(DAG, dl, Chain, Dst, Src, Size,
+ RTLIB::MEMSET);
+
if (STI.hasMOPS()) {
return EmitMOPS(AArch64ISD::MOPS_MEMSET, DAG, dl, Chain, Dst, Src, Size,
Alignment, isVolatile, DstPtrInfo, MachinePointerInfo{});
@@ -108,6 +175,11 @@ SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemmove(
MachinePointerInfo DstPtrInfo, MachinePointerInfo SrcPtrInfo) const {
const AArch64Subtarget &STI =
DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
+
+ SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
+ if (Attrs.hasStreamingBody() || Attrs.hasStreamingCompatibleInterface())
+ return EmitSpecializedLibcall(DAG, dl, Chain, Dst, Src, Size,
+ RTLIB::MEMMOVE);
if (STI.hasMOPS()) {
return EmitMOPS(AArch64ISD::MOPS_MEMMOVE, DAG, dl, Chain, Dst, Src, Size,
Alignment, isVolatile, DstPtrInfo, SrcPtrInfo);
diff --git a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h
index 73f93724d6fc7..9c55c21f3c320 100644
--- a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h
@@ -47,6 +47,10 @@ class AArch64SelectionDAGInfo : public SelectionDAGTargetInfo {
SDValue Chain, SDValue Op1, SDValue Op2,
MachinePointerInfo DstPtrInfo,
bool ZeroData) const override;
+
+ SDValue EmitSpecializedLibcall(SelectionDAG &DAG, const SDLoc &DL,
+ SDValue Chain, SDValue Dst, SDValue Src,
+ SDValue Size, RTLIB::Libcall LC) const;
};
}
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
index 3ee54e5df0a13..5080e4a0b4f9a 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
@@ -51,6 +51,9 @@ SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
if (FuncName == "__arm_tpidr2_restore")
Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::ZA_Shared |
SMEAttrs::SME_ABI_Routine);
+ if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
+ FuncName == "__arm_sc_memmove")
+ Bitmask |= SMEAttrs::SM_Compatible;
}
SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
diff --git a/llvm/test/CodeGen/AArch64/sme2-mops.ll b/llvm/test/CodeGen/AArch64/sme2-mops.ll
new file mode 100644
index 0000000000000..0ded6e965ecb9
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme2-mops.ll
@@ -0,0 +1,67 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2 -verify-machineinstrs < %s | FileCheck %s
+
+ at dst = global [512 x i8] zeroinitializer, align 1
+ at src = global [512 x i8] zeroinitializer, align 1
+
+
+define void @sc_memcpy(i64 noundef %n) "aarch64_pstate_sm_compatible" {
+; CHECK-LABEL: sc_memcpy:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT: .cfi_def_cfa_offset 16
+; CHECK-NEXT: .cfi_offset w30, -16
+; CHECK-NEXT: mov x2, x0
+; CHECK-NEXT: adrp x0, :got:dst
+; CHECK-NEXT: adrp x1, :got:src
+; CHECK-NEXT: ldr x0, [x0, :got_lo12:dst]
+; CHECK-NEXT: ldr x1, [x1, :got_lo12:src]
+; CHECK-NEXT: bl __arm_sc_memcpy
+; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT: ret
+entry:
+ tail call void @llvm.memcpy.p0.p0.i64(ptr align 1 @dst, ptr nonnull align 1 @src, i64 %n, i1 false)
+ ret void
+}
+
+define void @sc_memset(i64 noundef %n) "aarch64_pstate_sm_compatible" {
+; CHECK-LABEL: sc_memset:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT: .cfi_def_cfa_offset 16
+; CHECK-NEXT: .cfi_offset w30, -16
+; CHECK-NEXT: mov x2, x0
+; CHECK-NEXT: adrp x0, :got:dst
+; CHECK-NEXT: mov w1, #2 // =0x2
+; CHECK-NEXT: ldr x0, [x0, :got_lo12:dst]
+; CHECK-NEXT: // kill: def $w2 killed $w2 killed $x2
+; CHECK-NEXT: bl __arm_sc_memset
+; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT: ret
+entry:
+ tail call void @llvm.memset.p0.i64(ptr align 1 @dst, i8 2, i64 %n, i1 false)
+ ret void
+}
+
+define void @sc_memmove(i64 noundef %n) "aarch64_pstate_sm_compatible" {
+; CHECK-LABEL: sc_memmove:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT: .cfi_def_cfa_offset 16
+; CHECK-NEXT: .cfi_offset w30, -16
+; CHECK-NEXT: mov x2, x0
+; CHECK-NEXT: adrp x0, :got:dst
+; CHECK-NEXT: adrp x1, :got:src
+; CHECK-NEXT: ldr x0, [x0, :got_lo12:dst]
+; CHECK-NEXT: ldr x1, [x1, :got_lo12:src]
+; CHECK-NEXT: bl __arm_sc_memmove
+; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT: ret
+entry:
+ tail call void @llvm.memmove.p0.p0.i64(ptr align 1 @dst, ptr nonnull align 1 @src, i64 %n, i1 false)
+ ret void
+}
+
+declare void @llvm.memset.p0.i64(ptr nocapture writeonly, i8, i64, i1 immarg)
+declare void @llvm.memcpy.p0.p0.i64(ptr nocapture writeonly, ptr nocapture readonly, i64, i1 immarg)
+declare void @llvm.memmove.p0.p0.i64(ptr nocapture writeonly, ptr nocapture readonly, i64, i1 immarg)
>From 736bcd7eeb54af03f866329b99f18ee61b9df6b7 Mon Sep 17 00:00:00 2001
From: Dinar Temirbulatov <Dinar.Temirbulatov at arm.com>
Date: Tue, 30 Jan 2024 09:08:43 +0000
Subject: [PATCH 2/2] Resolved comments
---
.../Target/AArch64/AArch64ISelLowering.cpp | 10 ++------
.../AArch64/AArch64SelectionDAGInfo.cpp | 24 +++++++------------
llvm/test/CodeGen/AArch64/sme2-mops.ll | 1 -
3 files changed, 10 insertions(+), 25 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 6385b341bcf63..f572772d3c980 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -7659,14 +7659,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
SMEAttrs CalleeAttrs, CallerAttrs(MF.getFunction());
if (CLI.CB)
CalleeAttrs = SMEAttrs(*CLI.CB);
- else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee)) {
- if (StringRef(ES->getSymbol()) == StringRef("__arm_sc_memcpy")) {
- auto Attrs = AttributeList().addFnAttribute(
- *DAG.getContext(), "aarch64_pstate_sm_compatible");
- CalleeAttrs = SMEAttrs(Attrs);
- } else
- CalleeAttrs = SMEAttrs(ES->getSymbol());
- }
+ else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
+ CalleeAttrs = SMEAttrs(ES->getSymbol());
auto DescribeCallsite =
[&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
diff --git a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
index fff4e2333194e..1c4142e535793 100644
--- a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
@@ -84,28 +84,26 @@ SDValue AArch64SelectionDAGInfo::EmitSpecializedLibcall(
const AArch64TargetLowering *TLI = STI.getTargetLowering();
TargetLowering::ArgListTy Args;
TargetLowering::ArgListEntry Entry;
+ SDValue Symbol;
Entry.Ty = DAG.getDataLayout().getIntPtrType(*DAG.getContext());
Entry.Node = Dst;
Args.push_back(Entry);
+ EVT Ty = TLI->getPointerTy(DAG.getDataLayout());
- enum { SME_MEMCPY = 0, SME_MEMMOVE, SME_MEMSET } SMELibcall;
switch (LC) {
case RTLIB::MEMCPY:
- SMELibcall = SME_MEMCPY;
+ Symbol = DAG.getExternalSymbol("__arm_sc_memcpy", Ty);
Entry.Node = Src;
Args.push_back(Entry);
break;
case RTLIB::MEMMOVE:
- SMELibcall = SME_MEMMOVE;
+ Symbol = DAG.getExternalSymbol("__arm_sc_memmove", Ty);
Entry.Node = Src;
Args.push_back(Entry);
break;
case RTLIB::MEMSET:
- SMELibcall = SME_MEMSET;
- if (Src.getValueType().bitsGT(MVT::i32))
- Src = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Src);
- else if (Src.getValueType().bitsLT(MVT::i32))
- Src = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, Src);
+ Symbol = DAG.getExternalSymbol("__arm_sc_memset", Ty);
+ Src = DAG.getZExtOrTrunc(Src, DL, MVT::i32);
Entry.Node = Src;
Entry.Ty = Type::getInt32Ty(*DAG.getContext());
Entry.IsSExt = false;
@@ -116,18 +114,12 @@ SDValue AArch64SelectionDAGInfo::EmitSpecializedLibcall(
}
Entry.Node = Size;
Args.push_back(Entry);
- char const *FunctionNames[3] = {"__arm_sc_memcpy", "__arm_sc_memmove",
- "__arm_sc_memset"};
TargetLowering::CallLoweringInfo CLI(DAG);
CLI.setDebugLoc(DL)
.setChain(Chain)
- .setLibCallee(
- TLI->getLibcallCallingConv(RTLIB::MEMCPY),
- Type::getVoidTy(*DAG.getContext()),
- DAG.getExternalSymbol(FunctionNames[SMELibcall],
- TLI->getPointerTy(DAG.getDataLayout())),
- std::move(Args))
+ .setLibCallee(TLI->getLibcallCallingConv(RTLIB::MEMCPY),
+ Type::getVoidTy(*DAG.getContext()), Symbol, std::move(Args))
.setDiscardResult();
std::pair<SDValue, SDValue> CallResult = TLI->LowerCallTo(CLI);
return CallResult.second;
diff --git a/llvm/test/CodeGen/AArch64/sme2-mops.ll b/llvm/test/CodeGen/AArch64/sme2-mops.ll
index 0ded6e965ecb9..0599bc61a52f7 100644
--- a/llvm/test/CodeGen/AArch64/sme2-mops.ll
+++ b/llvm/test/CodeGen/AArch64/sme2-mops.ll
@@ -4,7 +4,6 @@
@dst = global [512 x i8] zeroinitializer, align 1
@src = global [512 x i8] zeroinitializer, align 1
-
define void @sc_memcpy(i64 noundef %n) "aarch64_pstate_sm_compatible" {
; CHECK-LABEL: sc_memcpy:
; CHECK: // %bb.0: // %entry
More information about the llvm-commits
mailing list