[llvm] [InstCombine] Pattern match minmax calls for unsigned saturation. (PR #99250)

Huihui Zhang via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 23 21:05:26 PDT 2024


================
@@ -1117,68 +1117,108 @@ static Instruction *moveAddAfterMinMax(IntrinsicInst *II,
   return IsSigned ? BinaryOperator::CreateNSWAdd(NewMinMax, Add->getOperand(1))
                   : BinaryOperator::CreateNUWAdd(NewMinMax, Add->getOperand(1));
 }
-/// Match a sadd_sat or ssub_sat which is using min/max to clamp the value.
-Instruction *InstCombinerImpl::matchSAddSubSat(IntrinsicInst &MinMax1) {
+/// Match a [s|u]add_sat or [s|u]sub_sat which is using min/max to clamp the
+/// value.
+Instruction *InstCombinerImpl::matchAddSubSat(IntrinsicInst &MinMax1) {
   Type *Ty = MinMax1.getType();
 
-  // We are looking for a tree of:
-  // max(INT_MIN, min(INT_MAX, add(sext(A), sext(B))))
-  // Where the min and max could be reversed
-  Instruction *MinMax2;
+  // 1. We are looking for a tree of signed saturation:
+  //    smax(SINT_MIN, smin(SINT_MAX, add|sub(sext(A), sext(B))))
+  //    Where the smin and smax could be reversed.
+  // 2. A tree of unsigned saturation:
+  //    smax(UINT_MIN, sub(zext(A), zext(B)))
+  //    Or umin(UINT_MAX, add(zext(A), zext(B))).
+  Instruction *MinMax2 = nullptr;
   BinaryOperator *AddSub;
-  const APInt *MinValue, *MaxValue;
-  if (match(&MinMax1, m_SMin(m_Instruction(MinMax2), m_APInt(MaxValue)))) {
-    if (!match(MinMax2, m_SMax(m_BinOp(AddSub), m_APInt(MinValue))))
+  const APInt *MinValue = nullptr, *MaxValue = nullptr;
+  bool IsUnsignedSaturate = false;
+  // Pattern match for unsigned saturation.
+  if (match(&MinMax1, m_UMin(m_BinOp(AddSub), m_APInt(MaxValue)))) {
+    // Bail out if AddSub could be negative.
+    if (!isKnownNonNegative(AddSub, SQ.getWithInstruction(AddSub)))
       return nullptr;
-  } else if (match(&MinMax1,
-                   m_SMax(m_Instruction(MinMax2), m_APInt(MinValue)))) {
-    if (!match(MinMax2, m_SMin(m_BinOp(AddSub), m_APInt(MaxValue))))
+    IsUnsignedSaturate = true;
+  } else if (match(&MinMax1, m_SMax(m_BinOp(AddSub), m_APInt(MinValue)))) {
+    if (!MinValue->isZero())
       return nullptr;
-  } else
-    return nullptr;
+    IsUnsignedSaturate = true;
+  } else {
+    // Pattern match for signed saturation.
+    if (match(&MinMax1, m_SMin(m_Instruction(MinMax2), m_APInt(MaxValue)))) {
+      if (!match(MinMax2, m_SMax(m_BinOp(AddSub), m_APInt(MinValue))))
+        return nullptr;
+    } else if (match(&MinMax1,
+                     m_SMax(m_Instruction(MinMax2), m_APInt(MinValue)))) {
+      if (!match(MinMax2, m_SMin(m_BinOp(AddSub), m_APInt(MaxValue))))
+        return nullptr;
+    } else
+      return nullptr;
+  }
 
   // Check that the constants clamp a saturate, and that the new type would be
   // sensible to convert to.
-  if (!(*MaxValue + 1).isPowerOf2() || -*MinValue != *MaxValue + 1)
+  if ((MaxValue && !(*MaxValue + 1).isPowerOf2()) ||
+      (!IsUnsignedSaturate && -*MinValue != *MaxValue + 1))
     return nullptr;
-  // In what bitwidth can this be treated as saturating arithmetics?
-  unsigned NewBitWidth = (*MaxValue + 1).logBase2() + 1;
+
+  // Trying to decide the bitwidth for saturating arithmetics.
+  Value *Op0 = AddSub->getOperand(0);
+  Value *Op1 = AddSub->getOperand(1);
+  unsigned Op0MaxBitWidth =
+      IsUnsignedSaturate ? computeKnownBits(Op0, 0, AddSub).countMaxActiveBits()
+                         : ComputeMaxSignificantBits(Op0, 0, AddSub);
+  unsigned Op1MaxBitWidth =
+      IsUnsignedSaturate ? computeKnownBits(Op1, 0, AddSub).countMaxActiveBits()
+                         : ComputeMaxSignificantBits(Op1, 0, AddSub);
+  unsigned NewBitWidth = IsUnsignedSaturate
+                             ? std::max(Op0MaxBitWidth, Op1MaxBitWidth)
+                             : (*MaxValue + 1).logBase2() + 1;
----------------
huihzhang wrote:

The problem is when trying to fold "smax(UINT_MIN, sub(zext(A), zext(B))) -> usub_sat", the MaxValue is not given. I was trying to use computeKnownBits to determine NewBitWidth.
I pushed a new update to first try setting NewBitWidth to half of the bitwidth of MinMax1, when MaxValue is not given.
Later use the results from computeKnownBits to try to reduce NewBitWidth further, and check for legality.
Please let me know if this approach is more sensible?

https://github.com/llvm/llvm-project/pull/99250


More information about the llvm-commits mailing list