[llvm] 18775a4 - [AArch64][SVE2] Use rshrnb for masked stores (#70026)

via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 26 00:42:29 PDT 2023


Author: Matthew Devereau
Date: 2023-10-26T08:42:25+01:00
New Revision: 18775a49416cf767b34eaaf925df5143e40582f4

URL: https://github.com/llvm/llvm-project/commit/18775a49416cf767b34eaaf925df5143e40582f4
DIFF: https://github.com/llvm/llvm-project/commit/18775a49416cf767b34eaaf925df5143e40582f4.diff

LOG: [AArch64][SVE2] Use rshrnb for masked stores (#70026)

This patch is a follow up on https://reviews.llvm.org/D155299. This
patch combines add+lsr to rshrnb when 'B' in:

  C = A + B
  D = C >> Shift

is equal to (1 << (Shift-1), and the bits in the top half of each vector
element are zeroed or ignored, such as in a truncating masked store.

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7211607fee528a6..038c23b5e8d50ad 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21002,6 +21002,12 @@ static SDValue combineBoolVectorAndTruncateStore(SelectionDAG &DAG,
                       Store->getMemOperand());
 }
 
+bool isHalvingTruncateOfLegalScalableType(EVT SrcVT, EVT DstVT) {
+  return (SrcVT == MVT::nxv8i16 && DstVT == MVT::nxv8i8) ||
+         (SrcVT == MVT::nxv4i32 && DstVT == MVT::nxv4i16) ||
+         (SrcVT == MVT::nxv2i64 && DstVT == MVT::nxv2i32);
+}
+
 static SDValue performSTORECombine(SDNode *N,
                                    TargetLowering::DAGCombinerInfo &DCI,
                                    SelectionDAG &DAG,
@@ -21043,16 +21049,16 @@ static SDValue performSTORECombine(SDNode *N,
   if (SDValue Store = combineBoolVectorAndTruncateStore(DAG, ST))
     return Store;
 
-  if (ST->isTruncatingStore())
+  if (ST->isTruncatingStore()) {
+    EVT StoreVT = ST->getMemoryVT();
+    if (!isHalvingTruncateOfLegalScalableType(ValueVT, StoreVT))
+      return SDValue();
     if (SDValue Rshrnb =
             trySimplifySrlAddToRshrnb(ST->getOperand(1), DAG, Subtarget)) {
-      EVT StoreVT = ST->getMemoryVT();
-      if ((ValueVT == MVT::nxv8i16 && StoreVT == MVT::nxv8i8) ||
-          (ValueVT == MVT::nxv4i32 && StoreVT == MVT::nxv4i16) ||
-          (ValueVT == MVT::nxv2i64 && StoreVT == MVT::nxv2i32))
-        return DAG.getTruncStore(ST->getChain(), ST, Rshrnb, ST->getBasePtr(),
-                                 StoreVT, ST->getMemOperand());
+      return DAG.getTruncStore(ST->getChain(), ST, Rshrnb, ST->getBasePtr(),
+                               StoreVT, ST->getMemOperand());
     }
+  }
 
   return SDValue();
 }
@@ -21098,6 +21104,19 @@ static SDValue performMSTORECombine(SDNode *N,
     }
   }
 
+  if (MST->isTruncatingStore()) {
+    EVT ValueVT = Value->getValueType(0);
+    EVT MemVT = MST->getMemoryVT();
+    if (!isHalvingTruncateOfLegalScalableType(ValueVT, MemVT))
+      return SDValue();
+    if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Value, DAG, Subtarget)) {
+      return DAG.getMaskedStore(MST->getChain(), DL, Rshrnb, MST->getBasePtr(),
+                                MST->getOffset(), MST->getMask(),
+                                MST->getMemoryVT(), MST->getMemOperand(),
+                                MST->getAddressingMode(), true);
+    }
+  }
+
   return SDValue();
 }
 

diff  --git a/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll b/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll
index a913177623df9ec..0afd11d098a0009 100644
--- a/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll
+++ b/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll
@@ -298,3 +298,22 @@ define void @neg_add_lshr_rshrnb_s(ptr %ptr, ptr %dst, i64 %index){
   store <vscale x 2 x i16> %3, ptr %4, align 1
   ret void
 }
+
+define void @masked_store_rshrnb(ptr %ptr, ptr %dst, i64 %index, <vscale x 8 x i1> %mask) {                             ; preds = %vector.body, %vector.ph
+; CHECK-LABEL: masked_store_rshrnb:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ld1h { z0.h }, p0/z, [x0]
+; CHECK-NEXT:    rshrnb z0.b, z0.h, #6
+; CHECK-NEXT:    st1b { z0.h }, p0, [x1, x2]
+; CHECK-NEXT:    ret
+  %wide.masked.load = tail call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr %ptr, i32 2, <vscale x 8 x i1> %mask, <vscale x 8 x i16> poison)
+  %1 = add <vscale x 8 x i16> %wide.masked.load, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 32, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
+  %2 = lshr <vscale x 8 x i16> %1, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 6, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
+  %3 = trunc <vscale x 8 x i16> %2 to <vscale x 8 x i8>
+  %4 = getelementptr inbounds i8, ptr %dst, i64 %index
+  tail call void @llvm.masked.store.nxv8i8.p0(<vscale x 8 x i8> %3, ptr %4, i32 1, <vscale x 8 x i1> %mask)
+  ret void
+}
+
+declare void @llvm.masked.store.nxv8i8.p0(<vscale x 8 x i8>, ptr, i32, <vscale x 8 x i1>)
+declare <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr, i32, <vscale x 8 x i1>, <vscale x 8 x i16>)


        


More information about the llvm-commits mailing list