[llvm] r206137 - Recognize test for overflow in integer multiplication.
Serge Pavlov
sepavloff at gmail.com
Sun Apr 13 11:23:43 PDT 2014
Author: sepavloff
Date: Sun Apr 13 13:23:41 2014
New Revision: 206137
URL: http://llvm.org/viewvc/llvm-project?rev=206137&view=rev
Log:
Recognize test for overflow in integer multiplication.
If multiplication involves zero-extended arguments and the result is
compared as in the patterns:
%mul32 = trunc i64 %mul64 to i32
%zext = zext i32 %mul32 to i64
%overflow = icmp ne i64 %mul64, %zext
or
%overflow = icmp ugt i64 %mul64 , 0xffffffff
then the multiplication may be replaced by call to umul.with.overflow.
This change fixes PR4917 and PR4918.
Differential Revision: http://llvm-reviews.chandlerc.com/D2814
Added:
llvm/trunk/test/Transforms/InstCombine/overflow-mul.ll
Modified:
llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp
Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp?rev=206137&r1=206136&r2=206137&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp Sun Apr 13 13:23:41 2014
@@ -2008,6 +2008,236 @@ static Instruction *ProcessUAddIdiom(Ins
return ExtractValueInst::Create(Call, 1, "uadd.overflow");
}
+/// \brief Recognize and process idiom involving test for multiplication
+/// overflow.
+///
+/// The caller has matched a pattern of the form:
+/// I = cmp u (mul(zext A, zext B), V
+/// The function checks if this is a test for overflow and if so replaces
+/// multiplication with call to 'mul.with.overflow' intrinsic.
+///
+/// \param I Compare instruction.
+/// \param MulVal Result of 'mult' instruction. It is one of the arguments of
+/// the compare instruction. Must be of integer type.
+/// \param OtherVal The other argument of compare instruction.
+/// \returns Instruction which must replace the compare instruction, NULL if no
+/// replacement required.
+static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal,
+ Value *OtherVal, InstCombiner &IC) {
+ assert(I.getOperand(0) == MulVal || I.getOperand(1) == MulVal);
+ assert(I.getOperand(0) == OtherVal || I.getOperand(1) == OtherVal);
+ assert(isa<IntegerType>(MulVal->getType()));
+ Instruction *MulInstr = cast<Instruction>(MulVal);
+ assert(MulInstr->getOpcode() == Instruction::Mul);
+
+ Instruction *LHS = cast<Instruction>(MulInstr->getOperand(0)),
+ *RHS = cast<Instruction>(MulInstr->getOperand(1));
+ assert(LHS->getOpcode() == Instruction::ZExt);
+ assert(RHS->getOpcode() == Instruction::ZExt);
+ Value *A = LHS->getOperand(0), *B = RHS->getOperand(0);
+
+ // Calculate type and width of the result produced by mul.with.overflow.
+ Type *TyA = A->getType(), *TyB = B->getType();
+ unsigned WidthA = TyA->getPrimitiveSizeInBits(),
+ WidthB = TyB->getPrimitiveSizeInBits();
+ unsigned MulWidth;
+ Type *MulType;
+ if (WidthB > WidthA) {
+ MulWidth = WidthB;
+ MulType = TyB;
+ } else {
+ MulWidth = WidthA;
+ MulType = TyA;
+ }
+
+ // In order to replace the original mul with a narrower mul.with.overflow,
+ // all uses must ignore upper bits of the product. The number of used low
+ // bits must be not greater than the width of mul.with.overflow.
+ if (MulVal->hasNUsesOrMore(2))
+ for (User *U : MulVal->users()) {
+ if (U == &I)
+ continue;
+ if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
+ // Check if truncation ignores bits above MulWidth.
+ unsigned TruncWidth = TI->getType()->getPrimitiveSizeInBits();
+ if (TruncWidth > MulWidth)
+ return 0;
+ } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
+ // Check if AND ignores bits above MulWidth.
+ if (BO->getOpcode() != Instruction::And)
+ return 0;
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1))) {
+ const APInt &CVal = CI->getValue();
+ if (CVal.getBitWidth() - CVal.countLeadingZeros() > MulWidth)
+ return 0;
+ }
+ } else {
+ // Other uses prohibit this transformation.
+ return 0;
+ }
+ }
+
+ // Recognize patterns
+ switch (I.getPredicate()) {
+ case ICmpInst::ICMP_EQ:
+ case ICmpInst::ICMP_NE:
+ // Recognize pattern:
+ // mulval = mul(zext A, zext B)
+ // cmp eq/neq mulval, zext trunc mulval
+ if (ZExtInst *Zext = dyn_cast<ZExtInst>(OtherVal))
+ if (Zext->hasOneUse()) {
+ Value *ZextArg = Zext->getOperand(0);
+ if (TruncInst *Trunc = dyn_cast<TruncInst>(ZextArg))
+ if (Trunc->getType()->getPrimitiveSizeInBits() == MulWidth)
+ break; //Recognized
+ }
+
+ // Recognize pattern:
+ // mulval = mul(zext A, zext B)
+ // cmp eq/neq mulval, and(mulval, mask), mask selects low MulWidth bits.
+ ConstantInt *CI;
+ Value *ValToMask;
+ if (match(OtherVal, m_And(m_Value(ValToMask), m_ConstantInt(CI)))) {
+ if (ValToMask != MulVal)
+ return 0;
+ const APInt &CVal = CI->getValue() + 1;
+ if (CVal.isPowerOf2()) {
+ unsigned MaskWidth = CVal.logBase2();
+ if (MaskWidth == MulWidth)
+ break; // Recognized
+ }
+ }
+ return 0;
+
+ case ICmpInst::ICMP_UGT:
+ // Recognize pattern:
+ // mulval = mul(zext A, zext B)
+ // cmp ugt mulval, max
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) {
+ APInt MaxVal = APInt::getMaxValue(MulWidth);
+ MaxVal = MaxVal.zext(CI->getBitWidth());
+ if (MaxVal.eq(CI->getValue()))
+ break; // Recognized
+ }
+ return 0;
+
+ case ICmpInst::ICMP_UGE:
+ // Recognize pattern:
+ // mulval = mul(zext A, zext B)
+ // cmp uge mulval, max+1
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) {
+ APInt MaxVal = APInt::getOneBitSet(CI->getBitWidth(), MulWidth);
+ if (MaxVal.eq(CI->getValue()))
+ break; // Recognized
+ }
+ return 0;
+
+ case ICmpInst::ICMP_ULE:
+ // Recognize pattern:
+ // mulval = mul(zext A, zext B)
+ // cmp ule mulval, max
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) {
+ APInt MaxVal = APInt::getMaxValue(MulWidth);
+ MaxVal = MaxVal.zext(CI->getBitWidth());
+ if (MaxVal.eq(CI->getValue()))
+ break; // Recognized
+ }
+ return 0;
+
+ case ICmpInst::ICMP_ULT:
+ // Recognize pattern:
+ // mulval = mul(zext A, zext B)
+ // cmp ule mulval, max + 1
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) {
+ APInt MaxVal(CI->getBitWidth(), 1ULL << MulWidth);
+ if (MaxVal.eq(CI->getValue()))
+ break; // Recognized
+ }
+ return 0;
+
+ default:
+ return 0;
+ }
+
+ InstCombiner::BuilderTy *Builder = IC.Builder;
+ Builder->SetInsertPoint(MulInstr);
+ Module *M = I.getParent()->getParent()->getParent();
+
+ // Replace: mul(zext A, zext B) --> mul.with.overflow(A, B)
+ Value *MulA = A, *MulB = B;
+ if (WidthA < MulWidth)
+ MulA = Builder->CreateZExt(A, MulType);
+ if (WidthB < MulWidth)
+ MulB = Builder->CreateZExt(B, MulType);
+ Value *F =
+ Intrinsic::getDeclaration(M, Intrinsic::umul_with_overflow, MulType);
+ CallInst *Call = Builder->CreateCall2(F, MulA, MulB, "umul");
+ IC.Worklist.Add(MulInstr);
+
+ // If there are uses of mul result other than the comparison, we know that
+ // they are truncation or binary AND. Change them to use result of
+ // mul.with.overflow and ajust properly mask/size.
+ if (MulVal->hasNUsesOrMore(2)) {
+ Value *Mul = Builder->CreateExtractValue(Call, 0, "umul.value");
+ for (User *U : MulVal->users()) {
+ if (U == &I || U == OtherVal)
+ continue;
+ if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
+ if (TI->getType()->getPrimitiveSizeInBits() == MulWidth)
+ IC.ReplaceInstUsesWith(*TI, Mul);
+ else
+ TI->setOperand(0, Mul);
+ } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
+ assert(BO->getOpcode() == Instruction::And);
+ // Replace (mul & mask) --> zext (mul.with.overflow & short_mask)
+ ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1));
+ APInt ShortMask = CI->getValue().trunc(MulWidth);
+ Value *ShortAnd = Builder->CreateAnd(Mul, ShortMask);
+ Instruction *Zext =
+ cast<Instruction>(Builder->CreateZExt(ShortAnd, BO->getType()));
+ IC.Worklist.Add(Zext);
+ IC.ReplaceInstUsesWith(*BO, Zext);
+ } else {
+ llvm_unreachable("Unexpected Binary operation");
+ }
+ IC.Worklist.Add(cast<Instruction>(U));
+ }
+ }
+ if (isa<Instruction>(OtherVal))
+ IC.Worklist.Add(cast<Instruction>(OtherVal));
+
+ // The original icmp gets replaced with the overflow value, maybe inverted
+ // depending on predicate.
+ bool Inverse = false;
+ switch (I.getPredicate()) {
+ case ICmpInst::ICMP_NE:
+ break;
+ case ICmpInst::ICMP_EQ:
+ Inverse = true;
+ break;
+ case ICmpInst::ICMP_UGT:
+ case ICmpInst::ICMP_UGE:
+ if (I.getOperand(0) == MulVal)
+ break;
+ Inverse = true;
+ break;
+ case ICmpInst::ICMP_ULT:
+ case ICmpInst::ICMP_ULE:
+ if (I.getOperand(1) == MulVal)
+ break;
+ Inverse = true;
+ break;
+ default:
+ llvm_unreachable("Unexpected predicate");
+ }
+ if (Inverse) {
+ Value *Res = Builder->CreateExtractValue(Call, 1);
+ return BinaryOperator::CreateNot(Res);
+ }
+
+ return ExtractValueInst::Create(Call, 1);
+}
+
// DemandedBitsLHSMask - When performing a comparison against a constant,
// it is possible that not all the bits in the LHS are demanded. This helper
// method computes the mask that IS demanded.
@@ -2877,6 +3107,16 @@ Instruction *InstCombiner::visitICmpInst
(Op0 == A || Op0 == B))
if (Instruction *R = ProcessUAddIdiom(I, Op1, *this))
return R;
+
+ // (zext a) * (zext b) --> llvm.umul.with.overflow.
+ if (match(Op0, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) {
+ if (Instruction *R = ProcessUMulZExtIdiom(I, Op0, Op1, *this))
+ return R;
+ }
+ if (match(Op1, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) {
+ if (Instruction *R = ProcessUMulZExtIdiom(I, Op1, Op0, *this))
+ return R;
+ }
}
if (I.isEquality()) {
Added: llvm/trunk/test/Transforms/InstCombine/overflow-mul.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/overflow-mul.ll?rev=206137&view=auto
==============================================================================
--- llvm/trunk/test/Transforms/InstCombine/overflow-mul.ll (added)
+++ llvm/trunk/test/Transforms/InstCombine/overflow-mul.ll Sun Apr 13 13:23:41 2014
@@ -0,0 +1,164 @@
+; RUN: opt -S -instcombine < %s | FileCheck %s
+
+; return mul(zext x, zext y) > MAX
+define i32 @pr4917_1(i32 %x, i32 %y) nounwind {
+; CHECK-LABEL: @pr4917_1(
+entry:
+ %l = zext i32 %x to i64
+ %r = zext i32 %y to i64
+; CHECK-NOT: zext i32
+ %mul64 = mul i64 %l, %r
+; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y)
+ %overflow = icmp ugt i64 %mul64, 4294967295
+; CHECK: extractvalue { i32, i1 } [[MUL]], 1
+ %retval = zext i1 %overflow to i32
+ ret i32 %retval
+}
+
+; return mul(zext x, zext y) >= MAX+1
+define i32 @pr4917_1a(i32 %x, i32 %y) nounwind {
+; CHECK-LABEL: @pr4917_1a(
+entry:
+ %l = zext i32 %x to i64
+ %r = zext i32 %y to i64
+; CHECK-NOT: zext i32
+ %mul64 = mul i64 %l, %r
+; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y)
+ %overflow = icmp uge i64 %mul64, 4294967296
+; CHECK: extractvalue { i32, i1 } [[MUL]], 1
+ %retval = zext i1 %overflow to i32
+ ret i32 %retval
+}
+
+; mul(zext x, zext y) > MAX
+; mul(x, y) is used
+define i32 @pr4917_2(i32 %x, i32 %y) nounwind {
+; CHECK-LABEL: @pr4917_2(
+entry:
+ %l = zext i32 %x to i64
+ %r = zext i32 %y to i64
+; CHECK-NOT: zext i32
+ %mul64 = mul i64 %l, %r
+; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y)
+ %overflow = icmp ugt i64 %mul64, 4294967295
+; CHECK-DAG: [[VAL:%.*]] = extractvalue { i32, i1 } [[MUL]], 0
+ %mul32 = trunc i64 %mul64 to i32
+; CHECK-DAG: [[OVFL:%.*]] = extractvalue { i32, i1 } [[MUL]], 1
+ %retval = select i1 %overflow, i32 %mul32, i32 111
+; CHECK: select i1 [[OVFL]], i32 [[VAL]]
+ ret i32 %retval
+}
+
+; return mul(zext x, zext y) > MAX
+; mul is used in non-truncate
+define i64 @pr4917_3(i32 %x, i32 %y) nounwind {
+; CHECK-LABEL: @pr4917_3(
+entry:
+ %l = zext i32 %x to i64
+ %r = zext i32 %y to i64
+ %mul64 = mul i64 %l, %r
+; CHECK-NOT: umul.with.overflow.i32
+ %overflow = icmp ugt i64 %mul64, 4294967295
+ %retval = select i1 %overflow, i64 %mul64, i64 111
+ ret i64 %retval
+}
+
+; return mul(zext x, zext y) <= MAX
+define i32 @pr4917_4(i32 %x, i32 %y) nounwind {
+; CHECK-LABEL: @pr4917_4(
+entry:
+ %l = zext i32 %x to i64
+ %r = zext i32 %y to i64
+; CHECK-NOT: zext i32
+ %mul64 = mul i64 %l, %r
+; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y)
+ %overflow = icmp ule i64 %mul64, 4294967295
+; CHECK: extractvalue { i32, i1 } [[MUL]], 1
+; CHECK: xor
+ %retval = zext i1 %overflow to i32
+ ret i32 %retval
+}
+
+; return mul(zext x, zext y) < MAX+1
+define i32 @pr4917_4a(i32 %x, i32 %y) nounwind {
+; CHECK-LABEL: @pr4917_4a(
+entry:
+ %l = zext i32 %x to i64
+ %r = zext i32 %y to i64
+; CHECK-NOT: zext i32
+ %mul64 = mul i64 %l, %r
+; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y)
+ %overflow = icmp ult i64 %mul64, 4294967296
+; CHECK: extractvalue { i32, i1 } [[MUL]], 1
+; CHECK: xor
+ %retval = zext i1 %overflow to i32
+ ret i32 %retval
+}
+
+; operands of mul are of different size
+define i32 @pr4917_5(i32 %x, i8 %y) nounwind {
+; CHECK-LABEL: @pr4917_5(
+entry:
+ %l = zext i32 %x to i64
+ %r = zext i8 %y to i64
+; CHECK: [[Y:%.*]] = zext i8 %y to i32
+ %mul64 = mul i64 %l, %r
+ %overflow = icmp ugt i64 %mul64, 4294967295
+ %mul32 = trunc i64 %mul64 to i32
+; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 [[Y]])
+; CHECK-DAG: [[VAL:%.*]] = extractvalue { i32, i1 } [[MUL]], 0
+; CHECK-DAG: [[OVFL:%.*]] = extractvalue { i32, i1 } [[MUL]], 1
+ %retval = select i1 %overflow, i32 %mul32, i32 111
+; CHECK: select i1 [[OVFL]], i32 [[VAL]]
+ ret i32 %retval
+}
+
+; mul(zext x, zext y) != zext trunc mul
+define i32 @pr4918_1(i32 %x, i32 %y) nounwind {
+; CHECK-LABEL: @pr4918_1(
+entry:
+ %l = zext i32 %x to i64
+ %r = zext i32 %y to i64
+ %mul64 = mul i64 %l, %r
+; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y)
+ %part32 = trunc i64 %mul64 to i32
+ %part64 = zext i32 %part32 to i64
+ %overflow = icmp ne i64 %mul64, %part64
+; CHECK: [[OVFL:%.*]] = extractvalue { i32, i1 } [[MUL:%.*]], 1
+ %retval = zext i1 %overflow to i32
+ ret i32 %retval
+}
+
+; mul(zext x, zext y) == zext trunc mul
+define i32 @pr4918_2(i32 %x, i32 %y) nounwind {
+; CHECK-LABEL: @pr4918_2(
+entry:
+ %l = zext i32 %x to i64
+ %r = zext i32 %y to i64
+ %mul64 = mul i64 %l, %r
+; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y)
+ %part32 = trunc i64 %mul64 to i32
+ %part64 = zext i32 %part32 to i64
+ %overflow = icmp eq i64 %mul64, %part64
+; CHECK: extractvalue { i32, i1 } [[MUL]]
+ %retval = zext i1 %overflow to i32
+; CHECK: xor
+ ret i32 %retval
+}
+
+; zext trunc mul != mul(zext x, zext y)
+define i32 @pr4918_3(i32 %x, i32 %y) nounwind {
+; CHECK-LABEL: @pr4918_3(
+entry:
+ %l = zext i32 %x to i64
+ %r = zext i32 %y to i64
+ %mul64 = mul i64 %l, %r
+; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y)
+ %part32 = trunc i64 %mul64 to i32
+ %part64 = zext i32 %part32 to i64
+ %overflow = icmp ne i64 %part64, %mul64
+; CHECK: extractvalue { i32, i1 } [[MUL]], 1
+ %retval = zext i1 %overflow to i32
+ ret i32 %retval
+}
+
More information about the llvm-commits
mailing list