[llvm] 6cd5eb1 - [InstCombine] Avoid some uses of ConstantExpr::getZExt() (NFC)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 28 08:02:41 PDT 2023


Author: Nikita Popov
Date: 2023-09-28T17:02:33+02:00
New Revision: 6cd5eb1f5469e3148c7504e0183b1b84227b495e

URL: https://github.com/llvm/llvm-project/commit/6cd5eb1f5469e3148c7504e0183b1b84227b495e
DIFF: https://github.com/llvm/llvm-project/commit/6cd5eb1f5469e3148c7504e0183b1b84227b495e.diff

LOG: [InstCombine] Avoid some uses of ConstantExpr::getZExt() (NFC)

Add helpers getLosslessUnsignedTrunc/getLosslessSignedTrunc for
this common pattern.

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
    llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
    llvm/lib/Transforms/InstCombine/InstCombineInternal.h
    llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
    llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index c04fe827207fd0d..703065280096fcb 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -1669,7 +1669,7 @@ bool InstCombinerImpl::shouldOptimizeCast(CastInst *CI) {
 
 /// Fold {and,or,xor} (cast X), C.
 static Instruction *foldLogicCastConstant(BinaryOperator &Logic, CastInst *Cast,
-                                          InstCombiner::BuilderTy &Builder) {
+                                          InstCombinerImpl &IC) {
   Constant *C = dyn_cast<Constant>(Logic.getOperand(1));
   if (!C)
     return nullptr;
@@ -1684,21 +1684,17 @@ static Instruction *foldLogicCastConstant(BinaryOperator &Logic, CastInst *Cast,
   // instruction may be cheaper (particularly in the case of vectors).
   Value *X;
   if (match(Cast, m_OneUse(m_ZExt(m_Value(X))))) {
-    Constant *TruncC = ConstantExpr::getTrunc(C, SrcTy);
-    Constant *ZextTruncC = ConstantExpr::getZExt(TruncC, DestTy);
-    if (ZextTruncC == C) {
+    if (Constant *TruncC = IC.getLosslessUnsignedTrunc(C, SrcTy)) {
       // LogicOpc (zext X), C --> zext (LogicOpc X, C)
-      Value *NewOp = Builder.CreateBinOp(LogicOpc, X, TruncC);
+      Value *NewOp = IC.Builder.CreateBinOp(LogicOpc, X, TruncC);
       return new ZExtInst(NewOp, DestTy);
     }
   }
 
   if (match(Cast, m_OneUse(m_SExt(m_Value(X))))) {
-    Constant *TruncC = ConstantExpr::getTrunc(C, SrcTy);
-    Constant *SextTruncC = ConstantExpr::getSExt(TruncC, DestTy);
-    if (SextTruncC == C) {
+    if (Constant *TruncC = IC.getLosslessSignedTrunc(C, SrcTy)) {
       // LogicOpc (sext X), C --> sext (LogicOpc X, C)
-      Value *NewOp = Builder.CreateBinOp(LogicOpc, X, TruncC);
+      Value *NewOp = IC.Builder.CreateBinOp(LogicOpc, X, TruncC);
       return new SExtInst(NewOp, DestTy);
     }
   }
@@ -1756,7 +1752,7 @@ Instruction *InstCombinerImpl::foldCastedBitwiseLogic(BinaryOperator &I) {
   if (!SrcTy->isIntOrIntVectorTy())
     return nullptr;
 
-  if (Instruction *Ret = foldLogicCastConstant(I, Cast0, Builder))
+  if (Instruction *Ret = foldLogicCastConstant(I, Cast0, *this))
     return Ret;
 
   CastInst *Cast1 = dyn_cast<CastInst>(Op1);

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index c6100f24b0507de..e29fb869686ca0b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1600,8 +1600,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     Constant *C;
     if (match(I0, m_ZExt(m_Value(X))) && match(I1, m_Constant(C)) &&
         I0->hasOneUse()) {
-      Constant *NarrowC = ConstantExpr::getTrunc(C, X->getType());
-      if (ConstantExpr::getZExt(NarrowC, II->getType()) == C) {
+      if (Constant *NarrowC = getLosslessUnsignedTrunc(C, X->getType())) {
         Value *NarrowMaxMin = Builder.CreateBinaryIntrinsic(IID, X, NarrowC);
         return CastInst::Create(Instruction::ZExt, NarrowMaxMin, II->getType());
       }
@@ -1623,8 +1622,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     Constant *C;
     if (match(I0, m_SExt(m_Value(X))) && match(I1, m_Constant(C)) &&
         I0->hasOneUse()) {
-      Constant *NarrowC = ConstantExpr::getTrunc(C, X->getType());
-      if (ConstantExpr::getSExt(NarrowC, II->getType()) == C) {
+      if (Constant *NarrowC = getLosslessSignedTrunc(C, X->getType())) {
         Value *NarrowMaxMin = Builder.CreateBinaryIntrinsic(IID, X, NarrowC);
         return CastInst::Create(Instruction::SExt, NarrowMaxMin, II->getType());
       }

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 40c24d87bfec508..d5f76073fc3ccb3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -219,6 +219,24 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   bool fmulByZeroIsZero(Value *MulVal, FastMathFlags FMF,
                         const Instruction *CtxI) const;
 
+  Constant *getLosslessUnsignedTrunc(Constant *C, Type *TruncTy) {
+    Constant *TruncC = ConstantExpr::getTrunc(C, TruncTy);
+    Constant *ExtTruncC =
+        ConstantFoldCastOperand(Instruction::ZExt, TruncC, C->getType(), DL);
+    if (ExtTruncC && ExtTruncC == C)
+      return TruncC;
+    return nullptr;
+  }
+
+  Constant *getLosslessSignedTrunc(Constant *C, Type *TruncTy) {
+    Constant *TruncC = ConstantExpr::getTrunc(C, TruncTy);
+    Constant *ExtTruncC =
+        ConstantFoldCastOperand(Instruction::SExt, TruncC, C->getType(), DL);
+    if (ExtTruncC && ExtTruncC == C)
+      return TruncC;
+    return nullptr;
+  }
+
 private:
   bool annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI);
   bool isDesirableIntType(unsigned BitWidth) const;

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index dc091ec7c60e8dd..0db5fc254f3c06f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1268,7 +1268,7 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
 /// If we have zero-extended operands of an unsigned div or rem, we may be able
 /// to narrow the operation (sink the zext below the math).
 static Instruction *narrowUDivURem(BinaryOperator &I,
-                                   InstCombiner::BuilderTy &Builder) {
+                                   InstCombinerImpl &IC) {
   Instruction::BinaryOps Opcode = I.getOpcode();
   Value *N = I.getOperand(0);
   Value *D = I.getOperand(1);
@@ -1278,7 +1278,7 @@ static Instruction *narrowUDivURem(BinaryOperator &I,
       X->getType() == Y->getType() && (N->hasOneUse() || D->hasOneUse())) {
     // udiv (zext X), (zext Y) --> zext (udiv X, Y)
     // urem (zext X), (zext Y) --> zext (urem X, Y)
-    Value *NarrowOp = Builder.CreateBinOp(Opcode, X, Y);
+    Value *NarrowOp = IC.Builder.CreateBinOp(Opcode, X, Y);
     return new ZExtInst(NarrowOp, Ty);
   }
 
@@ -1286,24 +1286,24 @@ static Instruction *narrowUDivURem(BinaryOperator &I,
   if (isa<Instruction>(N) && match(N, m_OneUse(m_ZExt(m_Value(X)))) &&
       match(D, m_Constant(C))) {
     // If the constant is the same in the smaller type, use the narrow version.
-    Constant *TruncC = ConstantExpr::getTrunc(C, X->getType());
-    if (ConstantExpr::getZExt(TruncC, Ty) != C)
+    Constant *TruncC = IC.getLosslessUnsignedTrunc(C, X->getType());
+    if (!TruncC)
       return nullptr;
 
     // udiv (zext X), C --> zext (udiv X, C')
     // urem (zext X), C --> zext (urem X, C')
-    return new ZExtInst(Builder.CreateBinOp(Opcode, X, TruncC), Ty);
+    return new ZExtInst(IC.Builder.CreateBinOp(Opcode, X, TruncC), Ty);
   }
   if (isa<Instruction>(D) && match(D, m_OneUse(m_ZExt(m_Value(X)))) &&
       match(N, m_Constant(C))) {
     // If the constant is the same in the smaller type, use the narrow version.
-    Constant *TruncC = ConstantExpr::getTrunc(C, X->getType());
-    if (ConstantExpr::getZExt(TruncC, Ty) != C)
+    Constant *TruncC = IC.getLosslessUnsignedTrunc(C, X->getType());
+    if (!TruncC)
       return nullptr;
 
     // udiv C, (zext X) --> zext (udiv C', X)
     // urem C, (zext X) --> zext (urem C', X)
-    return new ZExtInst(Builder.CreateBinOp(Opcode, TruncC, X), Ty);
+    return new ZExtInst(IC.Builder.CreateBinOp(Opcode, TruncC, X), Ty);
   }
 
   return nullptr;
@@ -1351,7 +1351,7 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) {
     return CastInst::CreateZExtOrBitCast(Cmp, Ty);
   }
 
-  if (Instruction *NarrowDiv = narrowUDivURem(I, Builder))
+  if (Instruction *NarrowDiv = narrowUDivURem(I, *this))
     return NarrowDiv;
 
   // If the udiv operands are non-overflowing multiplies with a common operand,
@@ -1941,7 +1941,7 @@ Instruction *InstCombinerImpl::visitURem(BinaryOperator &I) {
   if (Instruction *common = commonIRemTransforms(I))
     return common;
 
-  if (Instruction *NarrowRem = narrowUDivURem(I, Builder))
+  if (Instruction *NarrowRem = narrowUDivURem(I, *this))
     return NarrowRem;
 
   // X urem Y -> X and Y-1, where Y is a power of 2,

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
index 00115abf6500597..5b8b74c8648b8c0 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
@@ -825,8 +825,8 @@ Instruction *InstCombinerImpl::foldPHIArgZextsIntoPHI(PHINode &Phi) {
       NumZexts++;
     } else if (auto *C = dyn_cast<Constant>(V)) {
       // Make sure that constants can fit in the new type.
-      Constant *Trunc = ConstantExpr::getTrunc(C, NarrowType);
-      if (ConstantExpr::getZExt(Trunc, C->getType()) != C)
+      Constant *Trunc = getLosslessUnsignedTrunc(C, NarrowType);
+      if (!Trunc)
         return nullptr;
       NewIncoming.push_back(Trunc);
       NumConsts++;


        


More information about the llvm-commits mailing list