[llvm] 5c96847 - [DAG][sve] Lowering for VLS masked truncating stores

David Truby via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 17 07:05:08 PST 2021


Author: David Truby
Date: 2021-12-17T15:04:45Z
New Revision: 5c9684704d15503107ce79ccaf362402ad2b0b2a

URL: https://github.com/llvm/llvm-project/commit/5c9684704d15503107ce79ccaf362402ad2b0b2a
DIFF: https://github.com/llvm/llvm-project/commit/5c9684704d15503107ce79ccaf362402ad2b0b2a.diff

LOG: [DAG][sve] Lowering for VLS masked truncating stores

This extends the custom lowering for truncating stores on
fixed length vectors in SVE to support masked truncating stores.
It also adds a DAG combine for truncates followed by masked
stores.

Reviewed By: peterwaller-arm, paulwalker-arm

Differential Revision: https://reviews.llvm.org/D108115

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/TargetLowering.h
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/sve-fixed-length-masked-stores.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index d8da711b56f9f..83ce3d017cc5d 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -852,6 +852,20 @@ class TargetLoweringBase {
     return getBooleanContents(Type.isVector(), Type.isFloatingPoint());
   }
 
+  /// Promote the given target boolean to a target boolean of the given type.
+  /// A target boolean is an integer value, not necessarily of type i1, the bits
+  /// of which conform to getBooleanContents.
+  ///
+  /// ValVT is the type of values that produced the boolean.
+  SDValue promoteTargetBoolean(SelectionDAG &DAG, SDValue Bool,
+                               EVT ValVT) const {
+    SDLoc dl(Bool);
+    EVT BoolVT =
+        getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), ValVT);
+    ISD::NodeType ExtendCode = getExtendForContent(getBooleanContents(ValVT));
+    return DAG.getNode(ExtendCode, dl, BoolVT, Bool);
+  }
+
   /// Return target scheduling preference.
   Sched::Preference getSchedulingPreference() const {
     return SchedPreferenceInfo;

diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 4637ce56ff258..bc6dd70b99bef 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -10044,6 +10044,8 @@ SDValue DAGCombiner::visitMSTORE(SDNode *N) {
   MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(N);
   SDValue Mask = MST->getMask();
   SDValue Chain = MST->getChain();
+  SDValue Value = MST->getValue();
+  SDValue Ptr = MST->getBasePtr();
   SDLoc DL(N);
 
   // Zap masked stores with a zero mask.
@@ -10063,6 +10065,42 @@ SDValue DAGCombiner::visitMSTORE(SDNode *N) {
   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
     return SDValue(N, 0);
 
+  if (MST->isTruncatingStore() && MST->isUnindexed() &&
+      Value.getValueType().isInteger() &&
+      (!isa<ConstantSDNode>(Value) ||
+       !cast<ConstantSDNode>(Value)->isOpaque())) {
+    APInt TruncDemandedBits =
+        APInt::getLowBitsSet(Value.getScalarValueSizeInBits(),
+                             MST->getMemoryVT().getScalarSizeInBits());
+
+    // See if we can simplify the operation with
+    // SimplifyDemandedBits, which only works if the value has a single use.
+    if (SimplifyDemandedBits(Value, TruncDemandedBits)) {
+      // Re-visit the store if anything changed and the store hasn't been merged
+      // with another node (N is deleted) SimplifyDemandedBits will add Value's
+      // node back to the worklist if necessary, but we also need to re-visit
+      // the Store node itself.
+      if (N->getOpcode() != ISD::DELETED_NODE)
+        AddToWorklist(N);
+      return SDValue(N, 0);
+    }
+  }
+
+  // If this is a TRUNC followed by a masked store, fold this into a masked
+  // truncating store.  We can do this even if this is already a masked
+  // truncstore.
+  if ((Value.getOpcode() == ISD::TRUNCATE) && Value.getNode()->hasOneUse() &&
+      MST->isUnindexed() &&
+      TLI.canCombineTruncStore(Value.getOperand(0).getValueType(),
+                               MST->getMemoryVT(), LegalOperations)) {
+    auto Mask = TLI.promoteTargetBoolean(DAG, MST->getMask(),
+                                         Value.getOperand(0).getValueType());
+    return DAG.getMaskedStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
+                              MST->getOffset(), Mask, MST->getMemoryVT(),
+                              MST->getMemOperand(), MST->getAddressingMode(),
+                              /*IsTruncating=*/true);
+  }
+
   return SDValue();
 }
 

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
index 24874770276f2..c123ea3cd2d5a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
@@ -1007,11 +1007,7 @@ SDValue DAGTypeLegalizer::JoinIntegers(SDValue Lo, SDValue Hi) {
 ///
 /// ValVT is the type of values that produced the boolean.
 SDValue DAGTypeLegalizer::PromoteTargetBoolean(SDValue Bool, EVT ValVT) {
-  SDLoc dl(Bool);
-  EVT BoolVT = getSetCCResultType(ValVT);
-  ISD::NodeType ExtendCode =
-      TargetLowering::getExtendForContent(TLI.getBooleanContents(ValVT));
-  return DAG.getNode(ExtendCode, dl, BoolVT, Bool);
+  return TLI.promoteTargetBoolean(DAG, Bool, ValVT);
 }
 
 /// Return the lower LoVT bits of Op in Lo and the upper HiVT bits in Hi.

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 702fb595de22d..7fb06d95dc1cd 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -18897,10 +18897,7 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorStoreToSVE(
 
 SDValue AArch64TargetLowering::LowerFixedLengthVectorMStoreToSVE(
     SDValue Op, SelectionDAG &DAG) const {
-  auto Store = cast<MaskedStoreSDNode>(Op);
-
-  if (Store->isTruncatingStore())
-    return SDValue();
+  auto *Store = cast<MaskedStoreSDNode>(Op);
 
   SDLoc DL(Op);
   EVT VT = Store->getValue().getValueType();

diff  --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-stores.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-stores.ll
index 19bc39025dce1..73fc0c0a8f7e6 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-stores.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-stores.ll
@@ -155,21 +155,13 @@ define void @masked_store_v64f32(<64 x float>* %ap, <64 x float>* %bp) #0 {
 define void @masked_store_trunc_v8i64i8(<8 x i64>* %ap, <8 x i64>* %bp, <8 x i8>* %dest) #0 {
 ; VBITS_GE_512-LABEL: masked_store_trunc_v8i64i8:
 ; VBITS_GE_512:       // %bb.0:
-; VBITS_GE_512-NEXT:    ptrue p0.d, vl8
-; VBITS_GE_512-NEXT:    ld1d { z0.d }, p0/z, [x0]
-; VBITS_GE_512-NEXT:    ld1d { z1.d }, p0/z, [x1]
-; VBITS_GE_512-NEXT:    cmpeq p0.d, p0/z, z0.d, z1.d
-; VBITS_GE_512-NEXT:    uzp1 z0.s, z0.s, z0.s
-; VBITS_GE_512-NEXT:    mov z1.d, p0/z, #-1 // =0xffffffffffffffff
-; VBITS_GE_512-NEXT:    ptrue p0.b, vl8
-; VBITS_GE_512-NEXT:    uzp1 z1.s, z1.s, z1.s
-; VBITS_GE_512-NEXT:    uzp1 z0.h, z0.h, z0.h
-; VBITS_GE_512-NEXT:    uzp1 z1.h, z1.h, z1.h
-; VBITS_GE_512-NEXT:    uzp1 z0.b, z0.b, z0.b
-; VBITS_GE_512-NEXT:    uzp1 z1.b, z1.b, z1.b
-; VBITS_GE_512-NEXT:    cmpne p0.b, p0/z, z1.b, #0
-; VBITS_GE_512-NEXT:    st1b { z0.b }, p0, [x2]
-; VBITS_GE_512-NEXT:    ret
+; VBITS_GE_512-NEXT: ptrue p[[P0:[0-9]+]].d, vl8
+; VBITS_GE_512-NEXT: ld1d { [[Z0:z[0-9]+]].d }, p0/z, [x0]
+; VBITS_GE_512-NEXT: ld1d { [[Z1:z[0-9]+]].d }, p0/z, [x1]
+; VBITS_GE_512-NEXT: cmpeq p[[P1:[0-9]+]].d, p[[P0]]/z, [[Z0]].d, [[Z1]].d
+; VBITS_GE_512-NEXT: st1b { [[Z0]].d }, p[[P1]], [x{{[0-9]+}}]
+; VBITS_GE_512-NEXT: ret
+
   %a = load <8 x i64>, <8 x i64>* %ap
   %b = load <8 x i64>, <8 x i64>* %bp
   %mask = icmp eq <8 x i64> %a, %b
@@ -179,21 +171,13 @@ define void @masked_store_trunc_v8i64i8(<8 x i64>* %ap, <8 x i64>* %bp, <8 x i8>
 }
 
 define void @masked_store_trunc_v8i64i16(<8 x i64>* %ap, <8 x i64>* %bp, <8 x i16>* %dest) #0 {
-; VBITS_GE_512-LABEL: masked_store_trunc_v8i64i16:
-; VBITS_GE_512:       // %bb.0:
-; VBITS_GE_512-NEXT:    ptrue p0.d, vl8
-; VBITS_GE_512-NEXT:    ld1d { z0.d }, p0/z, [x0]
-; VBITS_GE_512-NEXT:    ld1d { z1.d }, p0/z, [x1]
-; VBITS_GE_512-NEXT:    cmpeq p0.d, p0/z, z0.d, z1.d
-; VBITS_GE_512-NEXT:    uzp1 z0.s, z0.s, z0.s
-; VBITS_GE_512-NEXT:    mov z1.d, p0/z, #-1 // =0xffffffffffffffff
-; VBITS_GE_512-NEXT:    ptrue p0.h, vl8
-; VBITS_GE_512-NEXT:    uzp1 z1.s, z1.s, z1.s
-; VBITS_GE_512-NEXT:    uzp1 z0.h, z0.h, z0.h
-; VBITS_GE_512-NEXT:    uzp1 z1.h, z1.h, z1.h
-; VBITS_GE_512-NEXT:    cmpne p0.h, p0/z, z1.h, #0
-; VBITS_GE_512-NEXT:    st1h { z0.h }, p0, [x2]
-; VBITS_GE_512-NEXT:    ret
+; CHECK-LABEL: masked_store_trunc_v8i64i16:
+; VBITS_GE_512: ptrue p[[P0:[0-9]+]].d, vl8
+; VBITS_GE_512-NEXT: ld1d { [[Z0:z[0-9]+]].d }, p0/z, [x0]
+; VBITS_GE_512-NEXT: ld1d { [[Z1:z[0-9]+]].d }, p0/z, [x1]
+; VBITS_GE_512-NEXT: cmpeq p[[P1:[0-9]+]].d, p[[P0]]/z, [[Z0]].d, [[Z1]].d
+; VBITS_GE_512-NEXT: st1h { [[Z0]].d }, p[[P1]], [x{{[0-9]+}}]
+; VBITS_GE_512-NEXT: ret
   %a = load <8 x i64>, <8 x i64>* %ap
   %b = load <8 x i64>, <8 x i64>* %bp
   %mask = icmp eq <8 x i64> %a, %b
@@ -203,19 +187,13 @@ define void @masked_store_trunc_v8i64i16(<8 x i64>* %ap, <8 x i64>* %bp, <8 x i1
 }
 
 define void @masked_store_trunc_v8i64i32(<8 x i64>* %ap, <8 x i64>* %bp, <8 x i32>* %dest) #0 {
-; VBITS_GE_512-LABEL: masked_store_trunc_v8i64i32:
-; VBITS_GE_512:       // %bb.0:
-; VBITS_GE_512-NEXT:    ptrue p0.d, vl8
-; VBITS_GE_512-NEXT:    ld1d { z0.d }, p0/z, [x0]
-; VBITS_GE_512-NEXT:    ld1d { z1.d }, p0/z, [x1]
-; VBITS_GE_512-NEXT:    cmpeq p0.d, p0/z, z0.d, z1.d
-; VBITS_GE_512-NEXT:    uzp1 z0.s, z0.s, z0.s
-; VBITS_GE_512-NEXT:    mov z1.d, p0/z, #-1 // =0xffffffffffffffff
-; VBITS_GE_512-NEXT:    ptrue p0.s, vl8
-; VBITS_GE_512-NEXT:    uzp1 z1.s, z1.s, z1.s
-; VBITS_GE_512-NEXT:    cmpne p0.s, p0/z, z1.s, #0
-; VBITS_GE_512-NEXT:    st1w { z0.s }, p0, [x2]
-; VBITS_GE_512-NEXT:    ret
+; CHECK-LABEL: masked_store_trunc_v8i64i32:
+; VBITS_GE_512: ptrue p[[P0:[0-9]+]].d, vl8
+; VBITS_GE_512-NEXT: ld1d { [[Z0:z[0-9]+]].d }, p0/z, [x0]
+; VBITS_GE_512-NEXT: ld1d { [[Z1:z[0-9]+]].d }, p0/z, [x1]
+; VBITS_GE_512-NEXT: cmpeq p[[P1:[0-9]+]].d, p[[P0]]/z, [[Z0]].d, [[Z1]].d
+; VBITS_GE_512-NEXT: st1w { [[Z0]].d }, p[[P1]], [x{{[0-9]+}}]
+; VBITS_GE_512-NEXT: ret
   %a = load <8 x i64>, <8 x i64>* %ap
   %b = load <8 x i64>, <8 x i64>* %bp
   %mask = icmp eq <8 x i64> %a, %b
@@ -225,21 +203,13 @@ define void @masked_store_trunc_v8i64i32(<8 x i64>* %ap, <8 x i64>* %bp, <8 x i3
 }
 
 define void @masked_store_trunc_v16i32i8(<16 x i32>* %ap, <16 x i32>* %bp, <16 x i8>* %dest) #0 {
-; VBITS_GE_512-LABEL: masked_store_trunc_v16i32i8:
-; VBITS_GE_512:       // %bb.0:
-; VBITS_GE_512-NEXT:    ptrue p0.s, vl16
-; VBITS_GE_512-NEXT:    ld1w { z0.s }, p0/z, [x0]
-; VBITS_GE_512-NEXT:    ld1w { z1.s }, p0/z, [x1]
-; VBITS_GE_512-NEXT:    cmpeq p0.s, p0/z, z0.s, z1.s
-; VBITS_GE_512-NEXT:    uzp1 z0.h, z0.h, z0.h
-; VBITS_GE_512-NEXT:    mov z1.s, p0/z, #-1 // =0xffffffffffffffff
-; VBITS_GE_512-NEXT:    ptrue p0.b, vl16
-; VBITS_GE_512-NEXT:    uzp1 z1.h, z1.h, z1.h
-; VBITS_GE_512-NEXT:    uzp1 z0.b, z0.b, z0.b
-; VBITS_GE_512-NEXT:    uzp1 z1.b, z1.b, z1.b
-; VBITS_GE_512-NEXT:    cmpne p0.b, p0/z, z1.b, #0
-; VBITS_GE_512-NEXT:    st1b { z0.b }, p0, [x2]
-; VBITS_GE_512-NEXT:    ret
+; CHECK-LABEL: masked_store_trunc_v16i32i8:
+; VBITS_GE_512: ptrue p[[P0:[0-9]+]].s, vl16
+; VBITS_GE_512-NEXT: ld1w { [[Z0:z[0-9]+]].s }, p0/z, [x0]
+; VBITS_GE_512-NEXT: ld1w { [[Z1:z[0-9]+]].s }, p0/z, [x1]
+; VBITS_GE_512-NEXT: cmpeq p[[P1:[0-9]+]].s, p[[P0]]/z, [[Z0]].s, [[Z1]].s
+; VBITS_GE_512-NEXT: st1b { [[Z0]].s }, p[[P1]], [x{{[0-9]+}}]
+; VBITS_GE_512-NEXT: ret
   %a = load <16 x i32>, <16 x i32>* %ap
   %b = load <16 x i32>, <16 x i32>* %bp
   %mask = icmp eq <16 x i32> %a, %b
@@ -249,19 +219,13 @@ define void @masked_store_trunc_v16i32i8(<16 x i32>* %ap, <16 x i32>* %bp, <16 x
 }
 
 define void @masked_store_trunc_v16i32i16(<16 x i32>* %ap, <16 x i32>* %bp, <16 x i16>* %dest) #0 {
-; VBITS_GE_512-LABEL: masked_store_trunc_v16i32i16:
-; VBITS_GE_512:       // %bb.0:
-; VBITS_GE_512-NEXT:    ptrue p0.s, vl16
-; VBITS_GE_512-NEXT:    ld1w { z0.s }, p0/z, [x0]
-; VBITS_GE_512-NEXT:    ld1w { z1.s }, p0/z, [x1]
-; VBITS_GE_512-NEXT:    cmpeq p0.s, p0/z, z0.s, z1.s
-; VBITS_GE_512-NEXT:    uzp1 z0.h, z0.h, z0.h
-; VBITS_GE_512-NEXT:    mov z1.s, p0/z, #-1 // =0xffffffffffffffff
-; VBITS_GE_512-NEXT:    ptrue p0.h, vl16
-; VBITS_GE_512-NEXT:    uzp1 z1.h, z1.h, z1.h
-; VBITS_GE_512-NEXT:    cmpne p0.h, p0/z, z1.h, #0
-; VBITS_GE_512-NEXT:    st1h { z0.h }, p0, [x2]
-; VBITS_GE_512-NEXT:    ret
+; CHECK-LABEL: masked_store_trunc_v16i32i16:
+; VBITS_GE_512: ptrue p[[P0:[0-9]+]].s, vl16
+; VBITS_GE_512-NEXT: ld1w { [[Z0:z[0-9]+]].s }, p0/z, [x0]
+; VBITS_GE_512-NEXT: ld1w { [[Z1:z[0-9]+]].s }, p0/z, [x1]
+; VBITS_GE_512-NEXT: cmpeq p[[P1:[0-9]+]].s, p[[P0]]/z, [[Z0]].s, [[Z1]].s
+; VBITS_GE_512-NEXT: st1h { [[Z0]].s }, p[[P1]], [x{{[0-9]+}}]
+; VBITS_GE_512-NEXT: ret
   %a = load <16 x i32>, <16 x i32>* %ap
   %b = load <16 x i32>, <16 x i32>* %bp
   %mask = icmp eq <16 x i32> %a, %b
@@ -271,19 +235,13 @@ define void @masked_store_trunc_v16i32i16(<16 x i32>* %ap, <16 x i32>* %bp, <16
 }
 
 define void @masked_store_trunc_v32i16i8(<32 x i16>* %ap, <32 x i16>* %bp, <32 x i8>* %dest) #0 {
-; VBITS_GE_512-LABEL: masked_store_trunc_v32i16i8:
-; VBITS_GE_512:       // %bb.0:
-; VBITS_GE_512-NEXT:    ptrue p0.h, vl32
-; VBITS_GE_512-NEXT:    ld1h { z0.h }, p0/z, [x0]
-; VBITS_GE_512-NEXT:    ld1h { z1.h }, p0/z, [x1]
-; VBITS_GE_512-NEXT:    cmpeq p0.h, p0/z, z0.h, z1.h
-; VBITS_GE_512-NEXT:    uzp1 z0.b, z0.b, z0.b
-; VBITS_GE_512-NEXT:    mov z1.h, p0/z, #-1 // =0xffffffffffffffff
-; VBITS_GE_512-NEXT:    ptrue p0.b, vl32
-; VBITS_GE_512-NEXT:    uzp1 z1.b, z1.b, z1.b
-; VBITS_GE_512-NEXT:    cmpne p0.b, p0/z, z1.b, #0
-; VBITS_GE_512-NEXT:    st1b { z0.b }, p0, [x2]
-; VBITS_GE_512-NEXT:    ret
+; CHECK-LABEL: masked_store_trunc_v32i16i8:
+; VBITS_GE_512: ptrue p[[P0:[0-9]+]].h, vl32
+; VBITS_GE_512-NEXT: ld1h { [[Z0:z[0-9]+]].h }, p0/z, [x0]
+; VBITS_GE_512-NEXT: ld1h { [[Z1:z[0-9]+]].h }, p0/z, [x1]
+; VBITS_GE_512-NEXT: cmpeq p[[P1:[0-9]+]].h, p[[P0]]/z, [[Z0]].h, [[Z1]].h
+; VBITS_GE_512-NEXT: st1b { [[Z0]].h }, p[[P1]], [x{{[0-9]+}}]
+; VBITS_GE_512-NEXT: ret
   %a = load <32 x i16>, <32 x i16>* %ap
   %b = load <32 x i16>, <32 x i16>* %bp
   %mask = icmp eq <32 x i16> %a, %b


        


More information about the llvm-commits mailing list