[llvm] select (sext m), (add X, C), X --> (add X, (and C, (sext m)))) (PR #83640)

via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 1 17:04:40 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-selectiondag

Author: None (elhewaty)

<details>
<summary>Changes</summary>

- [DAG][X86] Add tests for Folding select m, add(X, C), X --> add (X, and(C, m))(NFC)
- [DAG][X86] Fold select (sext m), (add X, C), X --> (add X, (and C, (sext m))))
- Fixes: https://github.com/llvm/llvm-project/issues/66101


---
Full diff: https://github.com/llvm/llvm-project/pull/83640.diff


2 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+9) 
- (modified) llvm/test/CodeGen/X86/vselect.ll (+32) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 33ada3655dc731..771f9e96f9dc64 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12070,6 +12070,15 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
   if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false))
     return DAG.getSelect(DL, VT, F, N2, N1);
 
+  // select (sext m), (add X, C), X --> (add X, (and C, (sext m))))
+  if (N1.getOpcode() == ISD::ADD && N1.getOperand(0) == N2 && N1->hasOneUse() &&
+      DAG.isConstantIntBuildVectorOrConstantInt(N1.getOperand(1)) &&
+      N0.getScalarValueSizeInBits() == N1.getScalarValueSizeInBits()) {
+    return DAG.getNode(
+        ISD::ADD, DL, N1.getValueType(), N2,
+        DAG.getNode(ISD::AND, DL, N0.getValueType(), N1.getOperand(1), N0));
+  }
+
   // Canonicalize integer abs.
   // vselect (setg[te] X,  0),  X, -X ->
   // vselect (setgt    X, -1),  X, -X ->
diff --git a/llvm/test/CodeGen/X86/vselect.ll b/llvm/test/CodeGen/X86/vselect.ll
index cc4eb0c8f7343b..9600afdc18c147 100644
--- a/llvm/test/CodeGen/X86/vselect.ll
+++ b/llvm/test/CodeGen/X86/vselect.ll
@@ -7,6 +7,38 @@
 ; Verify that we don't emit packed vector shifts instructions if the
 ; condition used by the vector select is a vector of constants.
 
+define <2 x i64> @masked_select_const(<2 x i64> %a, <2 x i64> %x, <2 x i64> %y) {
+; SSE-LABEL: masked_select_const:
+; SSE:       # %bb.0:
+; SSE-NEXT:    pcmpgtd %xmm2, %xmm1
+; SSE-NEXT:    pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
+; SSE-NEXT:    paddd %xmm1, %xmm0
+; SSE-NEXT:    retq
+;
+; AVX1-LABEL: masked_select_const:
+; AVX1:       # %bb.0:
+; AVX1-NEXT:    vpcmpgtd %xmm2, %xmm1, %xmm1
+; AVX1-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
+; AVX1-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVX1-NEXT:    retq
+;
+; AVX2-LABEL: masked_select_const:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    vpbroadcastd {{.*#+}} xmm3 = [4294967272,4294967272,4294967272,4294967272]
+; AVX2-NEXT:    vpcmpgtd %xmm2, %xmm1, %xmm1
+; AVX2-NEXT:    vpand %xmm3, %xmm1, %xmm1
+; AVX2-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVX2-NEXT:    retq
+  %bit_a = bitcast <2 x i64> %a to <4 x i32>
+  %sub.i = add <4 x i32> %bit_a, <i32 -24, i32 -24, i32 -24, i32 -24>
+  %bit_x = bitcast <2 x i64> %x to <4 x i32>
+  %bit_y = bitcast <2 x i64> %y to <4 x i32>
+  %cmp.i = icmp sgt <4 x i32> %bit_x, %bit_y
+  %sel = select <4 x i1> %cmp.i, <4 x i32> %sub.i, <4 x i32> %bit_a
+  %bit_sel = bitcast <4 x i32> %sel to <2 x i64>
+  ret <2 x i64> %bit_sel
+}
+
 define <4 x float> @test1(<4 x float> %a, <4 x float> %b) {
 ; SSE2-LABEL: test1:
 ; SSE2:       # %bb.0:

``````````

</details>


https://github.com/llvm/llvm-project/pull/83640


More information about the llvm-commits mailing list