[llvm] 1378e7d - [InstSimplify] add no-wrap parameters to simplifyMul and add more tests; NFC

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 18 10:29:58 PST 2023


Author: Sanjay Patel
Date: 2023-01-18T13:29:30-05:00
New Revision: 1378e7d8b8f3c536a0ad218b1f7a0a6cf963fbcf

URL: https://github.com/llvm/llvm-project/commit/1378e7d8b8f3c536a0ad218b1f7a0a6cf963fbcf
DIFF: https://github.com/llvm/llvm-project/commit/1378e7d8b8f3c536a0ad218b1f7a0a6cf963fbcf.diff

LOG: [InstSimplify] add no-wrap parameters to simplifyMul and add more tests; NFC

This gives mul the same capabilities as add/sub.
A potential improvement with nsw was noted in:
1720ec6da040729f17

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/InstructionSimplify.h
    llvm/lib/Analysis/InstructionSimplify.cpp
    llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
    llvm/test/Transforms/InstSimplify/mul.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h
index 0a2f199794f8b..d75e041567949 100644
--- a/llvm/include/llvm/Analysis/InstructionSimplify.h
+++ b/llvm/include/llvm/Analysis/InstructionSimplify.h
@@ -143,17 +143,36 @@ struct SimplifyQuery {
 // deprecated.
 // Please use the SimplifyQuery versions in new code.
 
-/// Given operand for an FNeg, fold the result or return null.
-Value *simplifyFNegInst(Value *Op, FastMathFlags FMF, const SimplifyQuery &Q);
-
 /// Given operands for an Add, fold the result or return null.
-Value *simplifyAddInst(Value *LHS, Value *RHS, bool isNSW, bool isNUW,
+Value *simplifyAddInst(Value *LHS, Value *RHS, bool IsNSW, bool IsNUW,
                        const SimplifyQuery &Q);
 
 /// Given operands for a Sub, fold the result or return null.
-Value *simplifySubInst(Value *LHS, Value *RHS, bool isNSW, bool isNUW,
+Value *simplifySubInst(Value *LHS, Value *RHS, bool IsNSW, bool IsNUW,
+                       const SimplifyQuery &Q);
+
+/// Given operands for a Mul, fold the result or return null.
+Value *simplifyMulInst(Value *LHS, Value *RHS, bool IsNSW, bool IsNUW,
                        const SimplifyQuery &Q);
 
+/// Given operands for an SDiv, fold the result or return null.
+Value *simplifySDivInst(Value *LHS, Value *RHS, bool IsExact,
+                        const SimplifyQuery &Q);
+
+/// Given operands for a UDiv, fold the result or return null.
+Value *simplifyUDivInst(Value *LHS, Value *RHS, bool IsExact,
+                        const SimplifyQuery &Q);
+
+/// Given operands for an SRem, fold the result or return null.
+Value *simplifySRemInst(Value *LHS, Value *RHS, const SimplifyQuery &Q);
+
+/// Given operands for a URem, fold the result or return null.
+Value *simplifyURemInst(Value *LHS, Value *RHS, const SimplifyQuery &Q);
+
+/// Given operand for an FNeg, fold the result or return null.
+Value *simplifyFNegInst(Value *Op, FastMathFlags FMF, const SimplifyQuery &Q);
+
+
 /// Given operands for an FAdd, fold the result or return null.
 Value *
 simplifyFAddInst(Value *LHS, Value *RHS, FastMathFlags FMF,
@@ -184,17 +203,6 @@ Value *simplifyFMAFMul(Value *LHS, Value *RHS, FastMathFlags FMF,
                        fp::ExceptionBehavior ExBehavior = fp::ebIgnore,
                        RoundingMode Rounding = RoundingMode::NearestTiesToEven);
 
-/// Given operands for a Mul, fold the result or return null.
-Value *simplifyMulInst(Value *LHS, Value *RHS, const SimplifyQuery &Q);
-
-/// Given operands for an SDiv, fold the result or return null.
-Value *simplifySDivInst(Value *LHS, Value *RHS, bool IsExact,
-                        const SimplifyQuery &Q);
-
-/// Given operands for a UDiv, fold the result or return null.
-Value *simplifyUDivInst(Value *LHS, Value *RHS, bool IsExact,
-                        const SimplifyQuery &Q);
-
 /// Given operands for an FDiv, fold the result or return null.
 Value *
 simplifyFDivInst(Value *LHS, Value *RHS, FastMathFlags FMF,
@@ -202,12 +210,6 @@ simplifyFDivInst(Value *LHS, Value *RHS, FastMathFlags FMF,
                  fp::ExceptionBehavior ExBehavior = fp::ebIgnore,
                  RoundingMode Rounding = RoundingMode::NearestTiesToEven);
 
-/// Given operands for an SRem, fold the result or return null.
-Value *simplifySRemInst(Value *LHS, Value *RHS, const SimplifyQuery &Q);
-
-/// Given operands for a URem, fold the result or return null.
-Value *simplifyURemInst(Value *LHS, Value *RHS, const SimplifyQuery &Q);
-
 /// Given operands for an FRem, fold the result or return null.
 Value *
 simplifyFRemInst(Value *LHS, Value *RHS, FastMathFlags FMF,
@@ -216,15 +218,15 @@ simplifyFRemInst(Value *LHS, Value *RHS, FastMathFlags FMF,
                  RoundingMode Rounding = RoundingMode::NearestTiesToEven);
 
 /// Given operands for a Shl, fold the result or return null.
-Value *simplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
+Value *simplifyShlInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW,
                        const SimplifyQuery &Q);
 
 /// Given operands for a LShr, fold the result or return null.
-Value *simplifyLShrInst(Value *Op0, Value *Op1, bool isExact,
+Value *simplifyLShrInst(Value *Op0, Value *Op1, bool IsExact,
                         const SimplifyQuery &Q);
 
 /// Given operands for a AShr, fold the result or return nulll.
-Value *simplifyAShrInst(Value *Op0, Value *Op1, bool isExact,
+Value *simplifyAShrInst(Value *Op0, Value *Op1, bool IsExact,
                         const SimplifyQuery &Q);
 
 /// Given operands for an And, fold the result or return null.

diff  --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 5f6548b9cd59c..53434c4cb2ac3 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -922,8 +922,8 @@ Value *llvm::simplifySubInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW,
 
 /// Given operands for a Mul, see if we can fold the result.
 /// If not, this returns null.
-static Value *simplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
-                              unsigned MaxRecurse) {
+static Value *simplifyMulInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW,
+                              const SimplifyQuery &Q, unsigned MaxRecurse) {
   if (Constant *C = foldOrCommuteConstant(Instruction::Mul, Op0, Op1, Q))
     return C;
 
@@ -980,8 +980,9 @@ static Value *simplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
   return nullptr;
 }
 
-Value *llvm::simplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) {
-  return ::simplifyMulInst(Op0, Op1, Q, RecursionLimit);
+Value *llvm::simplifyMulInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW,
+                             const SimplifyQuery &Q) {
+  return ::simplifyMulInst(Op0, Op1, IsNSW, IsNUW, Q, RecursionLimit);
 }
 
 /// Check for common or similar folds of integer division or integer remainder.
@@ -5707,7 +5708,8 @@ static Value *simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS,
     return simplifySubInst(LHS, RHS,  /* IsNSW */ false, /* IsNUW */ false, Q,
                            MaxRecurse);
   case Instruction::Mul:
-    return simplifyMulInst(LHS, RHS, Q, MaxRecurse);
+    return simplifyMulInst(LHS, RHS, /* IsNSW */ false, /* IsNUW */ false, Q,
+                           MaxRecurse);
   case Instruction::SDiv:
     return simplifySDivInst(LHS, RHS, /* IsExact */ false, Q, MaxRecurse);
   case Instruction::UDiv:
@@ -6582,7 +6584,9 @@ static Value *simplifyInstructionWithOperands(Instruction *I,
   case Instruction::FMul:
     return simplifyFMulInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q);
   case Instruction::Mul:
-    return simplifyMulInst(NewOps[0], NewOps[1], Q);
+    return simplifyMulInst(NewOps[0], NewOps[1],
+                           Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)),
+                           Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q);
   case Instruction::SDiv:
     return simplifySDivInst(NewOps[0], NewOps[1],
                             Q.IIQ.isExact(cast<BinaryOperator>(I)), Q);

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 460731d29c8b7..97f129e200de7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -187,7 +187,9 @@ static Value *foldMulShl1(BinaryOperator &Mul, bool CommuteOperands,
 
 Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
-  if (Value *V = simplifyMulInst(Op0, Op1, SQ.getWithInstruction(&I)))
+  if (Value *V =
+          simplifyMulInst(Op0, Op1, I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
+                          SQ.getWithInstruction(&I)))
     return replaceInstUsesWith(I, V);
 
   if (SimplifyAssociativeOrCommutative(I))

diff  --git a/llvm/test/Transforms/InstSimplify/mul.ll b/llvm/test/Transforms/InstSimplify/mul.ll
index 902bde54841f6..443a2250b0a20 100644
--- a/llvm/test/Transforms/InstSimplify/mul.ll
+++ b/llvm/test/Transforms/InstSimplify/mul.ll
@@ -50,12 +50,39 @@ define i32 @poison(i32 %x) {
   ret i32 %v
 }
 
+define i1 @mul_i1(i1 %x, i1 %y) {
+; CHECK-LABEL: @mul_i1(
+; CHECK-NEXT:    [[R:%.*]] = mul i1 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    ret i1 [[R]]
+;
+  %r = mul i1 %x, %y
+  ret i1 %r
+}
+
+define i1 @mul_i1_nsw(i1 %x, i1 %y) {
+; CHECK-LABEL: @mul_i1_nsw(
+; CHECK-NEXT:    [[R:%.*]] = mul nsw i1 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    ret i1 [[R]]
+;
+  %r = mul nsw i1 %x, %y
+  ret i1 %r
+}
+
+define i1 @mul_i1_nuw(i1 %x, i1 %y) {
+; CHECK-LABEL: @mul_i1_nuw(
+; CHECK-NEXT:    [[R:%.*]] = mul nuw i1 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    ret i1 [[R]]
+;
+  %r = mul nuw i1 %x, %y
+  ret i1 %r
+}
+
 define i1 @square_i1(i1 %x) {
 ; CHECK-LABEL: @square_i1(
 ; CHECK-NEXT:    ret i1 [[X:%.*]]
 ;
   %r = mul i1 %x, %x
-  ret i1 %x
+  ret i1 %r
 }
 
 define i1 @square_i1_nsw(i1 %x) {
@@ -63,7 +90,7 @@ define i1 @square_i1_nsw(i1 %x) {
 ; CHECK-NEXT:    ret i1 [[X:%.*]]
 ;
   %r = mul nsw i1 %x, %x
-  ret i1 %x
+  ret i1 %r
 }
 
 define i1 @square_i1_nuw(i1 %x) {
@@ -71,5 +98,5 @@ define i1 @square_i1_nuw(i1 %x) {
 ; CHECK-NEXT:    ret i1 [[X:%.*]]
 ;
   %r = mul nuw i1 %x, %x
-  ret i1 %x
+  ret i1 %r
 }


        


More information about the llvm-commits mailing list