[llvm] [InstCombine] Add support for transforming manual signed mul overflows (PR #93370)
via llvm-commits
llvm-commits at lists.llvm.org
Sat May 25 07:26:37 PDT 2024
https://github.com/AtariDreams updated https://github.com/llvm/llvm-project/pull/93370
>From 226d62eb18e15e741d898b1c44d6f9700919507a Mon Sep 17 00:00:00 2001
From: Rose <gfunni234 at gmail.com>
Date: Fri, 24 May 2024 23:54:06 -0400
Subject: [PATCH 1/2] Pre-commit tests (NFC)
---
.../Transforms/InstCombine/overflow-mul.ll | 39 +++++++++++++++++++
1 file changed, 39 insertions(+)
diff --git a/llvm/test/Transforms/InstCombine/overflow-mul.ll b/llvm/test/Transforms/InstCombine/overflow-mul.ll
index 6b5a65c03ee10..3bc3c290ca8c9 100644
--- a/llvm/test/Transforms/InstCombine/overflow-mul.ll
+++ b/llvm/test/Transforms/InstCombine/overflow-mul.ll
@@ -343,3 +343,42 @@ define i32 @extra_and_use_mask_too_large(i32 %x, i32 %y) {
%retval = zext i1 %overflow to i32
ret i32 %retval
}
+
+define i32 @smul(i32 %a, i32 %b) {
+; CHECK-LABEL: @smul(
+; CHECK-NEXT: [[CONV:%.*]] = sext i32 [[A:%.*]] to i64
+; CHECK-NEXT: [[CONV1:%.*]] = sext i32 [[B:%.*]] to i64
+; CHECK-NEXT: [[MUL:%.*]] = mul nsw i64 [[CONV1]], [[CONV]]
+; CHECK-NEXT: [[TMP1:%.*]] = add nsw i64 [[MUL]], -2147483648
+; CHECK-NEXT: [[TMP2:%.*]] = icmp ult i64 [[TMP1]], -4294967296
+; CHECK-NEXT: [[CONV3:%.*]] = zext i1 [[TMP2]] to i32
+; CHECK-NEXT: ret i32 [[CONV3]]
+;
+ %conv = sext i32 %a to i64
+ %conv1 = sext i32 %b to i64
+ %mul = mul nsw i64 %conv1, %conv
+ %1 = add nsw i64 %mul, -2147483648
+ %2 = icmp ult i64 %1, -4294967296
+ %conv3 = zext i1 %2 to i32
+ ret i32 %conv3
+}
+
+define i32 @smul2(i32 %a, i32 %b) {
+; CHECK-LABEL: @smul2(
+; CHECK-NEXT: [[CONV:%.*]] = sext i32 [[A:%.*]] to i64
+; CHECK-NEXT: [[CONV1:%.*]] = sext i32 [[B:%.*]] to i64
+; CHECK-NEXT: [[MUL:%.*]] = mul nsw i64 [[CONV1]], [[CONV]]
+; CHECK-NEXT: [[TMP1:%.*]] = add i64 [[MUL]], 2147483647
+; CHECK-NEXT: [[TMP2:%.*]] = icmp ult i64 [[TMP1]], 4294967295
+; CHECK-NEXT: [[CONV3:%.*]] = zext i1 [[TMP2]] to i32
+; CHECK-NEXT: ret i32 [[CONV3]]
+;
+ %conv = sext i32 %a to i64
+ %conv1 = sext i32 %b to i64
+ %mul = mul nsw i64 %conv1, %conv
+ %cmp = icmp sle i64 %mul, 2147483647
+ %cmp2 = icmp sgt i64 %mul, -2147483648
+ %1 = select i1 %cmp, i1 %cmp2, i1 false
+ %conv3 = zext i1 %1 to i32
+ ret i32 %conv3
+}
>From 9ed75dd0d7dc9d4aa409579920a8a5490d7c41c1 Mon Sep 17 00:00:00 2001
From: Rose <gfunni234 at gmail.com>
Date: Fri, 24 May 2024 23:50:47 -0400
Subject: [PATCH 2/2] [InstCombine] Add support for transforming manual checks
for signed mul overflows
Alive2 Proof:
https://alive2.llvm.org/ce/z/m-kd7-
---
.../InstCombine/InstCombineCompares.cpp | 176 ++++++++++++++++++
.../Transforms/InstCombine/overflow-mul.ll | 19 +-
2 files changed, 183 insertions(+), 12 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 430f3e12fa5b8..c25380aa22494 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -6170,6 +6170,172 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
return ExtractValueInst::Create(Call, 1);
}
+/// Recognize and process idiom involving test for multiplication
+/// overflow.
+///
+/// The caller has matched a pattern of the form:
+/// I = cmp u add (mul(sext A, sext B), V, W
+/// The function checks if this is a test for overflow and if so replaces
+/// multiplication with call to 'mul.with.overflow' intrinsic.
+///
+/// \param I Compare instruction.
+/// \param MulVal Result of 'mult' instruction. It is one of the arguments of
+/// the compare instruction. Must be of integer type.
+/// \param OtherVal The other argument of compare instruction.
+/// \returns Instruction which must replace the compare instruction, NULL if no
+/// replacement required.
+static Instruction *processSMulSExtIdiom(ICmpInst &I, Value *MulVal,
+ const APInt *AddVal,
+ const APInt *OtherVal,
+ InstCombinerImpl &IC) {
+ // Don't bother doing this transformation for pointers, don't do it for
+ // vectors.
+ if (!isa<IntegerType>(MulVal->getType()))
+ return nullptr;
+
+ auto *MulInstr = dyn_cast<Instruction>(MulVal);
+ if (!MulInstr)
+ return nullptr;
+ assert(MulInstr->getOpcode() == Instruction::Mul);
+
+ auto *LHS = cast<SExtInst>(MulInstr->getOperand(0)),
+ *RHS = cast<SExtInst>(MulInstr->getOperand(1));
+ assert(LHS->getOpcode() == Instruction::SExt);
+ assert(RHS->getOpcode() == Instruction::SExt);
+ Value *A = LHS->getOperand(0), *B = RHS->getOperand(0);
+
+ // Calculate type and width of the result produced by mul.with.overflow.
+ Type *TyA = A->getType(), *TyB = B->getType();
+ unsigned WidthA = TyA->getPrimitiveSizeInBits(),
+ WidthB = TyB->getPrimitiveSizeInBits();
+ unsigned MulWidth;
+ Type *MulType;
+ if (WidthB > WidthA) {
+ MulWidth = WidthB;
+ MulType = TyB;
+ } else {
+ MulWidth = WidthA;
+ MulType = TyA;
+ }
+
+ // In order to replace the original mul with a narrower mul.with.overflow,
+ // all uses must ignore upper bits of the product. The number of used low
+ // bits must be not greater than the width of mul.with.overflow.
+ if (MulVal->hasNUsesOrMore(2))
+ for (User *U : MulVal->users()) {
+ if (U == &I)
+ continue;
+ if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
+ // Check if truncation ignores bits above MulWidth.
+ unsigned TruncWidth = TI->getType()->getPrimitiveSizeInBits();
+ if (TruncWidth > MulWidth)
+ return nullptr;
+ } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
+ // Check if AND ignores bits above MulWidth.
+ if (BO->getOpcode() != Instruction::And)
+ return nullptr;
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1))) {
+ const APInt &CVal = CI->getValue();
+ if (CVal.getBitWidth() - CVal.countl_zero() > MulWidth)
+ return nullptr;
+ } else {
+ // In this case we could have the operand of the binary operation
+ // being defined in another block, and performing the replacement
+ // could break the dominance relation.
+ return nullptr;
+ }
+ } else {
+ // Other uses prohibit this transformation.
+ return nullptr;
+ }
+ }
+
+ // Recognize patterns
+ bool IsInverse = false;
+ switch (I.getPredicate()) {
+ case ICmpInst::ICMP_ULT: {
+ // Recognize pattern:
+ // mulval = mul(sext A, sext B)
+ // addval = add (mulval, min)
+ // cmp ult addval, -min * 2 + 1
+ APInt MinVal = APInt::getSignedMinValue(MulWidth);
+ MinVal = MinVal.sext(OtherVal->getBitWidth());
+ APInt MinMinVal = APInt::getSignedMinValue(MulWidth + 1);
+ MinMinVal = MinMinVal.sext(OtherVal->getBitWidth());
+ if (MinVal.eq(*AddVal) && MinMinVal.eq(*OtherVal))
+ break; // Recognized
+
+ // Recognize pattern:
+ // mulval = mul(sext A, sext B)
+ // addval = add (mulval, signedMax)
+ // cmp ult addval, unsignedMax
+ APInt MaxVal = APInt::getSignedMaxValue(MulWidth);
+ MaxVal = MaxVal.zext(OtherVal->getBitWidth()) + 1;
+ APInt MaxMaxVal = APInt::getMaxValue(MulWidth);
+ MaxMaxVal = MaxMaxVal.zext(OtherVal->getBitWidth()) + 1;
+ if (MaxVal.eq(*AddVal) && MaxMaxVal.eq(*OtherVal)) {
+ IsInverse = true;
+ break; // Recognized
+ }
+ return nullptr;
+ }
+
+ default:
+ return nullptr;
+ }
+
+ InstCombiner::BuilderTy &Builder = IC.Builder;
+ Builder.SetInsertPoint(MulInstr);
+
+ // Replace: mul(sext A, sext B) --> mul.with.overflow(A, B)
+ Value *MulA = A, *MulB = B;
+ if (WidthA < MulWidth)
+ MulA = Builder.CreateSExt(A, MulType);
+ if (WidthB < MulWidth)
+ MulB = Builder.CreateSExt(B, MulType);
+ Function *F = Intrinsic::getDeclaration(
+ I.getModule(), Intrinsic::smul_with_overflow, MulType);
+ CallInst *Call = Builder.CreateCall(F, {MulA, MulB}, "smul");
+ IC.addToWorklist(MulInstr);
+
+ // If there are uses of mul result other than the comparison, we know that
+ // they are truncation or binary AND. Change them to use result of
+ // mul.with.overflow and adjust properly mask/size.
+ if (MulVal->hasNUsesOrMore(2)) {
+ Value *Mul = Builder.CreateExtractValue(Call, 0, "smul.value");
+ for (User *U : make_early_inc_range(MulVal->users())) {
+ if (U == &I)
+ continue;
+ if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
+ if (TI->getType()->getPrimitiveSizeInBits() == MulWidth)
+ IC.replaceInstUsesWith(*TI, Mul);
+ else
+ TI->setOperand(0, Mul);
+ } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
+ assert(BO->getOpcode() == Instruction::And);
+ // Replace (mul & mask) --> zext (mul.with.overflow & short_mask)
+ ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1));
+ APInt ShortMask = CI->getValue().trunc(MulWidth);
+ Value *ShortAnd = Builder.CreateAnd(Mul, ShortMask);
+ Value *Zext = Builder.CreateZExt(ShortAnd, BO->getType());
+ IC.replaceInstUsesWith(*BO, Zext);
+ } else {
+ llvm_unreachable("Unexpected Binary operation");
+ }
+ IC.addToWorklist(cast<Instruction>(U));
+ }
+ }
+
+ // The original icmp gets replaced with the overflow value, maybe inverted
+ // depending on predicate.
+ if (IsInverse) {
+ Value *Res = Builder.CreateExtractValue(Call, 1);
+ return BinaryOperator::CreateNot(Res);
+ }
+
+ return ExtractValueInst::Create(Call, 1);
+}
+
/// When performing a comparison against a constant, it is possible that not all
/// the bits in the LHS are demanded. This helper method computes the mask that
/// IS demanded.
@@ -7415,6 +7581,16 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
return R;
}
+ // (sext X) * (sext Y) --> llvm.smul.with.overflow.
+ const APInt *C1;
+ if (match(Op0, m_Add(m_NSWMul(m_SExt(m_Value(X)), m_SExt(m_Value(Y))),
+ m_APInt(C))) &&
+ match(Op1, m_APInt(C1))) {
+ if (Instruction *R = processSMulSExtIdiom(
+ I, cast<Instruction>(Op0)->getOperand(0), C, C1, *this))
+ return R;
+ }
+
// Signbit test folds
// Fold (X u>> BitWidth - 1 Pred ZExt(i1)) --> X s< 0 Pred i1
// Fold (X s>> BitWidth - 1 Pred SExt(i1)) --> X s< 0 Pred i1
diff --git a/llvm/test/Transforms/InstCombine/overflow-mul.ll b/llvm/test/Transforms/InstCombine/overflow-mul.ll
index 3bc3c290ca8c9..582697b3d81e8 100644
--- a/llvm/test/Transforms/InstCombine/overflow-mul.ll
+++ b/llvm/test/Transforms/InstCombine/overflow-mul.ll
@@ -346,12 +346,9 @@ define i32 @extra_and_use_mask_too_large(i32 %x, i32 %y) {
define i32 @smul(i32 %a, i32 %b) {
; CHECK-LABEL: @smul(
-; CHECK-NEXT: [[CONV:%.*]] = sext i32 [[A:%.*]] to i64
-; CHECK-NEXT: [[CONV1:%.*]] = sext i32 [[B:%.*]] to i64
-; CHECK-NEXT: [[MUL:%.*]] = mul nsw i64 [[CONV1]], [[CONV]]
-; CHECK-NEXT: [[TMP1:%.*]] = add nsw i64 [[MUL]], -2147483648
-; CHECK-NEXT: [[TMP2:%.*]] = icmp ult i64 [[TMP1]], -4294967296
-; CHECK-NEXT: [[CONV3:%.*]] = zext i1 [[TMP2]] to i32
+; CHECK-NEXT: [[SMUL:%.*]] = call { i32, i1 } @llvm.smul.with.overflow.i32(i32 [[B:%.*]], i32 [[A:%.*]])
+; CHECK-NEXT: [[TMP1:%.*]] = extractvalue { i32, i1 } [[SMUL]], 1
+; CHECK-NEXT: [[CONV3:%.*]] = zext i1 [[TMP1]] to i32
; CHECK-NEXT: ret i32 [[CONV3]]
;
%conv = sext i32 %a to i64
@@ -365,11 +362,9 @@ define i32 @smul(i32 %a, i32 %b) {
define i32 @smul2(i32 %a, i32 %b) {
; CHECK-LABEL: @smul2(
-; CHECK-NEXT: [[CONV:%.*]] = sext i32 [[A:%.*]] to i64
-; CHECK-NEXT: [[CONV1:%.*]] = sext i32 [[B:%.*]] to i64
-; CHECK-NEXT: [[MUL:%.*]] = mul nsw i64 [[CONV1]], [[CONV]]
-; CHECK-NEXT: [[TMP1:%.*]] = add i64 [[MUL]], 2147483647
-; CHECK-NEXT: [[TMP2:%.*]] = icmp ult i64 [[TMP1]], 4294967295
+; CHECK-NEXT: [[SMUL:%.*]] = call { i32, i1 } @llvm.smul.with.overflow.i32(i32 [[B:%.*]], i32 [[A:%.*]])
+; CHECK-NEXT: [[TMP1:%.*]] = extractvalue { i32, i1 } [[SMUL]], 1
+; CHECK-NEXT: [[TMP2:%.*]] = xor i1 [[TMP1]], true
; CHECK-NEXT: [[CONV3:%.*]] = zext i1 [[TMP2]] to i32
; CHECK-NEXT: ret i32 [[CONV3]]
;
@@ -377,7 +372,7 @@ define i32 @smul2(i32 %a, i32 %b) {
%conv1 = sext i32 %b to i64
%mul = mul nsw i64 %conv1, %conv
%cmp = icmp sle i64 %mul, 2147483647
- %cmp2 = icmp sgt i64 %mul, -2147483648
+ %cmp2 = icmp sge i64 %mul, -2147483648
%1 = select i1 %cmp, i1 %cmp2, i1 false
%conv3 = zext i1 %1 to i32
ret i32 %conv3
More information about the llvm-commits
mailing list