[llvm] 6e379de - [Reassociate] Preserve `nuw` and `nsw` on `mul` chains

Noah Goldstein via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 1 07:23:02 PDT 2024


Author: Noah Goldstein
Date: 2024-07-01T22:22:36+08:00
New Revision: 6e379de3b144363c2f5a6f9335eef6f42e28ef37

URL: https://github.com/llvm/llvm-project/commit/6e379de3b144363c2f5a6f9335eef6f42e28ef37
DIFF: https://github.com/llvm/llvm-project/commit/6e379de3b144363c2f5a6f9335eef6f42e28ef37.diff

LOG: [Reassociate] Preserve `nuw` and `nsw` on `mul` chains

Basically the same rules as `add` but we also need to ensure all
operands a non-zero.

Proofs: https://alive2.llvm.org/ce/z/jzsYht

Closes #97040

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/Scalar/Reassociate.h
    llvm/lib/Transforms/Scalar/Reassociate.cpp
    llvm/test/Transforms/Reassociate/reassoc-mul-nuw.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Scalar/Reassociate.h b/llvm/include/llvm/Transforms/Scalar/Reassociate.h
index 8c93afcd587e8..3b2d2b83ced62 100644
--- a/llvm/include/llvm/Transforms/Scalar/Reassociate.h
+++ b/llvm/include/llvm/Transforms/Scalar/Reassociate.h
@@ -68,10 +68,13 @@ struct OverflowTracking {
   bool HasNUW;
   bool HasNSW;
   bool AllKnownNonNegative;
+  bool AllKnownNonZero;
   // Note: AllKnownNonNegative can be true in a case where one of the operands
   // is negative, but one the operators is not NSW. AllKnownNonNegative should
   // not be used independently of HasNSW
-  OverflowTracking() : HasNUW(true), HasNSW(true), AllKnownNonNegative(true) {}
+  OverflowTracking()
+      : HasNUW(true), HasNSW(true), AllKnownNonNegative(true),
+        AllKnownNonZero(true) {}
 };
 
 class XorOpnd;

diff  --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp
index ce7b95af24291..cc01d2e59a87e 100644
--- a/llvm/lib/Transforms/Scalar/Reassociate.cpp
+++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp
@@ -539,6 +539,16 @@ static bool LinearizeExprTree(Instruction *I,
     Ops.push_back(std::make_pair(V, Weight));
     if (Opcode == Instruction::Add && Flags.AllKnownNonNegative && Flags.HasNSW)
       Flags.AllKnownNonNegative &= isKnownNonNegative(V, SimplifyQuery(DL));
+    else if (Opcode == Instruction::Mul) {
+      // To preserve NUW we need all inputs non-zero.
+      // To preserve NSW we need all inputs strictly positive.
+      if (Flags.AllKnownNonZero &&
+          (Flags.HasNUW || (Flags.HasNSW && Flags.AllKnownNonNegative))) {
+        Flags.AllKnownNonZero &= isKnownNonZero(V, SimplifyQuery(DL));
+        if (Flags.HasNSW && Flags.AllKnownNonNegative)
+          Flags.AllKnownNonNegative &= isKnownNonNegative(V, SimplifyQuery(DL));
+      }
+    }
   }
 
   // For nilpotent operations or addition there may be no operands, for example
@@ -722,10 +732,9 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I,
           ExpressionChangedStart->setFastMathFlags(Flags);
         } else {
           ExpressionChangedStart->clearSubclassOptionalData();
-          // Note that it doesn't hold for mul if one of the operands is zero.
-          // TODO: We can preserve NUW flag if we prove that all mul operands
-          // are non-zero.
-          if (ExpressionChangedStart->getOpcode() == Instruction::Add) {
+          if (ExpressionChangedStart->getOpcode() == Instruction::Add ||
+              (ExpressionChangedStart->getOpcode() == Instruction::Mul &&
+               Flags.AllKnownNonZero)) {
             if (Flags.HasNUW)
               ExpressionChangedStart->setHasNoUnsignedWrap();
             if (Flags.HasNSW && (Flags.AllKnownNonNegative || Flags.HasNUW))

diff  --git a/llvm/test/Transforms/Reassociate/reassoc-mul-nuw.ll b/llvm/test/Transforms/Reassociate/reassoc-mul-nuw.ll
index 79cd57c292ce6..3ee6ffe2f22b8 100644
--- a/llvm/test/Transforms/Reassociate/reassoc-mul-nuw.ll
+++ b/llvm/test/Transforms/Reassociate/reassoc-mul-nuw.ll
@@ -21,8 +21,8 @@ define i4 @nuw_preserve_non_zero(i4 %a, i4 %b, i4 %c) {
 ; CHECK-NEXT:    [[A0:%.*]] = add nuw i4 [[A]], 1
 ; CHECK-NEXT:    [[B0:%.*]] = add nuw i4 [[B]], 1
 ; CHECK-NEXT:    [[C0:%.*]] = add nuw i4 [[C]], 1
-; CHECK-NEXT:    [[V0:%.*]] = mul i4 [[B0]], [[A0]]
-; CHECK-NEXT:    [[V1:%.*]] = mul i4 [[V0]], [[C0]]
+; CHECK-NEXT:    [[V0:%.*]] = mul nuw i4 [[B0]], [[A0]]
+; CHECK-NEXT:    [[V1:%.*]] = mul nuw i4 [[V0]], [[C0]]
 ; CHECK-NEXT:    ret i4 [[V1]]
 ;
   %a0 = add nuw i4 %a, 1
@@ -40,9 +40,9 @@ define i4 @re_order_mul_nuw(i4 %xx0, i4 %xx1, i4 %xx2, i4 %xx3) {
 ; CHECK-NEXT:    [[X1:%.*]] = add nuw i4 [[XX1]], 1
 ; CHECK-NEXT:    [[X2:%.*]] = add nuw i4 [[XX2]], 1
 ; CHECK-NEXT:    [[X3:%.*]] = add nuw i4 [[XX3]], 1
-; CHECK-NEXT:    [[MUL_B:%.*]] = mul i4 [[X1]], [[X0]]
-; CHECK-NEXT:    [[MUL_A:%.*]] = mul i4 [[MUL_B]], [[X2]]
-; CHECK-NEXT:    [[MUL_C:%.*]] = mul i4 [[MUL_A]], [[X3]]
+; CHECK-NEXT:    [[MUL_B:%.*]] = mul nuw i4 [[X1]], [[X0]]
+; CHECK-NEXT:    [[MUL_A:%.*]] = mul nuw i4 [[MUL_B]], [[X2]]
+; CHECK-NEXT:    [[MUL_C:%.*]] = mul nuw i4 [[MUL_A]], [[X3]]
 ; CHECK-NEXT:    ret i4 [[MUL_C]]
 ;
   %x0 = add nuw i4 %xx0, 1
@@ -88,9 +88,9 @@ define i4 @re_order_mul_nsw(i4 %xx0, i4 %xx1, i4 %xx2, i4 %xx3) {
 ; CHECK-NEXT:    [[X1:%.*]] = call i4 @llvm.smax.i4(i4 [[X1_NZ]], i4 1)
 ; CHECK-NEXT:    [[X2:%.*]] = call i4 @llvm.smax.i4(i4 [[X2_NZ]], i4 1)
 ; CHECK-NEXT:    [[X3:%.*]] = call i4 @llvm.smax.i4(i4 [[X3_NZ]], i4 1)
-; CHECK-NEXT:    [[MUL_B:%.*]] = mul i4 [[X1]], [[X0]]
-; CHECK-NEXT:    [[MUL_A:%.*]] = mul i4 [[MUL_B]], [[X2]]
-; CHECK-NEXT:    [[MUL_C:%.*]] = mul i4 [[MUL_A]], [[X3]]
+; CHECK-NEXT:    [[MUL_B:%.*]] = mul nsw i4 [[X1]], [[X0]]
+; CHECK-NEXT:    [[MUL_A:%.*]] = mul nsw i4 [[MUL_B]], [[X2]]
+; CHECK-NEXT:    [[MUL_C:%.*]] = mul nsw i4 [[MUL_A]], [[X3]]
 ; CHECK-NEXT:    ret i4 [[MUL_C]]
 ;
   %x0_nz = add nuw i4 %xx0, 1
@@ -114,9 +114,9 @@ define i4 @re_order_mul_nsw_nuw(i4 %xx0, i4 %xx1, i4 %xx2, i4 %xx3) {
 ; CHECK-NEXT:    [[X1:%.*]] = add nuw i4 [[XX1]], 1
 ; CHECK-NEXT:    [[X2:%.*]] = add nuw i4 [[XX2]], 1
 ; CHECK-NEXT:    [[X3:%.*]] = add nuw i4 [[XX3]], 1
-; CHECK-NEXT:    [[MUL_B:%.*]] = mul i4 [[X1]], [[X0]]
-; CHECK-NEXT:    [[MUL_A:%.*]] = mul i4 [[MUL_B]], [[X2]]
-; CHECK-NEXT:    [[MUL_C:%.*]] = mul i4 [[MUL_A]], [[X3]]
+; CHECK-NEXT:    [[MUL_B:%.*]] = mul nuw nsw i4 [[X1]], [[X0]]
+; CHECK-NEXT:    [[MUL_A:%.*]] = mul nuw nsw i4 [[MUL_B]], [[X2]]
+; CHECK-NEXT:    [[MUL_C:%.*]] = mul nuw nsw i4 [[MUL_A]], [[X3]]
 ; CHECK-NEXT:    ret i4 [[MUL_C]]
 ;
   %x0 = add nuw i4 %xx0, 1


        


More information about the llvm-commits mailing list