[llvm] Fix/aarch64 memset dup optimization (PR #166030)
Osama Abdelkader via llvm-commits
llvm-commits at lists.llvm.org
Sun Nov 2 13:19:57 PST 2025
https://github.com/osamakader updated https://github.com/llvm/llvm-project/pull/166030
>From c9b595dfbdbd5893917677aa756bc9dbd4d5bcdc Mon Sep 17 00:00:00 2001
From: Osama Abdelkader <osama.abdelkader at gmail.com>
Date: Sun, 2 Nov 2025 14:41:16 +0200
Subject: [PATCH] Optimize AArch64 memset to use NEON DUP instruction for small
sizes
This change improves memset code generation for non-zero values on AArch64
for sizes 4, 8, and 16 bytes by using NEON's DUP instruction instead of
the less efficient multiplication with 0x01010101 pattern.
Changes:
1. In SelectionDAG.cpp: For AArch64 targets, generate vector splats for
scalar i32/i64 memset operations, which are then efficiently lowered to
DUP instructions.
2. In AArch64ISelLowering.cpp: Modify getOptimalMemOpType and
getOptimalMemOpLLT to return v16i8 for non-zero memset operations of
any size when NEON is available (previously only for sizes >= 32 bytes).
3. Update test expectations to verify the new DUP-based code generation
for both NEON and GPR code paths.
The optimization is restricted to AArch64 only to avoid breaking RISCV
and X86 tests.
Signed-off-by: Osama Abdelkader <osama.abdelkader at gmail.com>
---
.../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 14 +++
.../Target/AArch64/AArch64ISelLowering.cpp | 51 ++++++++---
llvm/lib/Target/AArch64/AArch64ISelLowering.h | 3 +
llvm/test/CodeGen/AArch64/memset-inline.ll | 86 ++++++++++++-------
4 files changed, 113 insertions(+), 41 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 379242ec5a157..d1fcb802c5268 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -8543,6 +8543,20 @@ static SDValue getMemsetValue(SDValue Value, EVT VT, SelectionDAG &DAG,
if (!IntVT.isInteger())
IntVT = EVT::getIntegerVT(*DAG.getContext(), IntVT.getSizeInBits());
+ // For repeated-byte patterns, generate a vector splat instead of MUL to
+ // enable efficient lowering to DUP on targets like AArch64.
+ // Only do this on AArch64 targets to avoid breaking other architectures.
+ const TargetMachine &TM = DAG.getTarget();
+ if (NumBits > 8 && VT.isInteger() && !VT.isVector() &&
+ (NumBits == 32 || NumBits == 64) &&
+ TM.getTargetTriple().getArch() == Triple::aarch64) {
+ // Generate a vector of bytes: v4i8 for i32, v8i8 for i64
+ EVT ByteVecTy = EVT::getVectorVT(*DAG.getContext(), MVT::i8, NumBits / 8);
+ SDValue VecSplat = DAG.getSplatBuildVector(ByteVecTy, dl, Value);
+ // Bitcast back to the target integer type
+ return DAG.getNode(ISD::BITCAST, dl, IntVT, VecSplat);
+ }
+
Value = DAG.getNode(ISD::ZERO_EXTEND, dl, IntVT, Value);
if (NumBits > 8) {
// Use a multiplication with 0x010101... to extend the input to the
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 60aa61e993b26..170ae6ee8a89b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -18328,10 +18328,11 @@ EVT AArch64TargetLowering::getOptimalMemOpType(
bool CanImplicitFloat = !FuncAttributes.hasFnAttr(Attribute::NoImplicitFloat);
bool CanUseNEON = Subtarget->hasNEON() && CanImplicitFloat;
bool CanUseFP = Subtarget->hasFPARMv8() && CanImplicitFloat;
- // Only use AdvSIMD to implement memset of 32-byte and above. It would have
+ // For zero memset, only use AdvSIMD for 32-byte and above. It would have
// taken one instruction to materialize the v2i64 zero and one store (with
// restrictive addressing mode). Just do i64 stores.
- bool IsSmallMemset = Op.isMemset() && Op.size() < 32;
+ // For non-zero memset, use NEON even for smaller sizes as dup is efficient.
+ bool IsSmallZeroMemset = Op.isMemset() && Op.size() < 32 && Op.isZeroMemset();
auto AlignmentIsAcceptable = [&](EVT VT, Align AlignCheck) {
if (Op.isAligned(AlignCheck))
return true;
@@ -18341,10 +18342,12 @@ EVT AArch64TargetLowering::getOptimalMemOpType(
Fast;
};
- if (CanUseNEON && Op.isMemset() && !IsSmallMemset &&
- AlignmentIsAcceptable(MVT::v16i8, Align(16)))
+ // For non-zero memset, use NEON even for smaller sizes as dup + scalar store
+ // is efficient
+ if (CanUseNEON && Op.isMemset() && !IsSmallZeroMemset)
return MVT::v16i8;
- if (CanUseFP && !IsSmallMemset && AlignmentIsAcceptable(MVT::f128, Align(16)))
+ if (CanUseFP && !IsSmallZeroMemset &&
+ AlignmentIsAcceptable(MVT::f128, Align(16)))
return MVT::f128;
if (Op.size() >= 8 && AlignmentIsAcceptable(MVT::i64, Align(8)))
return MVT::i64;
@@ -18358,10 +18361,11 @@ LLT AArch64TargetLowering::getOptimalMemOpLLT(
bool CanImplicitFloat = !FuncAttributes.hasFnAttr(Attribute::NoImplicitFloat);
bool CanUseNEON = Subtarget->hasNEON() && CanImplicitFloat;
bool CanUseFP = Subtarget->hasFPARMv8() && CanImplicitFloat;
- // Only use AdvSIMD to implement memset of 32-byte and above. It would have
+ // For zero memset, only use AdvSIMD for 32-byte and above. It would have
// taken one instruction to materialize the v2i64 zero and one store (with
// restrictive addressing mode). Just do i64 stores.
- bool IsSmallMemset = Op.isMemset() && Op.size() < 32;
+ // For non-zero memset, use NEON even for smaller sizes as dup is efficient.
+ bool IsSmallZeroMemset = Op.isMemset() && Op.size() < 32 && Op.isZeroMemset();
auto AlignmentIsAcceptable = [&](EVT VT, Align AlignCheck) {
if (Op.isAligned(AlignCheck))
return true;
@@ -18371,10 +18375,12 @@ LLT AArch64TargetLowering::getOptimalMemOpLLT(
Fast;
};
- if (CanUseNEON && Op.isMemset() && !IsSmallMemset &&
- AlignmentIsAcceptable(MVT::v2i64, Align(16)))
+ // For non-zero memset, use NEON for all sizes where it's beneficial.
+ // NEON dup + scalar store works for any alignment and is efficient.
+ if (CanUseNEON && Op.isMemset() && !IsSmallZeroMemset)
return LLT::fixed_vector(2, 64);
- if (CanUseFP && !IsSmallMemset && AlignmentIsAcceptable(MVT::f128, Align(16)))
+ if (CanUseFP && !IsSmallZeroMemset &&
+ AlignmentIsAcceptable(MVT::f128, Align(16)))
return LLT::scalar(128);
if (Op.size() >= 8 && AlignmentIsAcceptable(MVT::i64, Align(8)))
return LLT::scalar(64);
@@ -29702,6 +29708,31 @@ AArch64TargetLowering::EmitKCFICheck(MachineBasicBlock &MBB,
.getInstr();
}
+bool AArch64TargetLowering::shallExtractConstSplatVectorElementToStore(
+ Type *VectorTy, unsigned ElemSizeInBits, unsigned &Index) const {
+ // On AArch64, we can efficiently extract a scalar from a splat vector using
+ // str s/d/q0 which extracts 32/64/128 bits from the vector register.
+ // This is useful for memset where we generate a v16i8 splat and need to store
+ // a smaller scalar (e.g., i32 for a 4-byte memset).
+ if (FixedVectorType *VTy = dyn_cast<FixedVectorType>(VectorTy)) {
+ // Only handle v16i8 splat (128 bits total, 16 elements of 8 bits each)
+ if (VTy->getNumElements() == 16 && VTy->getElementType()->isIntegerTy(8)) {
+ // Check if we're extracting a 32-bit or 64-bit element
+ if (ElemSizeInBits == 32) {
+ // Extract element 0 of the 128-bit vector as a 32-bit scalar
+ Index = 0;
+ return true;
+ }
+ if (ElemSizeInBits == 64) {
+ // Extract elements 0-7 as a 64-bit scalar
+ Index = 0;
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
bool AArch64TargetLowering::enableAggressiveFMAFusion(EVT VT) const {
return Subtarget->hasAggressiveFMA() && VT.isFloatingPoint();
}
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 2cb8ed29f252a..37fadf8a2b0b1 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -475,6 +475,9 @@ class AArch64TargetLowering : public TargetLowering {
MachineBasicBlock::instr_iterator &MBBI,
const TargetInstrInfo *TII) const override;
+ bool shallExtractConstSplatVectorElementToStore(
+ Type *VectorTy, unsigned ElemSizeInBits, unsigned &Index) const override;
+
/// Enable aggressive FMA fusion on targets that want it.
bool enableAggressiveFMAFusion(EVT VT) const override;
diff --git a/llvm/test/CodeGen/AArch64/memset-inline.ll b/llvm/test/CodeGen/AArch64/memset-inline.ll
index 02d852b5ce45a..ed9a752dc1f8d 100644
--- a/llvm/test/CodeGen/AArch64/memset-inline.ll
+++ b/llvm/test/CodeGen/AArch64/memset-inline.ll
@@ -27,39 +27,57 @@ define void @memset_2(ptr %a, i8 %value) nounwind {
}
define void @memset_4(ptr %a, i8 %value) nounwind {
-; ALL-LABEL: memset_4:
-; ALL: // %bb.0:
-; ALL-NEXT: mov w8, #16843009
-; ALL-NEXT: and w9, w1, #0xff
-; ALL-NEXT: mul w8, w9, w8
-; ALL-NEXT: str w8, [x0]
-; ALL-NEXT: ret
+; GPR-LABEL: memset_4:
+; GPR: // %bb.0:
+; GPR-NEXT: mov w8, #16843009
+; GPR-NEXT: and w9, w1, #0xff
+; GPR-NEXT: mul w8, w9, w8
+; GPR-NEXT: str w8, [x0]
+; GPR-NEXT: ret
+;
+; NEON-LABEL: memset_4:
+; NEON: // %bb.0:
+; NEON-NEXT: dup v0.8b, w1
+; NEON-NEXT: str s0, [x0]
+; NEON-NEXT: ret
tail call void @llvm.memset.inline.p0.i64(ptr %a, i8 %value, i64 4, i1 0)
ret void
}
define void @memset_8(ptr %a, i8 %value) nounwind {
-; ALL-LABEL: memset_8:
-; ALL: // %bb.0:
-; ALL-NEXT: // kill: def $w1 killed $w1 def $x1
-; ALL-NEXT: mov x8, #72340172838076673
-; ALL-NEXT: and x9, x1, #0xff
-; ALL-NEXT: mul x8, x9, x8
-; ALL-NEXT: str x8, [x0]
-; ALL-NEXT: ret
+; GPR-LABEL: memset_8:
+; GPR: // %bb.0:
+; GPR-NEXT: // kill: def $w1 killed $w1 def $x1
+; GPR-NEXT: mov x8, #72340172838076673
+; GPR-NEXT: and x9, x1, #0xff
+; GPR-NEXT: mul x8, x9, x8
+; GPR-NEXT: str x8, [x0]
+; GPR-NEXT: ret
+;
+; NEON-LABEL: memset_8:
+; NEON: // %bb.0:
+; NEON-NEXT: dup v0.8b, w1
+; NEON-NEXT: str d0, [x0]
+; NEON-NEXT: ret
tail call void @llvm.memset.inline.p0.i64(ptr %a, i8 %value, i64 8, i1 0)
ret void
}
define void @memset_16(ptr %a, i8 %value) nounwind {
-; ALL-LABEL: memset_16:
-; ALL: // %bb.0:
-; ALL-NEXT: // kill: def $w1 killed $w1 def $x1
-; ALL-NEXT: mov x8, #72340172838076673
-; ALL-NEXT: and x9, x1, #0xff
-; ALL-NEXT: mul x8, x9, x8
-; ALL-NEXT: stp x8, x8, [x0]
-; ALL-NEXT: ret
+; GPR-LABEL: memset_16:
+; GPR: // %bb.0:
+; GPR-NEXT: // kill: def $w1 killed $w1 def $x1
+; GPR-NEXT: mov x8, #72340172838076673
+; GPR-NEXT: and x9, x1, #0xff
+; GPR-NEXT: mul x8, x9, x8
+; GPR-NEXT: stp x8, x8, [x0]
+; GPR-NEXT: ret
+;
+; NEON-LABEL: memset_16:
+; NEON: // %bb.0:
+; NEON-NEXT: dup v0.16b, w1
+; NEON-NEXT: str q0, [x0]
+; NEON-NEXT: ret
tail call void @llvm.memset.inline.p0.i64(ptr %a, i8 %value, i64 16, i1 0)
ret void
}
@@ -110,14 +128,20 @@ define void @memset_64(ptr %a, i8 %value) nounwind {
; /////////////////////////////////////////////////////////////////////////////
define void @aligned_memset_16(ptr align 16 %a, i8 %value) nounwind {
-; ALL-LABEL: aligned_memset_16:
-; ALL: // %bb.0:
-; ALL-NEXT: // kill: def $w1 killed $w1 def $x1
-; ALL-NEXT: mov x8, #72340172838076673
-; ALL-NEXT: and x9, x1, #0xff
-; ALL-NEXT: mul x8, x9, x8
-; ALL-NEXT: stp x8, x8, [x0]
-; ALL-NEXT: ret
+; GPR-LABEL: aligned_memset_16:
+; GPR: // %bb.0:
+; GPR-NEXT: // kill: def $w1 killed $w1 def $x1
+; GPR-NEXT: mov x8, #72340172838076673
+; GPR-NEXT: and x9, x1, #0xff
+; GPR-NEXT: mul x8, x9, x8
+; GPR-NEXT: stp x8, x8, [x0]
+; GPR-NEXT: ret
+;
+; NEON-LABEL: aligned_memset_16:
+; NEON: // %bb.0:
+; NEON-NEXT: dup v0.16b, w1
+; NEON-NEXT: str q0, [x0]
+; NEON-NEXT: ret
tail call void @llvm.memset.inline.p0.i64(ptr align 16 %a, i8 %value, i64 16, i1 0)
ret void
}
More information about the llvm-commits
mailing list