[llvm-branch-commits] [llvm] 603d40d - [SVE][CodeGen] Add a DAG combine to extend mscatter indices

Kerry McLaughlin via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Nov 25 03:27:37 PST 2020


Author: Kerry McLaughlin
Date: 2020-11-25T11:18:22Z
New Revision: 603d40da9d532ab4706e32c07aba339e180ed865

URL: https://github.com/llvm/llvm-project/commit/603d40da9d532ab4706e32c07aba339e180ed865
DIFF: https://github.com/llvm/llvm-project/commit/603d40da9d532ab4706e32c07aba339e180ed865.diff

LOG: [SVE][CodeGen] Add a DAG combine to extend mscatter indices

This patch adds a target-specific DAG combine for mscatter to promote indices
with element types i8 or i16 before legalisation, plus various tests with illegal types.

Reviewed By: sdesmalen

Differential Revision: https://reviews.llvm.org/D90945

Added: 
    llvm/test/CodeGen/AArch64/sve-masked-scatter-legalise.ll

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b92eb1d0e4f6..e4c20cc4e6e3 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -835,6 +835,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   if (Subtarget->supportsAddressTopByteIgnored())
     setTargetDAGCombine(ISD::LOAD);
 
+  setTargetDAGCombine(ISD::MSCATTER);
+
   setTargetDAGCombine(ISD::MUL);
 
   setTargetDAGCombine(ISD::SELECT);
@@ -13944,6 +13946,44 @@ static SDValue performSTORECombine(SDNode *N,
   return SDValue();
 }
 
+static SDValue performMSCATTERCombine(SDNode *N,
+                                      TargetLowering::DAGCombinerInfo &DCI,
+                                      SelectionDAG &DAG) {
+  MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(N);
+  assert(MSC && "Can only combine scatter store nodes");
+
+  SDLoc DL(MSC);
+  SDValue Chain = MSC->getChain();
+  SDValue Scale = MSC->getScale();
+  SDValue Index = MSC->getIndex();
+  SDValue Data = MSC->getValue();
+  SDValue Mask = MSC->getMask();
+  SDValue BasePtr = MSC->getBasePtr();
+  ISD::MemIndexType IndexType = MSC->getIndexType();
+
+  EVT IdxVT = Index.getValueType();
+
+  if (DCI.isBeforeLegalize()) {
+    // SVE gather/scatter requires indices of i32/i64. Promote anything smaller
+    // prior to legalisation so the result can be split if required.
+    if ((IdxVT.getVectorElementType() == MVT::i8) ||
+        (IdxVT.getVectorElementType() == MVT::i16)) {
+      EVT NewIdxVT = IdxVT.changeVectorElementType(MVT::i32);
+      if (MSC->isIndexSigned())
+        Index = DAG.getNode(ISD::SIGN_EXTEND, DL, NewIdxVT, Index);
+      else
+        Index = DAG.getNode(ISD::ZERO_EXTEND, DL, NewIdxVT, Index);
+
+      SDValue Ops[] = { Chain, Data, Mask, BasePtr, Index, Scale };
+      return DAG.getMaskedScatter(DAG.getVTList(MVT::Other),
+                                  MSC->getMemoryVT(), DL, Ops,
+                                  MSC->getMemOperand(), IndexType,
+                                  MSC->isTruncatingStore());
+    }
+  }
+
+  return SDValue();
+}
 
 /// Target-specific DAG combine function for NEON load/store intrinsics
 /// to merge base address updates.
@@ -15136,6 +15176,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
     break;
   case ISD::STORE:
     return performSTORECombine(N, DCI, DAG, Subtarget);
+  case ISD::MSCATTER:
+    return performMSCATTERCombine(N, DCI, DAG);
   case AArch64ISD::BRCOND:
     return performBRCONDCombine(N, DCI, DAG);
   case AArch64ISD::TBNZ:

diff  --git a/llvm/test/CodeGen/AArch64/sve-masked-scatter-legalise.ll b/llvm/test/CodeGen/AArch64/sve-masked-scatter-legalise.ll
new file mode 100644
index 000000000000..c3746a61d875
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-masked-scatter-legalise.ll
@@ -0,0 +1,59 @@
+; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s
+
+; Tests that exercise various type legalisation scenarios for ISD::MSCATTER.
+
+; Code generate the scenario where the offset vector type is illegal.
+define void @masked_scatter_nxv16i8(<vscale x 16 x i8> %data, i8* %base, <vscale x 16 x i8> %offsets, <vscale x 16 x i1> %mask) {
+; CHECK-LABEL: masked_scatter_nxv16i8:
+; CHECK-DAG: st1b { {{z[0-9]+}}.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw]
+; CHECK-DAG: st1b { {{z[0-9]+}}.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw]
+; CHECK-DAG: st1b { {{z[0-9]+}}.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw]
+; CHECK-DAG: st1b { {{z[0-9]+}}.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw]
+; CHECK: ret
+  %ptrs = getelementptr i8, i8* %base, <vscale x 16 x i8> %offsets
+  call void @llvm.masked.scatter.nxv16i8(<vscale x 16 x i8> %data, <vscale x 16 x i8*> %ptrs, i32 1, <vscale x 16 x i1> %mask)
+  ret void
+}
+
+define void @masked_scatter_nxv8i16(<vscale x 8 x i16> %data, i16* %base, <vscale x 8 x i16> %offsets, <vscale x 8 x i1> %mask) {
+; CHECK-LABEL: masked_scatter_nxv8i16
+; CHECK-DAG: st1h { {{z[0-9]+}}.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #1]
+; CHECK-DAG: st1h { {{z[0-9]+}}.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #1]
+; CHECK: ret
+  %ptrs = getelementptr i16, i16* %base, <vscale x 8 x i16> %offsets
+  call void @llvm.masked.scatter.nxv8i16(<vscale x 8 x i16> %data, <vscale x 8 x i16*> %ptrs, i32 1, <vscale x 8 x i1> %mask)
+  ret void
+}
+
+define void @masked_scatter_nxv8f32(<vscale x 8 x float> %data, float* %base, <vscale x 8 x i32> %indexes, <vscale x 8 x i1> %masks) {
+; CHECK-LABEL: masked_scatter_nxv8f32
+; CHECK-DAG: st1w { z0.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, uxtw #2]
+; CHECK-DAG: st1w { z1.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, uxtw #2]
+  %ext = zext <vscale x 8 x i32> %indexes to <vscale x 8 x i64>
+  %ptrs = getelementptr float, float* %base, <vscale x 8 x i64> %ext
+  call void @llvm.masked.scatter.nxv8f32(<vscale x 8 x float> %data, <vscale x 8 x float*> %ptrs, i32 0, <vscale x 8 x i1> %masks)
+  ret void
+}
+
+; Code generate the worst case scenario when all vector types are illegal.
+define void @masked_scatter_nxv32i32(<vscale x 32 x i32> %data, i32* %base, <vscale x 32 x i32> %offsets, <vscale x 32 x i1> %mask) {
+; CHECK-LABEL: masked_scatter_nxv32i32:
+; CHECK-NOT: unpkhi
+; CHECK-DAG: st1w { z0.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2]
+; CHECK-DAG: st1w { z1.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2]
+; CHECK-DAG: st1w { z2.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2]
+; CHECK-DAG: st1w { z3.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2]
+; CHECK-DAG: st1w { z4.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2]
+; CHECK-DAG: st1w { z5.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2]
+; CHECK-DAG: st1w { z6.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2]
+; CHECK-DAG: st1w { z7.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2]
+; CHECK: ret
+  %ptrs = getelementptr i32, i32* %base, <vscale x 32 x i32> %offsets
+  call void @llvm.masked.scatter.nxv32i32(<vscale x 32 x i32> %data, <vscale x 32 x i32*> %ptrs, i32 4, <vscale x 32 x i1> %mask)
+  ret void
+}
+
+declare void @llvm.masked.scatter.nxv16i8(<vscale x 16 x i8>, <vscale x 16 x i8*>,  i32, <vscale x 16 x i1>)
+declare void @llvm.masked.scatter.nxv8i16(<vscale x 8 x i16>, <vscale x 8 x i16*>,  i32, <vscale x 8 x i1>)
+declare void @llvm.masked.scatter.nxv8f32(<vscale x 8 x float>, <vscale x 8 x float*>, i32, <vscale x 8 x i1>)
+declare void @llvm.masked.scatter.nxv32i32(<vscale x 32 x i32>, <vscale x 32 x i32*>,  i32, <vscale x 32 x i1>)


        


More information about the llvm-branch-commits mailing list