[llvm] b16e868 - [CodeGenPrepare][X86] Teach optimizeGatherScatterInst to turn a splat pointer into GEP with scalar base and 0 index
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Wed Sep 2 20:45:10 PDT 2020
Author: Craig Topper
Date: 2020-09-02T20:44:12-07:00
New Revision: b16e8687ab6c977ddab3409939e867828f394311
URL: https://github.com/llvm/llvm-project/commit/b16e8687ab6c977ddab3409939e867828f394311
DIFF: https://github.com/llvm/llvm-project/commit/b16e8687ab6c977ddab3409939e867828f394311.diff
LOG: [CodeGenPrepare][X86] Teach optimizeGatherScatterInst to turn a splat pointer into GEP with scalar base and 0 index
This helps SelectionDAGBuilder recognize the splat can be used as a uniform base.
Reviewed By: RKSimon
Differential Revision: https://reviews.llvm.org/D86371
Added:
Modified:
llvm/include/llvm/Analysis/VectorUtils.h
llvm/lib/Analysis/VectorUtils.cpp
llvm/lib/CodeGen/CodeGenPrepare.cpp
llvm/test/CodeGen/X86/masked_gather_scatter.ll
llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index 074960e7ced2..8498335bf78e 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -358,7 +358,7 @@ int getSplatIndex(ArrayRef<int> Mask);
/// Get splat value if the input is a splat vector or return nullptr.
/// The value may be extracted from a splat constants vector or from
/// a sequence of instructions that broadcast a single value into a vector.
-const Value *getSplatValue(const Value *V);
+Value *getSplatValue(const Value *V);
/// Return true if each element of the vector value \p V is poisoned or equal to
/// every other non-poisoned element. If an index element is specified, either
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index 0bc8b7281d91..e241300dd2e7 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -342,7 +342,7 @@ int llvm::getSplatIndex(ArrayRef<int> Mask) {
/// This function is not fully general. It checks only 2 cases:
/// the input value is (1) a splat constant vector or (2) a sequence
/// of instructions that broadcasts a scalar at element 0.
-const llvm::Value *llvm::getSplatValue(const Value *V) {
+Value *llvm::getSplatValue(const Value *V) {
if (isa<VectorType>(V->getType()))
if (auto *C = dyn_cast<Constant>(V))
return C->getSplatValue();
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 3272f36a1436..9a4ed2fab608 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -5314,88 +5314,112 @@ bool CodeGenPrepare::optimizeMemoryInst(Instruction *MemoryInst, Value *Addr,
/// zero index.
bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
Value *Ptr) {
- const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
- if (!GEP || !GEP->hasIndices())
+ // FIXME: Support scalable vectors.
+ if (isa<ScalableVectorType>(Ptr->getType()))
return false;
- // If the GEP and the gather/scatter aren't in the same BB, don't optimize.
- // FIXME: We should support this by sinking the GEP.
- if (MemoryInst->getParent() != GEP->getParent())
- return false;
-
- SmallVector<Value *, 2> Ops(GEP->op_begin(), GEP->op_end());
+ Value *NewAddr;
- bool RewriteGEP = false;
+ if (const auto *GEP = dyn_cast<GetElementPtrInst>(Ptr)) {
+ // Don't optimize GEPs that don't have indices.
+ if (!GEP->hasIndices())
+ return false;
- if (Ops[0]->getType()->isVectorTy()) {
- Ops[0] = const_cast<Value *>(getSplatValue(Ops[0]));
- if (!Ops[0])
+ // If the GEP and the gather/scatter aren't in the same BB, don't optimize.
+ // FIXME: We should support this by sinking the GEP.
+ if (MemoryInst->getParent() != GEP->getParent())
return false;
- RewriteGEP = true;
- }
- unsigned FinalIndex = Ops.size() - 1;
+ SmallVector<Value *, 2> Ops(GEP->op_begin(), GEP->op_end());
- // Ensure all but the last index is 0.
- // FIXME: This isn't strictly required. All that's required is that they are
- // all scalars or splats.
- for (unsigned i = 1; i < FinalIndex; ++i) {
- auto *C = dyn_cast<Constant>(Ops[i]);
- if (!C)
- return false;
- if (isa<VectorType>(C->getType()))
- C = C->getSplatValue();
- auto *CI = dyn_cast_or_null<ConstantInt>(C);
- if (!CI || !CI->isZero())
- return false;
- // Scalarize the index if needed.
- Ops[i] = CI;
- }
-
- // Try to scalarize the final index.
- if (Ops[FinalIndex]->getType()->isVectorTy()) {
- if (Value *V = const_cast<Value *>(getSplatValue(Ops[FinalIndex]))) {
- auto *C = dyn_cast<ConstantInt>(V);
- // Don't scalarize all zeros vector.
- if (!C || !C->isZero()) {
- Ops[FinalIndex] = V;
- RewriteGEP = true;
- }
+ bool RewriteGEP = false;
+
+ if (Ops[0]->getType()->isVectorTy()) {
+ Ops[0] = getSplatValue(Ops[0]);
+ if (!Ops[0])
+ return false;
+ RewriteGEP = true;
}
- }
- // If we made any changes or the we have extra operands, we need to generate
- // new instructions.
- if (!RewriteGEP && Ops.size() == 2)
- return false;
+ unsigned FinalIndex = Ops.size() - 1;
- unsigned NumElts = cast<FixedVectorType>(Ptr->getType())->getNumElements();
+ // Ensure all but the last index is 0.
+ // FIXME: This isn't strictly required. All that's required is that they are
+ // all scalars or splats.
+ for (unsigned i = 1; i < FinalIndex; ++i) {
+ auto *C = dyn_cast<Constant>(Ops[i]);
+ if (!C)
+ return false;
+ if (isa<VectorType>(C->getType()))
+ C = C->getSplatValue();
+ auto *CI = dyn_cast_or_null<ConstantInt>(C);
+ if (!CI || !CI->isZero())
+ return false;
+ // Scalarize the index if needed.
+ Ops[i] = CI;
+ }
+
+ // Try to scalarize the final index.
+ if (Ops[FinalIndex]->getType()->isVectorTy()) {
+ if (Value *V = getSplatValue(Ops[FinalIndex])) {
+ auto *C = dyn_cast<ConstantInt>(V);
+ // Don't scalarize all zeros vector.
+ if (!C || !C->isZero()) {
+ Ops[FinalIndex] = V;
+ RewriteGEP = true;
+ }
+ }
+ }
- IRBuilder<> Builder(MemoryInst);
+ // If we made any changes or the we have extra operands, we need to generate
+ // new instructions.
+ if (!RewriteGEP && Ops.size() == 2)
+ return false;
- Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType());
+ unsigned NumElts = cast<FixedVectorType>(Ptr->getType())->getNumElements();
- Value *NewAddr;
+ IRBuilder<> Builder(MemoryInst);
- // If the final index isn't a vector, emit a scalar GEP containing all ops
- // and a vector GEP with all zeroes final index.
- if (!Ops[FinalIndex]->getType()->isVectorTy()) {
- NewAddr = Builder.CreateGEP(Ops[0], makeArrayRef(Ops).drop_front());
- auto *IndexTy = FixedVectorType::get(ScalarIndexTy, NumElts);
- NewAddr = Builder.CreateGEP(NewAddr, Constant::getNullValue(IndexTy));
- } else {
- Value *Base = Ops[0];
- Value *Index = Ops[FinalIndex];
+ Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType());
- // Create a scalar GEP if there are more than 2 operands.
- if (Ops.size() != 2) {
- // Replace the last index with 0.
- Ops[FinalIndex] = Constant::getNullValue(ScalarIndexTy);
- Base = Builder.CreateGEP(Base, makeArrayRef(Ops).drop_front());
+ // If the final index isn't a vector, emit a scalar GEP containing all ops
+ // and a vector GEP with all zeroes final index.
+ if (!Ops[FinalIndex]->getType()->isVectorTy()) {
+ NewAddr = Builder.CreateGEP(Ops[0], makeArrayRef(Ops).drop_front());
+ auto *IndexTy = FixedVectorType::get(ScalarIndexTy, NumElts);
+ NewAddr = Builder.CreateGEP(NewAddr, Constant::getNullValue(IndexTy));
+ } else {
+ Value *Base = Ops[0];
+ Value *Index = Ops[FinalIndex];
+
+ // Create a scalar GEP if there are more than 2 operands.
+ if (Ops.size() != 2) {
+ // Replace the last index with 0.
+ Ops[FinalIndex] = Constant::getNullValue(ScalarIndexTy);
+ Base = Builder.CreateGEP(Base, makeArrayRef(Ops).drop_front());
+ }
+
+ // Now create the GEP with scalar pointer and vector index.
+ NewAddr = Builder.CreateGEP(Base, Index);
}
+ } else if (!isa<Constant>(Ptr)) {
+ // Not a GEP, maybe its a splat and we can create a GEP to enable
+ // SelectionDAGBuilder to use it as a uniform base.
+ Value *V = getSplatValue(Ptr);
+ if (!V)
+ return false;
+
+ unsigned NumElts = cast<FixedVectorType>(Ptr->getType())->getNumElements();
+
+ IRBuilder<> Builder(MemoryInst);
- // Now create the GEP with scalar pointer and vector index.
- NewAddr = Builder.CreateGEP(Base, Index);
+ // Emit a vector GEP with a scalar pointer and all 0s vector index.
+ Type *ScalarIndexTy = DL->getIndexType(V->getType()->getScalarType());
+ auto *IndexTy = FixedVectorType::get(ScalarIndexTy, NumElts);
+ NewAddr = Builder.CreateGEP(V, Constant::getNullValue(IndexTy));
+ } else {
+ // Constant, SelectionDAGBuilder knows to check if its a splat.
+ return false;
}
MemoryInst->replaceUsesOfWith(Ptr, NewAddr);
diff --git a/llvm/test/CodeGen/X86/masked_gather_scatter.ll b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
index c5781e834075..88418fd85fe5 100644
--- a/llvm/test/CodeGen/X86/masked_gather_scatter.ll
+++ b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
@@ -3323,14 +3323,13 @@ define void @scatter_16i64_constant_indices(i32* %ptr, <16 x i1> %mask, <16 x i3
define <4 x i32> @splat_ptr_gather(i32* %ptr, <4 x i1> %mask, <4 x i32> %passthru) {
; KNL_64-LABEL: splat_ptr_gather:
; KNL_64: # %bb.0:
-; KNL_64-NEXT: # kill: def $xmm1 killed $xmm1 def $ymm1
+; KNL_64-NEXT: # kill: def $xmm1 killed $xmm1 def $zmm1
; KNL_64-NEXT: vpslld $31, %xmm0, %xmm0
; KNL_64-NEXT: vptestmd %zmm0, %zmm0, %k0
; KNL_64-NEXT: kshiftlw $12, %k0, %k0
; KNL_64-NEXT: kshiftrw $12, %k0, %k1
-; KNL_64-NEXT: vmovq %rdi, %xmm0
-; KNL_64-NEXT: vpbroadcastq %xmm0, %ymm0
-; KNL_64-NEXT: vpgatherqd (,%zmm0), %ymm1 {%k1}
+; KNL_64-NEXT: vpxor %xmm0, %xmm0, %xmm0
+; KNL_64-NEXT: vpgatherdd (%rdi,%zmm0,4), %zmm1 {%k1}
; KNL_64-NEXT: vmovdqa %xmm1, %xmm0
; KNL_64-NEXT: vzeroupper
; KNL_64-NEXT: retq
@@ -3342,8 +3341,9 @@ define <4 x i32> @splat_ptr_gather(i32* %ptr, <4 x i1> %mask, <4 x i32> %passthr
; KNL_32-NEXT: vptestmd %zmm0, %zmm0, %k0
; KNL_32-NEXT: kshiftlw $12, %k0, %k0
; KNL_32-NEXT: kshiftrw $12, %k0, %k1
-; KNL_32-NEXT: vpbroadcastd {{[0-9]+}}(%esp), %xmm0
-; KNL_32-NEXT: vpgatherdd (,%zmm0), %zmm1 {%k1}
+; KNL_32-NEXT: movl {{[0-9]+}}(%esp), %eax
+; KNL_32-NEXT: vpxor %xmm0, %xmm0, %xmm0
+; KNL_32-NEXT: vpgatherdd (%eax,%zmm0,4), %zmm1 {%k1}
; KNL_32-NEXT: vmovdqa %xmm1, %xmm0
; KNL_32-NEXT: vzeroupper
; KNL_32-NEXT: retl
@@ -3352,18 +3352,18 @@ define <4 x i32> @splat_ptr_gather(i32* %ptr, <4 x i1> %mask, <4 x i32> %passthr
; SKX: # %bb.0:
; SKX-NEXT: vpslld $31, %xmm0, %xmm0
; SKX-NEXT: vpmovd2m %xmm0, %k1
-; SKX-NEXT: vpbroadcastq %rdi, %ymm0
-; SKX-NEXT: vpgatherqd (,%ymm0), %xmm1 {%k1}
+; SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0
+; SKX-NEXT: vpgatherdd (%rdi,%xmm0,4), %xmm1 {%k1}
; SKX-NEXT: vmovdqa %xmm1, %xmm0
-; SKX-NEXT: vzeroupper
; SKX-NEXT: retq
;
; SKX_32-LABEL: splat_ptr_gather:
; SKX_32: # %bb.0:
; SKX_32-NEXT: vpslld $31, %xmm0, %xmm0
; SKX_32-NEXT: vpmovd2m %xmm0, %k1
-; SKX_32-NEXT: vpbroadcastd {{[0-9]+}}(%esp), %xmm0
-; SKX_32-NEXT: vpgatherdd (,%xmm0), %xmm1 {%k1}
+; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax
+; SKX_32-NEXT: vpxor %xmm0, %xmm0, %xmm0
+; SKX_32-NEXT: vpgatherdd (%eax,%xmm0,4), %xmm1 {%k1}
; SKX_32-NEXT: vmovdqa %xmm1, %xmm0
; SKX_32-NEXT: retl
%1 = insertelement <4 x i32*> undef, i32* %ptr, i32 0
@@ -3376,14 +3376,13 @@ declare <4 x i32> @llvm.masked.gather.v4i32.v4p0i32(<4 x i32*>, i32, <4 x i1>,
define void @splat_ptr_scatter(i32* %ptr, <4 x i1> %mask, <4 x i32> %val) {
; KNL_64-LABEL: splat_ptr_scatter:
; KNL_64: # %bb.0:
-; KNL_64-NEXT: # kill: def $xmm1 killed $xmm1 def $ymm1
+; KNL_64-NEXT: # kill: def $xmm1 killed $xmm1 def $zmm1
; KNL_64-NEXT: vpslld $31, %xmm0, %xmm0
; KNL_64-NEXT: vptestmd %zmm0, %zmm0, %k0
; KNL_64-NEXT: kshiftlw $12, %k0, %k0
; KNL_64-NEXT: kshiftrw $12, %k0, %k1
-; KNL_64-NEXT: vmovq %rdi, %xmm0
-; KNL_64-NEXT: vpbroadcastq %xmm0, %ymm0
-; KNL_64-NEXT: vpscatterqd %ymm1, (,%zmm0) {%k1}
+; KNL_64-NEXT: vpxor %xmm0, %xmm0, %xmm0
+; KNL_64-NEXT: vpscatterdd %zmm1, (%rdi,%zmm0,4) {%k1}
; KNL_64-NEXT: vzeroupper
; KNL_64-NEXT: retq
;
@@ -3394,8 +3393,9 @@ define void @splat_ptr_scatter(i32* %ptr, <4 x i1> %mask, <4 x i32> %val) {
; KNL_32-NEXT: vptestmd %zmm0, %zmm0, %k0
; KNL_32-NEXT: kshiftlw $12, %k0, %k0
; KNL_32-NEXT: kshiftrw $12, %k0, %k1
-; KNL_32-NEXT: vpbroadcastd {{[0-9]+}}(%esp), %xmm0
-; KNL_32-NEXT: vpscatterdd %zmm1, (,%zmm0) {%k1}
+; KNL_32-NEXT: movl {{[0-9]+}}(%esp), %eax
+; KNL_32-NEXT: vpxor %xmm0, %xmm0, %xmm0
+; KNL_32-NEXT: vpscatterdd %zmm1, (%eax,%zmm0,4) {%k1}
; KNL_32-NEXT: vzeroupper
; KNL_32-NEXT: retl
;
@@ -3403,17 +3403,17 @@ define void @splat_ptr_scatter(i32* %ptr, <4 x i1> %mask, <4 x i32> %val) {
; SKX: # %bb.0:
; SKX-NEXT: vpslld $31, %xmm0, %xmm0
; SKX-NEXT: vpmovd2m %xmm0, %k1
-; SKX-NEXT: vpbroadcastq %rdi, %ymm0
-; SKX-NEXT: vpscatterqd %xmm1, (,%ymm0) {%k1}
-; SKX-NEXT: vzeroupper
+; SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0
+; SKX-NEXT: vpscatterdd %xmm1, (%rdi,%xmm0,4) {%k1}
; SKX-NEXT: retq
;
; SKX_32-LABEL: splat_ptr_scatter:
; SKX_32: # %bb.0:
; SKX_32-NEXT: vpslld $31, %xmm0, %xmm0
; SKX_32-NEXT: vpmovd2m %xmm0, %k1
-; SKX_32-NEXT: vpbroadcastd {{[0-9]+}}(%esp), %xmm0
-; SKX_32-NEXT: vpscatterdd %xmm1, (,%xmm0) {%k1}
+; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax
+; SKX_32-NEXT: vpxor %xmm0, %xmm0, %xmm0
+; SKX_32-NEXT: vpscatterdd %xmm1, (%eax,%xmm0,4) {%k1}
; SKX_32-NEXT: retl
%1 = insertelement <4 x i32*> undef, i32* %ptr, i32 0
%2 = shufflevector <4 x i32*> %1, <4 x i32*> undef, <4 x i32> zeroinitializer
diff --git a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll
index c1674ad4ca45..adb1930ca782 100644
--- a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll
+++ b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll
@@ -87,10 +87,9 @@ define <4 x i32> @global_struct_splat() {
define <4 x i32> @splat_ptr_gather(i32* %ptr, <4 x i1> %mask, <4 x i32> %passthru) {
; CHECK-LABEL: @splat_ptr_gather(
-; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x i32*> undef, i32* [[PTR:%.*]], i32 0
-; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i32*> [[TMP1]], <4 x i32*> undef, <4 x i32> zeroinitializer
-; CHECK-NEXT: [[TMP3:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0i32(<4 x i32*> [[TMP2]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> [[PASSTHRU:%.*]])
-; CHECK-NEXT: ret <4 x i32> [[TMP3]]
+; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, i32* [[PTR:%.*]], <4 x i64> zeroinitializer
+; CHECK-NEXT: [[TMP2:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0i32(<4 x i32*> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> [[PASSTHRU:%.*]])
+; CHECK-NEXT: ret <4 x i32> [[TMP2]]
;
%1 = insertelement <4 x i32*> undef, i32* %ptr, i32 0
%2 = shufflevector <4 x i32*> %1, <4 x i32*> undef, <4 x i32> zeroinitializer
@@ -100,9 +99,8 @@ define <4 x i32> @splat_ptr_gather(i32* %ptr, <4 x i1> %mask, <4 x i32> %passthr
define void @splat_ptr_scatter(i32* %ptr, <4 x i1> %mask, <4 x i32> %val) {
; CHECK-LABEL: @splat_ptr_scatter(
-; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x i32*> undef, i32* [[PTR:%.*]], i32 0
-; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i32*> [[TMP1]], <4 x i32*> undef, <4 x i32> zeroinitializer
-; CHECK-NEXT: call void @llvm.masked.scatter.v4i32.v4p0i32(<4 x i32> [[VAL:%.*]], <4 x i32*> [[TMP2]], i32 4, <4 x i1> [[MASK:%.*]])
+; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, i32* [[PTR:%.*]], <4 x i64> zeroinitializer
+; CHECK-NEXT: call void @llvm.masked.scatter.v4i32.v4p0i32(<4 x i32> [[VAL:%.*]], <4 x i32*> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]])
; CHECK-NEXT: ret void
;
%1 = insertelement <4 x i32*> undef, i32* %ptr, i32 0
More information about the llvm-commits
mailing list