[llvm] 213e308 - [DAG] Fold Y = sra (X, size(X)-1); mul (or (Y, 1), X) -> (abs X)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Sat Jun 15 13:10:28 PDT 2024


Author: Simon Pilgrim
Date: 2024-06-15T21:10:00+01:00
New Revision: 213e308633e533f74f04269766989bb89fde0921

URL: https://github.com/llvm/llvm-project/commit/213e308633e533f74f04269766989bb89fde0921
DIFF: https://github.com/llvm/llvm-project/commit/213e308633e533f74f04269766989bb89fde0921.diff

LOG: [DAG] Fold Y = sra (X, size(X)-1); mul (or (Y, 1), X) -> (abs X)

Similar to InstCombine implementation except we don't have to handle the NSW/is_int_min_poison case.

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/test/CodeGen/X86/combine-mul.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 80b8d48251472..78763e729ebd5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -4326,6 +4326,7 @@ template <class MatchContextClass> SDValue DAGCombiner::visitMUL(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
   EVT VT = N0.getValueType();
+  unsigned BitWidth = VT.getScalarSizeInBits();
   SDLoc DL(N);
   bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
   MatchContextClass Matcher(DAG, TLI, N);
@@ -4355,8 +4356,7 @@ template <class MatchContextClass> SDValue DAGCombiner::visitMUL(SDNode *N) {
         return FoldedVOp;
 
     N1IsConst = ISD::isConstantSplatVector(N1.getNode(), ConstValue1);
-    assert((!N1IsConst ||
-            ConstValue1.getBitWidth() == VT.getScalarSizeInBits()) &&
+    assert((!N1IsConst || ConstValue1.getBitWidth() == BitWidth) &&
            "Splat APInt should be element width");
   } else {
     N1IsConst = isa<ConstantSDNode>(N1);
@@ -4456,7 +4456,7 @@ template <class MatchContextClass> SDValue DAGCombiner::visitMUL(SDNode *N) {
       unsigned ShAmt =
           MathOp == ISD::ADD ? (MulC - 1).logBase2() : (MulC + 1).logBase2();
       ShAmt += TZeros;
-      assert(ShAmt < VT.getScalarSizeInBits() &&
+      assert(ShAmt < BitWidth &&
              "multiply-by-constant generated out of bounds shift");
       SDValue Shl =
           DAG.getNode(ISD::SHL, DL, VT, N0, DAG.getConstant(ShAmt, DL, VT));
@@ -4525,6 +4525,16 @@ template <class MatchContextClass> SDValue DAGCombiner::visitMUL(SDNode *N) {
     return DAG.getStepVector(DL, VT, NewStep);
   }
 
+  // Fold Y = sra (X, size(X)-1); mul (or (Y, 1), X) -> (abs X)
+  SDValue X;
+  if (!UseVP && (!LegalOperations || hasOperation(ISD::ABS, VT)) &&
+      sd_context_match(
+          N, Matcher,
+          m_Mul(m_Or(m_Sra(m_Value(X), m_SpecificInt(BitWidth - 1)), m_One()),
+                m_Deferred(X)))) {
+    return Matcher.getNode(ISD::ABS, DL, VT, X);
+  }
+
   // Fold ((mul x, 0/undef) -> 0,
   //       (mul x, 1) -> x) -> x)
   // -> and(x, mask)

diff  --git a/llvm/test/CodeGen/X86/combine-mul.ll b/llvm/test/CodeGen/X86/combine-mul.ll
index 7837843ce0917..8e4a50ea266c3 100644
--- a/llvm/test/CodeGen/X86/combine-mul.ll
+++ b/llvm/test/CodeGen/X86/combine-mul.ll
@@ -282,39 +282,17 @@ define <4 x i32> @combine_vec_mul_add(<4 x i32> %x) {
   ret <4 x i32> %2
 }
 
-; TODO fold Y = sra (X, size(X)-1); mul (or (Y, 1), X) -> (abs X)
+; fold Y = sra (X, size(X)-1); mul (or (Y, 1), X) -> (abs X)
 
 define <16 x i8> @combine_mul_to_abs_v16i8(<16 x i8> %x) {
 ; SSE-LABEL: combine_mul_to_abs_v16i8:
 ; SSE:       # %bb.0:
-; SSE-NEXT:    pxor %xmm2, %xmm2
-; SSE-NEXT:    pcmpgtb %xmm0, %xmm2
-; SSE-NEXT:    por {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm2
-; SSE-NEXT:    pmovzxbw {{.*#+}} xmm3 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero,xmm0[4],zero,xmm0[5],zero,xmm0[6],zero,xmm0[7],zero
-; SSE-NEXT:    punpckhbw {{.*#+}} xmm0 = xmm0[8,8,9,9,10,10,11,11,12,12,13,13,14,14,15,15]
-; SSE-NEXT:    pmovzxbw {{.*#+}} xmm1 = xmm2[0],zero,xmm2[1],zero,xmm2[2],zero,xmm2[3],zero,xmm2[4],zero,xmm2[5],zero,xmm2[6],zero,xmm2[7],zero
-; SSE-NEXT:    punpckhbw {{.*#+}} xmm2 = xmm2[8,8,9,9,10,10,11,11,12,12,13,13,14,14,15,15]
-; SSE-NEXT:    pmullw %xmm0, %xmm2
-; SSE-NEXT:    pmovzxbw {{.*#+}} xmm0 = [255,255,255,255,255,255,255,255]
-; SSE-NEXT:    pand %xmm0, %xmm2
-; SSE-NEXT:    pmullw %xmm3, %xmm1
-; SSE-NEXT:    pand %xmm0, %xmm1
-; SSE-NEXT:    packuswb %xmm2, %xmm1
-; SSE-NEXT:    movdqa %xmm1, %xmm0
+; SSE-NEXT:    pabsb %xmm0, %xmm0
 ; SSE-NEXT:    retq
 ;
 ; AVX-LABEL: combine_mul_to_abs_v16i8:
 ; AVX:       # %bb.0:
-; AVX-NEXT:    vpxor %xmm1, %xmm1, %xmm1
-; AVX-NEXT:    vpcmpgtb %xmm0, %xmm1, %xmm1
-; AVX-NEXT:    vpor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
-; AVX-NEXT:    vpmovzxbw {{.*#+}} ymm1 = xmm1[0],zero,xmm1[1],zero,xmm1[2],zero,xmm1[3],zero,xmm1[4],zero,xmm1[5],zero,xmm1[6],zero,xmm1[7],zero,xmm1[8],zero,xmm1[9],zero,xmm1[10],zero,xmm1[11],zero,xmm1[12],zero,xmm1[13],zero,xmm1[14],zero,xmm1[15],zero
-; AVX-NEXT:    vpmovzxbw {{.*#+}} ymm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero,xmm0[4],zero,xmm0[5],zero,xmm0[6],zero,xmm0[7],zero,xmm0[8],zero,xmm0[9],zero,xmm0[10],zero,xmm0[11],zero,xmm0[12],zero,xmm0[13],zero,xmm0[14],zero,xmm0[15],zero
-; AVX-NEXT:    vpmullw %ymm0, %ymm1, %ymm0
-; AVX-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm0
-; AVX-NEXT:    vextracti128 $1, %ymm0, %xmm1
-; AVX-NEXT:    vpackuswb %xmm1, %xmm0, %xmm0
-; AVX-NEXT:    vzeroupper
+; AVX-NEXT:    vpabsb %xmm0, %xmm0
 ; AVX-NEXT:    retq
   %s = ashr <16 x i8> %x, <i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7>
   %o = or <16 x i8> %s, <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>
@@ -325,34 +303,16 @@ define <16 x i8> @combine_mul_to_abs_v16i8(<16 x i8> %x) {
 define <2 x i64> @combine_mul_to_abs_v2i64(<2 x i64> %x) {
 ; SSE-LABEL: combine_mul_to_abs_v2i64:
 ; SSE:       # %bb.0:
-; SSE-NEXT:    pshufd {{.*#+}} xmm1 = xmm0[1,1,3,3]
-; SSE-NEXT:    psrad $31, %xmm1
-; SSE-NEXT:    por {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
-; SSE-NEXT:    movdqa %xmm0, %xmm2
-; SSE-NEXT:    psrlq $32, %xmm2
-; SSE-NEXT:    pmuludq %xmm1, %xmm2
-; SSE-NEXT:    movdqa %xmm1, %xmm3
-; SSE-NEXT:    psrlq $32, %xmm3
-; SSE-NEXT:    pmuludq %xmm0, %xmm3
-; SSE-NEXT:    paddq %xmm2, %xmm3
-; SSE-NEXT:    psllq $32, %xmm3
-; SSE-NEXT:    pmuludq %xmm1, %xmm0
-; SSE-NEXT:    paddq %xmm3, %xmm0
+; SSE-NEXT:    pxor %xmm1, %xmm1
+; SSE-NEXT:    psubq %xmm0, %xmm1
+; SSE-NEXT:    blendvpd %xmm0, %xmm1, %xmm0
 ; SSE-NEXT:    retq
 ;
 ; AVX-LABEL: combine_mul_to_abs_v2i64:
 ; AVX:       # %bb.0:
 ; AVX-NEXT:    vpxor %xmm1, %xmm1, %xmm1
-; AVX-NEXT:    vpcmpgtq %xmm0, %xmm1, %xmm1
-; AVX-NEXT:    vpor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
-; AVX-NEXT:    vpsrlq $32, %xmm0, %xmm2
-; AVX-NEXT:    vpmuludq %xmm1, %xmm2, %xmm2
-; AVX-NEXT:    vpsrlq $32, %xmm1, %xmm3
-; AVX-NEXT:    vpmuludq %xmm3, %xmm0, %xmm3
-; AVX-NEXT:    vpaddq %xmm2, %xmm3, %xmm2
-; AVX-NEXT:    vpsllq $32, %xmm2, %xmm2
-; AVX-NEXT:    vpmuludq %xmm1, %xmm0, %xmm0
-; AVX-NEXT:    vpaddq %xmm2, %xmm0, %xmm0
+; AVX-NEXT:    vpsubq %xmm0, %xmm1, %xmm1
+; AVX-NEXT:    vblendvpd %xmm0, %xmm1, %xmm0, %xmm0
 ; AVX-NEXT:    retq
   %s = ashr <2 x i64> %x, <i64 63, i64 63>
   %o = or <2 x i64> %s, <i64 1, i64 1>


        


More information about the llvm-commits mailing list