[llvm] r338090 - [X86] When removing sign extends from gather/scatter indices, make sure we handle UpdateNodeOperands finding an existing node to CSE with.
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Thu Jul 26 17:00:31 PDT 2018
Author: ctopper
Date: Thu Jul 26 17:00:30 2018
New Revision: 338090
URL: http://llvm.org/viewvc/llvm-project?rev=338090&view=rev
Log:
[X86] When removing sign extends from gather/scatter indices, make sure we handle UpdateNodeOperands finding an existing node to CSE with.
If this happens the operands aren't updated and the existing node is returned. Make sure we pass this existing node up to the DAG combiner so that a proper replacement happens. Otherwise we get stuck in an infinite loop with an unoptimized node.
Modified:
llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
llvm/trunk/test/CodeGen/X86/masked_gather_scatter.ll
Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=338090&r1=338089&r2=338090&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Thu Jul 26 17:00:30 2018
@@ -38100,12 +38100,14 @@ static SDValue combineGatherScatter(SDNo
Index.getOperand(0).getScalarValueSizeInBits() <= 32) {
SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
NewOps[4] = Index.getOperand(0);
- DAG.UpdateNodeOperands(N, NewOps);
- // The original sign extend has less users, add back to worklist in case
- // it needs to be removed
- DCI.AddToWorklist(Index.getNode());
- DCI.AddToWorklist(N);
- return SDValue(N, 0);
+ SDNode *Res = DAG.UpdateNodeOperands(N, NewOps);
+ if (Res == N) {
+ // The original sign extend has less users, add back to worklist in
+ // case it needs to be removed
+ DCI.AddToWorklist(Index.getNode());
+ DCI.AddToWorklist(N);
+ }
+ return SDValue(Res, 0);
}
}
@@ -38118,9 +38120,10 @@ static SDValue combineGatherScatter(SDNo
Index = DAG.getSExtOrTrunc(Index, DL, IndexVT);
SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
NewOps[4] = Index;
- DAG.UpdateNodeOperands(N, NewOps);
- DCI.AddToWorklist(N);
- return SDValue(N, 0);
+ SDNode *Res = DAG.UpdateNodeOperands(N, NewOps);
+ if (Res == N)
+ DCI.AddToWorklist(N);
+ return SDValue(Res, 0);
}
// Try to remove zero extends from 32->64 if we know the sign bit of
@@ -38131,12 +38134,14 @@ static SDValue combineGatherScatter(SDNo
if (DAG.SignBitIsZero(Index.getOperand(0))) {
SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
NewOps[4] = Index.getOperand(0);
- DAG.UpdateNodeOperands(N, NewOps);
- // The original zero extend has less users, add back to worklist in case
- // it needs to be removed
- DCI.AddToWorklist(Index.getNode());
- DCI.AddToWorklist(N);
- return SDValue(N, 0);
+ SDNode *Res = DAG.UpdateNodeOperands(N, NewOps);
+ if (Res == N) {
+ // The original sign extend has less users, add back to worklist in
+ // case it needs to be removed
+ DCI.AddToWorklist(Index.getNode());
+ DCI.AddToWorklist(N);
+ }
+ return SDValue(Res, 0);
}
}
}
Modified: llvm/trunk/test/CodeGen/X86/masked_gather_scatter.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/masked_gather_scatter.ll?rev=338090&r1=338089&r2=338090&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/masked_gather_scatter.ll (original)
+++ llvm/trunk/test/CodeGen/X86/masked_gather_scatter.ll Thu Jul 26 17:00:30 2018
@@ -2928,3 +2928,54 @@ define void @test_scatter_setcc_split(do
call void @llvm.masked.scatter.v16f64.v16p0f64(<16 x double> %src0, <16 x double*> %gep.random, i32 4, <16 x i1> %mask)
ret void
}
+
+; This test case previously triggered an infinite loop when the two gathers became identical after DAG combine removed the sign extend.
+define <16 x float> @test_sext_cse(float* %base, <16 x i32> %ind, <16 x i32>* %foo) {
+; KNL_64-LABEL: test_sext_cse:
+; KNL_64: # %bb.0:
+; KNL_64-NEXT: vmovaps %zmm0, (%rsi)
+; KNL_64-NEXT: kxnorw %k0, %k0, %k1
+; KNL_64-NEXT: vgatherdps (%rdi,%zmm0,4), %zmm1 {%k1}
+; KNL_64-NEXT: vaddps %zmm1, %zmm1, %zmm0
+; KNL_64-NEXT: retq
+;
+; KNL_32-LABEL: test_sext_cse:
+; KNL_32: # %bb.0:
+; KNL_32-NEXT: movl {{[0-9]+}}(%esp), %eax
+; KNL_32-NEXT: movl {{[0-9]+}}(%esp), %ecx
+; KNL_32-NEXT: vmovaps %zmm0, (%ecx)
+; KNL_32-NEXT: kxnorw %k0, %k0, %k1
+; KNL_32-NEXT: vgatherdps (%eax,%zmm0,4), %zmm1 {%k1}
+; KNL_32-NEXT: vaddps %zmm1, %zmm1, %zmm0
+; KNL_32-NEXT: retl
+;
+; SKX-LABEL: test_sext_cse:
+; SKX: # %bb.0:
+; SKX-NEXT: vmovaps %zmm0, (%rsi)
+; SKX-NEXT: kxnorw %k0, %k0, %k1
+; SKX-NEXT: vgatherdps (%rdi,%zmm0,4), %zmm1 {%k1}
+; SKX-NEXT: vaddps %zmm1, %zmm1, %zmm0
+; SKX-NEXT: retq
+;
+; SKX_32-LABEL: test_sext_cse:
+; SKX_32: # %bb.0:
+; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax
+; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %ecx
+; SKX_32-NEXT: vmovaps %zmm0, (%ecx)
+; SKX_32-NEXT: kxnorw %k0, %k0, %k1
+; SKX_32-NEXT: vgatherdps (%eax,%zmm0,4), %zmm1 {%k1}
+; SKX_32-NEXT: vaddps %zmm1, %zmm1, %zmm0
+; SKX_32-NEXT: retl
+ %broadcast.splatinsert = insertelement <16 x float*> undef, float* %base, i32 0
+ %broadcast.splat = shufflevector <16 x float*> %broadcast.splatinsert, <16 x float*> undef, <16 x i32> zeroinitializer
+
+ %sext_ind = sext <16 x i32> %ind to <16 x i64>
+ %gep.random = getelementptr float, <16 x float*> %broadcast.splat, <16 x i64> %sext_ind
+
+ store <16 x i32> %ind, <16 x i32>* %foo
+ %res = call <16 x float> @llvm.masked.gather.v16f32.v16p0f32(<16 x float*> %gep.random, i32 4, <16 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <16 x float> undef)
+ %gep.random2 = getelementptr float, <16 x float*> %broadcast.splat, <16 x i32> %ind
+ %res2 = call <16 x float> @llvm.masked.gather.v16f32.v16p0f32(<16 x float*> %gep.random2, i32 4, <16 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <16 x float> undef)
+ %res3 = fadd <16 x float> %res2, %res
+ ret <16 x float>%res3
+}
More information about the llvm-commits
mailing list