[llvm] 6ce7461 - [InstCombine] Avoid uses of ConstantExpr::getCast()

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 29 02:32:49 PDT 2023


Author: Nikita Popov
Date: 2023-09-29T11:32:41+02:00
New Revision: 6ce7461eeab7a7ca7496518184a79361c5e05c5e

URL: https://github.com/llvm/llvm-project/commit/6ce7461eeab7a7ca7496518184a79361c5e05c5e
DIFF: https://github.com/llvm/llvm-project/commit/6ce7461eeab7a7ca7496518184a79361c5e05c5e.diff

LOG: [InstCombine] Avoid uses of ConstantExpr::getCast()

Add a generalized getLosslessTrunc() helper to simplify this.

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
    llvm/lib/Transforms/InstCombine/InstCombineInternal.h
    llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
    llvm/lib/Transforms/InstCombine/InstructionCombining.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 05928b005dd47c5..091f7e1a6b74289 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -5588,25 +5588,20 @@ Instruction *InstCombinerImpl::foldICmpWithZextOrSext(ICmpInst &ICmp) {
   if (!C)
     return nullptr;
 
-  // Compute the constant that would happen if we truncated to SrcTy then
-  // re-extended to DestTy.
+  // If a lossless truncate is possible...
   Type *SrcTy = CastOp0->getSrcTy();
-  Type *DestTy = CastOp0->getDestTy();
-  Constant *Res1 = ConstantExpr::getTrunc(C, SrcTy);
-  Constant *Res2 = ConstantExpr::getCast(CastOp0->getOpcode(), Res1, DestTy);
-
-  // If the re-extended constant didn't change...
-  if (Res2 == C) {
+  Constant *Res = getLosslessTrunc(C, SrcTy, CastOp0->getOpcode());
+  if (Res) {
     if (ICmp.isEquality())
-      return new ICmpInst(ICmp.getPredicate(), X, Res1);
+      return new ICmpInst(ICmp.getPredicate(), X, Res);
 
     // A signed comparison of sign extended values simplifies into a
     // signed comparison.
     if (IsSignedExt && IsSignedCmp)
-      return new ICmpInst(ICmp.getPredicate(), X, Res1);
+      return new ICmpInst(ICmp.getPredicate(), X, Res);
 
     // The other three cases all fold into an unsigned comparison.
-    return new ICmpInst(ICmp.getUnsignedPredicate(), X, Res1);
+    return new ICmpInst(ICmp.getUnsignedPredicate(), X, Res);
   }
 
   // The re-extended constant changed, partly changed (in the case of a vector),

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index d5f76073fc3ccb3..c90925fd6edd2f1 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -219,22 +219,21 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   bool fmulByZeroIsZero(Value *MulVal, FastMathFlags FMF,
                         const Instruction *CtxI) const;
 
-  Constant *getLosslessUnsignedTrunc(Constant *C, Type *TruncTy) {
+  Constant *getLosslessTrunc(Constant *C, Type *TruncTy, unsigned ExtOp) {
     Constant *TruncC = ConstantExpr::getTrunc(C, TruncTy);
     Constant *ExtTruncC =
-        ConstantFoldCastOperand(Instruction::ZExt, TruncC, C->getType(), DL);
+        ConstantFoldCastOperand(ExtOp, TruncC, C->getType(), DL);
     if (ExtTruncC && ExtTruncC == C)
       return TruncC;
     return nullptr;
   }
 
+  Constant *getLosslessUnsignedTrunc(Constant *C, Type *TruncTy) {
+    return getLosslessTrunc(C, TruncTy, Instruction::ZExt);
+  }
+
   Constant *getLosslessSignedTrunc(Constant *C, Type *TruncTy) {
-    Constant *TruncC = ConstantExpr::getTrunc(C, TruncTy);
-    Constant *ExtTruncC =
-        ConstantFoldCastOperand(Instruction::SExt, TruncC, C->getType(), DL);
-    if (ExtTruncC && ExtTruncC == C)
-      return TruncC;
-    return nullptr;
+    return getLosslessTrunc(C, TruncTy, Instruction::SExt);
   }
 
 private:

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 05b3eaacc7b4d06..c42164da8aee1ab 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -2083,9 +2083,8 @@ Instruction *InstCombinerImpl::foldSelectExtConst(SelectInst &Sel) {
   // If the constant is the same after truncation to the smaller type and
   // extension to the original type, we can narrow the select.
   Type *SelType = Sel.getType();
-  Constant *TruncC = ConstantExpr::getTrunc(C, SmallType);
-  Constant *ExtC = ConstantExpr::getCast(ExtOpcode, TruncC, SelType);
-  if (ExtC == C && ExtInst->hasOneUse()) {
+  Constant *TruncC = getLosslessTrunc(C, SmallType, ExtOpcode);
+  if (TruncC && ExtInst->hasOneUse()) {
     Value *TruncCVal = cast<Value>(TruncC);
     if (ExtInst == Sel.getFalseValue())
       std::swap(X, TruncCVal);
@@ -2103,7 +2102,8 @@ Instruction *InstCombinerImpl::foldSelectExtConst(SelectInst &Sel) {
       // select X, (sext X), C --> select X, -1, C
       // select X, (zext X), C --> select X,  1, C
       Constant *One = ConstantInt::getTrue(SmallType);
-      Constant *AllOnesOrOne = ConstantExpr::getCast(ExtOpcode, One, SelType);
+      Value *AllOnesOrOne =
+          Builder.CreateCast((Instruction::CastOps)ExtOpcode, One, SelType);
       return SelectInst::Create(Cond, AllOnesOrOne, C, "", nullptr, &Sel);
     } else {
       // select X, C, (sext X) --> select X, C, 0

diff  --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 93cb5dec930b191..f6aabe17e06eb46 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -345,10 +345,12 @@ static bool simplifyAssocCastAssoc(BinaryOperator *BinOp1,
 
   // Fold the constants together in the destination type:
   // (op (cast (op X, C2)), C1) --> (op (cast X), FoldedC)
+  const DataLayout &DL = IC.getDataLayout();
   Type *DestTy = C1->getType();
-  Constant *CastC2 = ConstantExpr::getCast(CastOpcode, C2, DestTy);
-  Constant *FoldedC =
-      ConstantFoldBinaryOpOperands(AssocOpcode, C1, CastC2, IC.getDataLayout());
+  Constant *CastC2 = ConstantFoldCastOperand(CastOpcode, C2, DestTy, DL);
+  if (!CastC2)
+    return false;
+  Constant *FoldedC = ConstantFoldBinaryOpOperands(AssocOpcode, C1, CastC2, DL);
   if (!FoldedC)
     return false;
 
@@ -1898,8 +1900,8 @@ Instruction *InstCombinerImpl::narrowMathIfNoOverflow(BinaryOperator &BO) {
     Constant *WideC;
     if (!Op0->hasOneUse() || !match(Op1, m_Constant(WideC)))
       return nullptr;
-    Constant *NarrowC = ConstantExpr::getTrunc(WideC, X->getType());
-    if (ConstantExpr::getCast(CastOpc, NarrowC, BO.getType()) != WideC)
+    Constant *NarrowC = getLosslessTrunc(WideC, X->getType(), CastOpc);
+    if (!NarrowC)
       return nullptr;
     Y = NarrowC;
   }


        


More information about the llvm-commits mailing list