[llvm] [InstCombine] Factorise Add and Min/Max using Distributivity (PR #101717)

Jorge Botto via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 2 14:59:57 PDT 2024


================
@@ -1505,6 +1505,97 @@ foldMinimumOverTrailingOrLeadingZeroCount(Value *I0, Value *I1,
       ConstantInt::getTrue(ZeroUndef->getType()));
 }
 
+/// Return whether "X LOp (Y ROp Z)" is always equal to
+/// "(X LOp Y) ROp (X LOp Z)".
+static bool leftDistributesOverRightIntrinsic(Instruction::BinaryOps LOp,
+                                              bool hasNUW, bool hasNSW,
+                                              Intrinsic::ID ROp) {
+  switch (ROp) {
+  case Intrinsic::umax:
+    return hasNUW && LOp == Instruction::Add;
+  case Intrinsic::umin:
+    return hasNUW && LOp == Instruction::Add;
+  case Intrinsic::smax:
+    return hasNSW && LOp == Instruction::Add;
+  case Intrinsic::smin:
+    return hasNSW && LOp == Instruction::Add;
+  default:
+    return false;
+  }
+}
+
+// Attempts to factorise a common term
+// in an instruction that has the form "(A op' B) op (C op' D)
+// where op is an intrinsic and op' is a binop
+static Value *
+foldIntrinsicUsingDistributiveLaws(IntrinsicInst *II,
+                                   InstCombiner::BuilderTy &Builder) {
+  Value *LHS = II->getOperand(0), *RHS = II->getOperand(1);
+  Intrinsic::ID TopLevelOpcode = II->getIntrinsicID();
+
+  OverflowingBinaryOperator *Op0 = dyn_cast<OverflowingBinaryOperator>(LHS);
+  OverflowingBinaryOperator *Op1 = dyn_cast<OverflowingBinaryOperator>(RHS);
+
+  if (!Op0 || !Op1)
+    return nullptr;
+
+  if (Op0->getOpcode() != Op1->getOpcode())
+    return nullptr;
+
+  if (Op0->hasNoUnsignedWrap() != Op1->hasNoUnsignedWrap() ||
+      Op0->hasNoSignedWrap() != Op1->hasNoSignedWrap())
+    return nullptr;
+
+  if (!Op0->hasOneUse() || !Op1->hasOneUse())
+    return nullptr;
+
+  Instruction::BinaryOps InnerOpcode =
+      static_cast<Instruction::BinaryOps>(Op0->getOpcode());
+  bool HasNUW = Op0->hasNoUnsignedWrap();
+  bool HasNSW = Op0->hasNoSignedWrap();
+
+  if (!InnerOpcode)
+    return nullptr;
+
+  if (!leftDistributesOverRightIntrinsic(InnerOpcode, HasNUW, HasNSW,
+                                         TopLevelOpcode))
+    return nullptr;
+
+  assert(II->isCommutative() && Op0->isCommutative() &&
+         "Only inner and outer commutative op codes are supported.");
+
+  Value *A = Op0->getOperand(0);
+  Value *B = Op0->getOperand(1);
+  Value *C = Op1->getOperand(0);
+  Value *D = Op1->getOperand(1);
+
+  if (A == C || A == D) {
+    if (A != C)
+      std::swap(C, D);
+
+    Value *NewIntrinsic = Builder.CreateBinaryIntrinsic(TopLevelOpcode, B, D);
+    BinaryOperator *NewBinop =
+        cast<BinaryOperator>(Builder.CreateBinOp(InnerOpcode, NewIntrinsic, A));
+    NewBinop->setHasNoSignedWrap(HasNSW);
+    NewBinop->setHasNoUnsignedWrap(HasNUW);
+    return NewBinop;
+  }
+
+  if (B == D || B == C) {
+    if (B != D)
+      std::swap(C, D);
+
+    Value *NewIntrinsic = Builder.CreateBinaryIntrinsic(TopLevelOpcode, A, C);
+    BinaryOperator *NewBinop =
+        cast<BinaryOperator>(Builder.CreateBinOp(InnerOpcode, NewIntrinsic, B));
+    NewBinop->setHasNoSignedWrap(HasNSW);
+    NewBinop->setHasNoUnsignedWrap(HasNUW);
+    return NewBinop;
----------------
jf-botto wrote:

Thanks, totally get it. I've simplified the boolean logic into fewer/simpler arguments.

https://github.com/llvm/llvm-project/pull/101717


More information about the llvm-commits mailing list