[llvm] [AArch64][SVE2] Use rshrnb for masked stores (PR #70026)
Matthew Devereau via llvm-commits
llvm-commits at lists.llvm.org
Wed Oct 25 02:31:13 PDT 2023
https://github.com/MDevereau updated https://github.com/llvm/llvm-project/pull/70026
>From 512193b6f19e9e3a1f613308c2d7bbdcbd3846b5 Mon Sep 17 00:00:00 2001
From: Matt Devereau <matthew.devereau at arm.com>
Date: Tue, 24 Oct 2023 10:51:36 +0000
Subject: [PATCH 1/2] [AArch64][SVE2] Use rshrnb for masked stores
This patch is a follow up on https://reviews.llvm.org/D155299.
This patch combines add+lsr to rshrnb when 'B' in:
C = A + B
D = C >> Shift
is equal to (1 << (Shift-1), and the bits in the top half
of each vector element are zeroed or ignored, such as in a
truncating masked store.
---
.../Target/AArch64/AArch64ISelLowering.cpp | 15 +++++++++++++++
.../AArch64/sve2-intrinsics-combine-rshrnb.ll | 19 +++++++++++++++++++
2 files changed, 34 insertions(+)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a16a102e472e709..e09ebe01a336c96 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21017,6 +21017,21 @@ static SDValue performMSTORECombine(SDNode *N,
}
}
+ if (MST->isTruncatingStore()) {
+ if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Value, DAG, Subtarget)) {
+ EVT ValueVT = Value->getValueType(0);
+ EVT MemVT = MST->getMemoryVT();
+ if ((ValueVT == MVT::nxv8i16 && MemVT == MVT::nxv8i8) ||
+ (ValueVT == MVT::nxv4i32 && MemVT == MVT::nxv4i16) ||
+ (ValueVT == MVT::nxv2i64 && MemVT == MVT::nxv2i32)) {
+ return DAG.getMaskedStore(
+ MST->getChain(), DL, Rshrnb, MST->getBasePtr(), MST->getOffset(),
+ MST->getMask(), MST->getMemoryVT(), MST->getMemOperand(),
+ MST->getAddressingMode(), true);
+ }
+ }
+ }
+
return SDValue();
}
diff --git a/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll b/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll
index a913177623df9ec..0afd11d098a0009 100644
--- a/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll
+++ b/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll
@@ -298,3 +298,22 @@ define void @neg_add_lshr_rshrnb_s(ptr %ptr, ptr %dst, i64 %index){
store <vscale x 2 x i16> %3, ptr %4, align 1
ret void
}
+
+define void @masked_store_rshrnb(ptr %ptr, ptr %dst, i64 %index, <vscale x 8 x i1> %mask) { ; preds = %vector.body, %vector.ph
+; CHECK-LABEL: masked_store_rshrnb:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ld1h { z0.h }, p0/z, [x0]
+; CHECK-NEXT: rshrnb z0.b, z0.h, #6
+; CHECK-NEXT: st1b { z0.h }, p0, [x1, x2]
+; CHECK-NEXT: ret
+ %wide.masked.load = tail call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr %ptr, i32 2, <vscale x 8 x i1> %mask, <vscale x 8 x i16> poison)
+ %1 = add <vscale x 8 x i16> %wide.masked.load, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 32, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
+ %2 = lshr <vscale x 8 x i16> %1, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 6, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
+ %3 = trunc <vscale x 8 x i16> %2 to <vscale x 8 x i8>
+ %4 = getelementptr inbounds i8, ptr %dst, i64 %index
+ tail call void @llvm.masked.store.nxv8i8.p0(<vscale x 8 x i8> %3, ptr %4, i32 1, <vscale x 8 x i1> %mask)
+ ret void
+}
+
+declare void @llvm.masked.store.nxv8i8.p0(<vscale x 8 x i8>, ptr, i32, <vscale x 8 x i1>)
+declare <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr, i32, <vscale x 8 x i1>, <vscale x 8 x i16>)
>From 6daa917ac8e826f4fc37f3c5e24d53c4c165b44c Mon Sep 17 00:00:00 2001
From: Matt Devereau <matthew.devereau at arm.com>
Date: Wed, 25 Oct 2023 09:30:03 +0000
Subject: [PATCH 2/2] Add helper function for MVTs
---
.../Target/AArch64/AArch64ISelLowering.cpp | 38 ++++++++++---------
1 file changed, 21 insertions(+), 17 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e09ebe01a336c96..f9bce864b132b8b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -20921,6 +20921,12 @@ static SDValue combineBoolVectorAndTruncateStore(SelectionDAG &DAG,
Store->getMemOperand());
}
+bool isHalvingTruncateOfLegalScalableType(EVT SrcVT, EVT DstVT) {
+ return (SrcVT == MVT::nxv8i16 && DstVT == MVT::nxv8i8) ||
+ (SrcVT == MVT::nxv4i32 && DstVT == MVT::nxv4i16) ||
+ (SrcVT == MVT::nxv2i64 && DstVT == MVT::nxv2i32);
+}
+
static SDValue performSTORECombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG,
@@ -20962,16 +20968,16 @@ static SDValue performSTORECombine(SDNode *N,
if (SDValue Store = combineBoolVectorAndTruncateStore(DAG, ST))
return Store;
- if (ST->isTruncatingStore())
+ if (ST->isTruncatingStore()) {
+ EVT StoreVT = ST->getMemoryVT();
+ if (!isHalvingTruncateOfLegalScalableType(ValueVT, StoreVT))
+ return SDValue();
if (SDValue Rshrnb =
trySimplifySrlAddToRshrnb(ST->getOperand(1), DAG, Subtarget)) {
- EVT StoreVT = ST->getMemoryVT();
- if ((ValueVT == MVT::nxv8i16 && StoreVT == MVT::nxv8i8) ||
- (ValueVT == MVT::nxv4i32 && StoreVT == MVT::nxv4i16) ||
- (ValueVT == MVT::nxv2i64 && StoreVT == MVT::nxv2i32))
- return DAG.getTruncStore(ST->getChain(), ST, Rshrnb, ST->getBasePtr(),
- StoreVT, ST->getMemOperand());
+ return DAG.getTruncStore(ST->getChain(), ST, Rshrnb, ST->getBasePtr(),
+ StoreVT, ST->getMemOperand());
}
+ }
return SDValue();
}
@@ -21018,17 +21024,15 @@ static SDValue performMSTORECombine(SDNode *N,
}
if (MST->isTruncatingStore()) {
+ EVT ValueVT = Value->getValueType(0);
+ EVT MemVT = MST->getMemoryVT();
+ if (!isHalvingTruncateOfLegalScalableType(ValueVT, MemVT))
+ return SDValue();
if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Value, DAG, Subtarget)) {
- EVT ValueVT = Value->getValueType(0);
- EVT MemVT = MST->getMemoryVT();
- if ((ValueVT == MVT::nxv8i16 && MemVT == MVT::nxv8i8) ||
- (ValueVT == MVT::nxv4i32 && MemVT == MVT::nxv4i16) ||
- (ValueVT == MVT::nxv2i64 && MemVT == MVT::nxv2i32)) {
- return DAG.getMaskedStore(
- MST->getChain(), DL, Rshrnb, MST->getBasePtr(), MST->getOffset(),
- MST->getMask(), MST->getMemoryVT(), MST->getMemOperand(),
- MST->getAddressingMode(), true);
- }
+ return DAG.getMaskedStore(MST->getChain(), DL, Rshrnb, MST->getBasePtr(),
+ MST->getOffset(), MST->getMask(),
+ MST->getMemoryVT(), MST->getMemOperand(),
+ MST->getAddressingMode(), true);
}
}
More information about the llvm-commits
mailing list