[llvm] [InstCombine] Support multi-use values in cast elimination transforms (PR #165877)

via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 31 08:41:37 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Valeriy Savchenko (SavchenkoValeriy)

<details>
<summary>Changes</summary>

`canEvaluateTruncated` and `canEvaluateSExtd` previously rejected multi-use values to avoid duplication. This was overly conservative, if all users of a multi-use value are part of the transform, we can evaluate it in a different type without duplication.

This change tracks visited values and defers decisions on multi-use values until we verify all their users were visited. `EvaluateInDifferentType` now memoizes multi-use values to avoid creating duplicates.

Applied to truncation and sext. Zext unchanged due to its dual-return nature.


---

Patch is 47.74 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/165877.diff


7 Files Affected:

- (modified) llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp (+312-77) 
- (modified) llvm/test/Transforms/InstCombine/cast-mul-select.ll (+21-29) 
- (modified) llvm/test/Transforms/InstCombine/cast.ll (+144) 
- (modified) llvm/test/Transforms/InstCombine/catchswitch-phi.ll (+4-6) 
- (modified) llvm/test/Transforms/InstCombine/icmp-mul-zext.ll (+3-4) 
- (modified) llvm/test/Transforms/InstCombine/logical-select-inseltpoison.ll (+6-8) 
- (modified) llvm/test/Transforms/InstCombine/logical-select.ll (+6-8) 


``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 4c9b10a094981..6184c6d25d929 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -12,14 +12,21 @@
 
 #include "InstCombineInternal.h"
 #include "llvm/ADT/APInt.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/STLFunctionalExtras.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/DebugInfo.h"
+#include "llvm/IR/Instruction.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
 #include "llvm/Support/KnownBits.h"
 #include "llvm/Transforms/InstCombine/InstCombiner.h"
+#include <iterator>
 #include <optional>
 
 using namespace llvm;
@@ -27,12 +34,19 @@ using namespace PatternMatch;
 
 #define DEBUG_TYPE "instcombine"
 
-/// Given an expression that CanEvaluateTruncated or CanEvaluateSExtd returns
-/// true for, actually insert the code to evaluate the expression.
-Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
-                                                 bool isSigned) {
+using EvaluatedMap = SmallDenseMap<Value *, Value *, 8>;
+
+static Value *EvaluateInDifferentTypeImpl(Value *V, Type *Ty, bool isSigned,
+                                          InstCombinerImpl &IC,
+                                          EvaluatedMap &Processed) {
+  // Since we cover transformation of isntructions with multiple users, we might
+  // come to the same node via multiple paths. We should not create a
+  // replacement for every single one of them though.
+  if (const auto It = Processed.find(V); It != Processed.end())
+    return It->getSecond();
+
   if (Constant *C = dyn_cast<Constant>(V))
-    return ConstantFoldIntegerCast(C, Ty, isSigned, DL);
+    return ConstantFoldIntegerCast(C, Ty, isSigned, IC.getDataLayout());
 
   // Otherwise, it must be an instruction.
   Instruction *I = cast<Instruction>(V);
@@ -50,8 +64,10 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
   case Instruction::Shl:
   case Instruction::UDiv:
   case Instruction::URem: {
-    Value *LHS = EvaluateInDifferentType(I->getOperand(0), Ty, isSigned);
-    Value *RHS = EvaluateInDifferentType(I->getOperand(1), Ty, isSigned);
+    Value *LHS = EvaluateInDifferentTypeImpl(I->getOperand(0), Ty, isSigned, IC,
+                                             Processed);
+    Value *RHS = EvaluateInDifferentTypeImpl(I->getOperand(1), Ty, isSigned, IC,
+                                             Processed);
     Res = BinaryOperator::Create((Instruction::BinaryOps)Opc, LHS, RHS);
     if (Opc == Instruction::LShr || Opc == Instruction::AShr)
       Res->setIsExact(I->isExact());
@@ -72,8 +88,10 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
                                       Opc == Instruction::SExt);
     break;
   case Instruction::Select: {
-    Value *True = EvaluateInDifferentType(I->getOperand(1), Ty, isSigned);
-    Value *False = EvaluateInDifferentType(I->getOperand(2), Ty, isSigned);
+    Value *True = EvaluateInDifferentTypeImpl(I->getOperand(1), Ty, isSigned,
+                                              IC, Processed);
+    Value *False = EvaluateInDifferentTypeImpl(I->getOperand(2), Ty, isSigned,
+                                               IC, Processed);
     Res = SelectInst::Create(I->getOperand(0), True, False);
     break;
   }
@@ -81,8 +99,8 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
     PHINode *OPN = cast<PHINode>(I);
     PHINode *NPN = PHINode::Create(Ty, OPN->getNumIncomingValues());
     for (unsigned i = 0, e = OPN->getNumIncomingValues(); i != e; ++i) {
-      Value *V =
-          EvaluateInDifferentType(OPN->getIncomingValue(i), Ty, isSigned);
+      Value *V = EvaluateInDifferentTypeImpl(OPN->getIncomingValue(i), Ty,
+                                             isSigned, IC, Processed);
       NPN->addIncoming(V, OPN->getIncomingBlock(i));
     }
     Res = NPN;
@@ -90,8 +108,8 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
   }
   case Instruction::FPToUI:
   case Instruction::FPToSI:
-    Res = CastInst::Create(
-      static_cast<Instruction::CastOps>(Opc), I->getOperand(0), Ty);
+    Res = CastInst::Create(static_cast<Instruction::CastOps>(Opc),
+                           I->getOperand(0), Ty);
     break;
   case Instruction::Call:
     if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
@@ -111,8 +129,10 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
     auto *ScalarTy = cast<VectorType>(Ty)->getElementType();
     auto *VTy = cast<VectorType>(I->getOperand(0)->getType());
     auto *FixedTy = VectorType::get(ScalarTy, VTy->getElementCount());
-    Value *Op0 = EvaluateInDifferentType(I->getOperand(0), FixedTy, isSigned);
-    Value *Op1 = EvaluateInDifferentType(I->getOperand(1), FixedTy, isSigned);
+    Value *Op0 = EvaluateInDifferentTypeImpl(I->getOperand(0), FixedTy,
+                                             isSigned, IC, Processed);
+    Value *Op1 = EvaluateInDifferentTypeImpl(I->getOperand(1), FixedTy,
+                                             isSigned, IC, Processed);
     Res = new ShuffleVectorInst(Op0, Op1,
                                 cast<ShuffleVectorInst>(I)->getShuffleMask());
     break;
@@ -123,7 +143,22 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
   }
 
   Res->takeName(I);
-  return InsertNewInstWith(Res, I->getIterator());
+  Value *Result = IC.InsertNewInstWith(Res, I->getIterator());
+  // There is no need in keeping track of the old value/new value relationship
+  // when we have only one user, we came have here from that user and no-one
+  // else cares.
+  if (!V->hasOneUse()) {
+    Processed[V] = Result;
+  }
+  return Result;
+}
+
+/// Given an expression that CanEvaluateTruncated or CanEvaluateSExtd returns
+/// true for, actually insert the code to evaluate the expression.
+Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
+                                                 bool isSigned) {
+  EvaluatedMap Processed;
+  return EvaluateInDifferentTypeImpl(V, Ty, isSigned, *this, Processed);
 }
 
 Instruction::CastOps
@@ -227,9 +262,175 @@ Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) {
   return nullptr;
 }
 
+namespace {
+
+/// Helper class for evaluating whether a value can be computed in a different
+/// type without changing its value. Used by cast simplification transforms.
+class TypeEvaluationHelper {
+public:
+  /// Return true if we can evaluate the specified expression tree as type Ty
+  /// instead of its larger type, and arrive with the same value.
+  /// This is used by code that tries to eliminate truncates.
+  [[nodiscard]] static bool canEvaluateTruncated(Value *V, Type *Ty,
+                                                 InstCombinerImpl &IC,
+                                                 Instruction *CxtI);
+
+  /// Determine if the specified value can be computed in the specified wider
+  /// type and produce the same low bits. If not, return false.
+  [[nodiscard]] static bool canEvaluateZExtd(Value *V, Type *Ty,
+                                             unsigned &BitsToClear,
+                                             InstCombinerImpl &IC,
+                                             Instruction *CxtI);
+
+  /// Return true if we can take the specified value and return it as type Ty
+  /// without inserting any new casts and without changing the value of the
+  /// common low bits.
+  [[nodiscard]] static bool canEvaluateSExtd(Value *V, Type *Ty);
+
+private:
+  /// Constants and extensions/truncates from the destination type are always
+  /// free to be evaluated in that type.
+  [[nodiscard]] static bool canAlwaysEvaluateInType(Value *V, Type *Ty);
+
+  /// Check if we traversed all the users of the multi-use values we've seen.
+  [[nodiscard]] bool allPendingVisited() const {
+    return llvm::all_of(Pending,
+                        [this](Value *V) { return Visited.contains(V); });
+  }
+
+  /// A generic wrapper for canEvaluate* recursions to inject visitation
+  /// tracking and enforce correct multi-use value evaluations.
+  [[nodiscard]] bool
+  canEvaluate(Value *V, Type *Ty,
+              llvm::function_ref<bool(Value *, Type *Type)> Pred) {
+    if (canAlwaysEvaluateInType(V, Ty))
+      return true;
+
+    if (!isa<Instruction>(V))
+      return false;
+
+    auto *I = cast<Instruction>(V);
+    // We insert false by default to return false when we encounter user loops.
+    const auto [It, Inserted] = Visited.insert({V, false});
+
+    // There are three possible cases for us having information on this value
+    // in the Visited map:
+    //   1. We properly checked it and concluded that we can evaluate it (true)
+    //   2. We properly checked it and concluded that we can't (false)
+    //   3. We started to check it, but during the recursive traversal we came
+    //      back to it.
+    //
+    // For cases 1 and 2, we can safely return the stored result. For case 3, we
+    // can potentially have a situation where we can evaluate recursive user
+    // chains, but that can be quite tricky to do properly and isntead, we
+    // return false.
+    //
+    // In any case, we should return whatever was there in the map to begin
+    // with.
+    if (!Inserted)
+      return It->getSecond();
+
+    // We can easily make a decision about single-user values whether they can
+    // be evaluated in a different type or not, we came from that user. This is
+    // not as simple for multi-user values.
+    //
+    // In general, we have the following case (inverted control-flow, users are
+    // at the top):
+    //
+    // Cast %A
+    //  ____|
+    // /
+    // %A = Use %B, %C
+    //  ________|   |
+    // /            |
+    // %B = Use %D  |
+    //  ________|   |
+    // /            |
+    // %D = Use %C  |
+    //  ________|___|
+    // /
+    // %C = ...
+    //
+    // In this case, when we check %A, %B and %C, we are confident that we can
+    // make the decision here and now, since we came from their only users.
+    //
+    // For %C, it is harder. We come there twice, and when we come the first
+    // time, it's hard to tell if we will visit the second user (technically
+    // it's not hard, but we might need a lot of repetitive checks with non-zero
+    // cost).
+    //
+    // In the case above, we are allowed to evaluate %C in different type
+    // because all of it users were part of the traversal.
+    //
+    // In the following case, however, we can't make this conclusion:
+    //
+    // Cast %A
+    //  ____|
+    // /
+    // %A = Use %B, %C
+    //  ________|   |
+    // /            |
+    // %B = Use %D  |
+    //  ________|   |
+    // /            |
+    // %D = Use %C  |
+    //          |   |
+    // foo(%C)  |   |    <- never traversing foo(%C)
+    //  ________|___|
+    // /
+    // %C = ...
+    //
+    // In this case, we still can evaluate %C in a different type, but we'd need
+    // to create a copy of the original %C to be used in foo(%C). Such
+    // duplication might be not profitable.
+    //
+    // For this reason, we collect all users of the mult-user values and mark
+    // them as "pending" and defer this decision to the very end. When we are
+    // done and and ready to have a positive verdict, we should double-check all
+    // of the pending users and ensure that we visited them. allPendingVisited
+    // predicate checks exactly that.
+    if (!I->hasOneUse()) {
+      llvm::transform(I->uses(), std::back_inserter(Pending),
+                      [](Use &U) { return U.getUser(); });
+    }
+
+    const bool Result = Pred(V, Ty);
+    // We have to set result this way and not via It because Pred is recursive
+    // and it is very likely that we grew Visited and invalidated It.
+    Visited[V] = Result;
+    return Result;
+  }
+
+  /// Filter out values that we can not evaluate in the destination type for
+  /// free.
+  [[nodiscard]] bool canNotEvaluateInType(Value *V, Type *Ty);
+
+  [[nodiscard]] bool canEvaluateTruncatedImpl(Value *V, Type *Ty,
+                                              InstCombinerImpl &IC,
+                                              Instruction *CxtI);
+  [[nodiscard]] bool canEvaluateTruncatedPred(Value *V, Type *Ty,
+                                              InstCombinerImpl &IC,
+                                              Instruction *CxtI);
+  [[nodiscard]] bool canEvaluateZExtdImpl(Value *V, Type *Ty,
+                                          unsigned &BitsToClear,
+                                          InstCombinerImpl &IC,
+                                          Instruction *CxtI);
+  [[nodiscard]] bool canEvaluateSExtdImpl(Value *V, Type *Ty);
+  [[nodiscard]] bool canEvaluateSExtdPred(Value *V, Type *Ty);
+
+  /// A bookkeeping map to memorize an already made decision for a traversed
+  /// value.
+  SmallDenseMap<Value *, bool, 8> Visited;
+
+  /// A list of pending values to check in the end.
+  SmallVector<Value *, 8> Pending;
+};
+
+} // anonymous namespace
+
 /// Constants and extensions/truncates from the destination type are always
 /// free to be evaluated in that type. This is a helper for canEvaluate*.
-static bool canAlwaysEvaluateInType(Value *V, Type *Ty) {
+bool TypeEvaluationHelper::canAlwaysEvaluateInType(Value *V, Type *Ty) {
   if (isa<Constant>(V))
     return match(V, m_ImmConstant());
 
@@ -243,7 +444,7 @@ static bool canAlwaysEvaluateInType(Value *V, Type *Ty) {
 
 /// Filter out values that we can not evaluate in the destination type for free.
 /// This is a helper for canEvaluate*.
-static bool canNotEvaluateInType(Value *V, Type *Ty) {
+bool TypeEvaluationHelper::canNotEvaluateInType(Value *V, Type *Ty) {
   if (!isa<Instruction>(V))
     return true;
   // We don't extend or shrink something that has multiple uses --  doing so
@@ -265,13 +466,27 @@ static bool canNotEvaluateInType(Value *V, Type *Ty) {
 ///
 /// This function works on both vectors and scalars.
 ///
-static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
-                                 Instruction *CxtI) {
-  if (canAlwaysEvaluateInType(V, Ty))
-    return true;
-  if (canNotEvaluateInType(V, Ty))
-    return false;
+bool TypeEvaluationHelper::canEvaluateTruncated(Value *V, Type *Ty,
+                                                InstCombinerImpl &IC,
+                                                Instruction *CxtI) {
+  TypeEvaluationHelper TYH;
+  return TYH.canEvaluateTruncatedImpl(V, Ty, IC, CxtI) &&
+         // We need to check whether we visited all users of multi-user values,
+         // and we have to do it at the very end, outside of the recursion.
+         TYH.allPendingVisited();
+}
 
+bool TypeEvaluationHelper::canEvaluateTruncatedImpl(Value *V, Type *Ty,
+                                                    InstCombinerImpl &IC,
+                                                    Instruction *CxtI) {
+  return canEvaluate(V, Ty, [this, &IC, CxtI](Value *V, Type *Ty) {
+    return canEvaluateTruncatedPred(V, Ty, IC, CxtI);
+  });
+}
+
+bool TypeEvaluationHelper::canEvaluateTruncatedPred(Value *V, Type *Ty,
+                                                    InstCombinerImpl &IC,
+                                                    Instruction *CxtI) {
   auto *I = cast<Instruction>(V);
   Type *OrigTy = V->getType();
   switch (I->getOpcode()) {
@@ -282,8 +497,8 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
   case Instruction::Or:
   case Instruction::Xor:
     // These operators can all arbitrarily be extended or truncated.
-    return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
-           canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
+    return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) &&
+           canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI);
 
   case Instruction::UDiv:
   case Instruction::URem: {
@@ -296,8 +511,8 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
     // based on later context may introduce a trap.
     if (IC.MaskedValueIsZero(I->getOperand(0), Mask, I) &&
         IC.MaskedValueIsZero(I->getOperand(1), Mask, I)) {
-      return canEvaluateTruncated(I->getOperand(0), Ty, IC, I) &&
-             canEvaluateTruncated(I->getOperand(1), Ty, IC, I);
+      return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) &&
+             canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI);
     }
     break;
   }
@@ -308,8 +523,8 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
     KnownBits AmtKnownBits =
         llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout());
     if (AmtKnownBits.getMaxValue().ult(BitWidth))
-      return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
-             canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
+      return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) &&
+             canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI);
     break;
   }
   case Instruction::LShr: {
@@ -329,12 +544,12 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
       if (auto *Trunc = dyn_cast<TruncInst>(V->user_back())) {
         auto DemandedBits = Trunc->getType()->getScalarSizeInBits();
         if ((MaxShiftAmt + DemandedBits).ule(BitWidth))
-          return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
-                 canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
+          return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) &&
+                 canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI);
       }
       if (IC.MaskedValueIsZero(I->getOperand(0), ShiftedBits, CxtI))
-        return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
-               canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
+        return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) &&
+               canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI);
     }
     break;
   }
@@ -351,8 +566,8 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
     unsigned ShiftedBits = OrigBitWidth - BitWidth;
     if (AmtKnownBits.getMaxValue().ult(BitWidth) &&
         ShiftedBits < IC.ComputeNumSignBits(I->getOperand(0), CxtI))
-      return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
-             canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
+      return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) &&
+             canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI);
     break;
   }
   case Instruction::Trunc:
@@ -365,18 +580,18 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
     return true;
   case Instruction::Select: {
     SelectInst *SI = cast<SelectInst>(I);
-    return canEvaluateTruncated(SI->getTrueValue(), Ty, IC, CxtI) &&
-           canEvaluateTruncated(SI->getFalseValue(), Ty, IC, CxtI);
+    return canEvaluateTruncatedImpl(SI->getTrueValue(), Ty, IC, CxtI) &&
+           canEvaluateTruncatedImpl(SI->getFalseValue(), Ty, IC, CxtI);
   }
   case Instruction::PHI: {
     // We can change a phi if we can change all operands.  Note that we never
-    // get into trouble with cyclic PHIs here because we only consider
-    // instructions with a single use.
+    // get into trouble with cyclic PHIs here because canEvaluate handles use
+    // chain loops.
     PHINode *PN = cast<PHINode>(I);
-    for (Value *IncValue : PN->incoming_values())
-      if (!canEvaluateTruncated(IncValue, Ty, IC, CxtI))
-        return false;
-    return true;
+    return llvm::all_of(
+        PN->incoming_values(), [this, Ty, &IC, CxtI](Value *IncValue) {
+          return canEvaluateTruncatedImpl(IncValue, Ty, IC, CxtI);
+        });
   }
   case Instruction::FPToUI:
   case Instruction::FPToSI: {
@@ -385,14 +600,14 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
     // that did not exist in the original code.
     Type *InputTy = I->g...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/165877


More information about the llvm-commits mailing list