[libcxx] [flang] [mlir] [libc] [openmp] [libcxxabi] [clang-tools-extra] [clang] [llvm] [compiler-rt] [AArch64] Combine store (trunc X to <3 x i8>) to sequence of ST1.b. (PR #78637)

Florian Hahn via cfe-commits cfe-commits at lists.llvm.org
Mon Jan 22 08:06:21 PST 2024


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/78637

>From efd07e93aed51049ad3783c701284617ae446330 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Wed, 17 Jan 2024 11:11:59 +0000
Subject: [PATCH 1/2] [AArch64] Combine store (trunc X to <3 x i8>) to sequence
 of ST1.b.

Improve codegen for (trunc X to <3 x i8>) by converting it to a sequence
of 3 ST1.b, but first converting the truncate operand to either v8i8 or
v16i8, extracting the lanes for the truncate results and storing them.

At the moment, there are almost no cases in which such vector operations
will be generated automatically. The motivating case is non-power-of-2
SLP vectorization: https://github.com/llvm/llvm-project/pull/77790
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 50 +++++++++++++++++++
 .../AArch64/vec3-loads-ext-trunc-stores.ll    | 33 +++++-------
 2 files changed, 62 insertions(+), 21 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 8a6f1dc7487bae8..4be78f61fe7b651 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21318,6 +21318,53 @@ bool isHalvingTruncateOfLegalScalableType(EVT SrcVT, EVT DstVT) {
          (SrcVT == MVT::nxv2i64 && DstVT == MVT::nxv2i32);
 }
 
+// Combine store (trunc X to <3 x i8>) to sequence of ST1.b.
+static SDValue combineI8TruncStore(StoreSDNode *ST, SelectionDAG &DAG,
+                                   const AArch64Subtarget *Subtarget) {
+  SDValue Value = ST->getValue();
+  EVT ValueVT = Value.getValueType();
+
+  if (ST->isVolatile() || !Subtarget->isLittleEndian() ||
+      ST->getOriginalAlign() >= 4 || Value.getOpcode() != ISD::TRUNCATE ||
+      ValueVT != EVT::getVectorVT(*DAG.getContext(), MVT::i8, 3))
+    return SDValue();
+
+  SDLoc DL(ST);
+  auto WideVT = EVT::getVectorVT(
+      *DAG.getContext(),
+      Value->getOperand(0).getValueType().getVectorElementType(), 4);
+  SDValue UndefVector = DAG.getUNDEF(WideVT);
+  SDValue WideTrunc = DAG.getNode(
+      ISD::INSERT_SUBVECTOR, DL, WideVT,
+      {UndefVector, Value->getOperand(0), DAG.getVectorIdxConstant(0, DL)});
+  SDValue Cast = DAG.getNode(
+      ISD::BITCAST, DL, WideVT.getSizeInBits() == 64 ? MVT::v8i8 : MVT::v16i8,
+      WideTrunc);
+
+  SDValue Chain = ST->getChain();
+  SDValue E2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, Cast,
+                           DAG.getConstant(8, DL, MVT::i64));
+
+  SDValue Ptr2 =
+      DAG.getMemBasePlusOffset(ST->getBasePtr(), TypeSize::getFixed(2), DL);
+  Chain = DAG.getStore(Chain, DL, E2, Ptr2, ST->getPointerInfo(),
+                       ST->getOriginalAlign());
+
+  SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, Cast,
+                           DAG.getConstant(4, DL, MVT::i64));
+
+  SDValue Ptr1 =
+      DAG.getMemBasePlusOffset(ST->getBasePtr(), TypeSize::getFixed(1), DL);
+  Chain = DAG.getStore(Chain, DL, E1, Ptr1, ST->getPointerInfo(),
+                       ST->getOriginalAlign());
+  SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, Cast,
+                           DAG.getConstant(0, DL, MVT::i64));
+  Chain = DAG.getStore(Chain, DL, E0, ST->getBasePtr(), ST->getPointerInfo(),
+                       ST->getOriginalAlign());
+
+  return Chain;
+}
+
 static SDValue performSTORECombine(SDNode *N,
                                    TargetLowering::DAGCombinerInfo &DCI,
                                    SelectionDAG &DAG,
@@ -21333,6 +21380,9 @@ static SDValue performSTORECombine(SDNode *N,
     return EltVT == MVT::f32 || EltVT == MVT::f64;
   };
 
+  if (SDValue Res = combineI8TruncStore(ST, DAG, Subtarget))
+    return Res;
+
   // If this is an FP_ROUND followed by a store, fold this into a truncating
   // store. We can do this even if this is already a truncstore.
   // We purposefully don't care about legality of the nodes here as we know
diff --git a/llvm/test/CodeGen/AArch64/vec3-loads-ext-trunc-stores.ll b/llvm/test/CodeGen/AArch64/vec3-loads-ext-trunc-stores.ll
index 9eeb194409df6fa..60639ea91fbaa18 100644
--- a/llvm/test/CodeGen/AArch64/vec3-loads-ext-trunc-stores.ll
+++ b/llvm/test/CodeGen/AArch64/vec3-loads-ext-trunc-stores.ll
@@ -154,17 +154,12 @@ define <3 x i32> @load_v3i32(ptr %src) {
 define void @store_trunc_from_64bits(ptr %src, ptr %dst) {
 ; CHECK-LABEL: store_trunc_from_64bits:
 ; CHECK:       ; %bb.0: ; %entry
-; CHECK-NEXT:    sub sp, sp, #16
-; CHECK-NEXT:    .cfi_def_cfa_offset 16
-; CHECK-NEXT:    ldr s0, [x0]
-; CHECK-NEXT:    ldrh w8, [x0, #4]
-; CHECK-NEXT:    mov.h v0[2], w8
-; CHECK-NEXT:    xtn.8b v0, v0
-; CHECK-NEXT:    str s0, [sp, #12]
-; CHECK-NEXT:    ldrh w9, [sp, #12]
-; CHECK-NEXT:    strb w8, [x1, #2]
-; CHECK-NEXT:    strh w9, [x1]
-; CHECK-NEXT:    add sp, sp, #16
+; CHECK-NEXT:    add x8, x0, #4
+; CHECK-NEXT:    ld1r.4h { v0 }, [x8]
+; CHECK-NEXT:    ldr w8, [x0]
+; CHECK-NEXT:    strb w8, [x1]
+; CHECK-NEXT:    add x8, x1, #1
+; CHECK-NEXT:    st1.b { v0 }[4], [x8]
 ; CHECK-NEXT:    ret
 ;
 ; BE-LABEL: store_trunc_from_64bits:
@@ -236,17 +231,13 @@ entry:
 define void @shift_trunc_store(ptr %src, ptr %dst) {
 ; CHECK-LABEL: shift_trunc_store:
 ; CHECK:       ; %bb.0:
-; CHECK-NEXT:    sub sp, sp, #16
-; CHECK-NEXT:    .cfi_def_cfa_offset 16
 ; CHECK-NEXT:    ldr q0, [x0]
-; CHECK-NEXT:    shrn.4h v0, v0, #16
-; CHECK-NEXT:    xtn.8b v1, v0
-; CHECK-NEXT:    umov.h w8, v0[2]
-; CHECK-NEXT:    str s1, [sp, #12]
-; CHECK-NEXT:    ldrh w9, [sp, #12]
-; CHECK-NEXT:    strb w8, [x1, #2]
-; CHECK-NEXT:    strh w9, [x1]
-; CHECK-NEXT:    add sp, sp, #16
+; CHECK-NEXT:    add x8, x1, #1
+; CHECK-NEXT:    add x9, x1, #2
+; CHECK-NEXT:    ushr.4s v0, v0, #16
+; CHECK-NEXT:    st1.b { v0 }[4], [x8]
+; CHECK-NEXT:    st1.b { v0 }[8], [x9]
+; CHECK-NEXT:    st1.b { v0 }[0], [x1]
 ; CHECK-NEXT:    ret
 ;
 ; BE-LABEL: shift_trunc_store:

>From 3e747fe3f7000631af555f1cdaa90bd17b93c345 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Fri, 19 Jan 2024 16:13:29 +0000
Subject: [PATCH 2/2] !fixup fix extract index computation, adjust pointer info
 and alignment by offset

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 27 ++++++++++---------
 .../AArch64/vec3-loads-ext-trunc-stores.ll    |  7 +++--
 2 files changed, 20 insertions(+), 14 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 4be78f61fe7b651..badea7537373cae 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21341,26 +21341,29 @@ static SDValue combineI8TruncStore(StoreSDNode *ST, SelectionDAG &DAG,
       ISD::BITCAST, DL, WideVT.getSizeInBits() == 64 ? MVT::v8i8 : MVT::v16i8,
       WideTrunc);
 
+  unsigned IdxScale = WideVT.getScalarSizeInBits() / 8;
+  Align Align = ST->getOriginalAlign();
   SDValue Chain = ST->getChain();
   SDValue E2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, Cast,
-                           DAG.getConstant(8, DL, MVT::i64));
-
-  SDValue Ptr2 =
-      DAG.getMemBasePlusOffset(ST->getBasePtr(), TypeSize::getFixed(2), DL);
-  Chain = DAG.getStore(Chain, DL, E2, Ptr2, ST->getPointerInfo(),
-                       ST->getOriginalAlign());
+                           DAG.getConstant(2 * IdxScale, DL, MVT::i64));
+  TypeSize Offset2 = TypeSize::getFixed(2);
+  SDValue Ptr2 = DAG.getMemBasePlusOffset(ST->getBasePtr(), Offset2, DL);
+  Chain = DAG.getStore(Chain, DL, E2, Ptr2,
+                       ST->getPointerInfo().getWithOffset(Offset2),
+                       commonAlignment(Align, Offset2));
 
   SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, Cast,
-                           DAG.getConstant(4, DL, MVT::i64));
+                           DAG.getConstant(1 * IdxScale, DL, MVT::i64));
+  TypeSize Offset1 = TypeSize::getFixed(1);
+  SDValue Ptr1 = DAG.getMemBasePlusOffset(ST->getBasePtr(), Offset1, DL);
+  Chain = DAG.getStore(Chain, DL, E1, Ptr1,
+                       ST->getPointerInfo().getWithOffset(Offset2),
+                       commonAlignment(Align, Offset1));
 
-  SDValue Ptr1 =
-      DAG.getMemBasePlusOffset(ST->getBasePtr(), TypeSize::getFixed(1), DL);
-  Chain = DAG.getStore(Chain, DL, E1, Ptr1, ST->getPointerInfo(),
-                       ST->getOriginalAlign());
   SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, Cast,
                            DAG.getConstant(0, DL, MVT::i64));
   Chain = DAG.getStore(Chain, DL, E0, ST->getBasePtr(), ST->getPointerInfo(),
-                       ST->getOriginalAlign());
+                       Align);
 
   return Chain;
 }
diff --git a/llvm/test/CodeGen/AArch64/vec3-loads-ext-trunc-stores.ll b/llvm/test/CodeGen/AArch64/vec3-loads-ext-trunc-stores.ll
index 60639ea91fbaa18..9f783badf38f9a4 100644
--- a/llvm/test/CodeGen/AArch64/vec3-loads-ext-trunc-stores.ll
+++ b/llvm/test/CodeGen/AArch64/vec3-loads-ext-trunc-stores.ll
@@ -154,11 +154,14 @@ define <3 x i32> @load_v3i32(ptr %src) {
 define void @store_trunc_from_64bits(ptr %src, ptr %dst) {
 ; CHECK-LABEL: store_trunc_from_64bits:
 ; CHECK:       ; %bb.0: ; %entry
-; CHECK-NEXT:    add x8, x0, #4
-; CHECK-NEXT:    ld1r.4h { v0 }, [x8]
 ; CHECK-NEXT:    ldr w8, [x0]
+; CHECK-NEXT:    add x9, x0, #4
+; CHECK-NEXT:    ld1r.4h { v0 }, [x9]
+; CHECK-NEXT:    fmov s1, w8
 ; CHECK-NEXT:    strb w8, [x1]
 ; CHECK-NEXT:    add x8, x1, #1
+; CHECK-NEXT:    st1.b { v1 }[2], [x8]
+; CHECK-NEXT:    add x8, x1, #2
 ; CHECK-NEXT:    st1.b { v0 }[4], [x8]
 ; CHECK-NEXT:    ret
 ;



More information about the cfe-commits mailing list