[llvm] [InstCombine] Pattern match minmax calls for unsigned saturation. (PR #99250)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 23 01:04:44 PDT 2024
================
@@ -1117,68 +1117,108 @@ static Instruction *moveAddAfterMinMax(IntrinsicInst *II,
return IsSigned ? BinaryOperator::CreateNSWAdd(NewMinMax, Add->getOperand(1))
: BinaryOperator::CreateNUWAdd(NewMinMax, Add->getOperand(1));
}
-/// Match a sadd_sat or ssub_sat which is using min/max to clamp the value.
-Instruction *InstCombinerImpl::matchSAddSubSat(IntrinsicInst &MinMax1) {
+/// Match a [s|u]add_sat or [s|u]sub_sat which is using min/max to clamp the
+/// value.
+Instruction *InstCombinerImpl::matchAddSubSat(IntrinsicInst &MinMax1) {
Type *Ty = MinMax1.getType();
- // We are looking for a tree of:
- // max(INT_MIN, min(INT_MAX, add(sext(A), sext(B))))
- // Where the min and max could be reversed
- Instruction *MinMax2;
+ // 1. We are looking for a tree of signed saturation:
+ // smax(SINT_MIN, smin(SINT_MAX, add|sub(sext(A), sext(B))))
+ // Where the smin and smax could be reversed.
+ // 2. A tree of unsigned saturation:
+ // smax(UINT_MIN, sub(zext(A), zext(B)))
+ // Or umin(UINT_MAX, add(zext(A), zext(B))).
+ Instruction *MinMax2 = nullptr;
BinaryOperator *AddSub;
- const APInt *MinValue, *MaxValue;
- if (match(&MinMax1, m_SMin(m_Instruction(MinMax2), m_APInt(MaxValue)))) {
- if (!match(MinMax2, m_SMax(m_BinOp(AddSub), m_APInt(MinValue))))
+ const APInt *MinValue = nullptr, *MaxValue = nullptr;
+ bool IsUnsignedSaturate = false;
+ // Pattern match for unsigned saturation.
+ if (match(&MinMax1, m_UMin(m_BinOp(AddSub), m_APInt(MaxValue)))) {
+ // Bail out if AddSub could be negative.
+ if (!isKnownNonNegative(AddSub, SQ.getWithInstruction(AddSub)))
return nullptr;
- } else if (match(&MinMax1,
- m_SMax(m_Instruction(MinMax2), m_APInt(MinValue)))) {
- if (!match(MinMax2, m_SMin(m_BinOp(AddSub), m_APInt(MaxValue))))
+ IsUnsignedSaturate = true;
+ } else if (match(&MinMax1, m_SMax(m_BinOp(AddSub), m_APInt(MinValue)))) {
+ if (!MinValue->isZero())
return nullptr;
- } else
- return nullptr;
+ IsUnsignedSaturate = true;
+ } else {
+ // Pattern match for signed saturation.
+ if (match(&MinMax1, m_SMin(m_Instruction(MinMax2), m_APInt(MaxValue)))) {
+ if (!match(MinMax2, m_SMax(m_BinOp(AddSub), m_APInt(MinValue))))
+ return nullptr;
+ } else if (match(&MinMax1,
+ m_SMax(m_Instruction(MinMax2), m_APInt(MinValue)))) {
+ if (!match(MinMax2, m_SMin(m_BinOp(AddSub), m_APInt(MaxValue))))
+ return nullptr;
+ } else
+ return nullptr;
+ }
// Check that the constants clamp a saturate, and that the new type would be
// sensible to convert to.
- if (!(*MaxValue + 1).isPowerOf2() || -*MinValue != *MaxValue + 1)
+ if ((MaxValue && !(*MaxValue + 1).isPowerOf2()) ||
+ (!IsUnsignedSaturate && -*MinValue != *MaxValue + 1))
return nullptr;
- // In what bitwidth can this be treated as saturating arithmetics?
- unsigned NewBitWidth = (*MaxValue + 1).logBase2() + 1;
+
+ // Trying to decide the bitwidth for saturating arithmetics.
+ Value *Op0 = AddSub->getOperand(0);
+ Value *Op1 = AddSub->getOperand(1);
+ unsigned Op0MaxBitWidth =
+ IsUnsignedSaturate ? computeKnownBits(Op0, 0, AddSub).countMaxActiveBits()
+ : ComputeMaxSignificantBits(Op0, 0, AddSub);
+ unsigned Op1MaxBitWidth =
+ IsUnsignedSaturate ? computeKnownBits(Op1, 0, AddSub).countMaxActiveBits()
+ : ComputeMaxSignificantBits(Op1, 0, AddSub);
+ unsigned NewBitWidth = IsUnsignedSaturate
+ ? std::max(Op0MaxBitWidth, Op1MaxBitWidth)
+ : (*MaxValue + 1).logBase2() + 1;
----------------
goldsteinn wrote:
The compute known bits stuff should be after the:
```
if (!shouldChangeType(Ty->getScalarType()->getIntegerBitWidth(), NewBitWidth))
return nullptr;
// Also make sure that the inner min/max and the add/sub have one use.
```
Checks below.
https://github.com/llvm/llvm-project/pull/99250
More information about the llvm-commits
mailing list