[llvm] ad43217 - [InstCombine] Fold for masked gather when loading the same value each time.

Caroline Concatto via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 21 06:20:23 PST 2022


Author: Caroline Concatto
Date: 2022-01-21T14:19:51Z
New Revision: ad43217a046634be24174299beec3a28018ec3c0

URL: https://github.com/llvm/llvm-project/commit/ad43217a046634be24174299beec3a28018ec3c0
DIFF: https://github.com/llvm/llvm-project/commit/ad43217a046634be24174299beec3a28018ec3c0.diff

LOG: [InstCombine] Fold for masked gather when loading the same value each time.

This patch checks in the masked gather when the first operand value is a
splat and the mask is all one, because the masked gather is reloading the
same value each time. This patch replaces this pattern of masked gather by
a scalar load of the value and splats it in a vector.

Differential Revision: https://reviews.llvm.org/D115726

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
    llvm/test/Transforms/InstCombine/masked_intrinsics.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index e3a9e806abdba..f63a186166ecc 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -352,9 +352,27 @@ Instruction *InstCombinerImpl::simplifyMaskedStore(IntrinsicInst &II) {
 // * Dereferenceable address & few lanes -> scalarize speculative load/selects
 // * Adjacent vector addresses -> masked.load
 // * Narrow width by halfs excluding zero/undef lanes
-// * Vector splat address w/known mask -> scalar load
 // * Vector incrementing address -> vector masked load
 Instruction *InstCombinerImpl::simplifyMaskedGather(IntrinsicInst &II) {
+  auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(2));
+  if (!ConstMask)
+    return nullptr;
+
+  // Vector splat address w/known mask -> scalar load
+  // Fold the gather to load the source vector first lane
+  // because it is reloading the same value each time
+  if (ConstMask->isAllOnesValue())
+    if (auto *SplatPtr = getSplatValue(II.getArgOperand(0))) {
+      auto *VecTy = cast<VectorType>(II.getType());
+      const Align Alignment =
+          cast<ConstantInt>(II.getArgOperand(1))->getAlignValue();
+      LoadInst *L = Builder.CreateAlignedLoad(VecTy->getElementType(), SplatPtr,
+                                              Alignment, "load.scalar");
+      Value *Shuf =
+          Builder.CreateVectorSplat(VecTy->getElementCount(), L, "broadcast");
+      return replaceInstUsesWith(II, cast<Instruction>(Shuf));
+    }
+
   return nullptr;
 }
 

diff  --git a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
index a82aebd738fea..5ba559fd35f9c 100644
--- a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
+++ b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
@@ -376,3 +376,65 @@ define void @negative_scatter_v4i16_no_uniform_vals_no_uniform_ptrs_all_active_m
 ; Function Attrs:
 declare void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16>, <4 x i16*>, i32 immarg, <4 x i1>)
 declare void @llvm.masked.scatter.nxv4i16.nxv4p0i16(<vscale x 4 x i16>, <vscale x 4 x i16*>, i32 immarg, <vscale x 4 x i1>)
+
+; Test gathers that can be simplified to scalar load + splat
+
+;; Splat address and all active mask
+define <vscale x 2 x i64> @gather_nxv2i64_uniform_ptrs_all_active_mask(i64* %src) {
+; CHECK-LABEL: @gather_nxv2i64_uniform_ptrs_all_active_mask(
+; CHECK-NEXT:    [[LOAD_SCALAR:%.*]] = load i64, i64* [[SRC:%.*]], align 8
+; CHECK-NEXT:    [[BROADCAST_SPLATINSERT1:%.*]] = insertelement <vscale x 2 x i64> poison, i64 [[LOAD_SCALAR]], i64 0
+; CHECK-NEXT:    [[BROADCAST_SPLAT2:%.*]] = shufflevector <vscale x 2 x i64> [[BROADCAST_SPLATINSERT1]], <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer
+; CHECK-NEXT:    ret <vscale x 2 x i64> [[BROADCAST_SPLAT2]]
+;
+  %broadcast.splatinsert = insertelement <vscale x 2 x i64*> poison, i64 *%src, i32 0
+  %broadcast.splat = shufflevector <vscale x 2 x i64*> %broadcast.splatinsert, <vscale x 2 x i64*> poison, <vscale x 2 x i32> zeroinitializer
+  %res = call <vscale x 2 x i64> @llvm.masked.gather.nxv2i64(<vscale x 2 x i64*> %broadcast.splat, i32 8, <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> poison, i1 true, i32 0), <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer), <vscale x 2 x i64> undef)
+  ret <vscale x 2 x i64> %res
+}
+
+define <2 x i64> @gather_v2i64_uniform_ptrs_all_active_mask(i64* %src) {
+; CHECK-LABEL: @gather_v2i64_uniform_ptrs_all_active_mask(
+; CHECK-NEXT:    [[LOAD_SCALAR:%.*]] = load i64, i64* [[SRC:%.*]], align 8
+; CHECK-NEXT:    [[BROADCAST_SPLATINSERT1:%.*]] = insertelement <2 x i64> poison, i64 [[LOAD_SCALAR]], i64 0
+; CHECK-NEXT:    [[BROADCAST_SPLAT2:%.*]] = shufflevector <2 x i64> [[BROADCAST_SPLATINSERT1]], <2 x i64> poison, <2 x i32> zeroinitializer
+; CHECK-NEXT:    ret <2 x i64> [[BROADCAST_SPLAT2]]
+;
+  %broadcast.splatinsert = insertelement <2 x i64*> poison, i64 *%src, i32 0
+  %broadcast.splat = shufflevector <2 x i64*> %broadcast.splatinsert, <2 x i64*> poison, <2 x i32> zeroinitializer
+  %res = call <2 x i64> @llvm.masked.gather.v2i64(<2 x i64*> %broadcast.splat, i32 8, <2 x i1> <i1 1, i1 1>, <2 x i64> undef)
+  ret <2 x i64> %res
+}
+
+; Negative gather tests
+
+;; Vector of pointers is not a splat.
+define <2 x i64> @negative_gather_v2i64_non_uniform_ptrs_all_active_mask(<2 x i64*> %inVal, i64* %src ) {
+; CHECK-LABEL: @negative_gather_v2i64_non_uniform_ptrs_all_active_mask(
+; CHECK-NEXT:    [[INSERT_VALUE:%.*]] = insertelement <2 x i64*> [[INVAL:%.*]], i64* [[SRC:%.*]], i64 1
+; CHECK-NEXT:    [[RES:%.*]] = call <2 x i64> @llvm.masked.gather.v2i64.v2p0i64(<2 x i64*> [[INSERT_VALUE]], i32 8, <2 x i1> <i1 true, i1 true>, <2 x i64> undef)
+; CHECK-NEXT:    ret <2 x i64> [[RES]]
+;
+  %insert.value = insertelement <2 x i64*> %inVal, i64 *%src, i32 1
+  %res = call <2 x i64> @llvm.masked.gather.v2i64(<2 x i64*> %insert.value, i32 8, <2 x i1><i1 1, i1 1>, <2 x i64> undef)
+  ret <2 x i64> %res
+}
+
+;; Unknown mask value
+define <2 x i64> @negative_gather_v2i64_uniform_ptrs_no_all_active_mask(i64* %src, <2 x i1> %mask) {
+; CHECK-LABEL: @negative_gather_v2i64_uniform_ptrs_no_all_active_mask(
+; CHECK-NEXT:    [[BROADCAST_SPLATINSERT:%.*]] = insertelement <2 x i64*> poison, i64* [[SRC:%.*]], i64 0
+; CHECK-NEXT:    [[BROADCAST_SPLAT:%.*]] = shufflevector <2 x i64*> [[BROADCAST_SPLATINSERT]], <2 x i64*> poison, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[RES:%.*]] = call <2 x i64> @llvm.masked.gather.v2i64.v2p0i64(<2 x i64*> [[BROADCAST_SPLAT]], i32 8, <2 x i1> [[MASK:%.*]], <2 x i64> undef)
+; CHECK-NEXT:    ret <2 x i64> [[RES]]
+;
+  %broadcast.splatinsert = insertelement <2 x i64*> poison, i64 *%src, i32 0
+  %broadcast.splat = shufflevector <2 x i64*> %broadcast.splatinsert, <2 x i64*> poison, <2 x i32> zeroinitializer
+  %res = call <2 x i64> @llvm.masked.gather.v2i64(<2 x i64*> %broadcast.splat, i32 8, <2 x i1> %mask, <2 x i64> undef)
+  ret <2 x i64> %res
+}
+
+; Function Attrs:
+declare <vscale x 2 x i64> @llvm.masked.gather.nxv2i64(<vscale x 2 x i64*>, i32, <vscale x 2 x i1>, <vscale x 2 x i64>)
+declare <2 x i64> @llvm.masked.gather.v2i64(<2 x i64*>, i32, <2 x i1>, <2 x i64>)
+


        


More information about the llvm-commits mailing list