[llvm] f583117 - [InstCombine] refactor the SimplifyUsingDistributiveLaws NFC

via llvm-commits llvm-commits at lists.llvm.org
Sun Oct 30 06:05:54 PDT 2022


Author: zhongyunde
Date: 2022-10-30T21:04:06+08:00
New Revision: f58311796c49223de1fb3a09248698afe73e2852

URL: https://github.com/llvm/llvm-project/commit/f58311796c49223de1fb3a09248698afe73e2852
DIFF: https://github.com/llvm/llvm-project/commit/f58311796c49223de1fb3a09248698afe73e2852.diff

LOG: [InstCombine] refactor the SimplifyUsingDistributiveLaws NFC

Precommit for D136015
Reviewed By: spatel
Differential Revision: https://reviews.llvm.org/D137019

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
    llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
    llvm/lib/Transforms/InstCombine/InstCombineInternal.h
    llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
    llvm/lib/Transforms/InstCombine/InstructionCombining.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 9a4225ed5d93..76234ccfdcff 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1231,7 +1231,7 @@ Instruction *InstCombinerImpl::
 }
 
 /// This is a specialization of a more general transform from
-/// SimplifyUsingDistributiveLaws. If that code can be made to work optimally
+/// foldUsingDistributiveLaws. If that code can be made to work optimally
 /// for multi-use cases or propagating nsw/nuw, then we would not need this.
 static Instruction *factorizeMathWithShlOps(BinaryOperator &I,
                                             InstCombiner::BuilderTy &Builder) {
@@ -1322,7 +1322,7 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
     return Phi;
 
   // (A*B)+(A*C) -> A*(B+C) etc
-  if (Value *V = SimplifyUsingDistributiveLaws(I))
+  if (Value *V = foldUsingDistributiveLaws(I))
     return replaceInstUsesWith(I, V);
 
   if (Instruction *R = foldBoxMultiply(I))
@@ -1944,7 +1944,7 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
     return TryToNarrowDeduceFlags(); // Should have been handled in Negator!
 
   // (A*B)-(A*C) -> A*(B-C) etc
-  if (Value *V = SimplifyUsingDistributiveLaws(I))
+  if (Value *V = foldUsingDistributiveLaws(I))
     return replaceInstUsesWith(I, V);
 
   if (I.getType()->isIntOrIntVectorTy(1))

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index ac6780b51864..c3dabc2d4a07 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -1847,7 +1847,7 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
     return X;
 
   // (A|B)&(A|C) -> A|(B&C) etc
-  if (Value *V = SimplifyUsingDistributiveLaws(I))
+  if (Value *V = foldUsingDistributiveLaws(I))
     return replaceInstUsesWith(I, V);
 
   if (Value *V = SimplifyBSwap(I, Builder))
@@ -2825,7 +2825,7 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
     return X;
 
   // (A&B)|(A&C) -> A&(B|C) etc
-  if (Value *V = SimplifyUsingDistributiveLaws(I))
+  if (Value *V = foldUsingDistributiveLaws(I))
     return replaceInstUsesWith(I, V);
 
   if (Value *V = SimplifyBSwap(I, Builder))
@@ -3768,7 +3768,7 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
     return NewXor;
 
   // (A&B)^(A&C) -> A&(B^C) etc
-  if (Value *V = SimplifyUsingDistributiveLaws(I))
+  if (Value *V = foldUsingDistributiveLaws(I))
     return replaceInstUsesWith(I, V);
 
   // See if we can simplify any instructions used by the instruction whose sole

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 1a064abdcfcb..af5ae1d63a41 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -544,7 +544,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   /// -> "A*(B+C)") or expanding out if this results in simplifications (eg: "A
   /// & (B | C) -> (A&B) | (A&C)" if this is a win).  Returns the simplified
   /// value, or null if it didn't simplify.
-  Value *SimplifyUsingDistributiveLaws(BinaryOperator &I);
+  Value *foldUsingDistributiveLaws(BinaryOperator &I);
 
   /// Tries to simplify add operations using the definition of remainder.
   ///
@@ -560,8 +560,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
 
   /// This tries to simplify binary operations by factorizing out common terms
   /// (e. g. "(A*B)+(A*C)" -> "A*(B+C)").
-  Value *tryFactorization(BinaryOperator &, Instruction::BinaryOps, Value *,
-                          Value *, Value *, Value *);
+  Value *tryFactorizationFolds(BinaryOperator &I);
 
   /// Match a select chain which produces one of three values based on whether
   /// the LHS is less than, equal to, or greater than RHS respectively.

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index c3ad993e5da0..e4fccda750e6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -188,7 +188,7 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
   if (Instruction *Phi = foldBinopWithPhiOperands(I))
     return Phi;
 
-  if (Value *V = SimplifyUsingDistributiveLaws(I))
+  if (Value *V = foldUsingDistributiveLaws(I))
     return replaceInstUsesWith(I, V);
 
   Type *Ty = I.getType();

diff  --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 091bb96b6365..ebdd325335fd 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -633,10 +633,10 @@ getBinOpsForFactorization(Instruction::BinaryOps TopOpcode, BinaryOperator *Op,
 
 /// This tries to simplify binary operations by factorizing out common terms
 /// (e. g. "(A*B)+(A*C)" -> "A*(B+C)").
-Value *InstCombinerImpl::tryFactorization(BinaryOperator &I,
-                                          Instruction::BinaryOps InnerOpcode,
-                                          Value *A, Value *B, Value *C,
-                                          Value *D) {
+static Value *tryFactorization(BinaryOperator &I, const SimplifyQuery &SQ,
+                               InstCombiner::BuilderTy &Builder,
+                               Instruction::BinaryOps InnerOpcode, Value *A,
+                               Value *B, Value *C, Value *D) {
   assert(A && B && C && D && "All values must be provided");
 
   Value *V = nullptr;
@@ -730,46 +730,58 @@ Value *InstCombinerImpl::tryFactorization(BinaryOperator &I,
   return RetVal;
 }
 
+Value *InstCombinerImpl::tryFactorizationFolds(BinaryOperator &I) {
+  Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
+  BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS);
+  BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS);
+  Instruction::BinaryOps TopLevelOpcode = I.getOpcode();
+  Value *A, *B, *C, *D;
+  Instruction::BinaryOps LHSOpcode, RHSOpcode;
+
+  if (Op0)
+    LHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op0, A, B);
+  if (Op1)
+    RHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op1, C, D);
+
+  // The instruction has the form "(A op' B) op (C op' D)".  Try to factorize
+  // a common term.
+  if (Op0 && Op1 && LHSOpcode == RHSOpcode)
+    if (Value *V = tryFactorization(I, SQ, Builder, LHSOpcode, A, B, C, D))
+      return V;
+
+  // The instruction has the form "(A op' B) op (C)".  Try to factorize common
+  // term.
+  if (Op0)
+    if (Value *Ident = getIdentityValue(LHSOpcode, RHS))
+      if (Value *V =
+              tryFactorization(I, SQ, Builder, LHSOpcode, A, B, RHS, Ident))
+        return V;
+
+  // The instruction has the form "(B) op (C op' D)".  Try to factorize common
+  // term.
+  if (Op1)
+    if (Value *Ident = getIdentityValue(RHSOpcode, LHS))
+      if (Value *V =
+              tryFactorization(I, SQ, Builder, RHSOpcode, LHS, Ident, C, D))
+        return V;
+
+  return nullptr;
+}
+
 /// This tries to simplify binary operations which some other binary operation
 /// distributes over either by factorizing out common terms
 /// (eg "(A*B)+(A*C)" -> "A*(B+C)") or expanding out if this results in
 /// simplifications (eg: "A & (B | C) -> (A&B) | (A&C)" if this is a win).
 /// Returns the simplified value, or null if it didn't simplify.
-Value *InstCombinerImpl::SimplifyUsingDistributiveLaws(BinaryOperator &I) {
+Value *InstCombinerImpl::foldUsingDistributiveLaws(BinaryOperator &I) {
   Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
   BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS);
   BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS);
   Instruction::BinaryOps TopLevelOpcode = I.getOpcode();
 
-  {
-    // Factorization.
-    Value *A, *B, *C, *D;
-    Instruction::BinaryOps LHSOpcode, RHSOpcode;
-    if (Op0)
-      LHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op0, A, B);
-    if (Op1)
-      RHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op1, C, D);
-
-    // The instruction has the form "(A op' B) op (C op' D)".  Try to factorize
-    // a common term.
-    if (Op0 && Op1 && LHSOpcode == RHSOpcode)
-      if (Value *V = tryFactorization(I, LHSOpcode, A, B, C, D))
-        return V;
-
-    // The instruction has the form "(A op' B) op (C)".  Try to factorize common
-    // term.
-    if (Op0)
-      if (Value *Ident = getIdentityValue(LHSOpcode, RHS))
-        if (Value *V = tryFactorization(I, LHSOpcode, A, B, RHS, Ident))
-          return V;
-
-    // The instruction has the form "(B) op (C op' D)".  Try to factorize common
-    // term.
-    if (Op1)
-      if (Value *Ident = getIdentityValue(RHSOpcode, LHS))
-        if (Value *V = tryFactorization(I, RHSOpcode, LHS, Ident, C, D))
-          return V;
-  }
+  // Factorization.
+  if (Value *R = tryFactorizationFolds(I))
+    return R;
 
   // Expansion.
   if (Op0 && rightDistributesOverLeft(Op0->getOpcode(), TopLevelOpcode)) {


        


More information about the llvm-commits mailing list