[llvm] 0938cdb - [X86] computeKnownBitsForTargetNode - add handling for (V)PMADDUBSW nodes

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Sat Jun 15 10:32:20 PDT 2024


Author: Simon Pilgrim
Date: 2024-06-15T18:25:11+01:00
New Revision: 0938cdbfbd946e4033b070ec4fbd9dfd790eb91a

URL: https://github.com/llvm/llvm-project/commit/0938cdbfbd946e4033b070ec4fbd9dfd790eb91a
DIFF: https://github.com/llvm/llvm-project/commit/0938cdbfbd946e4033b070ec4fbd9dfd790eb91a.diff

LOG: [X86] computeKnownBitsForTargetNode - add handling for (V)PMADDUBSW nodes

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/combine-pmadd.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 6aa1a5b52bb67..f27c935812f51 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -37109,6 +37109,33 @@ static void computeKnownBitsForPMADDWD(SDValue LHS, SDValue RHS,
                                       /*NUW=*/false, Lo, Hi);
 }
 
+static void computeKnownBitsForPMADDUBSW(SDValue LHS, SDValue RHS,
+                                         KnownBits &Known,
+                                         const APInt &DemandedElts,
+                                         const SelectionDAG &DAG,
+                                         unsigned Depth) {
+  unsigned NumSrcElts = LHS.getValueType().getVectorNumElements();
+
+  // Multiply signed/unsigned i8 elements to create i16 values and add_sat Lo/Hi
+  // pairs.
+  APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
+  APInt DemandedLoElts =
+      DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b01));
+  APInt DemandedHiElts =
+      DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b10));
+  KnownBits LHSLo =
+      DAG.computeKnownBits(LHS, DemandedLoElts, Depth + 1).zext(16);
+  KnownBits LHSHi =
+      DAG.computeKnownBits(LHS, DemandedHiElts, Depth + 1).zext(16);
+  KnownBits RHSLo =
+      DAG.computeKnownBits(RHS, DemandedLoElts, Depth + 1).sext(16);
+  KnownBits RHSHi =
+      DAG.computeKnownBits(RHS, DemandedHiElts, Depth + 1).sext(16);
+  KnownBits Lo = KnownBits::mul(LHSLo, RHSLo);
+  KnownBits Hi = KnownBits::mul(LHSHi, RHSHi);
+  Known = KnownBits::sadd_sat(Lo, Hi);
+}
+
 void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
                                                       KnownBits &Known,
                                                       const APInt &DemandedElts,
@@ -37294,6 +37321,16 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
     computeKnownBitsForPMADDWD(LHS, RHS, Known, DemandedElts, DAG, Depth);
     break;
   }
+  case X86ISD::VPMADDUBSW: {
+    SDValue LHS = Op.getOperand(0);
+    SDValue RHS = Op.getOperand(1);
+    assert(VT.getVectorElementType() == MVT::i16 &&
+           LHS.getValueType() == RHS.getValueType() &&
+           LHS.getValueType().getVectorElementType() == MVT::i8 &&
+           "Unexpected PMADDUBSW types");
+    computeKnownBitsForPMADDUBSW(LHS, RHS, Known, DemandedElts, DAG, Depth);
+    break;
+  }
   case X86ISD::PMULUDQ: {
     KnownBits Known2;
     Known = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
@@ -37442,6 +37479,18 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
       computeKnownBitsForPMADDWD(LHS, RHS, Known, DemandedElts, DAG, Depth);
       break;
     }
+    case Intrinsic::x86_ssse3_pmadd_ub_sw_128:
+    case Intrinsic::x86_avx2_pmadd_ub_sw:
+    case Intrinsic::x86_avx512_pmaddubs_w_512: {
+      SDValue LHS = Op.getOperand(1);
+      SDValue RHS = Op.getOperand(2);
+      assert(VT.getScalarType() == MVT::i16 &&
+             LHS.getValueType() == RHS.getValueType() &&
+             LHS.getValueType().getScalarType() == MVT::i8 &&
+             "Unexpected PMADDUBSW types");
+      computeKnownBitsForPMADDUBSW(LHS, RHS, Known, DemandedElts, DAG, Depth);
+      break;
+    }
     case Intrinsic::x86_sse2_psad_bw:
     case Intrinsic::x86_avx2_psad_bw:
     case Intrinsic::x86_avx512_psad_bw_512: {

diff  --git a/llvm/test/CodeGen/X86/combine-pmadd.ll b/llvm/test/CodeGen/X86/combine-pmadd.ll
index 8c8da55503aa1..ad3b9e645f628 100644
--- a/llvm/test/CodeGen/X86/combine-pmadd.ll
+++ b/llvm/test/CodeGen/X86/combine-pmadd.ll
@@ -73,21 +73,10 @@ define <8 x i16> @combine_pmaddubsw_zero_commute(<16 x i8> %a0, <16 x i8> %a1) {
 }
 
 define i32 @combine_pmaddubsw_constant() {
-; SSE-LABEL: combine_pmaddubsw_constant:
-; SSE:       # %bb.0:
-; SSE-NEXT:    movdqa {{.*#+}} xmm0 = [0,1,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
-; SSE-NEXT:    pmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 # [1,2,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
-; SSE-NEXT:    pextrw $3, %xmm0, %eax
-; SSE-NEXT:    cwtl
-; SSE-NEXT:    retq
-;
-; AVX-LABEL: combine_pmaddubsw_constant:
-; AVX:       # %bb.0:
-; AVX-NEXT:    vmovdqa {{.*#+}} xmm0 = [0,1,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
-; AVX-NEXT:    vpmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 # [1,2,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
-; AVX-NEXT:    vpextrw $3, %xmm0, %eax
-; AVX-NEXT:    cwtl
-; AVX-NEXT:    retq
+; CHECK-LABEL: combine_pmaddubsw_constant:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    movl $1694, %eax # imm = 0x69E
+; CHECK-NEXT:    retq
   %1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> <i8 0, i8 1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>, <16 x i8> <i8 1, i8 2, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>)
   %2 = extractelement <8 x i16> %1, i32 3 ; ((uint16_t)-6*7)+(7*-8) = (250*7)+(7*-8) = 1694
   %3 = sext i16 %2 to i32
@@ -95,21 +84,10 @@ define i32 @combine_pmaddubsw_constant() {
 }
 
 define i32 @combine_pmaddubsw_constant_sat() {
-; SSE-LABEL: combine_pmaddubsw_constant_sat:
-; SSE:       # %bb.0:
-; SSE-NEXT:    movdqa {{.*#+}} xmm0 = [255,255,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
-; SSE-NEXT:    pmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 # [128,128,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
-; SSE-NEXT:    movd %xmm0, %eax
-; SSE-NEXT:    cwtl
-; SSE-NEXT:    retq
-;
-; AVX-LABEL: combine_pmaddubsw_constant_sat:
-; AVX:       # %bb.0:
-; AVX-NEXT:    vmovdqa {{.*#+}} xmm0 = [255,255,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
-; AVX-NEXT:    vpmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 # [128,128,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
-; AVX-NEXT:    vmovd %xmm0, %eax
-; AVX-NEXT:    cwtl
-; AVX-NEXT:    retq
+; CHECK-LABEL: combine_pmaddubsw_constant_sat:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    movl $-32768, %eax # imm = 0x8000
+; CHECK-NEXT:    retq
   %1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> <i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>, <16 x i8> <i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>)
   %2 = extractelement <8 x i16> %1, i32 0 ; add_sat_i16(((uint16_t)-1*-128),((uint16_t)-1*-128)_ = add_sat_i16(255*-128),(255*-128)) = sat_i16(-65280) = -32768
   %3 = sext i16 %2 to i32


        


More information about the llvm-commits mailing list