[llvm] 8e5a5b6 - [InstCombine] Fold for masked scatters to a uniform address

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 14 07:51:44 PST 2022


As a possible extension, you can perform the scalar splat for any 
constant max by using the last active lane in the mask.  That would 
generalize your code here.

Philip

On 1/14/22 1:56 AM, Caroline Concatto via llvm-commits wrote:
> Author: Caroline Concatto
> Date: 2022-01-14T09:44:34Z
> New Revision: 8e5a5b619d34c7846e35d219c0747b9c29654f15
>
> URL: https://github.com/llvm/llvm-project/commit/8e5a5b619d34c7846e35d219c0747b9c29654f15
> DIFF: https://github.com/llvm/llvm-project/commit/8e5a5b619d34c7846e35d219c0747b9c29654f15.diff
>
> LOG: [InstCombine] Fold for masked scatters to a uniform address
>
> When masked scatter intrinsic does a uniform store to a destination
> address from a source vector, and in this case, the mask is all one value.
> This patch replaces the masked scatter with an extracted element of the
> last lane of the source vector and stores it in the destination vector.
> This patch also folds when the value in the masked scatter is a splat.
> In this case, the mask cannot be all zero, and it folds to a scalar store
> of the value in the destination pointer.
>
> Differential Revision: https://reviews.llvm.org/D115724
>
> Added:
>      
>
> Modified:
>      llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
>      llvm/test/Transforms/InstCombine/masked_intrinsics.ll
>
> Removed:
>      
>
>
> ################################################################################
> diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
> index 22c736f423f07..37cc21aeff8b8 100644
> --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
> +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
> @@ -362,7 +362,6 @@ Instruction *InstCombinerImpl::simplifyMaskedGather(IntrinsicInst &II) {
>   // * Single constant active lane -> store
>   // * Adjacent vector addresses -> masked.store
>   // * Narrow store width by halfs excluding zero/undef lanes
> -// * Vector splat address w/known mask -> scalar store
>   // * Vector incrementing address -> vector masked store
>   Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
>     auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3));
> @@ -373,6 +372,34 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
>     if (ConstMask->isNullValue())
>       return eraseInstFromFunction(II);
>   
> +  // Vector splat address -> scalar store
> +  if (auto *SplatPtr = getSplatValue(II.getArgOperand(1))) {
> +    // scatter(splat(value), splat(ptr), non-zero-mask) -> store value, ptr
> +    if (auto *SplatValue = getSplatValue(II.getArgOperand(0))) {
> +      Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
> +      StoreInst *S =
> +          new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false, Alignment);
> +      S->copyMetadata(II);
> +      return S;
> +    }
> +    // scatter(vector, splat(ptr), splat(true)) -> store extract(vector,
> +    // lastlane), ptr
> +    if (ConstMask->isAllOnesValue()) {
> +      Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
> +      VectorType *WideLoadTy = cast<VectorType>(II.getArgOperand(1)->getType());
> +      ElementCount VF = WideLoadTy->getElementCount();
> +      Constant *EC =
> +          ConstantInt::get(Builder.getInt32Ty(), VF.getKnownMinValue());
> +      Value *RunTimeVF = VF.isScalable() ? Builder.CreateVScale(EC) : EC;
> +      Value *LastLane = Builder.CreateSub(RunTimeVF, Builder.getInt32(1));
> +      Value *Extract =
> +          Builder.CreateExtractElement(II.getArgOperand(0), LastLane);
> +      StoreInst *S =
> +          new StoreInst(Extract, SplatPtr, /*IsVolatile=*/false, Alignment);
> +      S->copyMetadata(II);
> +      return S;
> +    }
> +  }
>     if (isa<ScalableVectorType>(ConstMask->getType()))
>       return nullptr;
>   
>
> diff  --git a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
> index 884a042b58222..a82aebd738fea 100644
> --- a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
> +++ b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
> @@ -269,3 +269,110 @@ define void @scatter_demandedelts(double* %ptr, double %val)  {
>     call void @llvm.masked.scatter.v2f64.v2p0f64(<2 x double> %valvec2, <2 x double*> %ptrs, i32 8, <2 x i1> <i1 true, i1 false>)
>     ret void
>   }
> +
> +
> +; Test scatters that can be simplified to scalar stores.
> +
> +;; Value splat (mask is not used)
> +define void @scatter_v4i16_uniform_vals_uniform_ptrs_no_all_active_mask(i16* %dst, i16 %val) {
> +; CHECK-LABEL: @scatter_v4i16_uniform_vals_uniform_ptrs_no_all_active_mask(
> +; CHECK-NEXT:  entry:
> +; CHECK-NEXT:    store i16 [[VAL:%.*]], i16* [[DST:%.*]], align 2
> +; CHECK-NEXT:    ret void
> +;
> +entry:
> +  %broadcast.splatinsert = insertelement <4 x i16*> poison, i16* %dst, i32 0
> +  %broadcast.splat = shufflevector <4 x i16*> %broadcast.splatinsert, <4 x i16*> poison, <4 x i32> zeroinitializer
> +  %broadcast.value = insertelement <4 x i16> poison, i16 %val, i32 0
> +  %broadcast.splatvalue = shufflevector <4 x i16> %broadcast.value, <4 x i16> poison, <4 x i32> zeroinitializer
> +  call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %broadcast.splatvalue, <4 x i16*> %broadcast.splat, i32 2, <4 x i1> <i1 0, i1 0, i1 1, i1 1>)
> +  ret void
> +}
> +
> +define void @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask(i16* %dst, i16 %val) {
> +; CHECK-LABEL: @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask(
> +; CHECK-NEXT:  entry:
> +; CHECK-NEXT:    store i16 [[VAL:%.*]], i16* [[DST:%.*]], align 2
> +; CHECK-NEXT:    ret void
> +;
> +entry:
> +  %broadcast.splatinsert = insertelement <vscale x 4 x i16*> poison, i16* %dst, i32 0
> +  %broadcast.splat = shufflevector <vscale x 4 x i16*> %broadcast.splatinsert, <vscale x 4 x i16*> poison, <vscale x 4 x i32> zeroinitializer
> +  %broadcast.value = insertelement <vscale x 4 x i16> poison, i16 %val, i32 0
> +  %broadcast.splatvalue = shufflevector <vscale x 4 x i16> %broadcast.value, <vscale x 4 x i16> poison, <vscale x 4 x i32> zeroinitializer
> +  call void @llvm.masked.scatter.nxv4i16.nxv4p0i16(<vscale x 4 x i16> %broadcast.splatvalue, <vscale x 4 x i16*> %broadcast.splat, i32 2, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> zeroinitializer , i1 true, i32 0), <vscale x 4 x i1> zeroinitializer, <vscale x 4 x i32> zeroinitializer))
> +  ret void
> +}
> +
> +;; The pointer is splat and mask is all active, but value is not a splat
> +define void @scatter_v4i16_no_uniform_vals_uniform_ptrs_all_active_mask(i16* %dst, <4 x i16>*  %src)  {
> +; CHECK-LABEL: @scatter_v4i16_no_uniform_vals_uniform_ptrs_all_active_mask(
> +; CHECK-NEXT:  entry:
> +; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i16>, <4 x i16>* [[SRC:%.*]], align 2
> +; CHECK-NEXT:    [[TMP0:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i64 3
> +; CHECK-NEXT:    store i16 [[TMP0]], i16* [[DST:%.*]], align 2
> +; CHECK-NEXT:    ret void
> +;
> +entry:
> +  %broadcast.splatinsert = insertelement <4 x i16*> poison, i16* %dst, i32 0
> +  %broadcast.splat = shufflevector <4 x i16*> %broadcast.splatinsert, <4 x i16*> poison, <4 x i32> zeroinitializer
> +  %wide.load = load <4 x i16>, <4 x i16>* %src, align 2
> +  call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %wide.load, <4 x i16*> %broadcast.splat, i32 2, <4 x i1> <i1 1, i1 1, i1 1, i1 1>)
> +  ret void
> +}
> +
> +define void @scatter_nxv4i16_no_uniform_vals_uniform_ptrs_all_active_mask(i16* %dst, <vscale x 4 x i16>* %src) {
> +; CHECK-LABEL: @scatter_nxv4i16_no_uniform_vals_uniform_ptrs_all_active_mask(
> +; CHECK-NEXT:  entry:
> +; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 4 x i16>, <vscale x 4 x i16>* [[SRC:%.*]], align 2
> +; CHECK-NEXT:    [[TMP0:%.*]] = call i32 @llvm.vscale.i32()
> +; CHECK-NEXT:    [[TMP1:%.*]] = shl i32 [[TMP0]], 2
> +; CHECK-NEXT:    [[TMP2:%.*]] = add i32 [[TMP1]], -1
> +; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <vscale x 4 x i16> [[WIDE_LOAD]], i32 [[TMP2]]
> +; CHECK-NEXT:    store i16 [[TMP3]], i16* [[DST:%.*]], align 2
> +; CHECK-NEXT:    ret void
> +;
> +entry:
> +  %broadcast.splatinsert = insertelement <vscale x 4 x i16*> poison, i16* %dst, i32 0
> +  %broadcast.splat = shufflevector <vscale x 4 x i16*> %broadcast.splatinsert, <vscale x 4 x i16*> poison, <vscale x 4 x i32> zeroinitializer
> +  %wide.load = load <vscale x 4 x i16>, <vscale x 4 x i16>* %src, align 2
> +  call void @llvm.masked.scatter.nxv4i16.nxv4p0i16(<vscale x 4 x i16> %wide.load, <vscale x 4 x i16*> %broadcast.splat, i32 2, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i32 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer))
> +  ret void
> +}
> +
> +; Negative scatter tests
> +
> +;; Pointer is splat, but mask is not all active and  value is not a splat
> +define void @negative_scatter_v4i16_no_uniform_vals_uniform_ptrs_all_inactive_mask(i16* %dst, <4 x i16>* %src) {
> +; CHECK-LABEL: @negative_scatter_v4i16_no_uniform_vals_uniform_ptrs_all_inactive_mask(
> +; CHECK-NEXT:    [[INSERT_ELT:%.*]] = insertelement <4 x i16*> poison, i16* [[DST:%.*]], i64 0
> +; CHECK-NEXT:    [[BROADCAST_SPLAT:%.*]] = shufflevector <4 x i16*> [[INSERT_ELT]], <4 x i16*> poison, <4 x i32> <i32 undef, i32 undef, i32 0, i32 0>
> +; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i16>, <4 x i16>* [[SRC:%.*]], align 2
> +; CHECK-NEXT:    call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> [[WIDE_LOAD]], <4 x i16*> [[BROADCAST_SPLAT]], i32 2, <4 x i1> <i1 false, i1 false, i1 true, i1 true>)
> +; CHECK-NEXT:    ret void
> +;
> +  %insert.elt = insertelement <4 x i16*> poison, i16* %dst, i32 0
> +  %broadcast.splat = shufflevector <4 x i16*> %insert.elt, <4 x i16*> poison, <4 x i32> zeroinitializer
> +  %wide.load = load <4 x i16>, <4 x i16>* %src, align 2
> +  call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %wide.load, <4 x i16*> %broadcast.splat, i32 2, <4 x i1> <i1 0, i1 0, i1 1, i1 1>)
> +  ret void
> +}
> +
> +;; The pointer in NOT a splat
> +define void @negative_scatter_v4i16_no_uniform_vals_no_uniform_ptrs_all_active_mask(<4 x i16*> %inPtr, <4 x i16>* %src) {
> +; CHECK-LABEL: @negative_scatter_v4i16_no_uniform_vals_no_uniform_ptrs_all_active_mask(
> +; CHECK-NEXT:    [[BROADCAST:%.*]] = shufflevector <4 x i16*> [[INPTR:%.*]], <4 x i16*> poison, <4 x i32> zeroinitializer
> +; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i16>, <4 x i16>* [[SRC:%.*]], align 2
> +; CHECK-NEXT:    call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> [[WIDE_LOAD]], <4 x i16*> [[BROADCAST]], i32 2, <4 x i1> <i1 true, i1 true, i1 true, i1 true>)
> +; CHECK-NEXT:    ret void
> +;
> +  %broadcast= shufflevector <4 x i16*> %inPtr, <4 x i16*> poison, <4 x i32> zeroinitializer
> +  %wide.load = load <4 x i16>, <4 x i16>* %src, align 2
> +  call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %wide.load, <4 x i16*> %broadcast, i32 2, <4 x i1> <i1 1, i1 1, i1 1, i1 1> )
> +  ret void
> +}
> +
> +
> +; Function Attrs:
> +declare void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16>, <4 x i16*>, i32 immarg, <4 x i1>)
> +declare void @llvm.masked.scatter.nxv4i16.nxv4p0i16(<vscale x 4 x i16>, <vscale x 4 x i16*>, i32 immarg, <vscale x 4 x i1>)
>
>
>          
> _______________________________________________
> llvm-commits mailing list
> llvm-commits at lists.llvm.org
> https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-commits


More information about the llvm-commits mailing list