[llvm] [InstCombine] Factorise add/sub and max/min using distributivity (PR #101507)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Aug 1 09:11:38 PDT 2024
================
@@ -1505,6 +1505,80 @@ foldMinimumOverTrailingOrLeadingZeroCount(Value *I0, Value *I1,
ConstantInt::getTrue(ZeroUndef->getType()));
}
+/// Return whether "X LOp (Y ROp Z)" is always equal to
+/// "(X LOp Y) ROp (X LOp Z)".
+static bool leftDistributesOverRightIntrinsic(Intrinsic::ID LOp,
+ Intrinsic::ID ROp) {
+ switch (LOp) {
+ case Intrinsic::umax:
+ return ROp == Intrinsic::umin;
+ case Intrinsic::smax:
+ return ROp == Intrinsic::smin;
+ case Intrinsic::umin:
+ return ROp == Intrinsic::umax;
+ case Intrinsic::smin:
+ return ROp == Intrinsic::smax;
+ case Intrinsic::uadd_sat:
+ return ROp == Intrinsic::umax || ROp == Intrinsic::umin;
+ case Intrinsic::sadd_sat:
+ return ROp == Intrinsic::smax || ROp == Intrinsic::smin;
+ default:
+ return false;
+ }
+}
+
+// Attempts to factorise a common term
+// in an instruction that has the form "(A op' B) op (C op' D)
+static Instruction *
+foldCallUsingDistributiveLaws(CallInst *II, InstCombiner::BuilderTy &Builder) {
+ Value *LHS = II->getOperand(0), *RHS = II->getOperand(1);
+ Intrinsic::ID TopLevelOpcode = II->getCalledFunction()->getIntrinsicID();
+
+ if (LHS && RHS) {
+ CallInst *Op0 = dyn_cast<CallInst>(LHS);
+ CallInst *Op1 = dyn_cast<CallInst>(RHS);
+
+ if (!Op0 || !Op1)
+ return nullptr;
+
+ if (Op0->getCalledFunction()->getIntrinsicID() !=
+ Op1->getCalledFunction()->getIntrinsicID())
+ return nullptr;
+
+ Intrinsic::ID InnerOpcode = Op0->getCalledFunction()->getIntrinsicID();
+
+ bool InnerCommutative = Op0->isCommutative();
+ bool Distributive =
+ leftDistributesOverRightIntrinsic(InnerOpcode, TopLevelOpcode);
+
+ Value *A = Op0->getOperand(0);
+ Value *B = Op0->getOperand(1);
+ Value *C = Op1->getOperand(0);
+ Value *D = Op1->getOperand(1);
+
+ if (Distributive && (A == C || (InnerCommutative && A == D))) {
+ if (A != C)
+ std::swap(C, D);
+
+ Value *NewIntrinsic = Builder.CreateBinaryIntrinsic(TopLevelOpcode, B, D);
+ Function *F = Intrinsic::getDeclaration(II->getModule(), InnerOpcode,
+ II->getType());
+ return CallInst::Create(F, {NewIntrinsic, A});
----------------
goldsteinn wrote:
This can be:
```
return replaceInstUsesWith(*II, Builder.CreateBinaryIntrinsic(TopLevelOpcode, B, D));
```
Similary below.
https://github.com/llvm/llvm-project/pull/101507
More information about the llvm-commits
mailing list