[llvm] 741c127 - [SelectionDAG] Add computeOverflowForSignedMul / computeOverflowForUnsignedMul overflow handlers

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 7 02:03:43 PDT 2023


Author: Mohamed Atef
Date: 2023-09-07T10:03:18+01:00
New Revision: 741c1278175b9354442cd2143e1452714dc020a2

URL: https://github.com/llvm/llvm-project/commit/741c1278175b9354442cd2143e1452714dc020a2
DIFF: https://github.com/llvm/llvm-project/commit/741c1278175b9354442cd2143e1452714dc020a2.diff

LOG: [SelectionDAG] Add computeOverflowForSignedMul / computeOverflowForUnsignedMul overflow handlers

Support signed multiplication
Support unsigned multiplication

Differential Revision: https://reviews.llvm.org/D159406

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/SelectionDAG.h
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/test/CodeGen/X86/combine-mulo.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index eaa2b74a73f850e..f25b3ae4b2d8165 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -2020,6 +2020,24 @@ class SelectionDAG {
     return computeOverflowForSub(IsSigned, N0, N1) == OFK_Never;
   }
 
+  /// Determine if the result of the signed mul of 2 nodes can overflow.
+  OverflowKind computeOverflowForSignedMul(SDValue N0, SDValue N1) const;
+
+  /// Determine if the result of the unsigned mul of 2 nodes can overflow.
+  OverflowKind computeOverflowForUnsignedMul(SDValue N0, SDValue N1) const;
+
+  /// Determine if the result of the mul of 2 nodes can overflow.
+  OverflowKind computeOverflowForMul(bool IsSigned, SDValue N0,
+                                     SDValue N1) const {
+    return IsSigned ? computeOverflowForSignedMul(N0, N1)
+                    : computeOverflowForUnsignedMul(N0, N1);
+  }
+
+  /// Determine if the result of the mul of 2 nodes can never overflow.
+  bool willNotOverflowMul(bool IsSigned, SDValue N0, SDValue N1) const {
+    return computeOverflowForMul(IsSigned, N0, N1) == OFK_Never;
+  }
+
   /// Test if the given value is known to have exactly one bit set. This 
diff ers
   /// from computeKnownBits in that it doesn't necessarily determine which bit
   /// is set.

diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 82c82c1c19bf082..d917e8c00c4f92a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -5432,34 +5432,18 @@ SDValue DAGCombiner::visitMULO(SDNode *N) {
     return DAG.getNode(IsSigned ? ISD::SADDO : ISD::UADDO, DL,
                        N->getVTList(), N0, N0);
 
-  if (IsSigned) {
-    // A 1 bit SMULO overflows if both inputs are 1.
-    if (VT.getScalarSizeInBits() == 1) {
-      SDValue And = DAG.getNode(ISD::AND, DL, VT, N0, N1);
-      return CombineTo(N, And,
-                       DAG.getSetCC(DL, CarryVT, And,
-                                    DAG.getConstant(0, DL, VT), ISD::SETNE));
-    }
-
-    // Multiplying n * m significant bits yields a result of n + m significant
-    // bits. If the total number of significant bits does not exceed the
-    // result bit width (minus 1), there is no overflow.
-    unsigned SignBits = DAG.ComputeNumSignBits(N0);
-    if (SignBits > 1)
-      SignBits += DAG.ComputeNumSignBits(N1);
-    if (SignBits > VT.getScalarSizeInBits() + 1)
-      return CombineTo(N, DAG.getNode(ISD::MUL, DL, VT, N0, N1),
-                       DAG.getConstant(0, DL, CarryVT));
-  } else {
-    KnownBits N1Known = DAG.computeKnownBits(N1);
-    KnownBits N0Known = DAG.computeKnownBits(N0);
-    bool Overflow;
-    (void)N0Known.getMaxValue().umul_ov(N1Known.getMaxValue(), Overflow);
-    if (!Overflow)
-      return CombineTo(N, DAG.getNode(ISD::MUL, DL, VT, N0, N1),
-                       DAG.getConstant(0, DL, CarryVT));
+  // A 1 bit SMULO overflows if both inputs are 1.
+  if (IsSigned && VT.getScalarSizeInBits() == 1) {
+    SDValue And = DAG.getNode(ISD::AND, DL, VT, N0, N1);
+    SDValue Cmp = DAG.getSetCC(DL, CarryVT, And,
+                               DAG.getConstant(0, DL, VT), ISD::SETNE);
+    return CombineTo(N, And, Cmp);
   }
 
+  // If it cannot overflow, transform into a mul.
+  if (DAG.willNotOverflowMul(IsSigned, N0, N1))
+    return CombineTo(N, DAG.getNode(ISD::MUL, DL, VT, N0, N1),
+                     DAG.getConstant(0, DL, CarryVT));
   return SDValue();
 }
 

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index eba51281dbf8101..b2ba747ce209867 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -4099,6 +4099,49 @@ SelectionDAG::computeOverflowForUnsignedSub(SDValue N0, SDValue N1) const {
   return OFK_Sometime;
 }
 
+SelectionDAG::OverflowKind
+SelectionDAG::computeOverflowForUnsignedMul(SDValue N0, SDValue N1) const {
+  // X * 0 and X * 1 never overflow.
+  if (isNullConstant(N1) || isOneConstant(N1))
+    return OFK_Never;
+
+  KnownBits N0Known = computeKnownBits(N0);
+  KnownBits N1Known = computeKnownBits(N1);
+  ConstantRange N0Range = ConstantRange::fromKnownBits(N0Known, false);
+  ConstantRange N1Range = ConstantRange::fromKnownBits(N1Known, false);
+  return mapOverflowResult(N0Range.unsignedMulMayOverflow(N1Range));
+}
+
+SelectionDAG::OverflowKind
+SelectionDAG::computeOverflowForSignedMul(SDValue N0, SDValue N1) const {
+  // X * 0 and X * 1 never overflow.
+  if (isNullConstant(N1) || isOneConstant(N1))
+    return OFK_Never;
+
+  // Get the size of the result.
+  unsigned BitWidth = N0.getScalarValueSizeInBits();
+
+  // Sum of the sign bits.
+  unsigned SignBits = ComputeNumSignBits(N0) + ComputeNumSignBits(N1);
+
+  // If we have enough sign bits, then there's no overflow.
+  if (SignBits > BitWidth + 1)
+    return OFK_Never;
+
+  if (SignBits == BitWidth + 1) {
+    // The overflow occurs when the true multiplication of the
+    // the operands is the minimum negative number.
+    KnownBits N0Known = computeKnownBits(N0);
+    KnownBits N1Known = computeKnownBits(N1);
+    // If one of the operands is non-negative, then there's no
+    // overflow.
+    if (N0Known.isNonNegative() || N1Known.isNonNegative())
+      return OFK_Never;
+  }
+
+  return OFK_Sometime;
+}
+
 bool SelectionDAG::isKnownToBeAPowerOfTwo(SDValue Val, unsigned Depth) const {
   if (Depth >= MaxRecursionDepth)
     return false; // Limit search depth.

diff  --git a/llvm/test/CodeGen/X86/combine-mulo.ll b/llvm/test/CodeGen/X86/combine-mulo.ll
index e97cb589ab117e5..896269a288f56b2 100644
--- a/llvm/test/CodeGen/X86/combine-mulo.ll
+++ b/llvm/test/CodeGen/X86/combine-mulo.ll
@@ -96,7 +96,7 @@ define { i32, i1 } @combine_smul_nsw(i32 %a, i32 %b) {
 ; CHECK-NEXT:    andl $4095, %edi # imm = 0xFFF
 ; CHECK-NEXT:    andl $524287, %eax # imm = 0x7FFFF
 ; CHECK-NEXT:    imull %edi, %eax
-; CHECK-NEXT:    seto %dl
+; CHECK-NEXT:    xorl %edx, %edx
 ; CHECK-NEXT:    retq
   %aa = and i32 %a, 4095 ; 0xfff
   %bb = and i32 %b, 524287; 0x7ffff
@@ -109,19 +109,8 @@ define { <4 x i32>, <4 x i1> } @combine_vec_smul_nsw(<4 x i32> %a, <4 x i32> %b)
 ; SSE:       # %bb.0:
 ; SSE-NEXT:    pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
 ; SSE-NEXT:    pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
-; SSE-NEXT:    pshufd {{.*#+}} xmm2 = xmm1[1,1,3,3]
-; SSE-NEXT:    pshufd {{.*#+}} xmm3 = xmm0[1,1,3,3]
-; SSE-NEXT:    pmuldq %xmm2, %xmm3
-; SSE-NEXT:    movdqa %xmm0, %xmm2
-; SSE-NEXT:    pmuldq %xmm1, %xmm2
-; SSE-NEXT:    pshufd {{.*#+}} xmm2 = xmm2[1,1,3,3]
-; SSE-NEXT:    pblendw {{.*#+}} xmm2 = xmm2[0,1],xmm3[2,3],xmm2[4,5],xmm3[6,7]
-; SSE-NEXT:    pxor %xmm3, %xmm3
-; SSE-NEXT:    pcmpeqd %xmm2, %xmm3
-; SSE-NEXT:    pcmpeqd %xmm2, %xmm2
-; SSE-NEXT:    pxor %xmm3, %xmm2
 ; SSE-NEXT:    pmulld %xmm1, %xmm0
-; SSE-NEXT:    movdqa %xmm2, %xmm1
+; SSE-NEXT:    pxor %xmm1, %xmm1
 ; SSE-NEXT:    retq
 ;
 ; AVX-LABEL: combine_vec_smul_nsw:
@@ -129,18 +118,9 @@ define { <4 x i32>, <4 x i1> } @combine_vec_smul_nsw(<4 x i32> %a, <4 x i32> %b)
 ; AVX-NEXT:    vpbroadcastd {{.*#+}} xmm2 = [4095,4095,4095,4095]
 ; AVX-NEXT:    vpand %xmm2, %xmm0, %xmm0
 ; AVX-NEXT:    vpbroadcastd {{.*#+}} xmm2 = [524287,524287,524287,524287]
-; AVX-NEXT:    vpand %xmm2, %xmm1, %xmm2
-; AVX-NEXT:    vpshufd {{.*#+}} xmm1 = xmm2[1,1,3,3]
-; AVX-NEXT:    vpshufd {{.*#+}} xmm3 = xmm0[1,1,3,3]
-; AVX-NEXT:    vpmuldq %xmm1, %xmm3, %xmm1
-; AVX-NEXT:    vpmuldq %xmm2, %xmm0, %xmm3
-; AVX-NEXT:    vpshufd {{.*#+}} xmm3 = xmm3[1,1,3,3]
-; AVX-NEXT:    vpblendd {{.*#+}} xmm1 = xmm3[0],xmm1[1],xmm3[2],xmm1[3]
-; AVX-NEXT:    vpxor %xmm3, %xmm3, %xmm3
-; AVX-NEXT:    vpcmpeqd %xmm3, %xmm1, %xmm1
-; AVX-NEXT:    vpcmpeqd %xmm3, %xmm3, %xmm3
-; AVX-NEXT:    vpxor %xmm3, %xmm1, %xmm1
-; AVX-NEXT:    vpmulld %xmm2, %xmm0, %xmm0
+; AVX-NEXT:    vpand %xmm2, %xmm1, %xmm1
+; AVX-NEXT:    vpmulld %xmm1, %xmm0, %xmm0
+; AVX-NEXT:    vpxor %xmm1, %xmm1, %xmm1
 ; AVX-NEXT:    retq
   %aa = and <4 x i32> %a, <i32 4095, i32 4095, i32 4095, i32 4095>
   %bb = and <4 x i32> %b, <i32 524287, i32 524287, i32 524287, i32 524287>


        


More information about the llvm-commits mailing list