[llvm] [InstCombine] Infer zext nneg flag directly (PR #71906)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 10 00:53:16 PST 2023


https://github.com/dtcxzyw created https://github.com/llvm/llvm-project/pull/71906

This PR sets the `nneg` flag directly to salvage information and avoid recomputing `isKnownNonNegative`, which was introduced by https://github.com/llvm/llvm-project/pull/71534.

Alive2: https://alive2.llvm.org/ce/z/voT6HG

Compile-time impact: https://llvm-compile-time-tracker.com/compare.php?from=dc3faf0ed0e3f1ea9e435a006167d9649f865da1&to=04c38cca49af96b6a66fc006576d8482d1725648&stat=instructions:u

>From 04c38cca49af96b6a66fc006576d8482d1725648 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Fri, 10 Nov 2023 16:32:29 +0800
Subject: [PATCH] [InstCombine] Infer zext nneg flag directly

---
 llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp |  6 +++++-
 llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp |  4 +++-
 .../Transforms/InstCombine/InstCombineCompares.cpp   |  3 +++
 .../Transforms/InstCombine/InstCombineMulDivRem.cpp  |  7 ++++++-
 .../lib/Transforms/InstCombine/InstCombineShifts.cpp | 11 +++++++++--
 .../InstCombine/InstCombineSimplifyDemanded.cpp      |  2 ++
 llvm/test/Transforms/InstCombine/div-shift.ll        |  4 ++--
 llvm/test/Transforms/InstCombine/load-cmp.ll         |  2 +-
 llvm/test/Transforms/InstCombine/shift-by-signext.ll | 12 ++++++------
 llvm/test/Transforms/InstCombine/vector-udiv.ll      |  4 ++--
 10 files changed, 39 insertions(+), 16 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index b0d31a92464e375..bde39606886ba0a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -543,6 +543,8 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) {
       auto *Cttz = IC.Builder.CreateBinaryIntrinsic(Intrinsic::cttz, X,
                                                     IC.Builder.getTrue());
       auto *ZextCttz = IC.Builder.CreateZExt(Cttz, II.getType());
+      if (auto *ZextInst = dyn_cast<PossiblyNonNegInst>(ZextCttz))
+        ZextInst->setNonNeg();
       return IC.replaceInstUsesWith(II, ZextCttz);
     }
 
@@ -639,7 +641,9 @@ static Instruction *foldCtpop(IntrinsicInst &II, InstCombinerImpl &IC) {
   // ctpop (zext X) --> zext (ctpop X)
   if (match(Op0, m_OneUse(m_ZExt(m_Value(X))))) {
     Value *NarrowPop = IC.Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, X);
-    return CastInst::Create(Instruction::ZExt, NarrowPop, Ty);
+    Instruction *Zext = CastInst::Create(Instruction::ZExt, NarrowPop, Ty);
+    Zext->setNonNeg();
+    return Zext;
   }
 
   KnownBits Known(BitWidth);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index c59cf5a2a86ef44..a52b3fcb35234ae 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1159,7 +1159,9 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) {
       APInt AndValue(APInt::getLowBitsSet(SrcSize, MidSize));
       Constant *AndConst = ConstantInt::get(A->getType(), AndValue);
       Value *And = Builder.CreateAnd(A, AndConst, CSrc->getName() + ".mask");
-      return new ZExtInst(And, DestTy);
+      auto *ZExt = new ZExtInst(And, DestTy);
+      ZExt->setNonNeg();
+      return ZExt;
     }
 
     if (SrcSize == DstSize) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 7c2ad92f919a3cc..bf69e467f9b6fed 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -5914,6 +5914,9 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
         APInt ShortMask = CI->getValue().trunc(MulWidth);
         Value *ShortAnd = Builder.CreateAnd(Mul, ShortMask);
         Value *Zext = Builder.CreateZExt(ShortAnd, BO->getType());
+        if (ShortMask.isNonNegative())
+          if (auto *ZextInst = dyn_cast<PossiblyNonNegInst>(Zext))
+            ZextInst->setNonNeg();
         IC.replaceInstUsesWith(*BO, Zext);
       } else {
         llvm_unreachable("Unexpected Binary operation");
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index db0804380855e3a..b471b2c887bd160 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1243,7 +1243,12 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
   Value *X, *Y;
   if (match(Op, m_ZExt(m_Value(X))))
     if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
-      return IfFold([&]() { return Builder.CreateZExt(LogX, Op->getType()); });
+      return IfFold([&]() {
+        auto *V = Builder.CreateZExt(LogX, Op->getType());
+        if (auto *ZextInst = dyn_cast<PossiblyNonNegInst>(V))
+          ZextInst->setNonNeg();
+        return V;
+      });
 
   // log2(X << Y) -> log2(X) + Y
   // FIXME: Require one use unless X is 1?
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index e178f9536b69f21..ad3ce748dcd96c1 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -409,6 +409,8 @@ Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) {
   Value *Y;
   if (match(Op1, m_OneUse(m_SExt(m_Value(Y))))) {
     Value *NewExt = Builder.CreateZExt(Y, Ty, Op1->getName());
+    if (auto *ZextInst = dyn_cast<PossiblyNonNegInst>(NewExt))
+      ZextInst->setNonNeg();
     return BinaryOperator::Create(I.getOpcode(), Op0, NewExt);
   }
 
@@ -1331,7 +1333,10 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
              "Big shift not simplified to zero?");
       // lshr (zext iM X to iN), C --> zext (lshr X, C) to iN
       Value *NewLShr = Builder.CreateLShr(X, ShAmtC);
-      return new ZExtInst(NewLShr, Ty);
+      auto *Zext = new ZExtInst(NewLShr, Ty);
+      if (ShAmtC)
+        Zext->setNonNeg();
+      return Zext;
     }
 
     if (match(Op0, m_SExt(m_Value(X)))) {
@@ -1349,7 +1354,9 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
         // zeros? lshr (sext iM X to iN), N-1 --> zext (lshr X, M-1) to iN
         if (ShAmtC == BitWidth - 1) {
           Value *NewLShr = Builder.CreateLShr(X, SrcTyBitWidth - 1);
-          return new ZExtInst(NewLShr, Ty);
+          auto *Zext = new ZExtInst(NewLShr, Ty);
+          Zext->setNonNeg();
+          return Zext;
         }
 
         // lshr (sext iM X to iN), N-M --> zext (ashr X, min(N-M, M-1)) to iN
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index 29cf04d82b2e401..248e3f8f7075cf3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -481,6 +481,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
         DemandedMask.getActiveBits() <= SrcBitWidth) {
       // Convert to ZExt cast.
       CastInst *NewCast = new ZExtInst(I->getOperand(0), VTy, I->getName());
+      if (InputKnown.isNonNegative())
+        NewCast->setNonNeg();
       return InsertNewInstWith(NewCast, I->getIterator());
      }
 
diff --git a/llvm/test/Transforms/InstCombine/div-shift.ll b/llvm/test/Transforms/InstCombine/div-shift.ll
index d208837f04594ab..9610746811a43a6 100644
--- a/llvm/test/Transforms/InstCombine/div-shift.ll
+++ b/llvm/test/Transforms/InstCombine/div-shift.ll
@@ -38,7 +38,7 @@ define <2 x i32> @t1vec(<2 x i16> %x, <2 x i32> %y) {
 ; rdar://11721329
 define i64 @t2(i64 %x, i32 %y) {
 ; CHECK-LABEL: @t2(
-; CHECK-NEXT:    [[TMP1:%.*]] = zext i32 [[Y:%.*]] to i64
+; CHECK-NEXT:    [[TMP1:%.*]] = zext nneg i32 [[Y:%.*]] to i64
 ; CHECK-NEXT:    [[TMP2:%.*]] = lshr i64 [[X:%.*]], [[TMP1]]
 ; CHECK-NEXT:    ret i64 [[TMP2]]
 ;
@@ -52,7 +52,7 @@ define i64 @t2(i64 %x, i32 %y) {
 define i64 @t3(i64 %x, i32 %y) {
 ; CHECK-LABEL: @t3(
 ; CHECK-NEXT:    [[TMP1:%.*]] = add i32 [[Y:%.*]], 2
-; CHECK-NEXT:    [[TMP2:%.*]] = zext i32 [[TMP1]] to i64
+; CHECK-NEXT:    [[TMP2:%.*]] = zext nneg i32 [[TMP1]] to i64
 ; CHECK-NEXT:    [[TMP3:%.*]] = lshr i64 [[X:%.*]], [[TMP2]]
 ; CHECK-NEXT:    ret i64 [[TMP3]]
 ;
diff --git a/llvm/test/Transforms/InstCombine/load-cmp.ll b/llvm/test/Transforms/InstCombine/load-cmp.ll
index 56f6e042b3cafdb..e941284a798ed11 100644
--- a/llvm/test/Transforms/InstCombine/load-cmp.ll
+++ b/llvm/test/Transforms/InstCombine/load-cmp.ll
@@ -122,7 +122,7 @@ define i1 @test4(i32 %X) {
 
 define i1 @test4_i16(i16 %X) {
 ; CHECK-LABEL: @test4_i16(
-; CHECK-NEXT:    [[TMP1:%.*]] = zext i16 [[X:%.*]] to i32
+; CHECK-NEXT:    [[TMP1:%.*]] = zext nneg i16 [[X:%.*]] to i32
 ; CHECK-NEXT:    [[TMP2:%.*]] = lshr i32 933, [[TMP1]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = and i32 [[TMP2]], 1
 ; CHECK-NEXT:    [[R:%.*]] = icmp ne i32 [[TMP3]], 0
diff --git a/llvm/test/Transforms/InstCombine/shift-by-signext.ll b/llvm/test/Transforms/InstCombine/shift-by-signext.ll
index b72f33fc6502364..7fe4364cc08016d 100644
--- a/llvm/test/Transforms/InstCombine/shift-by-signext.ll
+++ b/llvm/test/Transforms/InstCombine/shift-by-signext.ll
@@ -6,7 +6,7 @@
 
 define i32 @t0_shl(i32 %x, i8 %shamt) {
 ; CHECK-LABEL: @t0_shl(
-; CHECK-NEXT:    [[SHAMT_WIDE1:%.*]] = zext i8 [[SHAMT:%.*]] to i32
+; CHECK-NEXT:    [[SHAMT_WIDE1:%.*]] = zext nneg i8 [[SHAMT:%.*]] to i32
 ; CHECK-NEXT:    [[R:%.*]] = shl i32 [[X:%.*]], [[SHAMT_WIDE1]]
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
@@ -16,7 +16,7 @@ define i32 @t0_shl(i32 %x, i8 %shamt) {
 }
 define i32 @t1_lshr(i32 %x, i8 %shamt) {
 ; CHECK-LABEL: @t1_lshr(
-; CHECK-NEXT:    [[SHAMT_WIDE1:%.*]] = zext i8 [[SHAMT:%.*]] to i32
+; CHECK-NEXT:    [[SHAMT_WIDE1:%.*]] = zext nneg i8 [[SHAMT:%.*]] to i32
 ; CHECK-NEXT:    [[R:%.*]] = lshr i32 [[X:%.*]], [[SHAMT_WIDE1]]
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
@@ -26,7 +26,7 @@ define i32 @t1_lshr(i32 %x, i8 %shamt) {
 }
 define i32 @t2_ashr(i32 %x, i8 %shamt) {
 ; CHECK-LABEL: @t2_ashr(
-; CHECK-NEXT:    [[SHAMT_WIDE1:%.*]] = zext i8 [[SHAMT:%.*]] to i32
+; CHECK-NEXT:    [[SHAMT_WIDE1:%.*]] = zext nneg i8 [[SHAMT:%.*]] to i32
 ; CHECK-NEXT:    [[R:%.*]] = ashr i32 [[X:%.*]], [[SHAMT_WIDE1]]
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
@@ -37,7 +37,7 @@ define i32 @t2_ashr(i32 %x, i8 %shamt) {
 
 define <2 x i32> @t3_vec_shl(<2 x i32> %x, <2 x i8> %shamt) {
 ; CHECK-LABEL: @t3_vec_shl(
-; CHECK-NEXT:    [[SHAMT_WIDE1:%.*]] = zext <2 x i8> [[SHAMT:%.*]] to <2 x i32>
+; CHECK-NEXT:    [[SHAMT_WIDE1:%.*]] = zext nneg <2 x i8> [[SHAMT:%.*]] to <2 x i32>
 ; CHECK-NEXT:    [[R:%.*]] = shl <2 x i32> [[X:%.*]], [[SHAMT_WIDE1]]
 ; CHECK-NEXT:    ret <2 x i32> [[R]]
 ;
@@ -47,7 +47,7 @@ define <2 x i32> @t3_vec_shl(<2 x i32> %x, <2 x i8> %shamt) {
 }
 define <2 x i32> @t4_vec_lshr(<2 x i32> %x, <2 x i8> %shamt) {
 ; CHECK-LABEL: @t4_vec_lshr(
-; CHECK-NEXT:    [[SHAMT_WIDE1:%.*]] = zext <2 x i8> [[SHAMT:%.*]] to <2 x i32>
+; CHECK-NEXT:    [[SHAMT_WIDE1:%.*]] = zext nneg <2 x i8> [[SHAMT:%.*]] to <2 x i32>
 ; CHECK-NEXT:    [[R:%.*]] = lshr <2 x i32> [[X:%.*]], [[SHAMT_WIDE1]]
 ; CHECK-NEXT:    ret <2 x i32> [[R]]
 ;
@@ -57,7 +57,7 @@ define <2 x i32> @t4_vec_lshr(<2 x i32> %x, <2 x i8> %shamt) {
 }
 define <2 x i32> @t5_vec_ashr(<2 x i32> %x, <2 x i8> %shamt) {
 ; CHECK-LABEL: @t5_vec_ashr(
-; CHECK-NEXT:    [[SHAMT_WIDE1:%.*]] = zext <2 x i8> [[SHAMT:%.*]] to <2 x i32>
+; CHECK-NEXT:    [[SHAMT_WIDE1:%.*]] = zext nneg <2 x i8> [[SHAMT:%.*]] to <2 x i32>
 ; CHECK-NEXT:    [[R:%.*]] = ashr <2 x i32> [[X:%.*]], [[SHAMT_WIDE1]]
 ; CHECK-NEXT:    ret <2 x i32> [[R]]
 ;
diff --git a/llvm/test/Transforms/InstCombine/vector-udiv.ll b/llvm/test/Transforms/InstCombine/vector-udiv.ll
index c3468e95ba3d4fd..c817b3a1ac5a0a4 100644
--- a/llvm/test/Transforms/InstCombine/vector-udiv.ll
+++ b/llvm/test/Transforms/InstCombine/vector-udiv.ll
@@ -75,7 +75,7 @@ define <4 x i32> @test_v4i32_shl_const_pow2(<4 x i32> %a0, <4 x i32> %a1) {
 define <4 x i32> @test_v4i32_zext_shl_splatconst_pow2(<4 x i32> %a0, <4 x i16> %a1) {
 ; CHECK-LABEL: @test_v4i32_zext_shl_splatconst_pow2(
 ; CHECK-NEXT:    [[TMP1:%.*]] = add <4 x i16> [[A1:%.*]], <i16 2, i16 2, i16 2, i16 2>
-; CHECK-NEXT:    [[TMP2:%.*]] = zext <4 x i16> [[TMP1]] to <4 x i32>
+; CHECK-NEXT:    [[TMP2:%.*]] = zext nneg <4 x i16> [[TMP1]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP3:%.*]] = lshr <4 x i32> [[A0:%.*]], [[TMP2]]
 ; CHECK-NEXT:    ret <4 x i32> [[TMP3]]
 ;
@@ -88,7 +88,7 @@ define <4 x i32> @test_v4i32_zext_shl_splatconst_pow2(<4 x i32> %a0, <4 x i16> %
 define <4 x i32> @test_v4i32_zext_shl_const_pow2(<4 x i32> %a0, <4 x i16> %a1) {
 ; CHECK-LABEL: @test_v4i32_zext_shl_const_pow2(
 ; CHECK-NEXT:    [[TMP1:%.*]] = add <4 x i16> [[A1:%.*]], <i16 2, i16 3, i16 4, i16 5>
-; CHECK-NEXT:    [[TMP2:%.*]] = zext <4 x i16> [[TMP1]] to <4 x i32>
+; CHECK-NEXT:    [[TMP2:%.*]] = zext nneg <4 x i16> [[TMP1]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP3:%.*]] = lshr <4 x i32> [[A0:%.*]], [[TMP2]]
 ; CHECK-NEXT:    ret <4 x i32> [[TMP3]]
 ;



More information about the llvm-commits mailing list