[llvm] f34dcf2 - [IRBuilder] Migrate all binops to folding API

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 30 07:41:25 PDT 2022


Author: Nikita Popov
Date: 2022-06-30T16:41:17+02:00
New Revision: f34dcf27637f5b657d9e244187631243bfccc25a

URL: https://github.com/llvm/llvm-project/commit/f34dcf27637f5b657d9e244187631243bfccc25a
DIFF: https://github.com/llvm/llvm-project/commit/f34dcf27637f5b657d9e244187631243bfccc25a.diff

LOG: [IRBuilder] Migrate all binops to folding API

Migrate all binops to use FoldXYZ rather than CreateXYZ APIs,
which are compatible with InstSimplifyFolder and fallible constant
folding.

Rather than continuing to add one method for every single operator,
add a generic FoldBinOp (plus variants for nowrap, exact and fmf
operators), which we would need anyway for CreateBinaryOp.

This change is not NFC because IRBuilder with InstSimplifyFolder
may perform more folding. However, this patch changes SCEVExpander
to not use the folder in InsertBinOp to minimize practical impact
and keep this change as close to NFC as possible.

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/InstSimplifyFolder.h
    llvm/include/llvm/Analysis/TargetFolder.h
    llvm/include/llvm/IR/ConstantFolder.h
    llvm/include/llvm/IR/IRBuilder.h
    llvm/include/llvm/IR/IRBuilderFolder.h
    llvm/include/llvm/IR/NoFolder.h
    llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/InstSimplifyFolder.h b/llvm/include/llvm/Analysis/InstSimplifyFolder.h
index a67424f605363..33b6f12efbca6 100644
--- a/llvm/include/llvm/Analysis/InstSimplifyFolder.h
+++ b/llvm/include/llvm/Analysis/InstSimplifyFolder.h
@@ -46,33 +46,25 @@ class InstSimplifyFolder final : public IRBuilderFolder {
   // Return an existing value or a constant if the operation can be simplified.
   // Otherwise return nullptr.
   //===--------------------------------------------------------------------===//
-  Value *FoldAdd(Value *LHS, Value *RHS, bool HasNUW = false,
-                 bool HasNSW = false) const override {
-    return simplifyAddInst(LHS, RHS, HasNUW, HasNSW, SQ);
-  }
-
-  Value *FoldAnd(Value *LHS, Value *RHS) const override {
-    return simplifyAndInst(LHS, RHS, SQ);
-  }
 
-  Value *FoldOr(Value *LHS, Value *RHS) const override {
-    return simplifyOrInst(LHS, RHS, SQ);
+  Value *FoldBinOp(Instruction::BinaryOps Opc, Value *LHS,
+                   Value *RHS) const override {
+    return simplifyBinOp(Opc, LHS, RHS, SQ);
   }
 
-  Value *FoldUDiv(Value *LHS, Value *RHS, bool IsExact) const override {
-    return simplifyUDivInst(LHS, RHS, SQ);
+  Value *FoldExactBinOp(Instruction::BinaryOps Opc, Value *LHS, Value *RHS,
+                        bool IsExact) const override {
+    return simplifyBinOp(Opc, LHS, RHS, SQ);
   }
 
-  Value *FoldSDiv(Value *LHS, Value *RHS, bool IsExact) const override {
-    return simplifySDivInst(LHS, RHS, SQ);
+  Value *FoldNoWrapBinOp(Instruction::BinaryOps Opc, Value *LHS, Value *RHS,
+                         bool HasNUW, bool HasNSW) const override {
+    return simplifyBinOp(Opc, LHS, RHS, SQ);
   }
 
-  Value *FoldURem(Value *LHS, Value *RHS) const override {
-    return simplifyURemInst(LHS, RHS, SQ);
-  }
-
-  Value *FoldSRem(Value *LHS, Value *RHS) const override {
-    return simplifySRemInst(LHS, RHS, SQ);
+  Value *FoldBinOpFMF(Instruction::BinaryOps Opc, Value *LHS, Value *RHS,
+                      FastMathFlags FMF) const override {
+    return simplifyBinOp(Opc, LHS, RHS, FMF, SQ);
   }
 
   Value *FoldICmp(CmpInst::Predicate P, Value *LHS, Value *RHS) const override {
@@ -115,54 +107,6 @@ class InstSimplifyFolder final : public IRBuilderFolder {
     return simplifyShuffleVectorInst(V1, V2, Mask, RetTy, SQ);
   }
 
-  //===--------------------------------------------------------------------===//
-  // Binary Operators
-  //===--------------------------------------------------------------------===//
-
-  Value *CreateFAdd(Constant *LHS, Constant *RHS) const override {
-    return ConstFolder.CreateFAdd(LHS, RHS);
-  }
-  Value *CreateSub(Constant *LHS, Constant *RHS, bool HasNUW = false,
-                   bool HasNSW = false) const override {
-    return ConstFolder.CreateSub(LHS, RHS, HasNUW, HasNSW);
-  }
-  Value *CreateFSub(Constant *LHS, Constant *RHS) const override {
-    return ConstFolder.CreateFSub(LHS, RHS);
-  }
-  Value *CreateMul(Constant *LHS, Constant *RHS, bool HasNUW = false,
-                   bool HasNSW = false) const override {
-    return ConstFolder.CreateMul(LHS, RHS, HasNUW, HasNSW);
-  }
-  Value *CreateFMul(Constant *LHS, Constant *RHS) const override {
-    return ConstFolder.CreateFMul(LHS, RHS);
-  }
-  Value *CreateFDiv(Constant *LHS, Constant *RHS) const override {
-    return ConstFolder.CreateFDiv(LHS, RHS);
-  }
-  Value *CreateFRem(Constant *LHS, Constant *RHS) const override {
-    return ConstFolder.CreateFRem(LHS, RHS);
-  }
-  Value *CreateShl(Constant *LHS, Constant *RHS, bool HasNUW = false,
-                   bool HasNSW = false) const override {
-    return ConstFolder.CreateShl(LHS, RHS, HasNUW, HasNSW);
-  }
-  Value *CreateLShr(Constant *LHS, Constant *RHS,
-                    bool isExact = false) const override {
-    return ConstFolder.CreateLShr(LHS, RHS, isExact);
-  }
-  Value *CreateAShr(Constant *LHS, Constant *RHS,
-                    bool isExact = false) const override {
-    return ConstFolder.CreateAShr(LHS, RHS, isExact);
-  }
-  Value *CreateXor(Constant *LHS, Constant *RHS) const override {
-    return ConstFolder.CreateXor(LHS, RHS);
-  }
-
-  Value *CreateBinOp(Instruction::BinaryOps Opc, Constant *LHS,
-                     Constant *RHS) const override {
-    return ConstFolder.CreateBinOp(Opc, LHS, RHS);
-  }
-
   //===--------------------------------------------------------------------===//
   // Unary Operators
   //===--------------------------------------------------------------------===//

diff  --git a/llvm/include/llvm/Analysis/TargetFolder.h b/llvm/include/llvm/Analysis/TargetFolder.h
index 1187e9cfd02a2..a360be5313aec 100644
--- a/llvm/include/llvm/Analysis/TargetFolder.h
+++ b/llvm/include/llvm/Analysis/TargetFolder.h
@@ -22,6 +22,7 @@
 #include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/IRBuilderFolder.h"
+#include "llvm/IR/Operator.h"
 
 namespace llvm {
 
@@ -49,63 +50,45 @@ class TargetFolder final : public IRBuilderFolder {
   // Return an existing value or a constant if the operation can be simplified.
   // Otherwise return nullptr.
   //===--------------------------------------------------------------------===//
-  Value *FoldAdd(Value *LHS, Value *RHS, bool HasNUW = false,
-                 bool HasNSW = false) const override {
-    auto *LC = dyn_cast<Constant>(LHS);
-    auto *RC = dyn_cast<Constant>(RHS);
-    if (LC && RC)
-      return Fold(ConstantExpr::getAdd(LC, RC, HasNUW, HasNSW));
-    return nullptr;
-  }
-
-  Value *FoldAnd(Value *LHS, Value *RHS) const override {
-    auto *LC = dyn_cast<Constant>(LHS);
-    auto *RC = dyn_cast<Constant>(RHS);
-    if (LC && RC)
-      return Fold(ConstantExpr::getAnd(LC, RC));
-    return nullptr;
-  }
 
-  Value *FoldOr(Value *LHS, Value *RHS) const override {
+  Value *FoldBinOp(Instruction::BinaryOps Opc, Value *LHS,
+                   Value *RHS) const override {
     auto *LC = dyn_cast<Constant>(LHS);
     auto *RC = dyn_cast<Constant>(RHS);
     if (LC && RC)
-      return Fold(ConstantExpr::getOr(LC, RC));
+      return Fold(ConstantExpr::get(Opc, LC, RC));
     return nullptr;
   }
 
-  Value *FoldUDiv(Value *LHS, Value *RHS, bool IsExact) const override {
+  Value *FoldExactBinOp(Instruction::BinaryOps Opc, Value *LHS, Value *RHS,
+                        bool IsExact) const override {
     auto *LC = dyn_cast<Constant>(LHS);
     auto *RC = dyn_cast<Constant>(RHS);
     if (LC && RC)
-      return Fold(ConstantExpr::getUDiv(LC, RC, IsExact));
+      return Fold(ConstantExpr::get(
+          Opc, LC, RC, IsExact ? PossiblyExactOperator::IsExact : 0));
     return nullptr;
   }
 
-  Value *FoldSDiv(Value *LHS, Value *RHS, bool IsExact) const override {
+  Value *FoldNoWrapBinOp(Instruction::BinaryOps Opc, Value *LHS, Value *RHS,
+                         bool HasNUW, bool HasNSW) const override {
     auto *LC = dyn_cast<Constant>(LHS);
     auto *RC = dyn_cast<Constant>(RHS);
-    if (LC && RC)
-      return Fold(ConstantExpr::getSDiv(LC, RC, IsExact));
-    return nullptr;
-  }
-
-  Value *FoldURem(Value *LHS, Value *RHS) const override {
-    auto *LC = dyn_cast<Constant>(LHS);
-    auto *RC = dyn_cast<Constant>(RHS);
-    if (LC && RC)
-      return Fold(ConstantExpr::getURem(LC, RC));
+    if (LC && RC) {
+      unsigned Flags = 0;
+      if (HasNUW)
+        Flags |= OverflowingBinaryOperator::NoUnsignedWrap;
+      if (HasNSW)
+        Flags |= OverflowingBinaryOperator::NoSignedWrap;
+      return Fold(ConstantExpr::get(Opc, LC, RC, Flags));
+    }
     return nullptr;
   }
 
-  Value *FoldSRem(Value *LHS, Value *RHS) const override {
-    auto *LC = dyn_cast<Constant>(LHS);
-    auto *RC = dyn_cast<Constant>(RHS);
-    if (LC && RC)
-      return Fold(ConstantExpr::getSRem(LC, RC));
-    return nullptr;
+  Value *FoldBinOpFMF(Instruction::BinaryOps Opc, Value *LHS, Value *RHS,
+                      FastMathFlags FMF) const override {
+    return FoldBinOp(Opc, LHS, RHS);
   }
-
   Value *FoldICmp(CmpInst::Predicate P, Value *LHS, Value *RHS) const override {
     auto *LC = dyn_cast<Constant>(LHS);
     auto *RC = dyn_cast<Constant>(RHS);
@@ -181,54 +164,6 @@ class TargetFolder final : public IRBuilderFolder {
     return nullptr;
   }
 
-  //===--------------------------------------------------------------------===//
-  // Binary Operators
-  //===--------------------------------------------------------------------===//
-
-  Constant *CreateFAdd(Constant *LHS, Constant *RHS) const override {
-    return Fold(ConstantExpr::getFAdd(LHS, RHS));
-  }
-  Constant *CreateSub(Constant *LHS, Constant *RHS,
-                      bool HasNUW = false, bool HasNSW = false) const override {
-    return Fold(ConstantExpr::getSub(LHS, RHS, HasNUW, HasNSW));
-  }
-  Constant *CreateFSub(Constant *LHS, Constant *RHS) const override {
-    return Fold(ConstantExpr::getFSub(LHS, RHS));
-  }
-  Constant *CreateMul(Constant *LHS, Constant *RHS,
-                      bool HasNUW = false, bool HasNSW = false) const override {
-    return Fold(ConstantExpr::getMul(LHS, RHS, HasNUW, HasNSW));
-  }
-  Constant *CreateFMul(Constant *LHS, Constant *RHS) const override {
-    return Fold(ConstantExpr::getFMul(LHS, RHS));
-  }
-  Constant *CreateFDiv(Constant *LHS, Constant *RHS) const override {
-    return Fold(ConstantExpr::getFDiv(LHS, RHS));
-  }
-  Constant *CreateFRem(Constant *LHS, Constant *RHS) const override {
-    return Fold(ConstantExpr::getFRem(LHS, RHS));
-  }
-  Constant *CreateShl(Constant *LHS, Constant *RHS,
-                      bool HasNUW = false, bool HasNSW = false) const override {
-    return Fold(ConstantExpr::getShl(LHS, RHS, HasNUW, HasNSW));
-  }
-  Constant *CreateLShr(Constant *LHS, Constant *RHS,
-                       bool isExact = false) const override {
-    return Fold(ConstantExpr::getLShr(LHS, RHS, isExact));
-  }
-  Constant *CreateAShr(Constant *LHS, Constant *RHS,
-                       bool isExact = false) const override {
-    return Fold(ConstantExpr::getAShr(LHS, RHS, isExact));
-  }
-  Constant *CreateXor(Constant *LHS, Constant *RHS) const override {
-    return Fold(ConstantExpr::getXor(LHS, RHS));
-  }
-
-  Constant *CreateBinOp(Instruction::BinaryOps Opc,
-                        Constant *LHS, Constant *RHS) const override {
-    return Fold(ConstantExpr::get(Opc, LHS, RHS));
-  }
-
   //===--------------------------------------------------------------------===//
   // Unary Operators
   //===--------------------------------------------------------------------===//

diff  --git a/llvm/include/llvm/IR/ConstantFolder.h b/llvm/include/llvm/IR/ConstantFolder.h
index 6f5661d043a2d..9cf68dc39a652 100644
--- a/llvm/include/llvm/IR/ConstantFolder.h
+++ b/llvm/include/llvm/IR/ConstantFolder.h
@@ -22,6 +22,7 @@
 #include "llvm/IR/ConstantFold.h"
 #include "llvm/IR/IRBuilderFolder.h"
 #include "llvm/IR/Instruction.h"
+#include "llvm/IR/Operator.h"
 
 namespace llvm {
 
@@ -38,61 +39,44 @@ class ConstantFolder final : public IRBuilderFolder {
   // Return an existing value or a constant if the operation can be simplified.
   // Otherwise return nullptr.
   //===--------------------------------------------------------------------===//
-  Value *FoldAdd(Value *LHS, Value *RHS, bool HasNUW = false,
-                 bool HasNSW = false) const override {
-    auto *LC = dyn_cast<Constant>(LHS);
-    auto *RC = dyn_cast<Constant>(RHS);
-    if (LC && RC)
-      return ConstantExpr::getAdd(LC, RC, HasNUW, HasNSW);
-    return nullptr;
-  }
-
-  Value *FoldAnd(Value *LHS, Value *RHS) const override {
-    auto *LC = dyn_cast<Constant>(LHS);
-    auto *RC = dyn_cast<Constant>(RHS);
-    if (LC && RC)
-      return ConstantExpr::getAnd(LC, RC);
-    return nullptr;
-  }
-
-  Value *FoldOr(Value *LHS, Value *RHS) const override {
-    auto *LC = dyn_cast<Constant>(LHS);
-    auto *RC = dyn_cast<Constant>(RHS);
-    if (LC && RC)
-      return ConstantExpr::getOr(LC, RC);
-    return nullptr;
-  }
 
-  Value *FoldUDiv(Value *LHS, Value *RHS, bool IsExact) const override {
+  Value *FoldBinOp(Instruction::BinaryOps Opc, Value *LHS,
+                   Value *RHS) const override {
     auto *LC = dyn_cast<Constant>(LHS);
     auto *RC = dyn_cast<Constant>(RHS);
     if (LC && RC)
-      return ConstantExpr::getUDiv(LC, RC, IsExact);
+      return ConstantExpr::get(Opc, LC, RC);
     return nullptr;
   }
 
-  Value *FoldSDiv(Value *LHS, Value *RHS, bool IsExact) const override {
+  Value *FoldExactBinOp(Instruction::BinaryOps Opc, Value *LHS, Value *RHS,
+                        bool IsExact) const override {
     auto *LC = dyn_cast<Constant>(LHS);
     auto *RC = dyn_cast<Constant>(RHS);
     if (LC && RC)
-      return ConstantExpr::getSDiv(LC, RC, IsExact);
+      return ConstantExpr::get(Opc, LC, RC,
+                               IsExact ? PossiblyExactOperator::IsExact : 0);
     return nullptr;
   }
 
-  Value *FoldURem(Value *LHS, Value *RHS) const override {
+  Value *FoldNoWrapBinOp(Instruction::BinaryOps Opc, Value *LHS, Value *RHS,
+                         bool HasNUW, bool HasNSW) const override {
     auto *LC = dyn_cast<Constant>(LHS);
     auto *RC = dyn_cast<Constant>(RHS);
-    if (LC && RC)
-      return ConstantExpr::getURem(LC, RC);
+    if (LC && RC) {
+      unsigned Flags = 0;
+      if (HasNUW)
+        Flags |= OverflowingBinaryOperator::NoUnsignedWrap;
+      if (HasNSW)
+        Flags |= OverflowingBinaryOperator::NoSignedWrap;
+      return ConstantExpr::get(Opc, LC, RC, Flags);
+    }
     return nullptr;
   }
 
-  Value *FoldSRem(Value *LHS, Value *RHS) const override {
-    auto *LC = dyn_cast<Constant>(LHS);
-    auto *RC = dyn_cast<Constant>(RHS);
-    if (LC && RC)
-      return ConstantExpr::getSRem(LC, RC);
-    return nullptr;
+  Value *FoldBinOpFMF(Instruction::BinaryOps Opc, Value *LHS, Value *RHS,
+                      FastMathFlags FMF) const override {
+    return FoldBinOp(Opc, LHS, RHS);
   }
 
   Value *FoldICmp(CmpInst::Predicate P, Value *LHS, Value *RHS) const override {
@@ -170,68 +154,6 @@ class ConstantFolder final : public IRBuilderFolder {
     return nullptr;
   }
 
-  //===--------------------------------------------------------------------===//
-  // Binary Operators
-  //===--------------------------------------------------------------------===//
-
-  Constant *CreateFAdd(Constant *LHS, Constant *RHS) const override {
-    return ConstantExpr::getFAdd(LHS, RHS);
-  }
-
-  Constant *CreateSub(Constant *LHS, Constant *RHS,
-                      bool HasNUW = false, bool HasNSW = false) const override {
-    return ConstantExpr::getSub(LHS, RHS, HasNUW, HasNSW);
-  }
-
-  Constant *CreateFSub(Constant *LHS, Constant *RHS) const override {
-    return ConstantExpr::getFSub(LHS, RHS);
-  }
-
-  Constant *CreateMul(Constant *LHS, Constant *RHS,
-                      bool HasNUW = false, bool HasNSW = false) const override {
-    return ConstantExpr::getMul(LHS, RHS, HasNUW, HasNSW);
-  }
-
-  Constant *CreateFMul(Constant *LHS, Constant *RHS) const override {
-    return ConstantExpr::getFMul(LHS, RHS);
-  }
-
-  Constant *CreateFDiv(Constant *LHS, Constant *RHS) const override {
-    return ConstantExpr::getFDiv(LHS, RHS);
-  }
-
-  Constant *CreateFRem(Constant *LHS, Constant *RHS) const override {
-    return ConstantExpr::getFRem(LHS, RHS);
-  }
-
-  Constant *CreateShl(Constant *LHS, Constant *RHS,
-                      bool HasNUW = false, bool HasNSW = false) const override {
-    return ConstantExpr::getShl(LHS, RHS, HasNUW, HasNSW);
-  }
-
-  Constant *CreateLShr(Constant *LHS, Constant *RHS,
-                       bool isExact = false) const override {
-    return ConstantExpr::getLShr(LHS, RHS, isExact);
-  }
-
-  Constant *CreateAShr(Constant *LHS, Constant *RHS,
-                       bool isExact = false) const override {
-    return ConstantExpr::getAShr(LHS, RHS, isExact);
-  }
-
-  Constant *CreateOr(Constant *LHS, Constant *RHS) const {
-    return ConstantExpr::getOr(LHS, RHS);
-  }
-
-  Constant *CreateXor(Constant *LHS, Constant *RHS) const override {
-    return ConstantExpr::getXor(LHS, RHS);
-  }
-
-  Constant *CreateBinOp(Instruction::BinaryOps Opc,
-                        Constant *LHS, Constant *RHS) const override {
-    return ConstantExpr::get(Opc, LHS, RHS);
-  }
-
   //===--------------------------------------------------------------------===//
   // Unary Operators
   //===--------------------------------------------------------------------===//

diff  --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h
index 3267d999d6ff8..902d945baabe2 100644
--- a/llvm/include/llvm/IR/IRBuilder.h
+++ b/llvm/include/llvm/IR/IRBuilder.h
@@ -1158,13 +1158,6 @@ class IRBuilderBase {
     return I;
   }
 
-  Value *foldConstant(Instruction::BinaryOps Opc, Value *L,
-                      Value *R, const Twine &Name) const {
-    auto *LC = dyn_cast<Constant>(L);
-    auto *RC = dyn_cast<Constant>(R);
-    return (LC && RC) ? Insert(Folder.CreateBinOp(Opc, LC, RC), Name) : nullptr;
-  }
-
   Value *getConstrainedFPRounding(Optional<RoundingMode> Rounding) {
     RoundingMode UseRounding = DefaultConstrainedRounding;
 
@@ -1206,10 +1199,11 @@ class IRBuilderBase {
 public:
   Value *CreateAdd(Value *LHS, Value *RHS, const Twine &Name = "",
                    bool HasNUW = false, bool HasNSW = false) {
-    if (auto *V = Folder.FoldAdd(LHS, RHS, HasNUW, HasNSW))
+    if (Value *V =
+            Folder.FoldNoWrapBinOp(Instruction::Add, LHS, RHS, HasNUW, HasNSW))
       return V;
-    return CreateInsertNUWNSWBinOp(Instruction::Add, LHS, RHS, Name,
-                                   HasNUW, HasNSW);
+    return CreateInsertNUWNSWBinOp(Instruction::Add, LHS, RHS, Name, HasNUW,
+                                   HasNSW);
   }
 
   Value *CreateNSWAdd(Value *LHS, Value *RHS, const Twine &Name = "") {
@@ -1222,11 +1216,11 @@ class IRBuilderBase {
 
   Value *CreateSub(Value *LHS, Value *RHS, const Twine &Name = "",
                    bool HasNUW = false, bool HasNSW = false) {
-    if (auto *LC = dyn_cast<Constant>(LHS))
-      if (auto *RC = dyn_cast<Constant>(RHS))
-        return Insert(Folder.CreateSub(LC, RC, HasNUW, HasNSW), Name);
-    return CreateInsertNUWNSWBinOp(Instruction::Sub, LHS, RHS, Name,
-                                   HasNUW, HasNSW);
+    if (Value *V =
+            Folder.FoldNoWrapBinOp(Instruction::Sub, LHS, RHS, HasNUW, HasNSW))
+      return V;
+    return CreateInsertNUWNSWBinOp(Instruction::Sub, LHS, RHS, Name, HasNUW,
+                                   HasNSW);
   }
 
   Value *CreateNSWSub(Value *LHS, Value *RHS, const Twine &Name = "") {
@@ -1239,11 +1233,11 @@ class IRBuilderBase {
 
   Value *CreateMul(Value *LHS, Value *RHS, const Twine &Name = "",
                    bool HasNUW = false, bool HasNSW = false) {
-    if (auto *LC = dyn_cast<Constant>(LHS))
-      if (auto *RC = dyn_cast<Constant>(RHS))
-        return Insert(Folder.CreateMul(LC, RC, HasNUW, HasNSW), Name);
-    return CreateInsertNUWNSWBinOp(Instruction::Mul, LHS, RHS, Name,
-                                   HasNUW, HasNSW);
+    if (Value *V =
+            Folder.FoldNoWrapBinOp(Instruction::Mul, LHS, RHS, HasNUW, HasNSW))
+      return V;
+    return CreateInsertNUWNSWBinOp(Instruction::Mul, LHS, RHS, Name, HasNUW,
+                                   HasNSW);
   }
 
   Value *CreateNSWMul(Value *LHS, Value *RHS, const Twine &Name = "") {
@@ -1256,7 +1250,7 @@ class IRBuilderBase {
 
   Value *CreateUDiv(Value *LHS, Value *RHS, const Twine &Name = "",
                     bool isExact = false) {
-    if (Value *V = Folder.FoldUDiv(LHS, RHS, isExact))
+    if (Value *V = Folder.FoldExactBinOp(Instruction::UDiv, LHS, RHS, isExact))
       return V;
     if (!isExact)
       return Insert(BinaryOperator::CreateUDiv(LHS, RHS), Name);
@@ -1269,7 +1263,7 @@ class IRBuilderBase {
 
   Value *CreateSDiv(Value *LHS, Value *RHS, const Twine &Name = "",
                     bool isExact = false) {
-    if (Value *V = Folder.FoldSDiv(LHS, RHS, isExact))
+    if (Value *V = Folder.FoldExactBinOp(Instruction::SDiv, LHS, RHS, isExact))
       return V;
     if (!isExact)
       return Insert(BinaryOperator::CreateSDiv(LHS, RHS), Name);
@@ -1281,22 +1275,22 @@ class IRBuilderBase {
   }
 
   Value *CreateURem(Value *LHS, Value *RHS, const Twine &Name = "") {
-    if (Value *V = Folder.FoldURem(LHS, RHS))
+    if (Value *V = Folder.FoldBinOp(Instruction::URem, LHS, RHS))
       return V;
     return Insert(BinaryOperator::CreateURem(LHS, RHS), Name);
   }
 
   Value *CreateSRem(Value *LHS, Value *RHS, const Twine &Name = "") {
-    if (Value *V = Folder.FoldSRem(LHS, RHS))
+    if (Value *V = Folder.FoldBinOp(Instruction::SRem, LHS, RHS))
       return V;
     return Insert(BinaryOperator::CreateSRem(LHS, RHS), Name);
   }
 
   Value *CreateShl(Value *LHS, Value *RHS, const Twine &Name = "",
                    bool HasNUW = false, bool HasNSW = false) {
-    if (auto *LC = dyn_cast<Constant>(LHS))
-      if (auto *RC = dyn_cast<Constant>(RHS))
-        return Insert(Folder.CreateShl(LC, RC, HasNUW, HasNSW), Name);
+    if (Value *V =
+            Folder.FoldNoWrapBinOp(Instruction::Shl, LHS, RHS, HasNUW, HasNSW))
+      return V;
     return CreateInsertNUWNSWBinOp(Instruction::Shl, LHS, RHS, Name,
                                    HasNUW, HasNSW);
   }
@@ -1315,9 +1309,8 @@ class IRBuilderBase {
 
   Value *CreateLShr(Value *LHS, Value *RHS, const Twine &Name = "",
                     bool isExact = false) {
-    if (auto *LC = dyn_cast<Constant>(LHS))
-      if (auto *RC = dyn_cast<Constant>(RHS))
-        return Insert(Folder.CreateLShr(LC, RC, isExact), Name);
+    if (Value *V = Folder.FoldExactBinOp(Instruction::LShr, LHS, RHS, isExact))
+      return V;
     if (!isExact)
       return Insert(BinaryOperator::CreateLShr(LHS, RHS), Name);
     return Insert(BinaryOperator::CreateExactLShr(LHS, RHS), Name);
@@ -1335,9 +1328,8 @@ class IRBuilderBase {
 
   Value *CreateAShr(Value *LHS, Value *RHS, const Twine &Name = "",
                     bool isExact = false) {
-    if (auto *LC = dyn_cast<Constant>(LHS))
-      if (auto *RC = dyn_cast<Constant>(RHS))
-        return Insert(Folder.CreateAShr(LC, RC, isExact), Name);
+    if (Value *V = Folder.FoldExactBinOp(Instruction::AShr, LHS, RHS, isExact))
+      return V;
     if (!isExact)
       return Insert(BinaryOperator::CreateAShr(LHS, RHS), Name);
     return Insert(BinaryOperator::CreateExactAShr(LHS, RHS), Name);
@@ -1354,7 +1346,7 @@ class IRBuilderBase {
   }
 
   Value *CreateAnd(Value *LHS, Value *RHS, const Twine &Name = "") {
-    if (auto *V = Folder.FoldAnd(LHS, RHS))
+    if (auto *V = Folder.FoldBinOp(Instruction::And, LHS, RHS))
       return V;
     return Insert(BinaryOperator::CreateAnd(LHS, RHS), Name);
   }
@@ -1376,7 +1368,7 @@ class IRBuilderBase {
   }
 
   Value *CreateOr(Value *LHS, Value *RHS, const Twine &Name = "") {
-    if (auto *V = Folder.FoldOr(LHS, RHS))
+    if (auto *V = Folder.FoldBinOp(Instruction::Or, LHS, RHS))
       return V;
     return Insert(BinaryOperator::CreateOr(LHS, RHS), Name);
   }
@@ -1398,7 +1390,8 @@ class IRBuilderBase {
   }
 
   Value *CreateXor(Value *LHS, Value *RHS, const Twine &Name = "") {
-    if (Value *V = foldConstant(Instruction::Xor, LHS, RHS, Name)) return V;
+    if (Value *V = Folder.FoldBinOp(Instruction::Xor, LHS, RHS))
+      return V;
     return Insert(BinaryOperator::CreateXor(LHS, RHS), Name);
   }
 
@@ -1416,7 +1409,8 @@ class IRBuilderBase {
       return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fadd,
                                       L, R, nullptr, Name, FPMD);
 
-    if (Value *V = foldConstant(Instruction::FAdd, L, R, Name)) return V;
+    if (Value *V = Folder.FoldBinOpFMF(Instruction::FAdd, L, R, FMF))
+      return V;
     Instruction *I = setFPAttrs(BinaryOperator::CreateFAdd(L, R), FPMD, FMF);
     return Insert(I, Name);
   }
@@ -1429,9 +1423,10 @@ class IRBuilderBase {
       return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fadd,
                                       L, R, FMFSource, Name);
 
-    if (Value *V = foldConstant(Instruction::FAdd, L, R, Name)) return V;
-    Instruction *I = setFPAttrs(BinaryOperator::CreateFAdd(L, R), nullptr,
-                                FMFSource->getFastMathFlags());
+    FastMathFlags FMF = FMFSource->getFastMathFlags();
+    if (Value *V = Folder.FoldBinOpFMF(Instruction::FAdd, L, R, FMF))
+      return V;
+    Instruction *I = setFPAttrs(BinaryOperator::CreateFAdd(L, R), nullptr, FMF);
     return Insert(I, Name);
   }
 
@@ -1441,7 +1436,8 @@ class IRBuilderBase {
       return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fsub,
                                       L, R, nullptr, Name, FPMD);
 
-    if (Value *V = foldConstant(Instruction::FSub, L, R, Name)) return V;
+    if (Value *V = Folder.FoldBinOpFMF(Instruction::FSub, L, R, FMF))
+      return V;
     Instruction *I = setFPAttrs(BinaryOperator::CreateFSub(L, R), FPMD, FMF);
     return Insert(I, Name);
   }
@@ -1454,9 +1450,10 @@ class IRBuilderBase {
       return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fsub,
                                       L, R, FMFSource, Name);
 
-    if (Value *V = foldConstant(Instruction::FSub, L, R, Name)) return V;
-    Instruction *I = setFPAttrs(BinaryOperator::CreateFSub(L, R), nullptr,
-                                FMFSource->getFastMathFlags());
+    FastMathFlags FMF = FMFSource->getFastMathFlags();
+    if (Value *V = Folder.FoldBinOpFMF(Instruction::FSub, L, R, FMF))
+      return V;
+    Instruction *I = setFPAttrs(BinaryOperator::CreateFSub(L, R), nullptr, FMF);
     return Insert(I, Name);
   }
 
@@ -1466,7 +1463,8 @@ class IRBuilderBase {
       return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fmul,
                                       L, R, nullptr, Name, FPMD);
 
-    if (Value *V = foldConstant(Instruction::FMul, L, R, Name)) return V;
+    if (Value *V = Folder.FoldBinOpFMF(Instruction::FMul, L, R, FMF))
+      return V;
     Instruction *I = setFPAttrs(BinaryOperator::CreateFMul(L, R), FPMD, FMF);
     return Insert(I, Name);
   }
@@ -1479,9 +1477,10 @@ class IRBuilderBase {
       return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fmul,
                                       L, R, FMFSource, Name);
 
-    if (Value *V = foldConstant(Instruction::FMul, L, R, Name)) return V;
-    Instruction *I = setFPAttrs(BinaryOperator::CreateFMul(L, R), nullptr,
-                                FMFSource->getFastMathFlags());
+    FastMathFlags FMF = FMFSource->getFastMathFlags();
+    if (Value *V = Folder.FoldBinOpFMF(Instruction::FMul, L, R, FMF))
+      return V;
+    Instruction *I = setFPAttrs(BinaryOperator::CreateFMul(L, R), nullptr, FMF);
     return Insert(I, Name);
   }
 
@@ -1491,7 +1490,8 @@ class IRBuilderBase {
       return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fdiv,
                                       L, R, nullptr, Name, FPMD);
 
-    if (Value *V = foldConstant(Instruction::FDiv, L, R, Name)) return V;
+    if (Value *V = Folder.FoldBinOpFMF(Instruction::FDiv, L, R, FMF))
+      return V;
     Instruction *I = setFPAttrs(BinaryOperator::CreateFDiv(L, R), FPMD, FMF);
     return Insert(I, Name);
   }
@@ -1504,9 +1504,9 @@ class IRBuilderBase {
       return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fdiv,
                                       L, R, FMFSource, Name);
 
-    if (Value *V = foldConstant(Instruction::FDiv, L, R, Name)) return V;
-    Instruction *I = setFPAttrs(BinaryOperator::CreateFDiv(L, R), nullptr,
-                                FMFSource->getFastMathFlags());
+    if (Value *V = Folder.FoldBinOpFMF(Instruction::FDiv, L, R, FMF))
+      return V;
+    Instruction *I = setFPAttrs(BinaryOperator::CreateFDiv(L, R), nullptr, FMF);
     return Insert(I, Name);
   }
 
@@ -1516,7 +1516,7 @@ class IRBuilderBase {
       return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_frem,
                                       L, R, nullptr, Name, FPMD);
 
-    if (Value *V = foldConstant(Instruction::FRem, L, R, Name)) return V;
+    if (Value *V = Folder.FoldBinOpFMF(Instruction::FRem, L, R, FMF)) return V;
     Instruction *I = setFPAttrs(BinaryOperator::CreateFRem(L, R), FPMD, FMF);
     return Insert(I, Name);
   }
@@ -1529,16 +1529,16 @@ class IRBuilderBase {
       return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_frem,
                                       L, R, FMFSource, Name);
 
-    if (Value *V = foldConstant(Instruction::FRem, L, R, Name)) return V;
-    Instruction *I = setFPAttrs(BinaryOperator::CreateFRem(L, R), nullptr,
-                                FMFSource->getFastMathFlags());
+    FastMathFlags FMF = FMFSource->getFastMathFlags();
+    if (Value *V = Folder.FoldBinOpFMF(Instruction::FRem, L, R, FMF)) return V;
+    Instruction *I = setFPAttrs(BinaryOperator::CreateFRem(L, R), nullptr, FMF);
     return Insert(I, Name);
   }
 
   Value *CreateBinOp(Instruction::BinaryOps Opc,
                      Value *LHS, Value *RHS, const Twine &Name = "",
                      MDNode *FPMathTag = nullptr) {
-    if (Value *V = foldConstant(Opc, LHS, RHS, Name)) return V;
+    if (Value *V = Folder.FoldBinOp(Opc, LHS, RHS)) return V;
     Instruction *BinOp = BinaryOperator::Create(Opc, LHS, RHS);
     if (isa<FPMathOperator>(BinOp))
       setFPAttrs(BinOp, FPMathTag, FMF);

diff  --git a/llvm/include/llvm/IR/IRBuilderFolder.h b/llvm/include/llvm/IR/IRBuilderFolder.h
index 38e150e1d9555..1cc59938b2d91 100644
--- a/llvm/include/llvm/IR/IRBuilderFolder.h
+++ b/llvm/include/llvm/IR/IRBuilderFolder.h
@@ -31,20 +31,19 @@ class IRBuilderFolder {
   // Return an existing value or a constant if the operation can be simplified.
   // Otherwise return nullptr.
   //===--------------------------------------------------------------------===//
-  virtual Value *FoldAdd(Value *LHS, Value *RHS, bool HasNUW = false,
-                         bool HasNSW = false) const = 0;
 
-  virtual Value *FoldAnd(Value *LHS, Value *RHS) const = 0;
+  virtual Value *FoldBinOp(Instruction::BinaryOps Opc, Value *LHS,
+                           Value *RHS) const = 0;
 
-  virtual Value *FoldOr(Value *LHS, Value *RHS) const = 0;
+  virtual Value *FoldExactBinOp(Instruction::BinaryOps Opc, Value *LHS,
+                                Value *RHS, bool IsExact) const = 0;
 
-  virtual Value *FoldUDiv(Value *LHS, Value *RHS, bool IsExact) const = 0;
+  virtual Value *FoldNoWrapBinOp(Instruction::BinaryOps Opc, Value *LHS,
+                                 Value *RHS, bool HasNUW,
+                                 bool HasNSW) const = 0;
 
-  virtual Value *FoldSDiv(Value *LHS, Value *RHS, bool IsExact) const = 0;
-
-  virtual Value *FoldURem(Value *LHS, Value *RHS) const = 0;
-
-  virtual Value *FoldSRem(Value *LHS, Value *RHS) const = 0;
+  virtual Value *FoldBinOpFMF(Instruction::BinaryOps Opc, Value *LHS,
+                              Value *RHS, FastMathFlags FMF) const = 0;
 
   virtual Value *FoldICmp(CmpInst::Predicate P, Value *LHS,
                           Value *RHS) const = 0;
@@ -68,29 +67,6 @@ class IRBuilderFolder {
   virtual Value *FoldShuffleVector(Value *V1, Value *V2,
                                    ArrayRef<int> Mask) const = 0;
 
-  //===--------------------------------------------------------------------===//
-  // Binary Operators
-  //===--------------------------------------------------------------------===//
-
-  virtual Value *CreateFAdd(Constant *LHS, Constant *RHS) const = 0;
-  virtual Value *CreateSub(Constant *LHS, Constant *RHS,
-                           bool HasNUW = false, bool HasNSW = false) const = 0;
-  virtual Value *CreateFSub(Constant *LHS, Constant *RHS) const = 0;
-  virtual Value *CreateMul(Constant *LHS, Constant *RHS,
-                           bool HasNUW = false, bool HasNSW = false) const = 0;
-  virtual Value *CreateFMul(Constant *LHS, Constant *RHS) const = 0;
-  virtual Value *CreateFDiv(Constant *LHS, Constant *RHS) const = 0;
-  virtual Value *CreateFRem(Constant *LHS, Constant *RHS) const = 0;
-  virtual Value *CreateShl(Constant *LHS, Constant *RHS,
-                           bool HasNUW = false, bool HasNSW = false) const = 0;
-  virtual Value *CreateLShr(Constant *LHS, Constant *RHS,
-                            bool isExact = false) const = 0;
-  virtual Value *CreateAShr(Constant *LHS, Constant *RHS,
-                            bool isExact = false) const = 0;
-  virtual Value *CreateXor(Constant *LHS, Constant *RHS) const = 0;
-  virtual Value *CreateBinOp(Instruction::BinaryOps Opc,
-                             Constant *LHS, Constant *RHS) const = 0;
-
   //===--------------------------------------------------------------------===//
   // Unary Operators
   //===--------------------------------------------------------------------===//

diff  --git a/llvm/include/llvm/IR/NoFolder.h b/llvm/include/llvm/IR/NoFolder.h
index 4b9e1834bf5da..183b5a466c597 100644
--- a/llvm/include/llvm/IR/NoFolder.h
+++ b/llvm/include/llvm/IR/NoFolder.h
@@ -43,26 +43,26 @@ class NoFolder final : public IRBuilderFolder {
   // Return an existing value or a constant if the operation can be simplified.
   // Otherwise return nullptr.
   //===--------------------------------------------------------------------===//
-  Value *FoldAdd(Value *LHS, Value *RHS, bool HasNUW = false,
-                 bool HasNSW = false) const override {
+
+  Value *FoldBinOp(Instruction::BinaryOps Opc, Value *LHS,
+                   Value *RHS) const override {
     return nullptr;
   }
 
-  Value *FoldAnd(Value *LHS, Value *RHS) const override { return nullptr; }
-
-  Value *FoldOr(Value *LHS, Value *RHS) const override { return nullptr; }
-
-  Value *FoldUDiv(Value *LHS, Value *RHS, bool IsExact) const override {
+  Value *FoldExactBinOp(Instruction::BinaryOps Opc, Value *LHS, Value *RHS,
+                        bool IsExact) const override {
     return nullptr;
   }
 
-  Value *FoldSDiv(Value *LHS, Value *RHS, bool IsExact) const override {
+  Value *FoldNoWrapBinOp(Instruction::BinaryOps Opc, Value *LHS, Value *RHS,
+                         bool HasNUW, bool HasNSW) const override {
     return nullptr;
   }
 
-  Value *FoldURem(Value *LHS, Value *RHS) const override { return nullptr; }
-
-  Value *FoldSRem(Value *LHS, Value *RHS) const override { return nullptr; }
+  Value *FoldBinOpFMF(Instruction::BinaryOps Opc, Value *LHS, Value *RHS,
+                      FastMathFlags FMF) const override {
+    return nullptr;
+  }
 
   Value *FoldICmp(CmpInst::Predicate P, Value *LHS, Value *RHS) const override {
     return nullptr;
@@ -101,79 +101,6 @@ class NoFolder final : public IRBuilderFolder {
     return nullptr;
   }
 
-  //===--------------------------------------------------------------------===//
-  // Binary Operators
-  //===--------------------------------------------------------------------===//
-
-  Instruction *CreateFAdd(Constant *LHS, Constant *RHS) const override {
-    return BinaryOperator::CreateFAdd(LHS, RHS);
-  }
-
-  Instruction *CreateSub(Constant *LHS, Constant *RHS,
-                         bool HasNUW = false,
-                         bool HasNSW = false) const override {
-    BinaryOperator *BO = BinaryOperator::CreateSub(LHS, RHS);
-    if (HasNUW) BO->setHasNoUnsignedWrap();
-    if (HasNSW) BO->setHasNoSignedWrap();
-    return BO;
-  }
-
-  Instruction *CreateFSub(Constant *LHS, Constant *RHS) const override {
-    return BinaryOperator::CreateFSub(LHS, RHS);
-  }
-
-  Instruction *CreateMul(Constant *LHS, Constant *RHS,
-                         bool HasNUW = false,
-                         bool HasNSW = false) const override {
-    BinaryOperator *BO = BinaryOperator::CreateMul(LHS, RHS);
-    if (HasNUW) BO->setHasNoUnsignedWrap();
-    if (HasNSW) BO->setHasNoSignedWrap();
-    return BO;
-  }
-
-  Instruction *CreateFMul(Constant *LHS, Constant *RHS) const override {
-    return BinaryOperator::CreateFMul(LHS, RHS);
-  }
-
-  Instruction *CreateFDiv(Constant *LHS, Constant *RHS) const override {
-    return BinaryOperator::CreateFDiv(LHS, RHS);
-  }
-
-  Instruction *CreateFRem(Constant *LHS, Constant *RHS) const override {
-    return BinaryOperator::CreateFRem(LHS, RHS);
-  }
-
-  Instruction *CreateShl(Constant *LHS, Constant *RHS, bool HasNUW = false,
-                         bool HasNSW = false) const override {
-    BinaryOperator *BO = BinaryOperator::CreateShl(LHS, RHS);
-    if (HasNUW) BO->setHasNoUnsignedWrap();
-    if (HasNSW) BO->setHasNoSignedWrap();
-    return BO;
-  }
-
-  Instruction *CreateLShr(Constant *LHS, Constant *RHS,
-                          bool isExact = false) const override {
-    if (!isExact)
-      return BinaryOperator::CreateLShr(LHS, RHS);
-    return BinaryOperator::CreateExactLShr(LHS, RHS);
-  }
-
-  Instruction *CreateAShr(Constant *LHS, Constant *RHS,
-                          bool isExact = false) const override {
-    if (!isExact)
-      return BinaryOperator::CreateAShr(LHS, RHS);
-    return BinaryOperator::CreateExactAShr(LHS, RHS);
-  }
-
-  Instruction *CreateXor(Constant *LHS, Constant *RHS) const override {
-    return BinaryOperator::CreateXor(LHS, RHS);
-  }
-
-  Instruction *CreateBinOp(Instruction::BinaryOps Opc,
-                           Constant *LHS, Constant *RHS) const override {
-    return BinaryOperator::Create(Opc, LHS, RHS);
-  }
-
   //===--------------------------------------------------------------------===//
   // Unary Operators
   //===--------------------------------------------------------------------===//

diff  --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index a02eec1c12ad6..401f1ee5a55d5 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -273,7 +273,9 @@ Value *SCEVExpander::InsertBinop(Instruction::BinaryOps Opcode,
   }
 
   // If we haven't found this binop, insert it.
-  Instruction *BO = cast<Instruction>(Builder.CreateBinOp(Opcode, LHS, RHS));
+  // TODO: Use the Builder, which will make CreateBinOp below fold with
+  // InstSimplifyFolder.
+  Instruction *BO = Builder.Insert(BinaryOperator::Create(Opcode, LHS, RHS));
   BO->setDebugLoc(Loc);
   if (Flags & SCEV::FlagNUW)
     BO->setHasNoUnsignedWrap();


        


More information about the llvm-commits mailing list