[llvm] ca4b1f8 - [X86] computeKnownBitsForTargetNode - add handling for PMADDWD/PMADDUBSW nodes

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 28 08:29:32 PDT 2024


Author: Simon Pilgrim
Date: 2024-06-28T16:29:14+01:00
New Revision: ca4b1f8629d3162a572a0888322789c56bb75921

URL: https://github.com/llvm/llvm-project/commit/ca4b1f8629d3162a572a0888322789c56bb75921
DIFF: https://github.com/llvm/llvm-project/commit/ca4b1f8629d3162a572a0888322789c56bb75921.diff

LOG: [X86] computeKnownBitsForTargetNode - add handling for PMADDWD/PMADDUBSW nodes

These were reverted in fa0e9acea5e4d363eef6acc484afc1b22ab8e698 while we triaged an infinite loop regression

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/combine-pmadd.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 0cd5b09b10ba8..50cacc3038b0c 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -37139,6 +37139,52 @@ static void computeKnownBitsForPSADBW(SDValue LHS, SDValue RHS,
   Known = Known.zext(64);
 }
 
+static void computeKnownBitsForPMADDWD(SDValue LHS, SDValue RHS,
+                                       KnownBits &Known,
+                                       const APInt &DemandedElts,
+                                       const SelectionDAG &DAG,
+                                       unsigned Depth) {
+  unsigned NumSrcElts = LHS.getValueType().getVectorNumElements();
+
+  // Multiply signed i16 elements to create i32 values and add Lo/Hi pairs.
+  APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
+  APInt DemandedLoElts =
+      DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b01));
+  APInt DemandedHiElts =
+      DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b10));
+  KnownBits LHSLo = DAG.computeKnownBits(LHS, DemandedLoElts, Depth + 1);
+  KnownBits LHSHi = DAG.computeKnownBits(LHS, DemandedHiElts, Depth + 1);
+  KnownBits RHSLo = DAG.computeKnownBits(RHS, DemandedLoElts, Depth + 1);
+  KnownBits RHSHi = DAG.computeKnownBits(RHS, DemandedHiElts, Depth + 1);
+  KnownBits Lo = KnownBits::mul(LHSLo.sext(32), RHSLo.sext(32));
+  KnownBits Hi = KnownBits::mul(LHSHi.sext(32), RHSHi.sext(32));
+  Known = KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/false,
+                                      /*NUW=*/false, Lo, Hi);
+}
+
+static void computeKnownBitsForPMADDUBSW(SDValue LHS, SDValue RHS,
+                                         KnownBits &Known,
+                                         const APInt &DemandedElts,
+                                         const SelectionDAG &DAG,
+                                         unsigned Depth) {
+  unsigned NumSrcElts = LHS.getValueType().getVectorNumElements();
+
+  // Multiply unsigned/signed i8 elements to create i16 values and add_sat Lo/Hi
+  // pairs.
+  APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
+  APInt DemandedLoElts =
+      DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b01));
+  APInt DemandedHiElts =
+      DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b10));
+  KnownBits LHSLo = DAG.computeKnownBits(LHS, DemandedLoElts, Depth + 1);
+  KnownBits LHSHi = DAG.computeKnownBits(LHS, DemandedHiElts, Depth + 1);
+  KnownBits RHSLo = DAG.computeKnownBits(RHS, DemandedLoElts, Depth + 1);
+  KnownBits RHSHi = DAG.computeKnownBits(RHS, DemandedHiElts, Depth + 1);
+  KnownBits Lo = KnownBits::mul(LHSLo.zext(16), RHSLo.sext(16));
+  KnownBits Hi = KnownBits::mul(LHSHi.zext(16), RHSHi.sext(16));
+  Known = KnownBits::sadd_sat(Lo, Hi);
+}
+
 void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
                                                       KnownBits &Known,
                                                       const APInt &DemandedElts,
@@ -37314,6 +37360,26 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
     }
     break;
   }
+  case X86ISD::VPMADDWD: {
+    SDValue LHS = Op.getOperand(0);
+    SDValue RHS = Op.getOperand(1);
+    assert(VT.getVectorElementType() == MVT::i32 &&
+           LHS.getValueType() == RHS.getValueType() &&
+           LHS.getValueType().getVectorElementType() == MVT::i16 &&
+           "Unexpected PMADDWD types");
+    computeKnownBitsForPMADDWD(LHS, RHS, Known, DemandedElts, DAG, Depth);
+    break;
+  }
+  case X86ISD::VPMADDUBSW: {
+    SDValue LHS = Op.getOperand(0);
+    SDValue RHS = Op.getOperand(1);
+    assert(VT.getVectorElementType() == MVT::i16 &&
+           LHS.getValueType() == RHS.getValueType() &&
+           LHS.getValueType().getVectorElementType() == MVT::i8 &&
+           "Unexpected PMADDUBSW types");
+    computeKnownBitsForPMADDUBSW(LHS, RHS, Known, DemandedElts, DAG, Depth);
+    break;
+  }
   case X86ISD::PMULUDQ: {
     KnownBits Known2;
     Known = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
@@ -37450,6 +37516,30 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
   }
   case ISD::INTRINSIC_WO_CHAIN: {
     switch (Op->getConstantOperandVal(0)) {
+    case Intrinsic::x86_sse2_pmadd_wd:
+    case Intrinsic::x86_avx2_pmadd_wd:
+    case Intrinsic::x86_avx512_pmaddw_d_512: {
+      SDValue LHS = Op.getOperand(1);
+      SDValue RHS = Op.getOperand(2);
+      assert(VT.getScalarType() == MVT::i32 &&
+             LHS.getValueType() == RHS.getValueType() &&
+             LHS.getValueType().getScalarType() == MVT::i16 &&
+             "Unexpected PMADDWD types");
+      computeKnownBitsForPMADDWD(LHS, RHS, Known, DemandedElts, DAG, Depth);
+      break;
+    }
+    case Intrinsic::x86_ssse3_pmadd_ub_sw_128:
+    case Intrinsic::x86_avx2_pmadd_ub_sw:
+    case Intrinsic::x86_avx512_pmaddubs_w_512: {
+      SDValue LHS = Op.getOperand(1);
+      SDValue RHS = Op.getOperand(2);
+      assert(VT.getScalarType() == MVT::i16 &&
+             LHS.getValueType() == RHS.getValueType() &&
+             LHS.getValueType().getScalarType() == MVT::i8 &&
+             "Unexpected PMADDUBSW types");
+      computeKnownBitsForPMADDUBSW(LHS, RHS, Known, DemandedElts, DAG, Depth);
+      break;
+    }
     case Intrinsic::x86_sse2_psad_bw:
     case Intrinsic::x86_avx2_psad_bw:
     case Intrinsic::x86_avx512_psad_bw_512: {

diff  --git a/llvm/test/CodeGen/X86/combine-pmadd.ll b/llvm/test/CodeGen/X86/combine-pmadd.ll
index 83cce1061bc1e..565d9ef6eb3e6 100644
--- a/llvm/test/CodeGen/X86/combine-pmadd.ll
+++ b/llvm/test/CodeGen/X86/combine-pmadd.ll
@@ -229,35 +229,10 @@ define i32 @combine_pmaddubsw_constant_sat() {
 
 ; Constant folding PMADDWD was causing an infinite loop in the PCMPGT commuting between 2 constant values.
 define i1 @pmaddwd_pcmpgt_infinite_loop() {
-; SSE-LABEL: pmaddwd_pcmpgt_infinite_loop:
-; SSE:       # %bb.0:
-; SSE-NEXT:    movdqa {{.*#+}} xmm0 = [2147483647,2147483647,2147483647,2147483647]
-; SSE-NEXT:    paddd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
-; SSE-NEXT:    pcmpgtd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
-; SSE-NEXT:    movmskps %xmm0, %eax
-; SSE-NEXT:    testl %eax, %eax
-; SSE-NEXT:    sete %al
-; SSE-NEXT:    retq
-;
-; AVX1-LABEL: pmaddwd_pcmpgt_infinite_loop:
-; AVX1:       # %bb.0:
-; AVX1-NEXT:    vpcmpeqd %xmm0, %xmm0, %xmm0
-; AVX1-NEXT:    vbroadcastss {{.*#+}} xmm1 = [2147483647,2147483647,2147483647,2147483647]
-; AVX1-NEXT:    vpaddd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
-; AVX1-NEXT:    vpcmpgtd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
-; AVX1-NEXT:    vtestps %xmm1, %xmm0
-; AVX1-NEXT:    sete %al
-; AVX1-NEXT:    retq
-;
-; AVX2-LABEL: pmaddwd_pcmpgt_infinite_loop:
-; AVX2:       # %bb.0:
-; AVX2-NEXT:    vpcmpeqd %xmm0, %xmm0, %xmm0
-; AVX2-NEXT:    vpbroadcastd {{.*#+}} xmm1 = [2147483647,2147483647,2147483647,2147483647]
-; AVX2-NEXT:    vpaddd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
-; AVX2-NEXT:    vpcmpgtd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
-; AVX2-NEXT:    vtestps %xmm1, %xmm0
-; AVX2-NEXT:    sete %al
-; AVX2-NEXT:    retq
+; CHECK-LABEL: pmaddwd_pcmpgt_infinite_loop:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    movb $1, %al
+; CHECK-NEXT:    retq
   %1 = tail call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> <i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768>, <8 x i16> <i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768>)
   %2 = icmp eq <4 x i32> %1, <i32 -2147483648, i32 -2147483648, i32 -2147483648, i32 -2147483648>
   %3 = select <4 x i1> %2, <4 x i32> <i32 2147483647, i32 2147483647, i32 2147483647, i32 2147483647>, <4 x i32> zeroinitializer


        


More information about the llvm-commits mailing list