[llvm] f857ed6 - [X86] computeKnownBitsForTargetNode - add handling for (V)PMADDWD nodes

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Sat Jun 15 07:40:45 PDT 2024


Author: Simon Pilgrim
Date: 2024-06-15T15:40:30+01:00
New Revision: f857ed623ca2536968804ecb6e7ad3b686e09700

URL: https://github.com/llvm/llvm-project/commit/f857ed623ca2536968804ecb6e7ad3b686e09700
DIFF: https://github.com/llvm/llvm-project/commit/f857ed623ca2536968804ecb6e7ad3b686e09700.diff

LOG: [X86] computeKnownBitsForTargetNode - add handling for (V)PMADDWD nodes

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/combine-pmadd.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 41dd4dc447bb1..6aa1a5b52bb67 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -37082,6 +37082,33 @@ static void computeKnownBitsForPSADBW(SDValue LHS, SDValue RHS,
   Known = Known.zext(64);
 }
 
+static void computeKnownBitsForPMADDWD(SDValue LHS, SDValue RHS,
+                                       KnownBits &Known,
+                                       const APInt &DemandedElts,
+                                       const SelectionDAG &DAG,
+                                       unsigned Depth) {
+  unsigned NumSrcElts = LHS.getValueType().getVectorNumElements();
+
+  // Multiply signed i16 elements to create i32 values and add Lo/Hi pairs.
+  APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
+  APInt DemandedLoElts =
+      DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b01));
+  APInt DemandedHiElts =
+      DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b10));
+  KnownBits LHSLo =
+      DAG.computeKnownBits(LHS, DemandedLoElts, Depth + 1).sext(32);
+  KnownBits LHSHi =
+      DAG.computeKnownBits(LHS, DemandedHiElts, Depth + 1).sext(32);
+  KnownBits RHSLo =
+      DAG.computeKnownBits(RHS, DemandedLoElts, Depth + 1).sext(32);
+  KnownBits RHSHi =
+      DAG.computeKnownBits(RHS, DemandedHiElts, Depth + 1).sext(32);
+  KnownBits Lo = KnownBits::mul(LHSLo, RHSLo);
+  KnownBits Hi = KnownBits::mul(LHSHi, RHSHi);
+  Known = KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/true,
+                                      /*NUW=*/false, Lo, Hi);
+}
+
 void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
                                                       KnownBits &Known,
                                                       const APInt &DemandedElts,
@@ -37257,6 +37284,16 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
     }
     break;
   }
+  case X86ISD::VPMADDWD: {
+    SDValue LHS = Op.getOperand(0);
+    SDValue RHS = Op.getOperand(1);
+    assert(VT.getVectorElementType() == MVT::i32 &&
+           LHS.getValueType() == RHS.getValueType() &&
+           LHS.getValueType().getVectorElementType() == MVT::i16 &&
+           "Unexpected PMADDWD types");
+    computeKnownBitsForPMADDWD(LHS, RHS, Known, DemandedElts, DAG, Depth);
+    break;
+  }
   case X86ISD::PMULUDQ: {
     KnownBits Known2;
     Known = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
@@ -37393,6 +37430,18 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
   }
   case ISD::INTRINSIC_WO_CHAIN: {
     switch (Op->getConstantOperandVal(0)) {
+    case Intrinsic::x86_sse2_pmadd_wd:
+    case Intrinsic::x86_avx2_pmadd_wd:
+    case Intrinsic::x86_avx512_pmaddw_d_512: {
+      SDValue LHS = Op.getOperand(1);
+      SDValue RHS = Op.getOperand(2);
+      assert(VT.getScalarType() == MVT::i32 &&
+             LHS.getValueType() == RHS.getValueType() &&
+             LHS.getValueType().getScalarType() == MVT::i16 &&
+             "Unexpected PMADDWD types");
+      computeKnownBitsForPMADDWD(LHS, RHS, Known, DemandedElts, DAG, Depth);
+      break;
+    }
     case Intrinsic::x86_sse2_psad_bw:
     case Intrinsic::x86_avx2_psad_bw:
     case Intrinsic::x86_avx512_psad_bw_512: {

diff  --git a/llvm/test/CodeGen/X86/combine-pmadd.ll b/llvm/test/CodeGen/X86/combine-pmadd.ll
index 0a4a59754b614..8a6adbdeb64d8 100644
--- a/llvm/test/CodeGen/X86/combine-pmadd.ll
+++ b/llvm/test/CodeGen/X86/combine-pmadd.ll
@@ -1,7 +1,7 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+sse4.1 | FileCheck %s --check-prefix=SSE
-; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx | FileCheck %s --check-prefixes=AVX
-; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx2 | FileCheck %s --check-prefixes=AVX
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+sse4.1 | FileCheck %s --check-prefixes=CHECK,SSE
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx | FileCheck %s --check-prefixes=CHECK,AVX
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx2 | FileCheck %s --check-prefixes=CHECK,AVX
 
 declare <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16>, <8 x i16>) nounwind readnone
 declare <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8>, <16 x i8>) nounwind readnone
@@ -34,21 +34,11 @@ define <4 x i32> @combine_pmaddwd_zero_commute(<8 x i16> %a0, <8 x i16> %a1) {
   ret <4 x i32> %1
 }
 
-; TODO: pmaddwd knownbits handling
 define i32 @combine_pmaddwd_constant() {
-; SSE-LABEL: combine_pmaddwd_constant:
-; SSE:       # %bb.0:
-; SSE-NEXT:    pmovsxbw {{.*#+}} xmm0 = [65535,2,3,65532,65531,6,7,65528]
-; SSE-NEXT:    pmaddwd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 # [65531,7,65527,65525,13,65521,17,65517]
-; SSE-NEXT:    pextrd $2, %xmm0, %eax
-; SSE-NEXT:    retq
-;
-; AVX-LABEL: combine_pmaddwd_constant:
-; AVX:       # %bb.0:
-; AVX-NEXT:    vpmovsxbw {{.*#+}} xmm0 = [65535,2,3,65532,65531,6,7,65528]
-; AVX-NEXT:    vpmaddwd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 # [65531,7,65527,65525,13,65521,17,65517]
-; AVX-NEXT:    vpextrd $2, %xmm0, %eax
-; AVX-NEXT:    retq
+; CHECK-LABEL: combine_pmaddwd_constant:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    movl $-155, %eax
+; CHECK-NEXT:    retq
   %1 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> <i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8>, <8 x i16> <i16 -5, i16 7, i16 -9, i16 -11, i16 13, i16 -15, i16 17, i16 -19>)
   %2 = extractelement <4 x i32> %1, i32 2 ; (-5*13)+(6*-15) = -155
   ret i32 %2


        


More information about the llvm-commits mailing list