[llvm] 8e5a5b6 - [InstCombine] Fold for masked scatters to a uniform address

Caroline Concatto via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 14 01:56:33 PST 2022


Author: Caroline Concatto
Date: 2022-01-14T09:44:34Z
New Revision: 8e5a5b619d34c7846e35d219c0747b9c29654f15

URL: https://github.com/llvm/llvm-project/commit/8e5a5b619d34c7846e35d219c0747b9c29654f15
DIFF: https://github.com/llvm/llvm-project/commit/8e5a5b619d34c7846e35d219c0747b9c29654f15.diff

LOG: [InstCombine] Fold for masked scatters to a uniform address

When masked scatter intrinsic does a uniform store to a destination
address from a source vector, and in this case, the mask is all one value.
This patch replaces the masked scatter with an extracted element of the
last lane of the source vector and stores it in the destination vector.
This patch also folds when the value in the masked scatter is a splat.
In this case, the mask cannot be all zero, and it folds to a scalar store
of the value in the destination pointer.

Differential Revision: https://reviews.llvm.org/D115724

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
    llvm/test/Transforms/InstCombine/masked_intrinsics.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 22c736f423f07..37cc21aeff8b8 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -362,7 +362,6 @@ Instruction *InstCombinerImpl::simplifyMaskedGather(IntrinsicInst &II) {
 // * Single constant active lane -> store
 // * Adjacent vector addresses -> masked.store
 // * Narrow store width by halfs excluding zero/undef lanes
-// * Vector splat address w/known mask -> scalar store
 // * Vector incrementing address -> vector masked store
 Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
   auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3));
@@ -373,6 +372,34 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
   if (ConstMask->isNullValue())
     return eraseInstFromFunction(II);
 
+  // Vector splat address -> scalar store
+  if (auto *SplatPtr = getSplatValue(II.getArgOperand(1))) {
+    // scatter(splat(value), splat(ptr), non-zero-mask) -> store value, ptr
+    if (auto *SplatValue = getSplatValue(II.getArgOperand(0))) {
+      Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
+      StoreInst *S =
+          new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false, Alignment);
+      S->copyMetadata(II);
+      return S;
+    }
+    // scatter(vector, splat(ptr), splat(true)) -> store extract(vector,
+    // lastlane), ptr
+    if (ConstMask->isAllOnesValue()) {
+      Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
+      VectorType *WideLoadTy = cast<VectorType>(II.getArgOperand(1)->getType());
+      ElementCount VF = WideLoadTy->getElementCount();
+      Constant *EC =
+          ConstantInt::get(Builder.getInt32Ty(), VF.getKnownMinValue());
+      Value *RunTimeVF = VF.isScalable() ? Builder.CreateVScale(EC) : EC;
+      Value *LastLane = Builder.CreateSub(RunTimeVF, Builder.getInt32(1));
+      Value *Extract =
+          Builder.CreateExtractElement(II.getArgOperand(0), LastLane);
+      StoreInst *S =
+          new StoreInst(Extract, SplatPtr, /*IsVolatile=*/false, Alignment);
+      S->copyMetadata(II);
+      return S;
+    }
+  }
   if (isa<ScalableVectorType>(ConstMask->getType()))
     return nullptr;
 

diff  --git a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
index 884a042b58222..a82aebd738fea 100644
--- a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
+++ b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
@@ -269,3 +269,110 @@ define void @scatter_demandedelts(double* %ptr, double %val)  {
   call void @llvm.masked.scatter.v2f64.v2p0f64(<2 x double> %valvec2, <2 x double*> %ptrs, i32 8, <2 x i1> <i1 true, i1 false>)
   ret void
 }
+
+
+; Test scatters that can be simplified to scalar stores.
+
+;; Value splat (mask is not used)
+define void @scatter_v4i16_uniform_vals_uniform_ptrs_no_all_active_mask(i16* %dst, i16 %val) {
+; CHECK-LABEL: @scatter_v4i16_uniform_vals_uniform_ptrs_no_all_active_mask(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    store i16 [[VAL:%.*]], i16* [[DST:%.*]], align 2
+; CHECK-NEXT:    ret void
+;
+entry:
+  %broadcast.splatinsert = insertelement <4 x i16*> poison, i16* %dst, i32 0
+  %broadcast.splat = shufflevector <4 x i16*> %broadcast.splatinsert, <4 x i16*> poison, <4 x i32> zeroinitializer
+  %broadcast.value = insertelement <4 x i16> poison, i16 %val, i32 0
+  %broadcast.splatvalue = shufflevector <4 x i16> %broadcast.value, <4 x i16> poison, <4 x i32> zeroinitializer
+  call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %broadcast.splatvalue, <4 x i16*> %broadcast.splat, i32 2, <4 x i1> <i1 0, i1 0, i1 1, i1 1>)
+  ret void
+}
+
+define void @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask(i16* %dst, i16 %val) {
+; CHECK-LABEL: @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    store i16 [[VAL:%.*]], i16* [[DST:%.*]], align 2
+; CHECK-NEXT:    ret void
+;
+entry:
+  %broadcast.splatinsert = insertelement <vscale x 4 x i16*> poison, i16* %dst, i32 0
+  %broadcast.splat = shufflevector <vscale x 4 x i16*> %broadcast.splatinsert, <vscale x 4 x i16*> poison, <vscale x 4 x i32> zeroinitializer
+  %broadcast.value = insertelement <vscale x 4 x i16> poison, i16 %val, i32 0
+  %broadcast.splatvalue = shufflevector <vscale x 4 x i16> %broadcast.value, <vscale x 4 x i16> poison, <vscale x 4 x i32> zeroinitializer
+  call void @llvm.masked.scatter.nxv4i16.nxv4p0i16(<vscale x 4 x i16> %broadcast.splatvalue, <vscale x 4 x i16*> %broadcast.splat, i32 2, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> zeroinitializer , i1 true, i32 0), <vscale x 4 x i1> zeroinitializer, <vscale x 4 x i32> zeroinitializer))
+  ret void
+}
+
+;; The pointer is splat and mask is all active, but value is not a splat
+define void @scatter_v4i16_no_uniform_vals_uniform_ptrs_all_active_mask(i16* %dst, <4 x i16>*  %src)  {
+; CHECK-LABEL: @scatter_v4i16_no_uniform_vals_uniform_ptrs_all_active_mask(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i16>, <4 x i16>* [[SRC:%.*]], align 2
+; CHECK-NEXT:    [[TMP0:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i64 3
+; CHECK-NEXT:    store i16 [[TMP0]], i16* [[DST:%.*]], align 2
+; CHECK-NEXT:    ret void
+;
+entry:
+  %broadcast.splatinsert = insertelement <4 x i16*> poison, i16* %dst, i32 0
+  %broadcast.splat = shufflevector <4 x i16*> %broadcast.splatinsert, <4 x i16*> poison, <4 x i32> zeroinitializer
+  %wide.load = load <4 x i16>, <4 x i16>* %src, align 2
+  call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %wide.load, <4 x i16*> %broadcast.splat, i32 2, <4 x i1> <i1 1, i1 1, i1 1, i1 1>)
+  ret void
+}
+
+define void @scatter_nxv4i16_no_uniform_vals_uniform_ptrs_all_active_mask(i16* %dst, <vscale x 4 x i16>* %src) {
+; CHECK-LABEL: @scatter_nxv4i16_no_uniform_vals_uniform_ptrs_all_active_mask(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 4 x i16>, <vscale x 4 x i16>* [[SRC:%.*]], align 2
+; CHECK-NEXT:    [[TMP0:%.*]] = call i32 @llvm.vscale.i32()
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i32 [[TMP0]], 2
+; CHECK-NEXT:    [[TMP2:%.*]] = add i32 [[TMP1]], -1
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <vscale x 4 x i16> [[WIDE_LOAD]], i32 [[TMP2]]
+; CHECK-NEXT:    store i16 [[TMP3]], i16* [[DST:%.*]], align 2
+; CHECK-NEXT:    ret void
+;
+entry:
+  %broadcast.splatinsert = insertelement <vscale x 4 x i16*> poison, i16* %dst, i32 0
+  %broadcast.splat = shufflevector <vscale x 4 x i16*> %broadcast.splatinsert, <vscale x 4 x i16*> poison, <vscale x 4 x i32> zeroinitializer
+  %wide.load = load <vscale x 4 x i16>, <vscale x 4 x i16>* %src, align 2
+  call void @llvm.masked.scatter.nxv4i16.nxv4p0i16(<vscale x 4 x i16> %wide.load, <vscale x 4 x i16*> %broadcast.splat, i32 2, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i32 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer))
+  ret void
+}
+
+; Negative scatter tests
+
+;; Pointer is splat, but mask is not all active and  value is not a splat
+define void @negative_scatter_v4i16_no_uniform_vals_uniform_ptrs_all_inactive_mask(i16* %dst, <4 x i16>* %src) {
+; CHECK-LABEL: @negative_scatter_v4i16_no_uniform_vals_uniform_ptrs_all_inactive_mask(
+; CHECK-NEXT:    [[INSERT_ELT:%.*]] = insertelement <4 x i16*> poison, i16* [[DST:%.*]], i64 0
+; CHECK-NEXT:    [[BROADCAST_SPLAT:%.*]] = shufflevector <4 x i16*> [[INSERT_ELT]], <4 x i16*> poison, <4 x i32> <i32 undef, i32 undef, i32 0, i32 0>
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i16>, <4 x i16>* [[SRC:%.*]], align 2
+; CHECK-NEXT:    call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> [[WIDE_LOAD]], <4 x i16*> [[BROADCAST_SPLAT]], i32 2, <4 x i1> <i1 false, i1 false, i1 true, i1 true>)
+; CHECK-NEXT:    ret void
+;
+  %insert.elt = insertelement <4 x i16*> poison, i16* %dst, i32 0
+  %broadcast.splat = shufflevector <4 x i16*> %insert.elt, <4 x i16*> poison, <4 x i32> zeroinitializer
+  %wide.load = load <4 x i16>, <4 x i16>* %src, align 2
+  call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %wide.load, <4 x i16*> %broadcast.splat, i32 2, <4 x i1> <i1 0, i1 0, i1 1, i1 1>)
+  ret void
+}
+
+;; The pointer in NOT a splat
+define void @negative_scatter_v4i16_no_uniform_vals_no_uniform_ptrs_all_active_mask(<4 x i16*> %inPtr, <4 x i16>* %src) {
+; CHECK-LABEL: @negative_scatter_v4i16_no_uniform_vals_no_uniform_ptrs_all_active_mask(
+; CHECK-NEXT:    [[BROADCAST:%.*]] = shufflevector <4 x i16*> [[INPTR:%.*]], <4 x i16*> poison, <4 x i32> zeroinitializer
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i16>, <4 x i16>* [[SRC:%.*]], align 2
+; CHECK-NEXT:    call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> [[WIDE_LOAD]], <4 x i16*> [[BROADCAST]], i32 2, <4 x i1> <i1 true, i1 true, i1 true, i1 true>)
+; CHECK-NEXT:    ret void
+;
+  %broadcast= shufflevector <4 x i16*> %inPtr, <4 x i16*> poison, <4 x i32> zeroinitializer
+  %wide.load = load <4 x i16>, <4 x i16>* %src, align 2
+  call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %wide.load, <4 x i16*> %broadcast, i32 2, <4 x i1> <i1 1, i1 1, i1 1, i1 1> )
+  ret void
+}
+
+
+; Function Attrs:
+declare void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16>, <4 x i16*>, i32 immarg, <4 x i1>)
+declare void @llvm.masked.scatter.nxv4i16.nxv4p0i16(<vscale x 4 x i16>, <vscale x 4 x i16*>, i32 immarg, <vscale x 4 x i1>)


        


More information about the llvm-commits mailing list