[llvm] [InstCombine] Add support for transforming manual signed mul overflows (PR #93370)

via llvm-commits llvm-commits at lists.llvm.org
Sat May 25 07:26:37 PDT 2024


https://github.com/AtariDreams updated https://github.com/llvm/llvm-project/pull/93370

>From 226d62eb18e15e741d898b1c44d6f9700919507a Mon Sep 17 00:00:00 2001
From: Rose <gfunni234 at gmail.com>
Date: Fri, 24 May 2024 23:54:06 -0400
Subject: [PATCH 1/2] Pre-commit tests (NFC)

---
 .../Transforms/InstCombine/overflow-mul.ll    | 39 +++++++++++++++++++
 1 file changed, 39 insertions(+)

diff --git a/llvm/test/Transforms/InstCombine/overflow-mul.ll b/llvm/test/Transforms/InstCombine/overflow-mul.ll
index 6b5a65c03ee10..3bc3c290ca8c9 100644
--- a/llvm/test/Transforms/InstCombine/overflow-mul.ll
+++ b/llvm/test/Transforms/InstCombine/overflow-mul.ll
@@ -343,3 +343,42 @@ define i32 @extra_and_use_mask_too_large(i32 %x, i32 %y) {
   %retval = zext i1 %overflow to i32
   ret i32 %retval
 }
+
+define i32 @smul(i32 %a, i32 %b) {
+; CHECK-LABEL: @smul(
+; CHECK-NEXT:    [[CONV:%.*]] = sext i32 [[A:%.*]] to i64
+; CHECK-NEXT:    [[CONV1:%.*]] = sext i32 [[B:%.*]] to i64
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i64 [[CONV1]], [[CONV]]
+; CHECK-NEXT:    [[TMP1:%.*]] = add nsw i64 [[MUL]], -2147483648
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ult i64 [[TMP1]], -4294967296
+; CHECK-NEXT:    [[CONV3:%.*]] = zext i1 [[TMP2]] to i32
+; CHECK-NEXT:    ret i32 [[CONV3]]
+;
+  %conv = sext i32 %a to i64
+  %conv1 = sext i32 %b to i64
+  %mul = mul nsw i64 %conv1, %conv
+  %1 = add nsw i64 %mul, -2147483648
+  %2 = icmp ult i64 %1, -4294967296
+  %conv3 = zext i1 %2 to i32
+  ret i32 %conv3
+}
+
+define i32 @smul2(i32 %a, i32 %b) {
+; CHECK-LABEL: @smul2(
+; CHECK-NEXT:    [[CONV:%.*]] = sext i32 [[A:%.*]] to i64
+; CHECK-NEXT:    [[CONV1:%.*]] = sext i32 [[B:%.*]] to i64
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i64 [[CONV1]], [[CONV]]
+; CHECK-NEXT:    [[TMP1:%.*]] = add i64 [[MUL]], 2147483647
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ult i64 [[TMP1]], 4294967295
+; CHECK-NEXT:    [[CONV3:%.*]] = zext i1 [[TMP2]] to i32
+; CHECK-NEXT:    ret i32 [[CONV3]]
+;
+  %conv = sext i32 %a to i64
+  %conv1 = sext i32 %b to i64
+  %mul = mul nsw i64 %conv1, %conv
+  %cmp = icmp sle i64 %mul, 2147483647
+  %cmp2 = icmp sgt i64 %mul, -2147483648
+  %1 = select i1 %cmp, i1 %cmp2, i1 false
+  %conv3 = zext i1 %1 to i32
+  ret i32 %conv3
+}

>From 9ed75dd0d7dc9d4aa409579920a8a5490d7c41c1 Mon Sep 17 00:00:00 2001
From: Rose <gfunni234 at gmail.com>
Date: Fri, 24 May 2024 23:50:47 -0400
Subject: [PATCH 2/2] [InstCombine] Add support for transforming manual checks
 for signed mul overflows

Alive2 Proof:
https://alive2.llvm.org/ce/z/m-kd7-
---
 .../InstCombine/InstCombineCompares.cpp       | 176 ++++++++++++++++++
 .../Transforms/InstCombine/overflow-mul.ll    |  19 +-
 2 files changed, 183 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 430f3e12fa5b8..c25380aa22494 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -6170,6 +6170,172 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
   return ExtractValueInst::Create(Call, 1);
 }
 
+/// Recognize and process idiom involving test for multiplication
+/// overflow.
+///
+/// The caller has matched a pattern of the form:
+///   I = cmp u add (mul(sext A, sext B), V, W
+/// The function checks if this is a test for overflow and if so replaces
+/// multiplication with call to 'mul.with.overflow' intrinsic.
+///
+/// \param I Compare instruction.
+/// \param MulVal Result of 'mult' instruction.  It is one of the arguments of
+///               the compare instruction.  Must be of integer type.
+/// \param OtherVal The other argument of compare instruction.
+/// \returns Instruction which must replace the compare instruction, NULL if no
+///          replacement required.
+static Instruction *processSMulSExtIdiom(ICmpInst &I, Value *MulVal,
+                                         const APInt *AddVal,
+                                         const APInt *OtherVal,
+                                         InstCombinerImpl &IC) {
+  // Don't bother doing this transformation for pointers, don't do it for
+  // vectors.
+  if (!isa<IntegerType>(MulVal->getType()))
+    return nullptr;
+
+  auto *MulInstr = dyn_cast<Instruction>(MulVal);
+  if (!MulInstr)
+    return nullptr;
+  assert(MulInstr->getOpcode() == Instruction::Mul);
+
+  auto *LHS = cast<SExtInst>(MulInstr->getOperand(0)),
+       *RHS = cast<SExtInst>(MulInstr->getOperand(1));
+  assert(LHS->getOpcode() == Instruction::SExt);
+  assert(RHS->getOpcode() == Instruction::SExt);
+  Value *A = LHS->getOperand(0), *B = RHS->getOperand(0);
+
+  // Calculate type and width of the result produced by mul.with.overflow.
+  Type *TyA = A->getType(), *TyB = B->getType();
+  unsigned WidthA = TyA->getPrimitiveSizeInBits(),
+           WidthB = TyB->getPrimitiveSizeInBits();
+  unsigned MulWidth;
+  Type *MulType;
+  if (WidthB > WidthA) {
+    MulWidth = WidthB;
+    MulType = TyB;
+  } else {
+    MulWidth = WidthA;
+    MulType = TyA;
+  }
+
+  // In order to replace the original mul with a narrower mul.with.overflow,
+  // all uses must ignore upper bits of the product.  The number of used low
+  // bits must be not greater than the width of mul.with.overflow.
+  if (MulVal->hasNUsesOrMore(2))
+    for (User *U : MulVal->users()) {
+      if (U == &I)
+        continue;
+      if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
+        // Check if truncation ignores bits above MulWidth.
+        unsigned TruncWidth = TI->getType()->getPrimitiveSizeInBits();
+        if (TruncWidth > MulWidth)
+          return nullptr;
+      } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
+        // Check if AND ignores bits above MulWidth.
+        if (BO->getOpcode() != Instruction::And)
+          return nullptr;
+        if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1))) {
+          const APInt &CVal = CI->getValue();
+          if (CVal.getBitWidth() - CVal.countl_zero() > MulWidth)
+            return nullptr;
+        } else {
+          // In this case we could have the operand of the binary operation
+          // being defined in another block, and performing the replacement
+          // could break the dominance relation.
+          return nullptr;
+        }
+      } else {
+        // Other uses prohibit this transformation.
+        return nullptr;
+      }
+    }
+
+  // Recognize patterns
+  bool IsInverse = false;
+  switch (I.getPredicate()) {
+  case ICmpInst::ICMP_ULT: {
+    // Recognize pattern:
+    //   mulval = mul(sext A, sext B)
+    //   addval = add (mulval, min)
+    //   cmp ult addval, -min * 2 + 1
+    APInt MinVal = APInt::getSignedMinValue(MulWidth);
+    MinVal = MinVal.sext(OtherVal->getBitWidth());
+    APInt MinMinVal = APInt::getSignedMinValue(MulWidth + 1);
+    MinMinVal = MinMinVal.sext(OtherVal->getBitWidth());
+    if (MinVal.eq(*AddVal) && MinMinVal.eq(*OtherVal))
+      break; // Recognized
+
+    // Recognize pattern:
+    //   mulval = mul(sext A, sext B)
+    //   addval = add (mulval, signedMax)
+    //   cmp ult addval, unsignedMax
+    APInt MaxVal = APInt::getSignedMaxValue(MulWidth);
+    MaxVal = MaxVal.zext(OtherVal->getBitWidth()) + 1;
+    APInt MaxMaxVal = APInt::getMaxValue(MulWidth);
+    MaxMaxVal = MaxMaxVal.zext(OtherVal->getBitWidth()) + 1;
+    if (MaxVal.eq(*AddVal) && MaxMaxVal.eq(*OtherVal)) {
+      IsInverse = true;
+      break; // Recognized
+    }
+    return nullptr;
+  }
+
+  default:
+    return nullptr;
+  }
+
+  InstCombiner::BuilderTy &Builder = IC.Builder;
+  Builder.SetInsertPoint(MulInstr);
+
+  // Replace: mul(sext A, sext B) --> mul.with.overflow(A, B)
+  Value *MulA = A, *MulB = B;
+  if (WidthA < MulWidth)
+    MulA = Builder.CreateSExt(A, MulType);
+  if (WidthB < MulWidth)
+    MulB = Builder.CreateSExt(B, MulType);
+  Function *F = Intrinsic::getDeclaration(
+      I.getModule(), Intrinsic::smul_with_overflow, MulType);
+  CallInst *Call = Builder.CreateCall(F, {MulA, MulB}, "smul");
+  IC.addToWorklist(MulInstr);
+
+  // If there are uses of mul result other than the comparison, we know that
+  // they are truncation or binary AND. Change them to use result of
+  // mul.with.overflow and adjust properly mask/size.
+  if (MulVal->hasNUsesOrMore(2)) {
+    Value *Mul = Builder.CreateExtractValue(Call, 0, "smul.value");
+    for (User *U : make_early_inc_range(MulVal->users())) {
+      if (U == &I)
+        continue;
+      if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
+        if (TI->getType()->getPrimitiveSizeInBits() == MulWidth)
+          IC.replaceInstUsesWith(*TI, Mul);
+        else
+          TI->setOperand(0, Mul);
+      } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
+        assert(BO->getOpcode() == Instruction::And);
+        // Replace (mul & mask) --> zext (mul.with.overflow & short_mask)
+        ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1));
+        APInt ShortMask = CI->getValue().trunc(MulWidth);
+        Value *ShortAnd = Builder.CreateAnd(Mul, ShortMask);
+        Value *Zext = Builder.CreateZExt(ShortAnd, BO->getType());
+        IC.replaceInstUsesWith(*BO, Zext);
+      } else {
+        llvm_unreachable("Unexpected Binary operation");
+      }
+      IC.addToWorklist(cast<Instruction>(U));
+    }
+  }
+
+  // The original icmp gets replaced with the overflow value, maybe inverted
+  // depending on predicate.
+  if (IsInverse) {
+    Value *Res = Builder.CreateExtractValue(Call, 1);
+    return BinaryOperator::CreateNot(Res);
+  }
+
+  return ExtractValueInst::Create(Call, 1);
+}
+
 /// When performing a comparison against a constant, it is possible that not all
 /// the bits in the LHS are demanded. This helper method computes the mask that
 /// IS demanded.
@@ -7415,6 +7581,16 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
         return R;
     }
 
+    // (sext X) * (sext Y)  --> llvm.smul.with.overflow.
+    const APInt *C1;
+    if (match(Op0, m_Add(m_NSWMul(m_SExt(m_Value(X)), m_SExt(m_Value(Y))),
+                         m_APInt(C))) &&
+        match(Op1, m_APInt(C1))) {
+      if (Instruction *R = processSMulSExtIdiom(
+              I, cast<Instruction>(Op0)->getOperand(0), C, C1, *this))
+        return R;
+    }
+
     // Signbit test folds
     // Fold (X u>> BitWidth - 1 Pred ZExt(i1))  -->  X s< 0 Pred i1
     // Fold (X s>> BitWidth - 1 Pred SExt(i1))  -->  X s< 0 Pred i1
diff --git a/llvm/test/Transforms/InstCombine/overflow-mul.ll b/llvm/test/Transforms/InstCombine/overflow-mul.ll
index 3bc3c290ca8c9..582697b3d81e8 100644
--- a/llvm/test/Transforms/InstCombine/overflow-mul.ll
+++ b/llvm/test/Transforms/InstCombine/overflow-mul.ll
@@ -346,12 +346,9 @@ define i32 @extra_and_use_mask_too_large(i32 %x, i32 %y) {
 
 define i32 @smul(i32 %a, i32 %b) {
 ; CHECK-LABEL: @smul(
-; CHECK-NEXT:    [[CONV:%.*]] = sext i32 [[A:%.*]] to i64
-; CHECK-NEXT:    [[CONV1:%.*]] = sext i32 [[B:%.*]] to i64
-; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i64 [[CONV1]], [[CONV]]
-; CHECK-NEXT:    [[TMP1:%.*]] = add nsw i64 [[MUL]], -2147483648
-; CHECK-NEXT:    [[TMP2:%.*]] = icmp ult i64 [[TMP1]], -4294967296
-; CHECK-NEXT:    [[CONV3:%.*]] = zext i1 [[TMP2]] to i32
+; CHECK-NEXT:    [[SMUL:%.*]] = call { i32, i1 } @llvm.smul.with.overflow.i32(i32 [[B:%.*]], i32 [[A:%.*]])
+; CHECK-NEXT:    [[TMP1:%.*]] = extractvalue { i32, i1 } [[SMUL]], 1
+; CHECK-NEXT:    [[CONV3:%.*]] = zext i1 [[TMP1]] to i32
 ; CHECK-NEXT:    ret i32 [[CONV3]]
 ;
   %conv = sext i32 %a to i64
@@ -365,11 +362,9 @@ define i32 @smul(i32 %a, i32 %b) {
 
 define i32 @smul2(i32 %a, i32 %b) {
 ; CHECK-LABEL: @smul2(
-; CHECK-NEXT:    [[CONV:%.*]] = sext i32 [[A:%.*]] to i64
-; CHECK-NEXT:    [[CONV1:%.*]] = sext i32 [[B:%.*]] to i64
-; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i64 [[CONV1]], [[CONV]]
-; CHECK-NEXT:    [[TMP1:%.*]] = add i64 [[MUL]], 2147483647
-; CHECK-NEXT:    [[TMP2:%.*]] = icmp ult i64 [[TMP1]], 4294967295
+; CHECK-NEXT:    [[SMUL:%.*]] = call { i32, i1 } @llvm.smul.with.overflow.i32(i32 [[B:%.*]], i32 [[A:%.*]])
+; CHECK-NEXT:    [[TMP1:%.*]] = extractvalue { i32, i1 } [[SMUL]], 1
+; CHECK-NEXT:    [[TMP2:%.*]] = xor i1 [[TMP1]], true
 ; CHECK-NEXT:    [[CONV3:%.*]] = zext i1 [[TMP2]] to i32
 ; CHECK-NEXT:    ret i32 [[CONV3]]
 ;
@@ -377,7 +372,7 @@ define i32 @smul2(i32 %a, i32 %b) {
   %conv1 = sext i32 %b to i64
   %mul = mul nsw i64 %conv1, %conv
   %cmp = icmp sle i64 %mul, 2147483647
-  %cmp2 = icmp sgt i64 %mul, -2147483648
+  %cmp2 = icmp sge i64 %mul, -2147483648
   %1 = select i1 %cmp, i1 %cmp2, i1 false
   %conv3 = zext i1 %1 to i32
   ret i32 %conv3



More information about the llvm-commits mailing list