[llvm] r299309 - [InstSimplify] add constant folding for fdiv/frem

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Sat Apr 1 12:05:11 PDT 2017


Author: spatel
Date: Sat Apr  1 14:05:11 2017
New Revision: 299309

URL: http://llvm.org/viewvc/llvm-project?rev=299309&view=rev
Log:
[InstSimplify] add constant folding for fdiv/frem

Also, add a helper function so we don't have to repeat this code for each binop.

Modified:
    llvm/trunk/lib/Analysis/InstructionSimplify.cpp
    llvm/trunk/test/Transforms/InstSimplify/fdiv.ll

Modified: llvm/trunk/lib/Analysis/InstructionSimplify.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/InstructionSimplify.cpp?rev=299309&r1=299308&r2=299309&view=diff
==============================================================================
--- llvm/trunk/lib/Analysis/InstructionSimplify.cpp (original)
+++ llvm/trunk/lib/Analysis/InstructionSimplify.cpp Sat Apr  1 14:05:11 2017
@@ -528,17 +528,26 @@ static Value *ThreadCmpOverPHI(CmpInst::
   return CommonValue;
 }
 
+static Constant *foldOrCommuteConstant(Instruction::BinaryOps Opcode,
+                                       Value *&Op0, Value *&Op1,
+                                       const Query &Q) {
+  if (auto *CLHS = dyn_cast<Constant>(Op0)) {
+    if (auto *CRHS = dyn_cast<Constant>(Op1))
+      return ConstantFoldBinaryOpOperands(Opcode, CLHS, CRHS, Q.DL);
+
+    // Canonicalize the constant to the RHS if this is a commutative operation.
+    if (Instruction::isCommutative(Opcode))
+      std::swap(Op0, Op1);
+  }
+  return nullptr;
+}
+
 /// Given operands for an Add, see if we can fold the result.
 /// If not, this returns null.
 static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
                               const Query &Q, unsigned MaxRecurse) {
-  if (Constant *CLHS = dyn_cast<Constant>(Op0)) {
-    if (Constant *CRHS = dyn_cast<Constant>(Op1))
-      return ConstantFoldBinaryOpOperands(Instruction::Add, CLHS, CRHS, Q.DL);
-
-    // Canonicalize the constant to the RHS.
-    std::swap(Op0, Op1);
-  }
+  if (Constant *C = foldOrCommuteConstant(Instruction::Add, Op0, Op1, Q))
+    return C;
 
   // X + undef -> undef
   if (match(Op1, m_Undef()))
@@ -674,9 +683,8 @@ static Constant *computePointerDifferenc
 /// If not, this returns null.
 static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
                               const Query &Q, unsigned MaxRecurse) {
-  if (Constant *CLHS = dyn_cast<Constant>(Op0))
-    if (Constant *CRHS = dyn_cast<Constant>(Op1))
-      return ConstantFoldBinaryOpOperands(Instruction::Sub, CLHS, CRHS, Q.DL);
+  if (Constant *C = foldOrCommuteConstant(Instruction::Sub, Op0, Op1, Q))
+    return C;
 
   // X - undef -> undef
   // undef - X -> undef
@@ -816,13 +824,8 @@ Value *llvm::SimplifySubInst(Value *Op0,
 /// returns null.
 static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF,
                               const Query &Q, unsigned MaxRecurse) {
-  if (Constant *CLHS = dyn_cast<Constant>(Op0)) {
-    if (Constant *CRHS = dyn_cast<Constant>(Op1))
-      return ConstantFoldBinaryOpOperands(Instruction::FAdd, CLHS, CRHS, Q.DL);
-
-    // Canonicalize the constant to the RHS.
-    std::swap(Op0, Op1);
-  }
+  if (Constant *C = foldOrCommuteConstant(Instruction::FAdd, Op0, Op1, Q))
+    return C;
 
   // fadd X, -0 ==> X
   if (match(Op1, m_NegZero()))
@@ -855,10 +858,8 @@ static Value *SimplifyFAddInst(Value *Op
 /// returns null.
 static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF,
                               const Query &Q, unsigned MaxRecurse) {
-  if (Constant *CLHS = dyn_cast<Constant>(Op0)) {
-    if (Constant *CRHS = dyn_cast<Constant>(Op1))
-      return ConstantFoldBinaryOpOperands(Instruction::FSub, CLHS, CRHS, Q.DL);
-  }
+  if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q))
+    return C;
 
   // fsub X, 0 ==> X
   if (match(Op1, m_Zero()))
@@ -889,13 +890,8 @@ static Value *SimplifyFSubInst(Value *Op
 /// Given the operands for an FMul, see if we can fold the result
 static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF,
                                const Query &Q, unsigned MaxRecurse) {
-  if (Constant *CLHS = dyn_cast<Constant>(Op0)) {
-    if (Constant *CRHS = dyn_cast<Constant>(Op1))
-      return ConstantFoldBinaryOpOperands(Instruction::FMul, CLHS, CRHS, Q.DL);
-
-    // Canonicalize the constant to the RHS.
-    std::swap(Op0, Op1);
-  }
+  if (Constant *C = foldOrCommuteConstant(Instruction::FMul, Op0, Op1, Q))
+    return C;
 
   // fmul X, 1.0 ==> X
   if (match(Op1, m_FPOne()))
@@ -912,13 +908,8 @@ static Value *SimplifyFMulInst(Value *Op
 /// If not, this returns null.
 static Value *SimplifyMulInst(Value *Op0, Value *Op1, const Query &Q,
                               unsigned MaxRecurse) {
-  if (Constant *CLHS = dyn_cast<Constant>(Op0)) {
-    if (Constant *CRHS = dyn_cast<Constant>(Op1))
-      return ConstantFoldBinaryOpOperands(Instruction::Mul, CLHS, CRHS, Q.DL);
-
-    // Canonicalize the constant to the RHS.
-    std::swap(Op0, Op1);
-  }
+  if (Constant *C = foldOrCommuteConstant(Instruction::Mul, Op0, Op1, Q))
+    return C;
 
   // X * undef -> 0
   if (match(Op1, m_Undef()))
@@ -1060,9 +1051,8 @@ static Value *simplifyDivRem(Value *Op0,
 /// If not, this returns null.
 static Value *SimplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
                           const Query &Q, unsigned MaxRecurse) {
-  if (Constant *C0 = dyn_cast<Constant>(Op0))
-    if (Constant *C1 = dyn_cast<Constant>(Op1))
-      return ConstantFoldBinaryOpOperands(Opcode, C0, C1, Q.DL);
+  if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q))
+    return C;
 
   if (Value *V = simplifyDivRem(Op0, Op1, true))
     return V;
@@ -1162,6 +1152,9 @@ Value *llvm::SimplifyUDivInst(Value *Op0
 
 static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF,
                                const Query &Q, unsigned) {
+  if (Constant *C = foldOrCommuteConstant(Instruction::FDiv, Op0, Op1, Q))
+    return C;
+
   // undef / X -> undef    (the undef could be a snan).
   if (match(Op0, m_Undef()))
     return Op0;
@@ -1211,9 +1204,8 @@ Value *llvm::SimplifyFDivInst(Value *Op0
 /// If not, this returns null.
 static Value *SimplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
                           const Query &Q, unsigned MaxRecurse) {
-  if (Constant *C0 = dyn_cast<Constant>(Op0))
-    if (Constant *C1 = dyn_cast<Constant>(Op1))
-      return ConstantFoldBinaryOpOperands(Opcode, C0, C1, Q.DL);
+  if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q))
+    return C;
 
   if (Value *V = simplifyDivRem(Op0, Op1, false))
     return V;
@@ -1287,7 +1279,10 @@ Value *llvm::SimplifyURemInst(Value *Op0
 }
 
 static Value *SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF,
-                               const Query &, unsigned) {
+                               const Query &Q, unsigned) {
+  if (Constant *C = foldOrCommuteConstant(Instruction::FRem, Op0, Op1, Q))
+    return C;
+
   // undef % X -> undef    (the undef could be a snan).
   if (match(Op0, m_Undef()))
     return Op0;
@@ -1343,11 +1338,10 @@ static bool isUndefShift(Value *Amount)
 
 /// Given operands for an Shl, LShr or AShr, see if we can fold the result.
 /// If not, this returns null.
-static Value *SimplifyShift(unsigned Opcode, Value *Op0, Value *Op1,
-                            const Query &Q, unsigned MaxRecurse) {
-  if (Constant *C0 = dyn_cast<Constant>(Op0))
-    if (Constant *C1 = dyn_cast<Constant>(Op1))
-      return ConstantFoldBinaryOpOperands(Opcode, C0, C1, Q.DL);
+static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0,
+                            Value *Op1, const Query &Q, unsigned MaxRecurse) {
+  if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q))
+    return C;
 
   // 0 shift by X -> 0
   if (match(Op0, m_Zero()))
@@ -1394,8 +1388,8 @@ static Value *SimplifyShift(unsigned Opc
 
 /// \brief Given operands for an Shl, LShr or AShr, see if we can
 /// fold the result.  If not, this returns null.
-static Value *SimplifyRightShift(unsigned Opcode, Value *Op0, Value *Op1,
-                                 bool isExact, const Query &Q,
+static Value *SimplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0,
+                                 Value *Op1, bool isExact, const Query &Q,
                                  unsigned MaxRecurse) {
   if (Value *V = SimplifyShift(Opcode, Op0, Op1, Q, MaxRecurse))
     return V;
@@ -1644,13 +1638,8 @@ static Value *SimplifyAndOfICmps(ICmpIns
 /// If not, this returns null.
 static Value *SimplifyAndInst(Value *Op0, Value *Op1, const Query &Q,
                               unsigned MaxRecurse) {
-  if (Constant *CLHS = dyn_cast<Constant>(Op0)) {
-    if (Constant *CRHS = dyn_cast<Constant>(Op1))
-      return ConstantFoldBinaryOpOperands(Instruction::And, CLHS, CRHS, Q.DL);
-
-    // Canonicalize the constant to the RHS.
-    std::swap(Op0, Op1);
-  }
+  if (Constant *C = foldOrCommuteConstant(Instruction::And, Op0, Op1, Q))
+    return C;
 
   // X & undef -> 0
   if (match(Op1, m_Undef()))
@@ -1846,13 +1835,8 @@ static Value *SimplifyOrOfICmps(ICmpInst
 /// If not, this returns null.
 static Value *SimplifyOrInst(Value *Op0, Value *Op1, const Query &Q,
                              unsigned MaxRecurse) {
-  if (Constant *CLHS = dyn_cast<Constant>(Op0)) {
-    if (Constant *CRHS = dyn_cast<Constant>(Op1))
-      return ConstantFoldBinaryOpOperands(Instruction::Or, CLHS, CRHS, Q.DL);
-
-    // Canonicalize the constant to the RHS.
-    std::swap(Op0, Op1);
-  }
+  if (Constant *C = foldOrCommuteConstant(Instruction::Or, Op0, Op1, Q))
+    return C;
 
   // X | undef -> -1
   if (match(Op1, m_Undef()))
@@ -1979,13 +1963,8 @@ Value *llvm::SimplifyOrInst(Value *Op0,
 /// If not, this returns null.
 static Value *SimplifyXorInst(Value *Op0, Value *Op1, const Query &Q,
                               unsigned MaxRecurse) {
-  if (Constant *CLHS = dyn_cast<Constant>(Op0)) {
-    if (Constant *CRHS = dyn_cast<Constant>(Op1))
-      return ConstantFoldBinaryOpOperands(Instruction::Xor, CLHS, CRHS, Q.DL);
-
-    // Canonicalize the constant to the RHS.
-    std::swap(Op0, Op1);
-  }
+  if (Constant *C = foldOrCommuteConstant(Instruction::Xor, Op0, Op1, Q))
+    return C;
 
   // A ^ undef -> undef
   if (match(Op1, m_Undef()))

Modified: llvm/trunk/test/Transforms/InstSimplify/fdiv.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstSimplify/fdiv.ll?rev=299309&r1=299308&r2=299309&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/InstSimplify/fdiv.ll (original)
+++ llvm/trunk/test/Transforms/InstSimplify/fdiv.ll Sat Apr  1 14:05:11 2017
@@ -3,8 +3,7 @@
 
 define float @fdiv_constant_fold() {
 ; CHECK-LABEL: @fdiv_constant_fold(
-; CHECK-NEXT:    [[F:%.*]] = fdiv float 3.000000e+00, 2.000000e+00
-; CHECK-NEXT:    ret float [[F]]
+; CHECK-NEXT:    ret float 1.500000e+00
 ;
   %f = fdiv float 3.0, 2.0
   ret float %f
@@ -12,8 +11,7 @@ define float @fdiv_constant_fold() {
 
 define float @frem_constant_fold() {
 ; CHECK-LABEL: @frem_constant_fold(
-; CHECK-NEXT:    [[F:%.*]] = frem float 3.000000e+00, 2.000000e+00
-; CHECK-NEXT:    ret float [[F]]
+; CHECK-NEXT:    ret float 1.000000e+00
 ;
   %f = frem float 3.0, 2.0
   ret float %f




More information about the llvm-commits mailing list