[llvm-branch-commits] [llvm] abe7775 - [SVE][CodeGen] Extend index of masked gathers

Kerry McLaughlin via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Dec 10 05:59:34 PST 2020


Author: Kerry McLaughlin
Date: 2020-12-10T13:54:45Z
New Revision: abe7775f5a43e5a0d8ec237542274ba3e73937e4

URL: https://github.com/llvm/llvm-project/commit/abe7775f5a43e5a0d8ec237542274ba3e73937e4
DIFF: https://github.com/llvm/llvm-project/commit/abe7775f5a43e5a0d8ec237542274ba3e73937e4.diff

LOG: [SVE][CodeGen] Extend index of masked gathers

This patch changes performMSCATTERCombine to also promote the indices of
masked gathers where the element type is i8 or i16, and adds various tests
for gathers with illegal types.

Reviewed By: sdesmalen

Differential Revision: https://reviews.llvm.org/D91433

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 5d9c66e170eab..01301abf10e3d 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -849,6 +849,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   if (Subtarget->supportsAddressTopByteIgnored())
     setTargetDAGCombine(ISD::LOAD);
 
+  setTargetDAGCombine(ISD::MGATHER);
   setTargetDAGCombine(ISD::MSCATTER);
 
   setTargetDAGCombine(ISD::MUL);
@@ -14063,20 +14064,19 @@ static SDValue performSTORECombine(SDNode *N,
   return SDValue();
 }
 
-static SDValue performMSCATTERCombine(SDNode *N,
+static SDValue performMaskedGatherScatterCombine(SDNode *N,
                                       TargetLowering::DAGCombinerInfo &DCI,
                                       SelectionDAG &DAG) {
-  MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(N);
-  assert(MSC && "Can only combine scatter store nodes");
+  MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
+  assert(MGS && "Can only combine gather load or scatter store nodes");
 
-  SDLoc DL(MSC);
-  SDValue Chain = MSC->getChain();
-  SDValue Scale = MSC->getScale();
-  SDValue Index = MSC->getIndex();
-  SDValue Data = MSC->getValue();
-  SDValue Mask = MSC->getMask();
-  SDValue BasePtr = MSC->getBasePtr();
-  ISD::MemIndexType IndexType = MSC->getIndexType();
+  SDLoc DL(MGS);
+  SDValue Chain = MGS->getChain();
+  SDValue Scale = MGS->getScale();
+  SDValue Index = MGS->getIndex();
+  SDValue Mask = MGS->getMask();
+  SDValue BasePtr = MGS->getBasePtr();
+  ISD::MemIndexType IndexType = MGS->getIndexType();
 
   EVT IdxVT = Index.getValueType();
 
@@ -14086,16 +14086,27 @@ static SDValue performMSCATTERCombine(SDNode *N,
     if ((IdxVT.getVectorElementType() == MVT::i8) ||
         (IdxVT.getVectorElementType() == MVT::i16)) {
       EVT NewIdxVT = IdxVT.changeVectorElementType(MVT::i32);
-      if (MSC->isIndexSigned())
+      if (MGS->isIndexSigned())
         Index = DAG.getNode(ISD::SIGN_EXTEND, DL, NewIdxVT, Index);
       else
         Index = DAG.getNode(ISD::ZERO_EXTEND, DL, NewIdxVT, Index);
 
-      SDValue Ops[] = { Chain, Data, Mask, BasePtr, Index, Scale };
-      return DAG.getMaskedScatter(DAG.getVTList(MVT::Other),
-                                  MSC->getMemoryVT(), DL, Ops,
-                                  MSC->getMemOperand(), IndexType,
-                                  MSC->isTruncatingStore());
+      if (auto *MGT = dyn_cast<MaskedGatherSDNode>(MGS)) {
+        SDValue PassThru = MGT->getPassThru();
+        SDValue Ops[] = { Chain, PassThru, Mask, BasePtr, Index, Scale };
+        return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other),
+                                   PassThru.getValueType(), DL, Ops,
+                                   MGT->getMemOperand(),
+                                   MGT->getIndexType(), MGT->getExtensionType());
+      } else {
+        auto *MSC = cast<MaskedScatterSDNode>(MGS);
+        SDValue Data = MSC->getValue();
+        SDValue Ops[] = { Chain, Data, Mask, BasePtr, Index, Scale };
+        return DAG.getMaskedScatter(DAG.getVTList(MVT::Other),
+                                    MSC->getMemoryVT(), DL, Ops,
+                                    MSC->getMemOperand(), IndexType,
+                                    MSC->isTruncatingStore());
+      }
     }
   }
 
@@ -15072,9 +15083,6 @@ static SDValue performGatherLoadCombine(SDNode *N, SelectionDAG &DAG,
 static SDValue
 performSignExtendInRegCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
                               SelectionDAG &DAG) {
-  if (DCI.isBeforeLegalizeOps())
-    return SDValue();
-
   SDLoc DL(N);
   SDValue Src = N->getOperand(0);
   unsigned Opc = Src->getOpcode();
@@ -15109,6 +15117,9 @@ performSignExtendInRegCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
     return DAG.getNode(SOpc, DL, N->getValueType(0), Ext);
   }
 
+  if (DCI.isBeforeLegalizeOps())
+    return SDValue();
+
   if (!EnableCombineMGatherIntrinsics)
     return SDValue();
 
@@ -15296,8 +15307,9 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
     break;
   case ISD::STORE:
     return performSTORECombine(N, DCI, DAG, Subtarget);
+  case ISD::MGATHER:
   case ISD::MSCATTER:
-    return performMSCATTERCombine(N, DCI, DAG);
+    return performMaskedGatherScatterCombine(N, DCI, DAG);
   case AArch64ISD::BRCOND:
     return performBRCONDCombine(N, DCI, DAG);
   case AArch64ISD::TBNZ:

diff  --git a/llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll b/llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll
index 076edc1fd86da..4482730a7d74c 100644
--- a/llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll
+++ b/llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll
@@ -54,6 +54,19 @@ define <vscale x 2 x i32> @masked_gather_nxv2i32(<vscale x 2 x i32*> %ptrs, <vsc
   ret <vscale x 2 x i32> %data
 }
 
+; Code generate the worst case scenario when all vector types are legal.
+define <vscale x 16 x i8> @masked_gather_nxv16i8(i8* %base, <vscale x 16 x i8> %indices, <vscale x 16 x i1> %mask) {
+; CHECK-LABEL: masked_gather_nxv16i8:
+; CHECK-DAG: ld1sb { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, {{z[0-9]+}}.s, sxtw]
+; CHECK-DAG: ld1sb { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, {{z[0-9]+}}.s, sxtw]
+; CHECK-DAG: ld1sb { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, {{z[0-9]+}}.s, sxtw]
+; CHECK-DAG: ld1sb { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, {{z[0-9]+}}.s, sxtw]
+; CHECK: ret
+  %ptrs = getelementptr i8, i8* %base, <vscale x 16 x i8> %indices
+  %data = call <vscale x 16 x i8> @llvm.masked.gather.nxv16i8(<vscale x 16 x i8*> %ptrs, i32 1, <vscale x 16 x i1> %mask, <vscale x 16 x i8> undef)
+  ret <vscale x 16 x i8> %data
+}
+
 ; Code generate the worst case scenario when all vector types are illegal.
 define <vscale x 32 x i32> @masked_gather_nxv32i32(i32* %base, <vscale x 32 x i32> %indices, <vscale x 32 x i1> %mask) {
 ; CHECK-LABEL: masked_gather_nxv32i32:


        


More information about the llvm-branch-commits mailing list