[llvm] a77dedc - [InstSimplify][InstCombine][ConstantFold] Move vector div/rem by zero fold to InstCombine (#114280)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 1 07:56:27 PDT 2024
Author: Yingwei Zheng
Date: 2024-11-01T22:56:22+08:00
New Revision: a77dedcacb4c5eb221395b69877981dd6ad98989
URL: https://github.com/llvm/llvm-project/commit/a77dedcacb4c5eb221395b69877981dd6ad98989
DIFF: https://github.com/llvm/llvm-project/commit/a77dedcacb4c5eb221395b69877981dd6ad98989.diff
LOG: [InstSimplify][InstCombine][ConstantFold] Move vector div/rem by zero fold to InstCombine (#114280)
Previously we fold `div/rem X, C` into `poison` if any element of the
constant divisor `C` is zero or undef. However, it is incorrect when
threading udiv over an vector select:
https://alive2.llvm.org/ce/z/3Ninx5
```
define <2 x i32> @vec_select_udiv_poison(<2 x i1> %x) {
%sel = select <2 x i1> %x, <2 x i32> <i32 -1, i32 -1>, <2 x i32> <i32 0, i32 1>
%div = udiv <2 x i32> <i32 42, i32 -7>, %sel
ret <2 x i32> %div
}
```
In this case, `threadBinOpOverSelect` folds `udiv <i32 42, i32 -7>, <i32
-1, i32 -1>` and `udiv <i32 42, i32 -7>, <i32 0, i32 1>` into
`zeroinitializer` and `poison`, respectively. One solution is to
introduce a new flag indicating that we are threading over a vector
select. But it requires to modify both `InstSimplify` and
`ConstantFold`.
However, this optimization doesn't provide benefits to real-world
programs:
https://dtcxzyw.github.io/llvm-opt-benchmark/coverage/data/zyw/opt-ci/actions-runner/_work/llvm-opt-benchmark/llvm-opt-benchmark/llvm/llvm-project/llvm/lib/IR/ConstantFold.cpp.html#L908
https://dtcxzyw.github.io/llvm-opt-benchmark/coverage/data/zyw/opt-ci/actions-runner/_work/llvm-opt-benchmark/llvm-opt-benchmark/llvm/llvm-project/llvm/lib/Analysis/InstructionSimplify.cpp.html#L1107
This patch moves the fold into InstCombine to avoid breaking numerous
existing tests.
Fixes #114191 and #113866 (only poison-safety issue).
Added:
Modified:
llvm/lib/Analysis/InstructionSimplify.cpp
llvm/lib/IR/ConstantFold.cpp
llvm/lib/Transforms/InstCombine/InstCombineInternal.h
llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
llvm/test/Transforms/InstCombine/div.ll
llvm/test/Transforms/InstCombine/rem.ll
llvm/test/Transforms/InstCombine/vector-udiv.ll
llvm/test/Transforms/InstSimplify/div.ll
llvm/test/Transforms/InstSimplify/rem.ll
Removed:
################################################################################
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index d08be1e55c853e..2cb2612bf611e3 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -1095,19 +1095,6 @@ static Value *simplifyDivRem(Instruction::BinaryOps Opcode, Value *Op0,
if (match(Op1, m_Zero()))
return PoisonValue::get(Ty);
- // If any element of a constant divisor fixed width vector is zero or undef
- // the behavior is undefined and we can fold the whole op to poison.
- auto *Op1C = dyn_cast<Constant>(Op1);
- auto *VTy = dyn_cast<FixedVectorType>(Ty);
- if (Op1C && VTy) {
- unsigned NumElts = VTy->getNumElements();
- for (unsigned i = 0; i != NumElts; ++i) {
- Constant *Elt = Op1C->getAggregateElement(i);
- if (Elt && (Elt->isNullValue() || Q.isUndefValue(Elt)))
- return PoisonValue::get(Ty);
- }
- }
-
// poison / X -> poison
// poison % X -> poison
if (isa<PoisonValue>(Op0))
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index 07dfbc41e79b00..c2780faee403d4 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -902,11 +902,6 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1,
Constant *ExtractIdx = ConstantInt::get(Ty, i);
Constant *LHS = ConstantExpr::getExtractElement(C1, ExtractIdx);
Constant *RHS = ConstantExpr::getExtractElement(C2, ExtractIdx);
-
- // If any element of a divisor vector is zero, the whole op is poison.
- if (Instruction::isIntDivRem(Opcode) && RHS->isNullValue())
- return PoisonValue::get(VTy);
-
Constant *Res = ConstantExpr::isDesirableBinOp(Opcode)
? ConstantExpr::get(Opcode, LHS, RHS)
: ConstantFoldBinaryInstruction(Opcode, LHS, RHS);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 7a060cdab2d37d..adbd9186c59c5a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -102,6 +102,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
Instruction *visitSRem(BinaryOperator &I);
Instruction *visitFRem(BinaryOperator &I);
bool simplifyDivRemOfSelectWithZeroOp(BinaryOperator &I);
+ Instruction *commonIDivRemTransforms(BinaryOperator &I);
Instruction *commonIRemTransforms(BinaryOperator &I);
Instruction *commonIDivTransforms(BinaryOperator &I);
Instruction *visitUDiv(BinaryOperator &I);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index b9c165da906da4..f85a3c93651353 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1158,29 +1158,39 @@ static Value *foldIDivShl(BinaryOperator &I, InstCombiner::BuilderTy &Builder) {
return nullptr;
}
-/// This function implements the transforms common to both integer division
-/// instructions (udiv and sdiv). It is called by the visitors to those integer
-/// division instructions.
-/// Common integer divide transforms
-Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) {
- if (Instruction *Phi = foldBinopWithPhiOperands(I))
- return Phi;
-
+/// Common integer divide/remainder transforms
+Instruction *InstCombinerImpl::commonIDivRemTransforms(BinaryOperator &I) {
+ assert(I.isIntDivRem() && "Unexpected instruction");
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
- bool IsSigned = I.getOpcode() == Instruction::SDiv;
+
+ // If any element of a constant divisor fixed width vector is zero or undef
+ // the behavior is undefined and we can fold the whole op to poison.
+ auto *Op1C = dyn_cast<Constant>(Op1);
Type *Ty = I.getType();
+ auto *VTy = dyn_cast<FixedVectorType>(Ty);
+ if (Op1C && VTy) {
+ unsigned NumElts = VTy->getNumElements();
+ for (unsigned i = 0; i != NumElts; ++i) {
+ Constant *Elt = Op1C->getAggregateElement(i);
+ if (Elt && (Elt->isNullValue() || isa<UndefValue>(Elt)))
+ return replaceInstUsesWith(I, PoisonValue::get(Ty));
+ }
+ }
+
+ if (Instruction *Phi = foldBinopWithPhiOperands(I))
+ return Phi;
// The RHS is known non-zero.
if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, I))
return replaceOperand(I, 1, V);
- // Handle cases involving: [su]div X, (select Cond, Y, Z)
- // This does not apply for fdiv.
+ // Handle cases involving: div/rem X, (select Cond, Y, Z)
if (simplifyDivRemOfSelectWithZeroOp(I))
return &I;
// If the divisor is a select-of-constants, try to constant fold all div ops:
- // C / (select Cond, TrueC, FalseC) --> select Cond, (C / TrueC), (C / FalseC)
+ // C div/rem (select Cond, TrueC, FalseC) --> select Cond, (C div/rem TrueC),
+ // (C div/rem FalseC)
// TODO: Adapt simplifyDivRemOfSelectWithZeroOp to allow this and other folds.
if (match(Op0, m_ImmConstant()) &&
match(Op1, m_Select(m_Value(), m_ImmConstant(), m_ImmConstant()))) {
@@ -1189,6 +1199,21 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) {
return R;
}
+ return nullptr;
+}
+
+/// This function implements the transforms common to both integer division
+/// instructions (udiv and sdiv). It is called by the visitors to those integer
+/// division instructions.
+/// Common integer divide transforms
+Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) {
+ if (Instruction *Res = commonIDivRemTransforms(I))
+ return Res;
+
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ bool IsSigned = I.getOpcode() == Instruction::SDiv;
+ Type *Ty = I.getType();
+
const APInt *C2;
if (match(Op1, m_APInt(C2))) {
Value *X;
@@ -2138,29 +2163,11 @@ static Instruction *simplifyIRemMulShl(BinaryOperator &I,
/// remainder instructions.
/// Common integer remainder transforms
Instruction *InstCombinerImpl::commonIRemTransforms(BinaryOperator &I) {
- if (Instruction *Phi = foldBinopWithPhiOperands(I))
- return Phi;
+ if (Instruction *Res = commonIDivRemTransforms(I))
+ return Res;
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
- // The RHS is known non-zero.
- if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, I))
- return replaceOperand(I, 1, V);
-
- // Handle cases involving: rem X, (select Cond, Y, Z)
- if (simplifyDivRemOfSelectWithZeroOp(I))
- return &I;
-
- // If the divisor is a select-of-constants, try to constant fold all rem ops:
- // C % (select Cond, TrueC, FalseC) --> select Cond, (C % TrueC), (C % FalseC)
- // TODO: Adapt simplifyDivRemOfSelectWithZeroOp to allow this and other folds.
- if (match(Op0, m_ImmConstant()) &&
- match(Op1, m_Select(m_Value(), m_ImmConstant(), m_ImmConstant()))) {
- if (Instruction *R = FoldOpIntoSelect(I, cast<SelectInst>(Op1),
- /*FoldWithMultiUse*/ true))
- return R;
- }
-
if (isa<Constant>(Op1)) {
if (Instruction *Op0I = dyn_cast<Instruction>(Op0)) {
if (SelectInst *SI = dyn_cast<SelectInst>(Op0I)) {
diff --git a/llvm/test/Transforms/InstCombine/div.ll b/llvm/test/Transforms/InstCombine/div.ll
index a91c9bfc91c40d..d5d7ce9b7b2636 100644
--- a/llvm/test/Transforms/InstCombine/div.ll
+++ b/llvm/test/Transforms/InstCombine/div.ll
@@ -1163,7 +1163,8 @@ define <2 x i8> @sdiv_constant_dividend_select_of_constants_divisor_vec(i1 %b) {
define <2 x i8> @sdiv_constant_dividend_select_of_constants_divisor_vec_ub1(i1 %b) {
; CHECK-LABEL: @sdiv_constant_dividend_select_of_constants_divisor_vec_ub1(
-; CHECK-NEXT: ret <2 x i8> <i8 -10, i8 -10>
+; CHECK-NEXT: [[R:%.*]] = select i1 [[B:%.*]], <2 x i8> <i8 poison, i8 8>, <2 x i8> <i8 -10, i8 -10>
+; CHECK-NEXT: ret <2 x i8> [[R]]
;
%s = select i1 %b, <2 x i8> <i8 0, i8 -5>, <2 x i8> <i8 -4, i8 4>
%r = sdiv <2 x i8> <i8 42, i8 -42>, %s
@@ -1269,7 +1270,8 @@ define <2 x i8> @udiv_constant_dividend_select_of_constants_divisor_vec(i1 %b) {
define <2 x i8> @udiv_constant_dividend_select_of_constants_divisor_vec_ub1(i1 %b) {
; CHECK-LABEL: @udiv_constant_dividend_select_of_constants_divisor_vec_ub1(
-; CHECK-NEXT: ret <2 x i8> <i8 0, i8 53>
+; CHECK-NEXT: [[R:%.*]] = select i1 [[B:%.*]], <2 x i8> <i8 poison, i8 0>, <2 x i8> <i8 0, i8 53>
+; CHECK-NEXT: ret <2 x i8> [[R]]
;
%s = select i1 %b, <2 x i8> <i8 0, i8 -5>, <2 x i8> <i8 -4, i8 4>
%r = udiv <2 x i8> <i8 42, i8 -42>, %s
diff --git a/llvm/test/Transforms/InstCombine/rem.ll b/llvm/test/Transforms/InstCombine/rem.ll
index 4262ef85553b64..4f7687aeaf8bc8 100644
--- a/llvm/test/Transforms/InstCombine/rem.ll
+++ b/llvm/test/Transforms/InstCombine/rem.ll
@@ -997,7 +997,8 @@ define <2 x i8> @urem_constant_dividend_select_of_constants_divisor_vec(i1 %b) {
define <2 x i8> @urem_constant_dividend_select_of_constants_divisor_vec_ub1(i1 %b) {
; CHECK-LABEL: @urem_constant_dividend_select_of_constants_divisor_vec_ub1(
-; CHECK-NEXT: ret <2 x i8> <i8 42, i8 2>
+; CHECK-NEXT: [[R:%.*]] = select i1 [[B:%.*]], <2 x i8> <i8 poison, i8 -42>, <2 x i8> <i8 42, i8 2>
+; CHECK-NEXT: ret <2 x i8> [[R]]
;
%s = select i1 %b, <2 x i8> <i8 0, i8 -5>, <2 x i8> <i8 -4, i8 4>
%r = urem <2 x i8> <i8 42, i8 -42>, %s
diff --git a/llvm/test/Transforms/InstCombine/vector-udiv.ll b/llvm/test/Transforms/InstCombine/vector-udiv.ll
index c817b3a1ac5a0a..0289b7c70cc4fb 100644
--- a/llvm/test/Transforms/InstCombine/vector-udiv.ll
+++ b/llvm/test/Transforms/InstCombine/vector-udiv.ll
@@ -97,3 +97,16 @@ define <4 x i32> @test_v4i32_zext_shl_const_pow2(<4 x i32> %a0, <4 x i16> %a1) {
%3 = udiv <4 x i32> %a0, %2
ret <4 x i32> %3
}
+
+; Make sure we do not simplify udiv <i32 42, i32 -7>, <i32 0, i32 1> to
+; poison when threading udiv over selects
+
+define <2 x i32> @vec_select_udiv_poison(<2 x i1> %x) {
+; CHECK-LABEL: @vec_select_udiv_poison(
+; CHECK-NEXT: [[DIV:%.*]] = select <2 x i1> [[X:%.*]], <2 x i32> zeroinitializer, <2 x i32> <i32 poison, i32 -7>
+; CHECK-NEXT: ret <2 x i32> [[DIV]]
+;
+ %sel = select <2 x i1> %x, <2 x i32> <i32 -1, i32 -1>, <2 x i32> <i32 0, i32 1>
+ %div = udiv <2 x i32> <i32 42, i32 -7>, %sel
+ ret <2 x i32> %div
+}
diff --git a/llvm/test/Transforms/InstSimplify/div.ll b/llvm/test/Transforms/InstSimplify/div.ll
index 5ca2e8837b924b..e2bc121aee4571 100644
--- a/llvm/test/Transforms/InstSimplify/div.ll
+++ b/llvm/test/Transforms/InstSimplify/div.ll
@@ -29,7 +29,7 @@ define <2 x i32> @zero_dividend_vector_poison_elt(<2 x i32> %A) {
define <2 x i8> @sdiv_zero_elt_vec_constfold(<2 x i8> %x) {
; CHECK-LABEL: @sdiv_zero_elt_vec_constfold(
-; CHECK-NEXT: ret <2 x i8> poison
+; CHECK-NEXT: ret <2 x i8> <i8 poison, i8 0>
;
%div = sdiv <2 x i8> <i8 1, i8 2>, <i8 0, i8 -42>
ret <2 x i8> %div
@@ -37,7 +37,7 @@ define <2 x i8> @sdiv_zero_elt_vec_constfold(<2 x i8> %x) {
define <2 x i8> @udiv_zero_elt_vec_constfold(<2 x i8> %x) {
; CHECK-LABEL: @udiv_zero_elt_vec_constfold(
-; CHECK-NEXT: ret <2 x i8> poison
+; CHECK-NEXT: ret <2 x i8> <i8 0, i8 poison>
;
%div = udiv <2 x i8> <i8 1, i8 2>, <i8 42, i8 0>
ret <2 x i8> %div
@@ -45,7 +45,8 @@ define <2 x i8> @udiv_zero_elt_vec_constfold(<2 x i8> %x) {
define <2 x i8> @sdiv_zero_elt_vec(<2 x i8> %x) {
; CHECK-LABEL: @sdiv_zero_elt_vec(
-; CHECK-NEXT: ret <2 x i8> poison
+; CHECK-NEXT: [[DIV:%.*]] = sdiv <2 x i8> [[X:%.*]], <i8 -42, i8 0>
+; CHECK-NEXT: ret <2 x i8> [[DIV]]
;
%div = sdiv <2 x i8> %x, <i8 -42, i8 0>
ret <2 x i8> %div
@@ -53,7 +54,8 @@ define <2 x i8> @sdiv_zero_elt_vec(<2 x i8> %x) {
define <2 x i8> @udiv_zero_elt_vec(<2 x i8> %x) {
; CHECK-LABEL: @udiv_zero_elt_vec(
-; CHECK-NEXT: ret <2 x i8> poison
+; CHECK-NEXT: [[DIV:%.*]] = udiv <2 x i8> [[X:%.*]], <i8 0, i8 42>
+; CHECK-NEXT: ret <2 x i8> [[DIV]]
;
%div = udiv <2 x i8> %x, <i8 0, i8 42>
ret <2 x i8> %div
@@ -61,7 +63,8 @@ define <2 x i8> @udiv_zero_elt_vec(<2 x i8> %x) {
define <2 x i8> @sdiv_poison_elt_vec(<2 x i8> %x) {
; CHECK-LABEL: @sdiv_poison_elt_vec(
-; CHECK-NEXT: ret <2 x i8> poison
+; CHECK-NEXT: [[DIV:%.*]] = sdiv <2 x i8> [[X:%.*]], <i8 -42, i8 poison>
+; CHECK-NEXT: ret <2 x i8> [[DIV]]
;
%div = sdiv <2 x i8> %x, <i8 -42, i8 poison>
ret <2 x i8> %div
@@ -69,7 +72,8 @@ define <2 x i8> @sdiv_poison_elt_vec(<2 x i8> %x) {
define <2 x i8> @udiv_poison_elt_vec(<2 x i8> %x) {
; CHECK-LABEL: @udiv_poison_elt_vec(
-; CHECK-NEXT: ret <2 x i8> poison
+; CHECK-NEXT: [[DIV:%.*]] = udiv <2 x i8> [[X:%.*]], <i8 poison, i8 42>
+; CHECK-NEXT: ret <2 x i8> [[DIV]]
;
%div = udiv <2 x i8> %x, <i8 poison, i8 42>
ret <2 x i8> %div
diff --git a/llvm/test/Transforms/InstSimplify/rem.ll b/llvm/test/Transforms/InstSimplify/rem.ll
index aceb7cb12185d6..5ec803c6d0481e 100644
--- a/llvm/test/Transforms/InstSimplify/rem.ll
+++ b/llvm/test/Transforms/InstSimplify/rem.ll
@@ -29,7 +29,7 @@ define <2 x i32> @zero_dividend_vector_poison_elt(<2 x i32> %A) {
define <2 x i8> @srem_zero_elt_vec_constfold(<2 x i8> %x) {
; CHECK-LABEL: @srem_zero_elt_vec_constfold(
-; CHECK-NEXT: ret <2 x i8> poison
+; CHECK-NEXT: ret <2 x i8> <i8 poison, i8 2>
;
%rem = srem <2 x i8> <i8 1, i8 2>, <i8 0, i8 -42>
ret <2 x i8> %rem
@@ -37,7 +37,7 @@ define <2 x i8> @srem_zero_elt_vec_constfold(<2 x i8> %x) {
define <2 x i8> @urem_zero_elt_vec_constfold(<2 x i8> %x) {
; CHECK-LABEL: @urem_zero_elt_vec_constfold(
-; CHECK-NEXT: ret <2 x i8> poison
+; CHECK-NEXT: ret <2 x i8> <i8 1, i8 poison>
;
%rem = urem <2 x i8> <i8 1, i8 2>, <i8 42, i8 0>
ret <2 x i8> %rem
@@ -45,7 +45,8 @@ define <2 x i8> @urem_zero_elt_vec_constfold(<2 x i8> %x) {
define <2 x i8> @srem_zero_elt_vec(<2 x i8> %x) {
; CHECK-LABEL: @srem_zero_elt_vec(
-; CHECK-NEXT: ret <2 x i8> poison
+; CHECK-NEXT: [[REM:%.*]] = srem <2 x i8> [[X:%.*]], <i8 -42, i8 0>
+; CHECK-NEXT: ret <2 x i8> [[REM]]
;
%rem = srem <2 x i8> %x, <i8 -42, i8 0>
ret <2 x i8> %rem
@@ -53,7 +54,8 @@ define <2 x i8> @srem_zero_elt_vec(<2 x i8> %x) {
define <2 x i8> @urem_zero_elt_vec(<2 x i8> %x) {
; CHECK-LABEL: @urem_zero_elt_vec(
-; CHECK-NEXT: ret <2 x i8> poison
+; CHECK-NEXT: [[REM:%.*]] = urem <2 x i8> [[X:%.*]], <i8 0, i8 42>
+; CHECK-NEXT: ret <2 x i8> [[REM]]
;
%rem = urem <2 x i8> %x, <i8 0, i8 42>
ret <2 x i8> %rem
@@ -61,7 +63,8 @@ define <2 x i8> @urem_zero_elt_vec(<2 x i8> %x) {
define <2 x i8> @srem_undef_elt_vec(<2 x i8> %x) {
; CHECK-LABEL: @srem_undef_elt_vec(
-; CHECK-NEXT: ret <2 x i8> poison
+; CHECK-NEXT: [[REM:%.*]] = srem <2 x i8> [[X:%.*]], <i8 -42, i8 undef>
+; CHECK-NEXT: ret <2 x i8> [[REM]]
;
%rem = srem <2 x i8> %x, <i8 -42, i8 undef>
ret <2 x i8> %rem
@@ -69,7 +72,8 @@ define <2 x i8> @srem_undef_elt_vec(<2 x i8> %x) {
define <2 x i8> @urem_undef_elt_vec(<2 x i8> %x) {
; CHECK-LABEL: @urem_undef_elt_vec(
-; CHECK-NEXT: ret <2 x i8> poison
+; CHECK-NEXT: [[REM:%.*]] = urem <2 x i8> [[X:%.*]], <i8 undef, i8 42>
+; CHECK-NEXT: ret <2 x i8> [[REM]]
;
%rem = urem <2 x i8> %x, <i8 undef, i8 42>
ret <2 x i8> %rem
More information about the llvm-commits
mailing list