[llvm] [InstCombine] Pattern match minmax calls for unsigned saturation. (PR #99250)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 23 01:04:44 PDT 2024


================
@@ -1117,68 +1117,108 @@ static Instruction *moveAddAfterMinMax(IntrinsicInst *II,
   return IsSigned ? BinaryOperator::CreateNSWAdd(NewMinMax, Add->getOperand(1))
                   : BinaryOperator::CreateNUWAdd(NewMinMax, Add->getOperand(1));
 }
-/// Match a sadd_sat or ssub_sat which is using min/max to clamp the value.
-Instruction *InstCombinerImpl::matchSAddSubSat(IntrinsicInst &MinMax1) {
+/// Match a [s|u]add_sat or [s|u]sub_sat which is using min/max to clamp the
+/// value.
+Instruction *InstCombinerImpl::matchAddSubSat(IntrinsicInst &MinMax1) {
   Type *Ty = MinMax1.getType();
 
-  // We are looking for a tree of:
-  // max(INT_MIN, min(INT_MAX, add(sext(A), sext(B))))
-  // Where the min and max could be reversed
-  Instruction *MinMax2;
+  // 1. We are looking for a tree of signed saturation:
+  //    smax(SINT_MIN, smin(SINT_MAX, add|sub(sext(A), sext(B))))
+  //    Where the smin and smax could be reversed.
+  // 2. A tree of unsigned saturation:
+  //    smax(UINT_MIN, sub(zext(A), zext(B)))
+  //    Or umin(UINT_MAX, add(zext(A), zext(B))).
+  Instruction *MinMax2 = nullptr;
   BinaryOperator *AddSub;
-  const APInt *MinValue, *MaxValue;
-  if (match(&MinMax1, m_SMin(m_Instruction(MinMax2), m_APInt(MaxValue)))) {
-    if (!match(MinMax2, m_SMax(m_BinOp(AddSub), m_APInt(MinValue))))
+  const APInt *MinValue = nullptr, *MaxValue = nullptr;
+  bool IsUnsignedSaturate = false;
+  // Pattern match for unsigned saturation.
+  if (match(&MinMax1, m_UMin(m_BinOp(AddSub), m_APInt(MaxValue)))) {
+    // Bail out if AddSub could be negative.
+    if (!isKnownNonNegative(AddSub, SQ.getWithInstruction(AddSub)))
       return nullptr;
-  } else if (match(&MinMax1,
-                   m_SMax(m_Instruction(MinMax2), m_APInt(MinValue)))) {
-    if (!match(MinMax2, m_SMin(m_BinOp(AddSub), m_APInt(MaxValue))))
+    IsUnsignedSaturate = true;
+  } else if (match(&MinMax1, m_SMax(m_BinOp(AddSub), m_APInt(MinValue)))) {
+    if (!MinValue->isZero())
       return nullptr;
-  } else
-    return nullptr;
+    IsUnsignedSaturate = true;
+  } else {
+    // Pattern match for signed saturation.
+    if (match(&MinMax1, m_SMin(m_Instruction(MinMax2), m_APInt(MaxValue)))) {
+      if (!match(MinMax2, m_SMax(m_BinOp(AddSub), m_APInt(MinValue))))
+        return nullptr;
+    } else if (match(&MinMax1,
+                     m_SMax(m_Instruction(MinMax2), m_APInt(MinValue)))) {
+      if (!match(MinMax2, m_SMin(m_BinOp(AddSub), m_APInt(MaxValue))))
+        return nullptr;
+    } else
+      return nullptr;
+  }
 
   // Check that the constants clamp a saturate, and that the new type would be
   // sensible to convert to.
-  if (!(*MaxValue + 1).isPowerOf2() || -*MinValue != *MaxValue + 1)
+  if ((MaxValue && !(*MaxValue + 1).isPowerOf2()) ||
+      (!IsUnsignedSaturate && -*MinValue != *MaxValue + 1))
     return nullptr;
-  // In what bitwidth can this be treated as saturating arithmetics?
-  unsigned NewBitWidth = (*MaxValue + 1).logBase2() + 1;
+
+  // Trying to decide the bitwidth for saturating arithmetics.
+  Value *Op0 = AddSub->getOperand(0);
+  Value *Op1 = AddSub->getOperand(1);
+  unsigned Op0MaxBitWidth =
+      IsUnsignedSaturate ? computeKnownBits(Op0, 0, AddSub).countMaxActiveBits()
+                         : ComputeMaxSignificantBits(Op0, 0, AddSub);
+  unsigned Op1MaxBitWidth =
+      IsUnsignedSaturate ? computeKnownBits(Op1, 0, AddSub).countMaxActiveBits()
+                         : ComputeMaxSignificantBits(Op1, 0, AddSub);
+  unsigned NewBitWidth = IsUnsignedSaturate
+                             ? std::max(Op0MaxBitWidth, Op1MaxBitWidth)
+                             : (*MaxValue + 1).logBase2() + 1;
----------------
goldsteinn wrote:

The compute known bits stuff should be after the:
```
  if (!shouldChangeType(Ty->getScalarType()->getIntegerBitWidth(), NewBitWidth))
    return nullptr;

  // Also make sure that the inner min/max and the add/sub have one use.
```

Checks below.

https://github.com/llvm/llvm-project/pull/99250


More information about the llvm-commits mailing list