[llvm] 8e5a5b6 - [InstCombine] Fold for masked scatters to a uniform address
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Fri Jan 14 07:51:44 PST 2022
As a possible extension, you can perform the scalar splat for any
constant max by using the last active lane in the mask. That would
generalize your code here.
Philip
On 1/14/22 1:56 AM, Caroline Concatto via llvm-commits wrote:
> Author: Caroline Concatto
> Date: 2022-01-14T09:44:34Z
> New Revision: 8e5a5b619d34c7846e35d219c0747b9c29654f15
>
> URL: https://github.com/llvm/llvm-project/commit/8e5a5b619d34c7846e35d219c0747b9c29654f15
> DIFF: https://github.com/llvm/llvm-project/commit/8e5a5b619d34c7846e35d219c0747b9c29654f15.diff
>
> LOG: [InstCombine] Fold for masked scatters to a uniform address
>
> When masked scatter intrinsic does a uniform store to a destination
> address from a source vector, and in this case, the mask is all one value.
> This patch replaces the masked scatter with an extracted element of the
> last lane of the source vector and stores it in the destination vector.
> This patch also folds when the value in the masked scatter is a splat.
> In this case, the mask cannot be all zero, and it folds to a scalar store
> of the value in the destination pointer.
>
> Differential Revision: https://reviews.llvm.org/D115724
>
> Added:
>
>
> Modified:
> llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
> llvm/test/Transforms/InstCombine/masked_intrinsics.ll
>
> Removed:
>
>
>
> ################################################################################
> diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
> index 22c736f423f07..37cc21aeff8b8 100644
> --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
> +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
> @@ -362,7 +362,6 @@ Instruction *InstCombinerImpl::simplifyMaskedGather(IntrinsicInst &II) {
> // * Single constant active lane -> store
> // * Adjacent vector addresses -> masked.store
> // * Narrow store width by halfs excluding zero/undef lanes
> -// * Vector splat address w/known mask -> scalar store
> // * Vector incrementing address -> vector masked store
> Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
> auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3));
> @@ -373,6 +372,34 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
> if (ConstMask->isNullValue())
> return eraseInstFromFunction(II);
>
> + // Vector splat address -> scalar store
> + if (auto *SplatPtr = getSplatValue(II.getArgOperand(1))) {
> + // scatter(splat(value), splat(ptr), non-zero-mask) -> store value, ptr
> + if (auto *SplatValue = getSplatValue(II.getArgOperand(0))) {
> + Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
> + StoreInst *S =
> + new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false, Alignment);
> + S->copyMetadata(II);
> + return S;
> + }
> + // scatter(vector, splat(ptr), splat(true)) -> store extract(vector,
> + // lastlane), ptr
> + if (ConstMask->isAllOnesValue()) {
> + Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
> + VectorType *WideLoadTy = cast<VectorType>(II.getArgOperand(1)->getType());
> + ElementCount VF = WideLoadTy->getElementCount();
> + Constant *EC =
> + ConstantInt::get(Builder.getInt32Ty(), VF.getKnownMinValue());
> + Value *RunTimeVF = VF.isScalable() ? Builder.CreateVScale(EC) : EC;
> + Value *LastLane = Builder.CreateSub(RunTimeVF, Builder.getInt32(1));
> + Value *Extract =
> + Builder.CreateExtractElement(II.getArgOperand(0), LastLane);
> + StoreInst *S =
> + new StoreInst(Extract, SplatPtr, /*IsVolatile=*/false, Alignment);
> + S->copyMetadata(II);
> + return S;
> + }
> + }
> if (isa<ScalableVectorType>(ConstMask->getType()))
> return nullptr;
>
>
> diff --git a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
> index 884a042b58222..a82aebd738fea 100644
> --- a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
> +++ b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
> @@ -269,3 +269,110 @@ define void @scatter_demandedelts(double* %ptr, double %val) {
> call void @llvm.masked.scatter.v2f64.v2p0f64(<2 x double> %valvec2, <2 x double*> %ptrs, i32 8, <2 x i1> <i1 true, i1 false>)
> ret void
> }
> +
> +
> +; Test scatters that can be simplified to scalar stores.
> +
> +;; Value splat (mask is not used)
> +define void @scatter_v4i16_uniform_vals_uniform_ptrs_no_all_active_mask(i16* %dst, i16 %val) {
> +; CHECK-LABEL: @scatter_v4i16_uniform_vals_uniform_ptrs_no_all_active_mask(
> +; CHECK-NEXT: entry:
> +; CHECK-NEXT: store i16 [[VAL:%.*]], i16* [[DST:%.*]], align 2
> +; CHECK-NEXT: ret void
> +;
> +entry:
> + %broadcast.splatinsert = insertelement <4 x i16*> poison, i16* %dst, i32 0
> + %broadcast.splat = shufflevector <4 x i16*> %broadcast.splatinsert, <4 x i16*> poison, <4 x i32> zeroinitializer
> + %broadcast.value = insertelement <4 x i16> poison, i16 %val, i32 0
> + %broadcast.splatvalue = shufflevector <4 x i16> %broadcast.value, <4 x i16> poison, <4 x i32> zeroinitializer
> + call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %broadcast.splatvalue, <4 x i16*> %broadcast.splat, i32 2, <4 x i1> <i1 0, i1 0, i1 1, i1 1>)
> + ret void
> +}
> +
> +define void @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask(i16* %dst, i16 %val) {
> +; CHECK-LABEL: @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask(
> +; CHECK-NEXT: entry:
> +; CHECK-NEXT: store i16 [[VAL:%.*]], i16* [[DST:%.*]], align 2
> +; CHECK-NEXT: ret void
> +;
> +entry:
> + %broadcast.splatinsert = insertelement <vscale x 4 x i16*> poison, i16* %dst, i32 0
> + %broadcast.splat = shufflevector <vscale x 4 x i16*> %broadcast.splatinsert, <vscale x 4 x i16*> poison, <vscale x 4 x i32> zeroinitializer
> + %broadcast.value = insertelement <vscale x 4 x i16> poison, i16 %val, i32 0
> + %broadcast.splatvalue = shufflevector <vscale x 4 x i16> %broadcast.value, <vscale x 4 x i16> poison, <vscale x 4 x i32> zeroinitializer
> + call void @llvm.masked.scatter.nxv4i16.nxv4p0i16(<vscale x 4 x i16> %broadcast.splatvalue, <vscale x 4 x i16*> %broadcast.splat, i32 2, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> zeroinitializer , i1 true, i32 0), <vscale x 4 x i1> zeroinitializer, <vscale x 4 x i32> zeroinitializer))
> + ret void
> +}
> +
> +;; The pointer is splat and mask is all active, but value is not a splat
> +define void @scatter_v4i16_no_uniform_vals_uniform_ptrs_all_active_mask(i16* %dst, <4 x i16>* %src) {
> +; CHECK-LABEL: @scatter_v4i16_no_uniform_vals_uniform_ptrs_all_active_mask(
> +; CHECK-NEXT: entry:
> +; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i16>, <4 x i16>* [[SRC:%.*]], align 2
> +; CHECK-NEXT: [[TMP0:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i64 3
> +; CHECK-NEXT: store i16 [[TMP0]], i16* [[DST:%.*]], align 2
> +; CHECK-NEXT: ret void
> +;
> +entry:
> + %broadcast.splatinsert = insertelement <4 x i16*> poison, i16* %dst, i32 0
> + %broadcast.splat = shufflevector <4 x i16*> %broadcast.splatinsert, <4 x i16*> poison, <4 x i32> zeroinitializer
> + %wide.load = load <4 x i16>, <4 x i16>* %src, align 2
> + call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %wide.load, <4 x i16*> %broadcast.splat, i32 2, <4 x i1> <i1 1, i1 1, i1 1, i1 1>)
> + ret void
> +}
> +
> +define void @scatter_nxv4i16_no_uniform_vals_uniform_ptrs_all_active_mask(i16* %dst, <vscale x 4 x i16>* %src) {
> +; CHECK-LABEL: @scatter_nxv4i16_no_uniform_vals_uniform_ptrs_all_active_mask(
> +; CHECK-NEXT: entry:
> +; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <vscale x 4 x i16>, <vscale x 4 x i16>* [[SRC:%.*]], align 2
> +; CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.vscale.i32()
> +; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[TMP0]], 2
> +; CHECK-NEXT: [[TMP2:%.*]] = add i32 [[TMP1]], -1
> +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <vscale x 4 x i16> [[WIDE_LOAD]], i32 [[TMP2]]
> +; CHECK-NEXT: store i16 [[TMP3]], i16* [[DST:%.*]], align 2
> +; CHECK-NEXT: ret void
> +;
> +entry:
> + %broadcast.splatinsert = insertelement <vscale x 4 x i16*> poison, i16* %dst, i32 0
> + %broadcast.splat = shufflevector <vscale x 4 x i16*> %broadcast.splatinsert, <vscale x 4 x i16*> poison, <vscale x 4 x i32> zeroinitializer
> + %wide.load = load <vscale x 4 x i16>, <vscale x 4 x i16>* %src, align 2
> + call void @llvm.masked.scatter.nxv4i16.nxv4p0i16(<vscale x 4 x i16> %wide.load, <vscale x 4 x i16*> %broadcast.splat, i32 2, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i32 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer))
> + ret void
> +}
> +
> +; Negative scatter tests
> +
> +;; Pointer is splat, but mask is not all active and value is not a splat
> +define void @negative_scatter_v4i16_no_uniform_vals_uniform_ptrs_all_inactive_mask(i16* %dst, <4 x i16>* %src) {
> +; CHECK-LABEL: @negative_scatter_v4i16_no_uniform_vals_uniform_ptrs_all_inactive_mask(
> +; CHECK-NEXT: [[INSERT_ELT:%.*]] = insertelement <4 x i16*> poison, i16* [[DST:%.*]], i64 0
> +; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <4 x i16*> [[INSERT_ELT]], <4 x i16*> poison, <4 x i32> <i32 undef, i32 undef, i32 0, i32 0>
> +; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i16>, <4 x i16>* [[SRC:%.*]], align 2
> +; CHECK-NEXT: call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> [[WIDE_LOAD]], <4 x i16*> [[BROADCAST_SPLAT]], i32 2, <4 x i1> <i1 false, i1 false, i1 true, i1 true>)
> +; CHECK-NEXT: ret void
> +;
> + %insert.elt = insertelement <4 x i16*> poison, i16* %dst, i32 0
> + %broadcast.splat = shufflevector <4 x i16*> %insert.elt, <4 x i16*> poison, <4 x i32> zeroinitializer
> + %wide.load = load <4 x i16>, <4 x i16>* %src, align 2
> + call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %wide.load, <4 x i16*> %broadcast.splat, i32 2, <4 x i1> <i1 0, i1 0, i1 1, i1 1>)
> + ret void
> +}
> +
> +;; The pointer in NOT a splat
> +define void @negative_scatter_v4i16_no_uniform_vals_no_uniform_ptrs_all_active_mask(<4 x i16*> %inPtr, <4 x i16>* %src) {
> +; CHECK-LABEL: @negative_scatter_v4i16_no_uniform_vals_no_uniform_ptrs_all_active_mask(
> +; CHECK-NEXT: [[BROADCAST:%.*]] = shufflevector <4 x i16*> [[INPTR:%.*]], <4 x i16*> poison, <4 x i32> zeroinitializer
> +; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i16>, <4 x i16>* [[SRC:%.*]], align 2
> +; CHECK-NEXT: call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> [[WIDE_LOAD]], <4 x i16*> [[BROADCAST]], i32 2, <4 x i1> <i1 true, i1 true, i1 true, i1 true>)
> +; CHECK-NEXT: ret void
> +;
> + %broadcast= shufflevector <4 x i16*> %inPtr, <4 x i16*> poison, <4 x i32> zeroinitializer
> + %wide.load = load <4 x i16>, <4 x i16>* %src, align 2
> + call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %wide.load, <4 x i16*> %broadcast, i32 2, <4 x i1> <i1 1, i1 1, i1 1, i1 1> )
> + ret void
> +}
> +
> +
> +; Function Attrs:
> +declare void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16>, <4 x i16*>, i32 immarg, <4 x i1>)
> +declare void @llvm.masked.scatter.nxv4i16.nxv4p0i16(<vscale x 4 x i16>, <vscale x 4 x i16*>, i32 immarg, <vscale x 4 x i1>)
>
>
>
> _______________________________________________
> llvm-commits mailing list
> llvm-commits at lists.llvm.org
> https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-commits
More information about the llvm-commits
mailing list