[llvm] [InstCombine] Factorise add/sub and max/min using distributivity (PR #101507)

via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 1 10:20:04 PDT 2024


================
@@ -1505,6 +1505,76 @@ foldMinimumOverTrailingOrLeadingZeroCount(Value *I0, Value *I1,
       ConstantInt::getTrue(ZeroUndef->getType()));
 }
 
+/// Return whether "X LOp (Y ROp Z)" is always equal to
+/// "(X LOp Y) ROp (X LOp Z)".
+static bool leftDistributesOverRightIntrinsic(Intrinsic::ID LOp,
+                                              Intrinsic::ID ROp) {
+  switch (LOp) {
+  case Intrinsic::umax:
+    return ROp == Intrinsic::umin;
+  case Intrinsic::smax:
+    return ROp == Intrinsic::smin;
+  case Intrinsic::umin:
+    return ROp == Intrinsic::umax;
+  case Intrinsic::smin:
+    return ROp == Intrinsic::smax;
+  case Intrinsic::uadd_sat:
+    return ROp == Intrinsic::umax || ROp == Intrinsic::umin;
+  case Intrinsic::sadd_sat:
+    return ROp == Intrinsic::smax || ROp == Intrinsic::smin;
+  default:
+    return false;
+  }
+}
+
+// Attempts to factorise a common term
+// in an instruction that has the form "(A op' B) op (C op' D)
+static Value *
+foldIntrinsicUsingDistributiveLaws(IntrinsicInst *II, InstCombiner::BuilderTy &Builder) {
+  Value *LHS = II->getOperand(0), *RHS = II->getOperand(1);
+  Intrinsic::ID TopLevelOpcode = II->getIntrinsicID();
+
+  if (LHS && RHS) {
+    IntrinsicInst *Op0 = dyn_cast<IntrinsicInst>(LHS);
+    IntrinsicInst *Op1 = dyn_cast<IntrinsicInst>(RHS);
+
+    if (!Op0 || !Op1)
+      return nullptr;
+
+    if (Op0->getIntrinsicID() !=
+        Op1->getIntrinsicID())
+      return nullptr;
+
+    Intrinsic::ID InnerOpcode = Op0->getIntrinsicID();
+
+    bool InnerCommutative = Op0->isCommutative();
+    bool Distributive =
+        leftDistributesOverRightIntrinsic(InnerOpcode, TopLevelOpcode);
+
+    Value *A = Op0->getOperand(0);
+    Value *B = Op0->getOperand(1);
+    Value *C = Op1->getOperand(0);
+    Value *D = Op1->getOperand(1);
+
+    if (Distributive && (A == C || (InnerCommutative && A == D))) {
+      if (A != C)
+        std::swap(C, D);
+
+      Value *NewIntrinsic = Builder.CreateBinaryIntrinsic(TopLevelOpcode, B, D);
----------------
goldsteinn wrote:

You are relying on `TopLevelOpcode` being commutative here which is fine given that all of the intrins supported are commutative, but a bit odd given that you go through the motions of handling whether inner is commutative (which it can't be given the intrins currently supported).

I would just add an `assert` that both `II` and `Op0` are commutative and add extra logic to handle the non-commutative case for now (NB: `assert` after the early return if `Distributive` is false).

https://github.com/llvm/llvm-project/pull/101507


More information about the llvm-commits mailing list