[llvm] c00a44f - [VP] IR expansion pass for VP gather and scatter

Simon Moll via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 18 08:00:50 PDT 2022


Author: Lorenzo Albano
Date: 2022-07-18T17:00:38+02:00
New Revision: c00a44fa683983f2f0373f5f7b684f8b0213609f

URL: https://github.com/llvm/llvm-project/commit/c00a44fa683983f2f0373f5f7b684f8b0213609f
DIFF: https://github.com/llvm/llvm-project/commit/c00a44fa683983f2f0373f5f7b684f8b0213609f.diff

LOG: [VP] IR expansion pass for VP gather and scatter

Add vp_gather and vp_scatter expansion to unpredicated intrinsics.

Reviewed By: simoll

Differential Revision: https://reviews.llvm.org/D120664

Added: 
    llvm/test/CodeGen/Generic/expand-vp-gather-scatter.ll

Modified: 
    llvm/lib/CodeGen/ExpandVectorPredication.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/ExpandVectorPredication.cpp b/llvm/lib/CodeGen/ExpandVectorPredication.cpp
index ae913f543ce6..515484c583ed 100644
--- a/llvm/lib/CodeGen/ExpandVectorPredication.cpp
+++ b/llvm/lib/CodeGen/ExpandVectorPredication.cpp
@@ -402,6 +402,8 @@ CachingVPExpander::expandPredicationInMemoryIntrinsic(IRBuilder<> &Builder,
                                                       VPIntrinsic &VPI) {
   assert(VPI.canIgnoreVectorLengthParam());
 
+  const auto &DL = F.getParent()->getDataLayout();
+
   Value *MaskParam = VPI.getMaskParam();
   Value *PtrParam = VPI.getMemoryPointerParam();
   Value *DataParam = VPI.getMemoryDataParam();
@@ -437,6 +439,22 @@ CachingVPExpander::expandPredicationInMemoryIntrinsic(IRBuilder<> &Builder,
           VPI.getType(), PtrParam, AlignOpt.valueOrOne(), MaskParam);
 
     break;
+  case Intrinsic::vp_scatter: {
+    auto *ElementType =
+        cast<VectorType>(DataParam->getType())->getElementType();
+    NewMemoryInst = Builder.CreateMaskedScatter(
+        DataParam, PtrParam,
+        AlignOpt.getValueOr(DL.getPrefTypeAlign(ElementType)), MaskParam);
+    break;
+  }
+  case Intrinsic::vp_gather: {
+    auto *ElementType = cast<VectorType>(VPI.getType())->getElementType();
+    NewMemoryInst = Builder.CreateMaskedGather(
+        VPI.getType(), PtrParam,
+        AlignOpt.getValueOr(DL.getPrefTypeAlign(ElementType)), MaskParam,
+        nullptr, VPI.getName());
+    break;
+  }
   }
 
   assert(NewMemoryInst);
@@ -525,6 +543,8 @@ Value *CachingVPExpander::expandPredication(VPIntrinsic &VPI) {
     break;
   case Intrinsic::vp_load:
   case Intrinsic::vp_store:
+  case Intrinsic::vp_gather:
+  case Intrinsic::vp_scatter:
     return expandPredicationInMemoryIntrinsic(Builder, VPI);
   }
 

diff  --git a/llvm/test/CodeGen/Generic/expand-vp-gather-scatter.ll b/llvm/test/CodeGen/Generic/expand-vp-gather-scatter.ll
new file mode 100644
index 000000000000..8e257195bd82
--- /dev/null
+++ b/llvm/test/CodeGen/Generic/expand-vp-gather-scatter.ll
@@ -0,0 +1,118 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt --expandvp -S < %s | FileCheck %s
+
+; Fixed vectors
+define <4 x i32> @vpgather_v4i32(<4 x i32*> %ptrs, <4 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: @vpgather_v4i32(
+; CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <4 x i32> poison, i32 [[EVL:%.*]], i32 0
+; CHECK-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <4 x i32> [[DOTSPLATINSERT]], <4 x i32> poison, <4 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp ult <4 x i32> <i32 0, i32 1, i32 2, i32 3>, [[DOTSPLAT]]
+; CHECK-NEXT:    [[TMP2:%.*]] = and <4 x i1> [[TMP1]], [[M:%.*]]
+; CHECK-NEXT:    [[V1:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0i32(<4 x i32*> [[PTRS:%.*]], i32 4, <4 x i1> [[TMP2]], <4 x i32> undef)
+; CHECK-NEXT:    ret <4 x i32> [[V1]]
+;
+  %v = call <4 x i32> @llvm.vp.gather.v4i32.v4p0i32(<4 x i32*> %ptrs, <4 x i1> %m, i32 %evl)
+  ret <4 x i32> %v
+}
+
+define <2 x i64> @vpgather_v2i64(<2 x i64*> %ptrs, <2 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: @vpgather_v2i64(
+; CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <2 x i32> poison, i32 [[EVL:%.*]], i32 0
+; CHECK-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <2 x i32> [[DOTSPLATINSERT]], <2 x i32> poison, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp ult <2 x i32> <i32 0, i32 1>, [[DOTSPLAT]]
+; CHECK-NEXT:    [[TMP2:%.*]] = and <2 x i1> [[TMP1]], [[M:%.*]]
+; CHECK-NEXT:    [[V1:%.*]] = call <2 x i64> @llvm.masked.gather.v2i64.v2p0i64(<2 x i64*> [[PTRS:%.*]], i32 8, <2 x i1> [[TMP2]], <2 x i64> undef)
+; CHECK-NEXT:    ret <2 x i64> [[V1]]
+;
+  %v = call <2 x i64> @llvm.vp.gather.v2i64.v2p0i64(<2 x i64*> %ptrs, <2 x i1> %m, i32 %evl)
+  ret <2 x i64> %v
+}
+
+define void @vpscatter_v4i32(<4 x i32> %val, <4 x i32*> %ptrs, <4 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: @vpscatter_v4i32(
+; CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <4 x i32> poison, i32 [[EVL:%.*]], i32 0
+; CHECK-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <4 x i32> [[DOTSPLATINSERT]], <4 x i32> poison, <4 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp ult <4 x i32> <i32 0, i32 1, i32 2, i32 3>, [[DOTSPLAT]]
+; CHECK-NEXT:    [[TMP2:%.*]] = and <4 x i1> [[TMP1]], [[M:%.*]]
+; CHECK-NEXT:    call void @llvm.masked.scatter.v4i32.v4p0i32(<4 x i32> [[VAL:%.*]], <4 x i32*> [[PTRS:%.*]], i32 4, <4 x i1> [[TMP2]])
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.vp.scatter.v4i32.v4p0i32(<4 x i32> %val, <4 x i32*> %ptrs, <4 x i1> %m, i32 %evl)
+  ret void
+}
+
+define void @vpscatter_v2i64(<2 x i64> %val, <2 x i64*> %ptrs, <2 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: @vpscatter_v2i64(
+; CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <2 x i32> poison, i32 [[EVL:%.*]], i32 0
+; CHECK-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <2 x i32> [[DOTSPLATINSERT]], <2 x i32> poison, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp ult <2 x i32> <i32 0, i32 1>, [[DOTSPLAT]]
+; CHECK-NEXT:    [[TMP2:%.*]] = and <2 x i1> [[TMP1]], [[M:%.*]]
+; CHECK-NEXT:    call void @llvm.masked.scatter.v2i64.v2p0i64(<2 x i64> [[VAL:%.*]], <2 x i64*> [[PTRS:%.*]], i32 8, <2 x i1> [[TMP2]])
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.vp.scatter.v2i64.v2p0i64(<2 x i64> %val, <2 x i64*> %ptrs, <2 x i1> %m, i32 %evl)
+  ret void
+}
+
+; Scalable vectors
+define <vscale x 2 x i32> @vpgather_nxv2i32(<vscale x 2 x i32*> %ptrs, <vscale x 2 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: @vpgather_nxv2i32(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i32(i32 0, i32 [[EVL:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = and <vscale x 2 x i1> [[TMP1]], [[M:%.*]]
+; CHECK-NEXT:    [[VSCALE:%.*]] = call i32 @llvm.vscale.i32()
+; CHECK-NEXT:    [[SCALABLE_SIZE:%.*]] = mul nuw i32 [[VSCALE]], 2
+; CHECK-NEXT:    [[V1:%.*]] = call <vscale x 2 x i32> @llvm.masked.gather.nxv2i32.nxv2p0i32(<vscale x 2 x i32*> [[PTRS:%.*]], i32 4, <vscale x 2 x i1> [[TMP2]], <vscale x 2 x i32> undef)
+; CHECK-NEXT:    ret <vscale x 2 x i32> [[V1]]
+;
+  %v = call <vscale x 2 x i32> @llvm.vp.gather.nxv2i32.nxv2p0i32(<vscale x 2 x i32*> %ptrs, <vscale x 2 x i1> %m, i32 %evl)
+  ret <vscale x 2 x i32> %v
+}
+
+define <vscale x 1 x i64> @vpgather_nxv1i64(<vscale x 1 x i64*> %ptrs, <vscale x 1 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: @vpgather_nxv1i64(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 1 x i1> @llvm.get.active.lane.mask.nxv1i1.i32(i32 0, i32 [[EVL:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = and <vscale x 1 x i1> [[TMP1]], [[M:%.*]]
+; CHECK-NEXT:    [[VSCALE:%.*]] = call i32 @llvm.vscale.i32()
+; CHECK-NEXT:    [[SCALABLE_SIZE:%.*]] = mul nuw i32 [[VSCALE]], 1
+; CHECK-NEXT:    [[V1:%.*]] = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0i64(<vscale x 1 x i64*> [[PTRS:%.*]], i32 8, <vscale x 1 x i1> [[TMP2]], <vscale x 1 x i64> undef)
+; CHECK-NEXT:    ret <vscale x 1 x i64> [[V1]]
+;
+  %v = call <vscale x 1 x i64> @llvm.vp.gather.nxv1i64.nxv1p0i64(<vscale x 1 x i64*> %ptrs, <vscale x 1 x i1> %m, i32 %evl)
+  ret <vscale x 1 x i64> %v
+}
+
+define void @vpscatter_nxv2i32(<vscale x 2 x i32> %val, <vscale x 2 x i32*> %ptrs, <vscale x 2 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: @vpscatter_nxv2i32(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i32(i32 0, i32 [[EVL:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = and <vscale x 2 x i1> [[TMP1]], [[M:%.*]]
+; CHECK-NEXT:    [[VSCALE:%.*]] = call i32 @llvm.vscale.i32()
+; CHECK-NEXT:    [[SCALABLE_SIZE:%.*]] = mul nuw i32 [[VSCALE]], 2
+; CHECK-NEXT:    call void @llvm.masked.scatter.nxv2i32.nxv2p0i32(<vscale x 2 x i32> [[VAL:%.*]], <vscale x 2 x i32*> [[PTRS:%.*]], i32 4, <vscale x 2 x i1> [[TMP2]])
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.vp.scatter.nxv2i32.nxv2p0i32(<vscale x 2 x i32> %val, <vscale x 2 x i32*> %ptrs, <vscale x 2 x i1> %m, i32 %evl)
+  ret void
+}
+
+define void @vpscatter_nxv1i64(<vscale x 1 x i64> %val, <vscale x 1 x i64*> %ptrs, <vscale x 1 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: @vpscatter_nxv1i64(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 1 x i1> @llvm.get.active.lane.mask.nxv1i1.i32(i32 0, i32 [[EVL:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = and <vscale x 1 x i1> [[TMP1]], [[M:%.*]]
+; CHECK-NEXT:    [[VSCALE:%.*]] = call i32 @llvm.vscale.i32()
+; CHECK-NEXT:    [[SCALABLE_SIZE:%.*]] = mul nuw i32 [[VSCALE]], 1
+; CHECK-NEXT:    call void @llvm.masked.scatter.nxv1i64.nxv1p0i64(<vscale x 1 x i64> [[VAL:%.*]], <vscale x 1 x i64*> [[PTRS:%.*]], i32 8, <vscale x 1 x i1> [[TMP2]])
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.vp.scatter.nxv1i64.nxv1p0i64(<vscale x 1 x i64> %val, <vscale x 1 x i64*> %ptrs, <vscale x 1 x i1> %m, i32 %evl)
+  ret void
+}
+
+declare <4 x i32> @llvm.vp.gather.v4i32.v4p0i32(<4 x i32*>, <4 x i1>, i32)
+declare <2 x i64> @llvm.vp.gather.v2i64.v2p0i64(<2 x i64*>, <2 x i1>, i32)
+declare void @llvm.vp.scatter.v4i32.v4p0i32(<4 x i32>, <4 x i32*>, <4 x i1>, i32)
+declare void @llvm.vp.scatter.v2i64.v2p0i64(<2 x i64>, <2 x i64*>, <2 x i1>, i32)
+
+declare <vscale x 2 x i32> @llvm.vp.gather.nxv2i32.nxv2p0i32(<vscale x 2 x i32*>, <vscale x 2 x i1>, i32)
+declare <vscale x 1 x i64> @llvm.vp.gather.nxv1i64.nxv1p0i64(<vscale x 1 x i64*>, <vscale x 1 x i1>, i32)
+declare void @llvm.vp.scatter.nxv2i32.nxv2p0i32(<vscale x 2 x i32>, <vscale x 2 x i32*>, <vscale x 2 x i1>, i32)
+declare void @llvm.vp.scatter.nxv1i64.nxv1p0i64(<vscale x 1 x i64>, <vscale x 1 x i64*>, <vscale x 1 x i1>, i32)


        


More information about the llvm-commits mailing list