[llvm] b16e868 - [CodeGenPrepare][X86] Teach optimizeGatherScatterInst to turn a splat pointer into GEP with scalar base and 0 index

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 2 20:45:10 PDT 2020


Author: Craig Topper
Date: 2020-09-02T20:44:12-07:00
New Revision: b16e8687ab6c977ddab3409939e867828f394311

URL: https://github.com/llvm/llvm-project/commit/b16e8687ab6c977ddab3409939e867828f394311
DIFF: https://github.com/llvm/llvm-project/commit/b16e8687ab6c977ddab3409939e867828f394311.diff

LOG: [CodeGenPrepare][X86] Teach optimizeGatherScatterInst to turn a splat pointer into GEP with scalar base and 0 index

This helps SelectionDAGBuilder recognize the splat can be used as a uniform base.

Reviewed By: RKSimon

Differential Revision: https://reviews.llvm.org/D86371

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/VectorUtils.h
    llvm/lib/Analysis/VectorUtils.cpp
    llvm/lib/CodeGen/CodeGenPrepare.cpp
    llvm/test/CodeGen/X86/masked_gather_scatter.ll
    llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index 074960e7ced2..8498335bf78e 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -358,7 +358,7 @@ int getSplatIndex(ArrayRef<int> Mask);
 /// Get splat value if the input is a splat vector or return nullptr.
 /// The value may be extracted from a splat constants vector or from
 /// a sequence of instructions that broadcast a single value into a vector.
-const Value *getSplatValue(const Value *V);
+Value *getSplatValue(const Value *V);
 
 /// Return true if each element of the vector value \p V is poisoned or equal to
 /// every other non-poisoned element. If an index element is specified, either

diff  --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index 0bc8b7281d91..e241300dd2e7 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -342,7 +342,7 @@ int llvm::getSplatIndex(ArrayRef<int> Mask) {
 /// This function is not fully general. It checks only 2 cases:
 /// the input value is (1) a splat constant vector or (2) a sequence
 /// of instructions that broadcasts a scalar at element 0.
-const llvm::Value *llvm::getSplatValue(const Value *V) {
+Value *llvm::getSplatValue(const Value *V) {
   if (isa<VectorType>(V->getType()))
     if (auto *C = dyn_cast<Constant>(V))
       return C->getSplatValue();

diff  --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 3272f36a1436..9a4ed2fab608 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -5314,88 +5314,112 @@ bool CodeGenPrepare::optimizeMemoryInst(Instruction *MemoryInst, Value *Addr,
 /// zero index.
 bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
                                                Value *Ptr) {
-  const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
-  if (!GEP || !GEP->hasIndices())
+  // FIXME: Support scalable vectors.
+  if (isa<ScalableVectorType>(Ptr->getType()))
     return false;
 
-  // If the GEP and the gather/scatter aren't in the same BB, don't optimize.
-  // FIXME: We should support this by sinking the GEP.
-  if (MemoryInst->getParent() != GEP->getParent())
-    return false;
-
-  SmallVector<Value *, 2> Ops(GEP->op_begin(), GEP->op_end());
+  Value *NewAddr;
 
-  bool RewriteGEP = false;
+  if (const auto *GEP = dyn_cast<GetElementPtrInst>(Ptr)) {
+    // Don't optimize GEPs that don't have indices.
+    if (!GEP->hasIndices())
+      return false;
 
-  if (Ops[0]->getType()->isVectorTy()) {
-    Ops[0] = const_cast<Value *>(getSplatValue(Ops[0]));
-    if (!Ops[0])
+    // If the GEP and the gather/scatter aren't in the same BB, don't optimize.
+    // FIXME: We should support this by sinking the GEP.
+    if (MemoryInst->getParent() != GEP->getParent())
       return false;
-    RewriteGEP = true;
-  }
 
-  unsigned FinalIndex = Ops.size() - 1;
+    SmallVector<Value *, 2> Ops(GEP->op_begin(), GEP->op_end());
 
-  // Ensure all but the last index is 0.
-  // FIXME: This isn't strictly required. All that's required is that they are
-  // all scalars or splats.
-  for (unsigned i = 1; i < FinalIndex; ++i) {
-    auto *C = dyn_cast<Constant>(Ops[i]);
-    if (!C)
-      return false;
-    if (isa<VectorType>(C->getType()))
-      C = C->getSplatValue();
-    auto *CI = dyn_cast_or_null<ConstantInt>(C);
-    if (!CI || !CI->isZero())
-      return false;
-    // Scalarize the index if needed.
-    Ops[i] = CI;
-  }
-
-  // Try to scalarize the final index.
-  if (Ops[FinalIndex]->getType()->isVectorTy()) {
-    if (Value *V = const_cast<Value *>(getSplatValue(Ops[FinalIndex]))) {
-      auto *C = dyn_cast<ConstantInt>(V);
-      // Don't scalarize all zeros vector.
-      if (!C || !C->isZero()) {
-        Ops[FinalIndex] = V;
-        RewriteGEP = true;
-      }
+    bool RewriteGEP = false;
+
+    if (Ops[0]->getType()->isVectorTy()) {
+      Ops[0] = getSplatValue(Ops[0]);
+      if (!Ops[0])
+        return false;
+      RewriteGEP = true;
     }
-  }
 
-  // If we made any changes or the we have extra operands, we need to generate
-  // new instructions.
-  if (!RewriteGEP && Ops.size() == 2)
-    return false;
+    unsigned FinalIndex = Ops.size() - 1;
 
-  unsigned NumElts = cast<FixedVectorType>(Ptr->getType())->getNumElements();
+    // Ensure all but the last index is 0.
+    // FIXME: This isn't strictly required. All that's required is that they are
+    // all scalars or splats.
+    for (unsigned i = 1; i < FinalIndex; ++i) {
+      auto *C = dyn_cast<Constant>(Ops[i]);
+      if (!C)
+        return false;
+      if (isa<VectorType>(C->getType()))
+        C = C->getSplatValue();
+      auto *CI = dyn_cast_or_null<ConstantInt>(C);
+      if (!CI || !CI->isZero())
+        return false;
+      // Scalarize the index if needed.
+      Ops[i] = CI;
+    }
+
+    // Try to scalarize the final index.
+    if (Ops[FinalIndex]->getType()->isVectorTy()) {
+      if (Value *V = getSplatValue(Ops[FinalIndex])) {
+        auto *C = dyn_cast<ConstantInt>(V);
+        // Don't scalarize all zeros vector.
+        if (!C || !C->isZero()) {
+          Ops[FinalIndex] = V;
+          RewriteGEP = true;
+        }
+      }
+    }
 
-  IRBuilder<> Builder(MemoryInst);
+    // If we made any changes or the we have extra operands, we need to generate
+    // new instructions.
+    if (!RewriteGEP && Ops.size() == 2)
+      return false;
 
-  Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType());
+    unsigned NumElts = cast<FixedVectorType>(Ptr->getType())->getNumElements();
 
-  Value *NewAddr;
+    IRBuilder<> Builder(MemoryInst);
 
-  // If the final index isn't a vector, emit a scalar GEP containing all ops
-  // and a vector GEP with all zeroes final index.
-  if (!Ops[FinalIndex]->getType()->isVectorTy()) {
-    NewAddr = Builder.CreateGEP(Ops[0], makeArrayRef(Ops).drop_front());
-    auto *IndexTy = FixedVectorType::get(ScalarIndexTy, NumElts);
-    NewAddr = Builder.CreateGEP(NewAddr, Constant::getNullValue(IndexTy));
-  } else {
-    Value *Base = Ops[0];
-    Value *Index = Ops[FinalIndex];
+    Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType());
 
-    // Create a scalar GEP if there are more than 2 operands.
-    if (Ops.size() != 2) {
-      // Replace the last index with 0.
-      Ops[FinalIndex] = Constant::getNullValue(ScalarIndexTy);
-      Base = Builder.CreateGEP(Base, makeArrayRef(Ops).drop_front());
+    // If the final index isn't a vector, emit a scalar GEP containing all ops
+    // and a vector GEP with all zeroes final index.
+    if (!Ops[FinalIndex]->getType()->isVectorTy()) {
+      NewAddr = Builder.CreateGEP(Ops[0], makeArrayRef(Ops).drop_front());
+      auto *IndexTy = FixedVectorType::get(ScalarIndexTy, NumElts);
+      NewAddr = Builder.CreateGEP(NewAddr, Constant::getNullValue(IndexTy));
+    } else {
+      Value *Base = Ops[0];
+      Value *Index = Ops[FinalIndex];
+
+      // Create a scalar GEP if there are more than 2 operands.
+      if (Ops.size() != 2) {
+        // Replace the last index with 0.
+        Ops[FinalIndex] = Constant::getNullValue(ScalarIndexTy);
+        Base = Builder.CreateGEP(Base, makeArrayRef(Ops).drop_front());
+      }
+
+      // Now create the GEP with scalar pointer and vector index.
+      NewAddr = Builder.CreateGEP(Base, Index);
     }
+  } else if (!isa<Constant>(Ptr)) {
+    // Not a GEP, maybe its a splat and we can create a GEP to enable
+    // SelectionDAGBuilder to use it as a uniform base.
+    Value *V = getSplatValue(Ptr);
+    if (!V)
+      return false;
+
+    unsigned NumElts = cast<FixedVectorType>(Ptr->getType())->getNumElements();
+
+    IRBuilder<> Builder(MemoryInst);
 
-    // Now create the GEP with scalar pointer and vector index.
-    NewAddr = Builder.CreateGEP(Base, Index);
+    // Emit a vector GEP with a scalar pointer and all 0s vector index.
+    Type *ScalarIndexTy = DL->getIndexType(V->getType()->getScalarType());
+    auto *IndexTy = FixedVectorType::get(ScalarIndexTy, NumElts);
+    NewAddr = Builder.CreateGEP(V, Constant::getNullValue(IndexTy));
+  } else {
+    // Constant, SelectionDAGBuilder knows to check if its a splat.
+    return false;
   }
 
   MemoryInst->replaceUsesOfWith(Ptr, NewAddr);

diff  --git a/llvm/test/CodeGen/X86/masked_gather_scatter.ll b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
index c5781e834075..88418fd85fe5 100644
--- a/llvm/test/CodeGen/X86/masked_gather_scatter.ll
+++ b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
@@ -3323,14 +3323,13 @@ define void @scatter_16i64_constant_indices(i32* %ptr, <16 x i1> %mask, <16 x i3
 define <4 x i32> @splat_ptr_gather(i32* %ptr, <4 x i1> %mask, <4 x i32> %passthru) {
 ; KNL_64-LABEL: splat_ptr_gather:
 ; KNL_64:       # %bb.0:
-; KNL_64-NEXT:    # kill: def $xmm1 killed $xmm1 def $ymm1
+; KNL_64-NEXT:    # kill: def $xmm1 killed $xmm1 def $zmm1
 ; KNL_64-NEXT:    vpslld $31, %xmm0, %xmm0
 ; KNL_64-NEXT:    vptestmd %zmm0, %zmm0, %k0
 ; KNL_64-NEXT:    kshiftlw $12, %k0, %k0
 ; KNL_64-NEXT:    kshiftrw $12, %k0, %k1
-; KNL_64-NEXT:    vmovq %rdi, %xmm0
-; KNL_64-NEXT:    vpbroadcastq %xmm0, %ymm0
-; KNL_64-NEXT:    vpgatherqd (,%zmm0), %ymm1 {%k1}
+; KNL_64-NEXT:    vpxor %xmm0, %xmm0, %xmm0
+; KNL_64-NEXT:    vpgatherdd (%rdi,%zmm0,4), %zmm1 {%k1}
 ; KNL_64-NEXT:    vmovdqa %xmm1, %xmm0
 ; KNL_64-NEXT:    vzeroupper
 ; KNL_64-NEXT:    retq
@@ -3342,8 +3341,9 @@ define <4 x i32> @splat_ptr_gather(i32* %ptr, <4 x i1> %mask, <4 x i32> %passthr
 ; KNL_32-NEXT:    vptestmd %zmm0, %zmm0, %k0
 ; KNL_32-NEXT:    kshiftlw $12, %k0, %k0
 ; KNL_32-NEXT:    kshiftrw $12, %k0, %k1
-; KNL_32-NEXT:    vpbroadcastd {{[0-9]+}}(%esp), %xmm0
-; KNL_32-NEXT:    vpgatherdd (,%zmm0), %zmm1 {%k1}
+; KNL_32-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; KNL_32-NEXT:    vpxor %xmm0, %xmm0, %xmm0
+; KNL_32-NEXT:    vpgatherdd (%eax,%zmm0,4), %zmm1 {%k1}
 ; KNL_32-NEXT:    vmovdqa %xmm1, %xmm0
 ; KNL_32-NEXT:    vzeroupper
 ; KNL_32-NEXT:    retl
@@ -3352,18 +3352,18 @@ define <4 x i32> @splat_ptr_gather(i32* %ptr, <4 x i1> %mask, <4 x i32> %passthr
 ; SKX:       # %bb.0:
 ; SKX-NEXT:    vpslld $31, %xmm0, %xmm0
 ; SKX-NEXT:    vpmovd2m %xmm0, %k1
-; SKX-NEXT:    vpbroadcastq %rdi, %ymm0
-; SKX-NEXT:    vpgatherqd (,%ymm0), %xmm1 {%k1}
+; SKX-NEXT:    vpxor %xmm0, %xmm0, %xmm0
+; SKX-NEXT:    vpgatherdd (%rdi,%xmm0,4), %xmm1 {%k1}
 ; SKX-NEXT:    vmovdqa %xmm1, %xmm0
-; SKX-NEXT:    vzeroupper
 ; SKX-NEXT:    retq
 ;
 ; SKX_32-LABEL: splat_ptr_gather:
 ; SKX_32:       # %bb.0:
 ; SKX_32-NEXT:    vpslld $31, %xmm0, %xmm0
 ; SKX_32-NEXT:    vpmovd2m %xmm0, %k1
-; SKX_32-NEXT:    vpbroadcastd {{[0-9]+}}(%esp), %xmm0
-; SKX_32-NEXT:    vpgatherdd (,%xmm0), %xmm1 {%k1}
+; SKX_32-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; SKX_32-NEXT:    vpxor %xmm0, %xmm0, %xmm0
+; SKX_32-NEXT:    vpgatherdd (%eax,%xmm0,4), %xmm1 {%k1}
 ; SKX_32-NEXT:    vmovdqa %xmm1, %xmm0
 ; SKX_32-NEXT:    retl
   %1 = insertelement <4 x i32*> undef, i32* %ptr, i32 0
@@ -3376,14 +3376,13 @@ declare  <4 x i32> @llvm.masked.gather.v4i32.v4p0i32(<4 x i32*>, i32, <4 x i1>,
 define void @splat_ptr_scatter(i32* %ptr, <4 x i1> %mask, <4 x i32> %val) {
 ; KNL_64-LABEL: splat_ptr_scatter:
 ; KNL_64:       # %bb.0:
-; KNL_64-NEXT:    # kill: def $xmm1 killed $xmm1 def $ymm1
+; KNL_64-NEXT:    # kill: def $xmm1 killed $xmm1 def $zmm1
 ; KNL_64-NEXT:    vpslld $31, %xmm0, %xmm0
 ; KNL_64-NEXT:    vptestmd %zmm0, %zmm0, %k0
 ; KNL_64-NEXT:    kshiftlw $12, %k0, %k0
 ; KNL_64-NEXT:    kshiftrw $12, %k0, %k1
-; KNL_64-NEXT:    vmovq %rdi, %xmm0
-; KNL_64-NEXT:    vpbroadcastq %xmm0, %ymm0
-; KNL_64-NEXT:    vpscatterqd %ymm1, (,%zmm0) {%k1}
+; KNL_64-NEXT:    vpxor %xmm0, %xmm0, %xmm0
+; KNL_64-NEXT:    vpscatterdd %zmm1, (%rdi,%zmm0,4) {%k1}
 ; KNL_64-NEXT:    vzeroupper
 ; KNL_64-NEXT:    retq
 ;
@@ -3394,8 +3393,9 @@ define void @splat_ptr_scatter(i32* %ptr, <4 x i1> %mask, <4 x i32> %val) {
 ; KNL_32-NEXT:    vptestmd %zmm0, %zmm0, %k0
 ; KNL_32-NEXT:    kshiftlw $12, %k0, %k0
 ; KNL_32-NEXT:    kshiftrw $12, %k0, %k1
-; KNL_32-NEXT:    vpbroadcastd {{[0-9]+}}(%esp), %xmm0
-; KNL_32-NEXT:    vpscatterdd %zmm1, (,%zmm0) {%k1}
+; KNL_32-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; KNL_32-NEXT:    vpxor %xmm0, %xmm0, %xmm0
+; KNL_32-NEXT:    vpscatterdd %zmm1, (%eax,%zmm0,4) {%k1}
 ; KNL_32-NEXT:    vzeroupper
 ; KNL_32-NEXT:    retl
 ;
@@ -3403,17 +3403,17 @@ define void @splat_ptr_scatter(i32* %ptr, <4 x i1> %mask, <4 x i32> %val) {
 ; SKX:       # %bb.0:
 ; SKX-NEXT:    vpslld $31, %xmm0, %xmm0
 ; SKX-NEXT:    vpmovd2m %xmm0, %k1
-; SKX-NEXT:    vpbroadcastq %rdi, %ymm0
-; SKX-NEXT:    vpscatterqd %xmm1, (,%ymm0) {%k1}
-; SKX-NEXT:    vzeroupper
+; SKX-NEXT:    vpxor %xmm0, %xmm0, %xmm0
+; SKX-NEXT:    vpscatterdd %xmm1, (%rdi,%xmm0,4) {%k1}
 ; SKX-NEXT:    retq
 ;
 ; SKX_32-LABEL: splat_ptr_scatter:
 ; SKX_32:       # %bb.0:
 ; SKX_32-NEXT:    vpslld $31, %xmm0, %xmm0
 ; SKX_32-NEXT:    vpmovd2m %xmm0, %k1
-; SKX_32-NEXT:    vpbroadcastd {{[0-9]+}}(%esp), %xmm0
-; SKX_32-NEXT:    vpscatterdd %xmm1, (,%xmm0) {%k1}
+; SKX_32-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; SKX_32-NEXT:    vpxor %xmm0, %xmm0, %xmm0
+; SKX_32-NEXT:    vpscatterdd %xmm1, (%eax,%xmm0,4) {%k1}
 ; SKX_32-NEXT:    retl
   %1 = insertelement <4 x i32*> undef, i32* %ptr, i32 0
   %2 = shufflevector <4 x i32*> %1, <4 x i32*> undef, <4 x i32> zeroinitializer

diff  --git a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll
index c1674ad4ca45..adb1930ca782 100644
--- a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll
+++ b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll
@@ -87,10 +87,9 @@ define <4 x i32> @global_struct_splat() {
 
 define <4 x i32> @splat_ptr_gather(i32* %ptr, <4 x i1> %mask, <4 x i32> %passthru) {
 ; CHECK-LABEL: @splat_ptr_gather(
-; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <4 x i32*> undef, i32* [[PTR:%.*]], i32 0
-; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <4 x i32*> [[TMP1]], <4 x i32*> undef, <4 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP3:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0i32(<4 x i32*> [[TMP2]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> [[PASSTHRU:%.*]])
-; CHECK-NEXT:    ret <4 x i32> [[TMP3]]
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, i32* [[PTR:%.*]], <4 x i64> zeroinitializer
+; CHECK-NEXT:    [[TMP2:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0i32(<4 x i32*> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> [[PASSTHRU:%.*]])
+; CHECK-NEXT:    ret <4 x i32> [[TMP2]]
 ;
   %1 = insertelement <4 x i32*> undef, i32* %ptr, i32 0
   %2 = shufflevector <4 x i32*> %1, <4 x i32*> undef, <4 x i32> zeroinitializer
@@ -100,9 +99,8 @@ define <4 x i32> @splat_ptr_gather(i32* %ptr, <4 x i1> %mask, <4 x i32> %passthr
 
 define void @splat_ptr_scatter(i32* %ptr, <4 x i1> %mask, <4 x i32> %val) {
 ; CHECK-LABEL: @splat_ptr_scatter(
-; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <4 x i32*> undef, i32* [[PTR:%.*]], i32 0
-; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <4 x i32*> [[TMP1]], <4 x i32*> undef, <4 x i32> zeroinitializer
-; CHECK-NEXT:    call void @llvm.masked.scatter.v4i32.v4p0i32(<4 x i32> [[VAL:%.*]], <4 x i32*> [[TMP2]], i32 4, <4 x i1> [[MASK:%.*]])
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, i32* [[PTR:%.*]], <4 x i64> zeroinitializer
+; CHECK-NEXT:    call void @llvm.masked.scatter.v4i32.v4p0i32(<4 x i32> [[VAL:%.*]], <4 x i32*> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]])
 ; CHECK-NEXT:    ret void
 ;
   %1 = insertelement <4 x i32*> undef, i32* %ptr, i32 0


        


More information about the llvm-commits mailing list