[llvm] 5ab9000 - [AArch64][SVE] Fix selection failures for scalable MLOAD nodes with passthru

Bradley Smith via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 6 07:40:58 PDT 2021


Author: Bradley Smith
Date: 2021-07-06T14:17:23Z
New Revision: 5ab9000fbb3057336ca33721096bb8766cb5b675

URL: https://github.com/llvm/llvm-project/commit/5ab9000fbb3057336ca33721096bb8766cb5b675
DIFF: https://github.com/llvm/llvm-project/commit/5ab9000fbb3057336ca33721096bb8766cb5b675.diff

LOG: [AArch64][SVE] Fix selection failures for scalable MLOAD nodes with passthru

Differential Revision: https://reviews.llvm.org/D105348

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll
    llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll
    llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 31f97b5fae2d..e4fb9b7ae967 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1154,6 +1154,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::FP_TO_SINT, VT, Custom);
       setOperationAction(ISD::MGATHER, VT, Custom);
       setOperationAction(ISD::MSCATTER, VT, Custom);
+      setOperationAction(ISD::MLOAD, VT, Custom);
       setOperationAction(ISD::MUL, VT, Custom);
       setOperationAction(ISD::MULHS, VT, Custom);
       setOperationAction(ISD::MULHU, VT, Custom);
@@ -1245,6 +1246,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
       setOperationAction(ISD::MGATHER, VT, Custom);
       setOperationAction(ISD::MSCATTER, VT, Custom);
+      setOperationAction(ISD::MLOAD, VT, Custom);
       setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
       setOperationAction(ISD::SELECT, VT, Custom);
       setOperationAction(ISD::FADD, VT, Custom);
@@ -1280,6 +1282,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
       setOperationAction(ISD::MGATHER, VT, Custom);
       setOperationAction(ISD::MSCATTER, VT, Custom);
+      setOperationAction(ISD::MLOAD, VT, Custom);
     }
 
     setOperationAction(ISD::SPLAT_VECTOR, MVT::nxv8bf16, Custom);
@@ -4476,6 +4479,32 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
   return DAG.getNode(Opcode, DL, VTs, Ops);
 }
 
+SDValue AArch64TargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  MaskedLoadSDNode *LoadNode = cast<MaskedLoadSDNode>(Op);
+  assert(LoadNode && "Expected custom lowering of a masked load node");
+  EVT VT = Op->getValueType(0);
+
+  if (useSVEForFixedLengthVectorVT(VT, true))
+    return LowerFixedLengthVectorMLoadToSVE(Op, DAG);
+
+  SDValue PassThru = LoadNode->getPassThru();
+  SDValue Mask = LoadNode->getMask();
+
+  if (PassThru->isUndef() || isZerosVector(PassThru.getNode()))
+    return Op;
+
+  SDValue Load = DAG.getMaskedLoad(
+      VT, DL, LoadNode->getChain(), LoadNode->getBasePtr(),
+      LoadNode->getOffset(), Mask, DAG.getUNDEF(VT), LoadNode->getMemoryVT(),
+      LoadNode->getMemOperand(), LoadNode->getAddressingMode(),
+      LoadNode->getExtensionType());
+
+  SDValue Result = DAG.getSelect(DL, VT, Mask, Load, PassThru);
+
+  return DAG.getMergeValues({Result, Load.getValue(1)}, DL);
+}
+
 // Custom lower trunc store for v4i8 vectors, since it is promoted to v4i16.
 static SDValue LowerTruncateVectorStore(SDLoc DL, StoreSDNode *ST,
                                         EVT VT, EVT MemVT,
@@ -4854,7 +4883,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
   case ISD::TRUNCATE:
     return LowerTRUNCATE(Op, DAG);
   case ISD::MLOAD:
-    return LowerFixedLengthVectorMLoadToSVE(Op, DAG);
+    return LowerMLOAD(Op, DAG);
   case ISD::LOAD:
     if (useSVEForFixedLengthVectorVT(Op.getValueType()))
       return LowerFixedLengthVectorLoadToSVE(Op, DAG);

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 3ce04314ba26..392910560eff 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -859,6 +859,8 @@ class AArch64TargetLowering : public TargetLowering {
   SDValue LowerMGATHER(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerMSCATTER(SDValue Op, SelectionDAG &DAG) const;
 
+  SDValue LowerMLOAD(SDValue Op, SelectionDAG &DAG) const;
+
   SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG) const;
 
   bool isEligibleForTailCallOptimization(

diff  --git a/llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll b/llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll
index 634c2907c76a..47a8e1a5db3e 100644
--- a/llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll
+++ b/llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll
@@ -92,6 +92,15 @@ define <vscale x 8 x bfloat> @masked_load_nxv8bf16(<vscale x 8 x bfloat> *%a, <v
   ret <vscale x 8 x bfloat> %load
 }
 
+define <vscale x 4 x i32> @masked_load_passthru(<vscale x 4 x i32> *%a, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %passthru) nounwind {
+; CHECK-LABEL: masked_load_passthru:
+; CHECK-NEXT: ld1w { z1.s }, p0/z, [x0]
+; CHECK-NEXT: mov z0.s, p0/m, z1.s
+; CHECK-NEXT: ret
+  %load = call <vscale x 4 x i32> @llvm.masked.load.nxv4i32(<vscale x 4 x i32> *%a, i32 4, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %passthru)
+  ret <vscale x 4 x i32> %load
+}
+
 ;
 ; Masked Stores
 ;

diff  --git a/llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll b/llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll
index 8edd35435101..f7efa546bb0b 100644
--- a/llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll
+++ b/llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll
@@ -58,6 +58,18 @@ define <vscale x 8 x i16> @masked_sload_nxv8i8(<vscale x 8 x i8> *%a, <vscale x
   ret <vscale x 8 x i16> %ext
 }
 
+define <vscale x 2 x i64> @masked_sload_passthru(<vscale x 2 x i32> *%a, <vscale x 2 x i1> %mask, <vscale x 2 x i32> %passthru) {
+; CHECK-LABEL: masked_sload_passthru:
+; CHECK: ld1sw { [[IN:z[0-9]+]].d }, [[PG1:p[0-9]+]]/z, [x0]
+; CHECK-NEXT: ptrue [[PG2:p[0-9]+]].d
+; CHECK-NEXT: sxtw z0.d, [[PG2]]/m, z0.d
+; CHECK-NEXT: mov z0.d, [[PG1]]/m, [[IN]].d
+; CHECK-NEXT: ret
+  %load = call <vscale x 2 x i32> @llvm.masked.load.nxv2i32(<vscale x 2 x i32> *%a, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i32> %passthru)
+  %ext = sext <vscale x 2 x i32> %load to <vscale x 2 x i64>
+  ret <vscale x 2 x i64> %ext
+}
+
 declare <vscale x 2 x i8> @llvm.masked.load.nxv2i8(<vscale x 2 x i8>*, i32, <vscale x 2 x i1>, <vscale x 2 x i8>)
 declare <vscale x 2 x i16> @llvm.masked.load.nxv2i16(<vscale x 2 x i16>*, i32, <vscale x 2 x i1>, <vscale x 2 x i16>)
 declare <vscale x 2 x i32> @llvm.masked.load.nxv2i32(<vscale x 2 x i32>*, i32, <vscale x 2 x i1>, <vscale x 2 x i32>)

diff  --git a/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll b/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll
index 1747616d673e..7dbebee3d1ee 100644
--- a/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll
+++ b/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll
@@ -64,6 +64,18 @@ define <vscale x 8 x i16> @masked_zload_nxv8i8(<vscale x 8 x i8>* %src, <vscale
   ret <vscale x 8 x i16> %ext
 }
 
+define <vscale x 2 x i64> @masked_zload_passthru(<vscale x 2 x i32>* %src, <vscale x 2 x i1> %mask, <vscale x 2 x i32> %passthru) {
+; CHECK-LABEL: masked_zload_passthru:
+; CHECK-NOT: ld1sw
+; CHECK: ld1w { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0]
+; CHECK-NEXT: and z0.d, z0.d, #0xffffffff
+; CHECK-NEXT: mov z0.d, [[PG]]/m, [[IN]].d
+; CHECK-NEXT: ret
+  %load = call <vscale x 2 x i32> @llvm.masked.load.nxv2i32(<vscale x 2 x i32>* %src, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i32> %passthru)
+  %ext = zext <vscale x 2 x i32> %load to <vscale x 2 x i64>
+  ret <vscale x 2 x i64> %ext
+}
+
 declare <vscale x 2 x i8> @llvm.masked.load.nxv2i8(<vscale x 2 x i8>*, i32, <vscale x 2 x i1>, <vscale x 2 x i8>)
 declare <vscale x 2 x i16> @llvm.masked.load.nxv2i16(<vscale x 2 x i16>*, i32, <vscale x 2 x i1>, <vscale x 2 x i16>)
 declare <vscale x 2 x i32> @llvm.masked.load.nxv2i32(<vscale x 2 x i32>*, i32, <vscale x 2 x i1>, <vscale x 2 x i32>)


        


More information about the llvm-commits mailing list