[llvm] r206137 - Recognize test for overflow in integer multiplication.

Benjamin Kramer benny.kra at gmail.com
Sun Apr 13 12:06:20 PDT 2014


On 13.04.2014, at 20:23, Serge Pavlov <sepavloff at gmail.com> wrote:

> Author: sepavloff
> Date: Sun Apr 13 13:23:41 2014
> New Revision: 206137
> 
> URL: http://llvm.org/viewvc/llvm-project?rev=206137&view=rev
> Log:
> Recognize test for overflow in integer multiplication.
> 
> If multiplication involves zero-extended arguments and the result is
> compared as in the patterns:
> 
>    %mul32 = trunc i64 %mul64 to i32
>    %zext = zext i32 %mul32 to i64
>    %overflow = icmp ne i64 %mul64, %zext
> or
>    %overflow = icmp ugt i64 %mul64 , 0xffffffff
> 
> then the multiplication may be replaced by call to umul.with.overflow.
> This change fixes PR4917 and PR4918.
> 
> Differential Revision: http://llvm-reviews.chandlerc.com/D2814
> 
> Added:
>    llvm/trunk/test/Transforms/InstCombine/overflow-mul.ll
> Modified:
>    llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp
> 
> Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp
> URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp?rev=206137&r1=206136&r2=206137&view=diff
> ==============================================================================
> --- llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp (original)
> +++ llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp Sun Apr 13 13:23:41 2014
> @@ -2008,6 +2008,236 @@ static Instruction *ProcessUAddIdiom(Ins
>   return ExtractValueInst::Create(Call, 1, "uadd.overflow");
> }
> 
> +/// \brief Recognize and process idiom involving test for multiplication
> +/// overflow.
> +///
> +/// The caller has matched a pattern of the form:
> +///   I = cmp u (mul(zext A, zext B), V
> +/// The function checks if this is a test for overflow and if so replaces
> +/// multiplication with call to 'mul.with.overflow' intrinsic.
> +///
> +/// \param I Compare instruction.
> +/// \param MulVal Result of 'mult' instruction.  It is one of the arguments of
> +///               the compare instruction.  Must be of integer type.
> +/// \param OtherVal The other argument of compare instruction.
> +/// \returns Instruction which must replace the compare instruction, NULL if no
> +///          replacement required.
> +static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal,
> +                                         Value *OtherVal, InstCombiner &IC) {
> +  assert(I.getOperand(0) == MulVal || I.getOperand(1) == MulVal);
> +  assert(I.getOperand(0) == OtherVal || I.getOperand(1) == OtherVal);
> +  assert(isa<IntegerType>(MulVal->getType()));
> +  Instruction *MulInstr = cast<Instruction>(MulVal);
> +  assert(MulInstr->getOpcode() == Instruction::Mul);
> +
> +  Instruction *LHS = cast<Instruction>(MulInstr->getOperand(0)),
> +              *RHS = cast<Instruction>(MulInstr->getOperand(1));
> +  assert(LHS->getOpcode() == Instruction::ZExt);
> +  assert(RHS->getOpcode() == Instruction::ZExt);
> +  Value *A = LHS->getOperand(0), *B = RHS->getOperand(0);
> +
> +  // Calculate type and width of the result produced by mul.with.overflow.
> +  Type *TyA = A->getType(), *TyB = B->getType();
> +  unsigned WidthA = TyA->getPrimitiveSizeInBits(),
> +           WidthB = TyB->getPrimitiveSizeInBits();
> +  unsigned MulWidth;
> +  Type *MulType;
> +  if (WidthB > WidthA) {
> +    MulWidth = WidthB;
> +    MulType = TyB;
> +  } else {
> +    MulWidth = WidthA;
> +    MulType = TyA;
> +  }
> +
> +  // In order to replace the original mul with a narrower mul.with.overflow,
> +  // all uses must ignore upper bits of the product.  The number of used low
> +  // bits must be not greater than the width of mul.with.overflow.
> +  if (MulVal->hasNUsesOrMore(2))
> +    for (User *U : MulVal->users()) {
> +      if (U == &I)
> +        continue;
> +      if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
> +        // Check if truncation ignores bits above MulWidth.
> +        unsigned TruncWidth = TI->getType()->getPrimitiveSizeInBits();
> +        if (TruncWidth > MulWidth)
> +          return 0;
> +      } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
> +        // Check if AND ignores bits above MulWidth.
> +        if (BO->getOpcode() != Instruction::And)
> +          return 0;
> +        if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1))) {
> +          const APInt &CVal = CI->getValue();
> +          if (CVal.getBitWidth() - CVal.countLeadingZeros() > MulWidth)
> +            return 0;
> +        }
> +      } else {
> +        // Other uses prohibit this transformation.
> +        return 0;
> +      }
> +    }
> +
> +  // Recognize patterns
> +  switch (I.getPredicate()) {
> +  case ICmpInst::ICMP_EQ:
> +  case ICmpInst::ICMP_NE:
> +    // Recognize pattern:
> +    //   mulval = mul(zext A, zext B)
> +    //   cmp eq/neq mulval, zext trunc mulval
> +    if (ZExtInst *Zext = dyn_cast<ZExtInst>(OtherVal))
> +      if (Zext->hasOneUse()) {
> +        Value *ZextArg = Zext->getOperand(0);
> +        if (TruncInst *Trunc = dyn_cast<TruncInst>(ZextArg))
> +          if (Trunc->getType()->getPrimitiveSizeInBits() == MulWidth)
> +            break; //Recognized
> +      }
> +
> +    // Recognize pattern:
> +    //   mulval = mul(zext A, zext B)
> +    //   cmp eq/neq mulval, and(mulval, mask), mask selects low MulWidth bits.
> +    ConstantInt *CI;
> +    Value *ValToMask;
> +    if (match(OtherVal, m_And(m_Value(ValToMask), m_ConstantInt(CI)))) {
> +      if (ValToMask != MulVal)
> +        return 0;
> +      const APInt &CVal = CI->getValue() + 1;
> +      if (CVal.isPowerOf2()) {
> +        unsigned MaskWidth = CVal.logBase2();
> +        if (MaskWidth == MulWidth)
> +          break; // Recognized
> +      }
> +    }
> +    return 0;
> +
> +  case ICmpInst::ICMP_UGT:
> +    // Recognize pattern:
> +    //   mulval = mul(zext A, zext B)
> +    //   cmp ugt mulval, max
> +    if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) {
> +      APInt MaxVal = APInt::getMaxValue(MulWidth);
> +      MaxVal = MaxVal.zext(CI->getBitWidth());
> +      if (MaxVal.eq(CI->getValue()))
> +        break; // Recognized
> +    }
> +    return 0;
> +
> +  case ICmpInst::ICMP_UGE:
> +    // Recognize pattern:
> +    //   mulval = mul(zext A, zext B)
> +    //   cmp uge mulval, max+1
> +    if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) {
> +      APInt MaxVal = APInt::getOneBitSet(CI->getBitWidth(), MulWidth);
> +      if (MaxVal.eq(CI->getValue()))
> +        break; // Recognized
> +    }
> +    return 0;
> +
> +  case ICmpInst::ICMP_ULE:
> +    // Recognize pattern:
> +    //   mulval = mul(zext A, zext B)
> +    //   cmp ule mulval, max
> +    if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) {
> +      APInt MaxVal = APInt::getMaxValue(MulWidth);
> +      MaxVal = MaxVal.zext(CI->getBitWidth());
> +      if (MaxVal.eq(CI->getValue()))
> +        break; // Recognized
> +    }
> +    return 0;
> +
> +  case ICmpInst::ICMP_ULT:
> +    // Recognize pattern:
> +    //   mulval = mul(zext A, zext B)
> +    //   cmp ule mulval, max + 1
> +    if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) {
> +      APInt MaxVal(CI->getBitWidth(), 1ULL << MulWidth);

There's still a case using 64 bit arithmetic here.

> +      if (MaxVal.eq(CI->getValue()))
> +        break; // Recognized
> +    }
> +    return 0;
> +
> +  default:
> +    return 0;
> +  }
> +
> +  InstCombiner::BuilderTy *Builder = IC.Builder;
> +  Builder->SetInsertPoint(MulInstr);
> +  Module *M = I.getParent()->getParent()->getParent();
> +
> +  // Replace: mul(zext A, zext B) --> mul.with.overflow(A, B)
> +  Value *MulA = A, *MulB = B;
> +  if (WidthA < MulWidth)
> +    MulA = Builder->CreateZExt(A, MulType);
> +  if (WidthB < MulWidth)
> +    MulB = Builder->CreateZExt(B, MulType);
> +  Value *F =
> +      Intrinsic::getDeclaration(M, Intrinsic::umul_with_overflow, MulType);
> +  CallInst *Call = Builder->CreateCall2(F, MulA, MulB, "umul");
> +  IC.Worklist.Add(MulInstr);
> +
> +  // If there are uses of mul result other than the comparison, we know that
> +  // they are truncation or binary AND. Change them to use result of
> +  // mul.with.overflow and ajust properly mask/size.

Typo: ajust.

- Ben

> +  if (MulVal->hasNUsesOrMore(2)) {
> +    Value *Mul = Builder->CreateExtractValue(Call, 0, "umul.value");
> +    for (User *U : MulVal->users()) {
> +      if (U == &I || U == OtherVal)
> +        continue;
> +      if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
> +        if (TI->getType()->getPrimitiveSizeInBits() == MulWidth)
> +          IC.ReplaceInstUsesWith(*TI, Mul);
> +        else
> +          TI->setOperand(0, Mul);
> +      } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
> +        assert(BO->getOpcode() == Instruction::And);
> +        // Replace (mul & mask) --> zext (mul.with.overflow & short_mask)
> +        ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1));
> +        APInt ShortMask = CI->getValue().trunc(MulWidth);
> +        Value *ShortAnd = Builder->CreateAnd(Mul, ShortMask);
> +        Instruction *Zext =
> +            cast<Instruction>(Builder->CreateZExt(ShortAnd, BO->getType()));
> +        IC.Worklist.Add(Zext);
> +        IC.ReplaceInstUsesWith(*BO, Zext);
> +      } else {
> +        llvm_unreachable("Unexpected Binary operation");
> +      }
> +      IC.Worklist.Add(cast<Instruction>(U));
> +    }
> +  }
> +  if (isa<Instruction>(OtherVal))
> +    IC.Worklist.Add(cast<Instruction>(OtherVal));
> +
> +  // The original icmp gets replaced with the overflow value, maybe inverted
> +  // depending on predicate.
> +  bool Inverse = false;
> +  switch (I.getPredicate()) {
> +  case ICmpInst::ICMP_NE:
> +    break;
> +  case ICmpInst::ICMP_EQ:
> +    Inverse = true;
> +    break;
> +  case ICmpInst::ICMP_UGT:
> +  case ICmpInst::ICMP_UGE:
> +    if (I.getOperand(0) == MulVal)
> +      break;
> +    Inverse = true;
> +    break;
> +  case ICmpInst::ICMP_ULT:
> +  case ICmpInst::ICMP_ULE:
> +    if (I.getOperand(1) == MulVal)
> +      break;
> +    Inverse = true;
> +    break;
> +  default:
> +    llvm_unreachable("Unexpected predicate");
> +  }
> +  if (Inverse) {
> +    Value *Res = Builder->CreateExtractValue(Call, 1);
> +    return BinaryOperator::CreateNot(Res);
> +  }
> +
> +  return ExtractValueInst::Create(Call, 1);
> +}
> +
> // DemandedBitsLHSMask - When performing a comparison against a constant,
> // it is possible that not all the bits in the LHS are demanded.  This helper
> // method computes the mask that IS demanded.
> @@ -2877,6 +3107,16 @@ Instruction *InstCombiner::visitICmpInst
>         (Op0 == A || Op0 == B))
>       if (Instruction *R = ProcessUAddIdiom(I, Op1, *this))
>         return R;
> +
> +    // (zext a) * (zext b)  --> llvm.umul.with.overflow.
> +    if (match(Op0, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) {
> +      if (Instruction *R = ProcessUMulZExtIdiom(I, Op0, Op1, *this))
> +        return R;
> +    }
> +    if (match(Op1, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) {
> +      if (Instruction *R = ProcessUMulZExtIdiom(I, Op1, Op0, *this))
> +        return R;
> +    }
>   }
> 
>   if (I.isEquality()) {
> 
> Added: llvm/trunk/test/Transforms/InstCombine/overflow-mul.ll
> URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/overflow-mul.ll?rev=206137&view=auto
> ==============================================================================
> --- llvm/trunk/test/Transforms/InstCombine/overflow-mul.ll (added)
> +++ llvm/trunk/test/Transforms/InstCombine/overflow-mul.ll Sun Apr 13 13:23:41 2014
> @@ -0,0 +1,164 @@
> +; RUN: opt -S -instcombine < %s | FileCheck %s
> +
> +; return mul(zext x, zext y) > MAX
> +define i32 @pr4917_1(i32 %x, i32 %y) nounwind {
> +; CHECK-LABEL: @pr4917_1(
> +entry:
> +  %l = zext i32 %x to i64
> +  %r = zext i32 %y to i64
> +; CHECK-NOT: zext i32
> +  %mul64 = mul i64 %l, %r
> +; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y)
> +  %overflow = icmp ugt i64 %mul64, 4294967295
> +; CHECK: extractvalue { i32, i1 } [[MUL]], 1
> +  %retval = zext i1 %overflow to i32
> +  ret i32 %retval
> +}
> +
> +; return mul(zext x, zext y) >= MAX+1
> +define i32 @pr4917_1a(i32 %x, i32 %y) nounwind {
> +; CHECK-LABEL: @pr4917_1a(
> +entry:
> +  %l = zext i32 %x to i64
> +  %r = zext i32 %y to i64
> +; CHECK-NOT: zext i32
> +  %mul64 = mul i64 %l, %r
> +; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y)
> +  %overflow = icmp uge i64 %mul64, 4294967296
> +; CHECK: extractvalue { i32, i1 } [[MUL]], 1
> +  %retval = zext i1 %overflow to i32
> +  ret i32 %retval
> +}
> +
> +; mul(zext x, zext y) > MAX
> +; mul(x, y) is used
> +define i32 @pr4917_2(i32 %x, i32 %y) nounwind {
> +; CHECK-LABEL: @pr4917_2(
> +entry:
> +  %l = zext i32 %x to i64
> +  %r = zext i32 %y to i64
> +; CHECK-NOT: zext i32
> +  %mul64 = mul i64 %l, %r
> +; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y)
> +  %overflow = icmp ugt i64 %mul64, 4294967295
> +; CHECK-DAG: [[VAL:%.*]] = extractvalue { i32, i1 } [[MUL]], 0
> +  %mul32 = trunc i64 %mul64 to i32
> +; CHECK-DAG: [[OVFL:%.*]] = extractvalue { i32, i1 } [[MUL]], 1
> +  %retval = select i1 %overflow, i32 %mul32, i32 111
> +; CHECK: select i1 [[OVFL]], i32 [[VAL]]
> +  ret i32 %retval
> +}
> +
> +; return mul(zext x, zext y) > MAX
> +; mul is used in non-truncate
> +define i64 @pr4917_3(i32 %x, i32 %y) nounwind {
> +; CHECK-LABEL: @pr4917_3(
> +entry:
> +  %l = zext i32 %x to i64
> +  %r = zext i32 %y to i64
> +  %mul64 = mul i64 %l, %r
> +; CHECK-NOT: umul.with.overflow.i32
> +  %overflow = icmp ugt i64 %mul64, 4294967295
> +  %retval = select i1 %overflow, i64 %mul64, i64 111
> +  ret i64 %retval
> +}
> +
> +; return mul(zext x, zext y) <= MAX
> +define i32 @pr4917_4(i32 %x, i32 %y) nounwind {
> +; CHECK-LABEL: @pr4917_4(
> +entry:
> +  %l = zext i32 %x to i64
> +  %r = zext i32 %y to i64
> +; CHECK-NOT: zext i32
> +  %mul64 = mul i64 %l, %r
> +; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y)
> +  %overflow = icmp ule i64 %mul64, 4294967295
> +; CHECK: extractvalue { i32, i1 } [[MUL]], 1
> +; CHECK: xor
> +  %retval = zext i1 %overflow to i32
> +  ret i32 %retval
> +}
> +
> +; return mul(zext x, zext y) < MAX+1
> +define i32 @pr4917_4a(i32 %x, i32 %y) nounwind {
> +; CHECK-LABEL: @pr4917_4a(
> +entry:
> +  %l = zext i32 %x to i64
> +  %r = zext i32 %y to i64
> +; CHECK-NOT: zext i32
> +  %mul64 = mul i64 %l, %r
> +; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y)
> +  %overflow = icmp ult i64 %mul64, 4294967296
> +; CHECK: extractvalue { i32, i1 } [[MUL]], 1
> +; CHECK: xor
> +  %retval = zext i1 %overflow to i32
> +  ret i32 %retval
> +}
> +
> +; operands of mul are of different size
> +define i32 @pr4917_5(i32 %x, i8 %y) nounwind {
> +; CHECK-LABEL: @pr4917_5(
> +entry:
> +  %l = zext i32 %x to i64
> +  %r = zext i8 %y to i64
> +; CHECK: [[Y:%.*]] = zext i8 %y to i32
> +  %mul64 = mul i64 %l, %r
> +  %overflow = icmp ugt i64 %mul64, 4294967295
> +  %mul32 = trunc i64 %mul64 to i32
> +; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 [[Y]])
> +; CHECK-DAG: [[VAL:%.*]] = extractvalue { i32, i1 } [[MUL]], 0
> +; CHECK-DAG: [[OVFL:%.*]] = extractvalue { i32, i1 } [[MUL]], 1
> +  %retval = select i1 %overflow, i32 %mul32, i32 111
> +; CHECK: select i1 [[OVFL]], i32 [[VAL]]
> +  ret i32 %retval
> +}
> +
> +; mul(zext x, zext y) != zext trunc mul
> +define i32 @pr4918_1(i32 %x, i32 %y) nounwind {
> +; CHECK-LABEL: @pr4918_1(
> +entry:
> +  %l = zext i32 %x to i64
> +  %r = zext i32 %y to i64
> +  %mul64 = mul i64 %l, %r
> +; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y)
> +  %part32 = trunc i64 %mul64 to i32
> +  %part64 = zext i32 %part32 to i64
> +  %overflow = icmp ne i64 %mul64, %part64
> +; CHECK: [[OVFL:%.*]] = extractvalue { i32, i1 } [[MUL:%.*]], 1
> +  %retval = zext i1 %overflow to i32
> +  ret i32 %retval
> +}
> +
> +; mul(zext x, zext y) == zext trunc mul
> +define i32 @pr4918_2(i32 %x, i32 %y) nounwind {
> +; CHECK-LABEL: @pr4918_2(
> +entry:
> +  %l = zext i32 %x to i64
> +  %r = zext i32 %y to i64
> +  %mul64 = mul i64 %l, %r
> +; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y)
> +  %part32 = trunc i64 %mul64 to i32
> +  %part64 = zext i32 %part32 to i64
> +  %overflow = icmp eq i64 %mul64, %part64
> +; CHECK: extractvalue { i32, i1 } [[MUL]]
> +  %retval = zext i1 %overflow to i32
> +; CHECK: xor
> +  ret i32 %retval
> +}
> +
> +; zext trunc mul != mul(zext x, zext y)
> +define i32 @pr4918_3(i32 %x, i32 %y) nounwind {
> +; CHECK-LABEL: @pr4918_3(
> +entry:
> +  %l = zext i32 %x to i64
> +  %r = zext i32 %y to i64
> +  %mul64 = mul i64 %l, %r
> +; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y)
> +  %part32 = trunc i64 %mul64 to i32
> +  %part64 = zext i32 %part32 to i64
> +  %overflow = icmp ne i64 %part64, %mul64
> +; CHECK: extractvalue { i32, i1 } [[MUL]], 1
> +  %retval = zext i1 %overflow to i32
> +  ret i32 %retval
> +}
> +
> 
> 
> _______________________________________________
> llvm-commits mailing list
> llvm-commits at cs.uiuc.edu
> http://lists.cs.uiuc.edu/mailman/listinfo/llvm-commits





More information about the llvm-commits mailing list