[llvm] ad43217 - [InstCombine] Fold for masked gather when loading the same value each time.

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 25 07:09:05 PST 2022


Caroline,

A generalization of this would be to check for a non-zero mask, and then 
do a load, broadcast, select sequence.  I'd encourage you to follow up 
with that.

Philip

On 1/21/22 6:20 AM, Caroline Concatto via llvm-commits wrote:
> Author: Caroline Concatto
> Date: 2022-01-21T14:19:51Z
> New Revision: ad43217a046634be24174299beec3a28018ec3c0
>
> URL: https://github.com/llvm/llvm-project/commit/ad43217a046634be24174299beec3a28018ec3c0
> DIFF: https://github.com/llvm/llvm-project/commit/ad43217a046634be24174299beec3a28018ec3c0.diff
>
> LOG: [InstCombine] Fold for masked gather when loading the same value each time.
>
> This patch checks in the masked gather when the first operand value is a
> splat and the mask is all one, because the masked gather is reloading the
> same value each time. This patch replaces this pattern of masked gather by
> a scalar load of the value and splats it in a vector.
>
> Differential Revision: https://reviews.llvm.org/D115726
>
> Added:
>      
>
> Modified:
>      llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
>      llvm/test/Transforms/InstCombine/masked_intrinsics.ll
>
> Removed:
>      
>
>
> ################################################################################
> diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
> index e3a9e806abdba..f63a186166ecc 100644
> --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
> +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
> @@ -352,9 +352,27 @@ Instruction *InstCombinerImpl::simplifyMaskedStore(IntrinsicInst &II) {
>   // * Dereferenceable address & few lanes -> scalarize speculative load/selects
>   // * Adjacent vector addresses -> masked.load
>   // * Narrow width by halfs excluding zero/undef lanes
> -// * Vector splat address w/known mask -> scalar load
>   // * Vector incrementing address -> vector masked load
>   Instruction *InstCombinerImpl::simplifyMaskedGather(IntrinsicInst &II) {
> +  auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(2));
> +  if (!ConstMask)
> +    return nullptr;
> +
> +  // Vector splat address w/known mask -> scalar load
> +  // Fold the gather to load the source vector first lane
> +  // because it is reloading the same value each time
> +  if (ConstMask->isAllOnesValue())
> +    if (auto *SplatPtr = getSplatValue(II.getArgOperand(0))) {
> +      auto *VecTy = cast<VectorType>(II.getType());
> +      const Align Alignment =
> +          cast<ConstantInt>(II.getArgOperand(1))->getAlignValue();
> +      LoadInst *L = Builder.CreateAlignedLoad(VecTy->getElementType(), SplatPtr,
> +                                              Alignment, "load.scalar");
> +      Value *Shuf =
> +          Builder.CreateVectorSplat(VecTy->getElementCount(), L, "broadcast");
> +      return replaceInstUsesWith(II, cast<Instruction>(Shuf));
> +    }
> +
>     return nullptr;
>   }
>   
>
> diff  --git a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
> index a82aebd738fea..5ba559fd35f9c 100644
> --- a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
> +++ b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
> @@ -376,3 +376,65 @@ define void @negative_scatter_v4i16_no_uniform_vals_no_uniform_ptrs_all_active_m
>   ; Function Attrs:
>   declare void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16>, <4 x i16*>, i32 immarg, <4 x i1>)
>   declare void @llvm.masked.scatter.nxv4i16.nxv4p0i16(<vscale x 4 x i16>, <vscale x 4 x i16*>, i32 immarg, <vscale x 4 x i1>)
> +
> +; Test gathers that can be simplified to scalar load + splat
> +
> +;; Splat address and all active mask
> +define <vscale x 2 x i64> @gather_nxv2i64_uniform_ptrs_all_active_mask(i64* %src) {
> +; CHECK-LABEL: @gather_nxv2i64_uniform_ptrs_all_active_mask(
> +; CHECK-NEXT:    [[LOAD_SCALAR:%.*]] = load i64, i64* [[SRC:%.*]], align 8
> +; CHECK-NEXT:    [[BROADCAST_SPLATINSERT1:%.*]] = insertelement <vscale x 2 x i64> poison, i64 [[LOAD_SCALAR]], i64 0
> +; CHECK-NEXT:    [[BROADCAST_SPLAT2:%.*]] = shufflevector <vscale x 2 x i64> [[BROADCAST_SPLATINSERT1]], <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer
> +; CHECK-NEXT:    ret <vscale x 2 x i64> [[BROADCAST_SPLAT2]]
> +;
> +  %broadcast.splatinsert = insertelement <vscale x 2 x i64*> poison, i64 *%src, i32 0
> +  %broadcast.splat = shufflevector <vscale x 2 x i64*> %broadcast.splatinsert, <vscale x 2 x i64*> poison, <vscale x 2 x i32> zeroinitializer
> +  %res = call <vscale x 2 x i64> @llvm.masked.gather.nxv2i64(<vscale x 2 x i64*> %broadcast.splat, i32 8, <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> poison, i1 true, i32 0), <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer), <vscale x 2 x i64> undef)
> +  ret <vscale x 2 x i64> %res
> +}
> +
> +define <2 x i64> @gather_v2i64_uniform_ptrs_all_active_mask(i64* %src) {
> +; CHECK-LABEL: @gather_v2i64_uniform_ptrs_all_active_mask(
> +; CHECK-NEXT:    [[LOAD_SCALAR:%.*]] = load i64, i64* [[SRC:%.*]], align 8
> +; CHECK-NEXT:    [[BROADCAST_SPLATINSERT1:%.*]] = insertelement <2 x i64> poison, i64 [[LOAD_SCALAR]], i64 0
> +; CHECK-NEXT:    [[BROADCAST_SPLAT2:%.*]] = shufflevector <2 x i64> [[BROADCAST_SPLATINSERT1]], <2 x i64> poison, <2 x i32> zeroinitializer
> +; CHECK-NEXT:    ret <2 x i64> [[BROADCAST_SPLAT2]]
> +;
> +  %broadcast.splatinsert = insertelement <2 x i64*> poison, i64 *%src, i32 0
> +  %broadcast.splat = shufflevector <2 x i64*> %broadcast.splatinsert, <2 x i64*> poison, <2 x i32> zeroinitializer
> +  %res = call <2 x i64> @llvm.masked.gather.v2i64(<2 x i64*> %broadcast.splat, i32 8, <2 x i1> <i1 1, i1 1>, <2 x i64> undef)
> +  ret <2 x i64> %res
> +}
> +
> +; Negative gather tests
> +
> +;; Vector of pointers is not a splat.
> +define <2 x i64> @negative_gather_v2i64_non_uniform_ptrs_all_active_mask(<2 x i64*> %inVal, i64* %src ) {
> +; CHECK-LABEL: @negative_gather_v2i64_non_uniform_ptrs_all_active_mask(
> +; CHECK-NEXT:    [[INSERT_VALUE:%.*]] = insertelement <2 x i64*> [[INVAL:%.*]], i64* [[SRC:%.*]], i64 1
> +; CHECK-NEXT:    [[RES:%.*]] = call <2 x i64> @llvm.masked.gather.v2i64.v2p0i64(<2 x i64*> [[INSERT_VALUE]], i32 8, <2 x i1> <i1 true, i1 true>, <2 x i64> undef)
> +; CHECK-NEXT:    ret <2 x i64> [[RES]]
> +;
> +  %insert.value = insertelement <2 x i64*> %inVal, i64 *%src, i32 1
> +  %res = call <2 x i64> @llvm.masked.gather.v2i64(<2 x i64*> %insert.value, i32 8, <2 x i1><i1 1, i1 1>, <2 x i64> undef)
> +  ret <2 x i64> %res
> +}
> +
> +;; Unknown mask value
> +define <2 x i64> @negative_gather_v2i64_uniform_ptrs_no_all_active_mask(i64* %src, <2 x i1> %mask) {
> +; CHECK-LABEL: @negative_gather_v2i64_uniform_ptrs_no_all_active_mask(
> +; CHECK-NEXT:    [[BROADCAST_SPLATINSERT:%.*]] = insertelement <2 x i64*> poison, i64* [[SRC:%.*]], i64 0
> +; CHECK-NEXT:    [[BROADCAST_SPLAT:%.*]] = shufflevector <2 x i64*> [[BROADCAST_SPLATINSERT]], <2 x i64*> poison, <2 x i32> zeroinitializer
> +; CHECK-NEXT:    [[RES:%.*]] = call <2 x i64> @llvm.masked.gather.v2i64.v2p0i64(<2 x i64*> [[BROADCAST_SPLAT]], i32 8, <2 x i1> [[MASK:%.*]], <2 x i64> undef)
> +; CHECK-NEXT:    ret <2 x i64> [[RES]]
> +;
> +  %broadcast.splatinsert = insertelement <2 x i64*> poison, i64 *%src, i32 0
> +  %broadcast.splat = shufflevector <2 x i64*> %broadcast.splatinsert, <2 x i64*> poison, <2 x i32> zeroinitializer
> +  %res = call <2 x i64> @llvm.masked.gather.v2i64(<2 x i64*> %broadcast.splat, i32 8, <2 x i1> %mask, <2 x i64> undef)
> +  ret <2 x i64> %res
> +}
> +
> +; Function Attrs:
> +declare <vscale x 2 x i64> @llvm.masked.gather.nxv2i64(<vscale x 2 x i64*>, i32, <vscale x 2 x i1>, <vscale x 2 x i64>)
> +declare <2 x i64> @llvm.masked.gather.v2i64(<2 x i64*>, i32, <2 x i1>, <2 x i64>)
> +
>
>
>          
> _______________________________________________
> llvm-commits mailing list
> llvm-commits at lists.llvm.org
> https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-commits


More information about the llvm-commits mailing list