[llvm] [InstCombine] Support multi-use values in cast elimination transforms (PR #165877)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Oct 31 08:41:37 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Valeriy Savchenko (SavchenkoValeriy)
<details>
<summary>Changes</summary>
`canEvaluateTruncated` and `canEvaluateSExtd` previously rejected multi-use values to avoid duplication. This was overly conservative, if all users of a multi-use value are part of the transform, we can evaluate it in a different type without duplication.
This change tracks visited values and defers decisions on multi-use values until we verify all their users were visited. `EvaluateInDifferentType` now memoizes multi-use values to avoid creating duplicates.
Applied to truncation and sext. Zext unchanged due to its dual-return nature.
---
Patch is 47.74 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/165877.diff
7 Files Affected:
- (modified) llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp (+312-77)
- (modified) llvm/test/Transforms/InstCombine/cast-mul-select.ll (+21-29)
- (modified) llvm/test/Transforms/InstCombine/cast.ll (+144)
- (modified) llvm/test/Transforms/InstCombine/catchswitch-phi.ll (+4-6)
- (modified) llvm/test/Transforms/InstCombine/icmp-mul-zext.ll (+3-4)
- (modified) llvm/test/Transforms/InstCombine/logical-select-inseltpoison.ll (+6-8)
- (modified) llvm/test/Transforms/InstCombine/logical-select.ll (+6-8)
``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 4c9b10a094981..6184c6d25d929 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -12,14 +12,21 @@
#include "InstCombineInternal.h"
#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DebugInfo.h"
+#include "llvm/IR/Instruction.h"
#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
#include "llvm/Support/KnownBits.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
+#include <iterator>
#include <optional>
using namespace llvm;
@@ -27,12 +34,19 @@ using namespace PatternMatch;
#define DEBUG_TYPE "instcombine"
-/// Given an expression that CanEvaluateTruncated or CanEvaluateSExtd returns
-/// true for, actually insert the code to evaluate the expression.
-Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
- bool isSigned) {
+using EvaluatedMap = SmallDenseMap<Value *, Value *, 8>;
+
+static Value *EvaluateInDifferentTypeImpl(Value *V, Type *Ty, bool isSigned,
+ InstCombinerImpl &IC,
+ EvaluatedMap &Processed) {
+ // Since we cover transformation of isntructions with multiple users, we might
+ // come to the same node via multiple paths. We should not create a
+ // replacement for every single one of them though.
+ if (const auto It = Processed.find(V); It != Processed.end())
+ return It->getSecond();
+
if (Constant *C = dyn_cast<Constant>(V))
- return ConstantFoldIntegerCast(C, Ty, isSigned, DL);
+ return ConstantFoldIntegerCast(C, Ty, isSigned, IC.getDataLayout());
// Otherwise, it must be an instruction.
Instruction *I = cast<Instruction>(V);
@@ -50,8 +64,10 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
case Instruction::Shl:
case Instruction::UDiv:
case Instruction::URem: {
- Value *LHS = EvaluateInDifferentType(I->getOperand(0), Ty, isSigned);
- Value *RHS = EvaluateInDifferentType(I->getOperand(1), Ty, isSigned);
+ Value *LHS = EvaluateInDifferentTypeImpl(I->getOperand(0), Ty, isSigned, IC,
+ Processed);
+ Value *RHS = EvaluateInDifferentTypeImpl(I->getOperand(1), Ty, isSigned, IC,
+ Processed);
Res = BinaryOperator::Create((Instruction::BinaryOps)Opc, LHS, RHS);
if (Opc == Instruction::LShr || Opc == Instruction::AShr)
Res->setIsExact(I->isExact());
@@ -72,8 +88,10 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
Opc == Instruction::SExt);
break;
case Instruction::Select: {
- Value *True = EvaluateInDifferentType(I->getOperand(1), Ty, isSigned);
- Value *False = EvaluateInDifferentType(I->getOperand(2), Ty, isSigned);
+ Value *True = EvaluateInDifferentTypeImpl(I->getOperand(1), Ty, isSigned,
+ IC, Processed);
+ Value *False = EvaluateInDifferentTypeImpl(I->getOperand(2), Ty, isSigned,
+ IC, Processed);
Res = SelectInst::Create(I->getOperand(0), True, False);
break;
}
@@ -81,8 +99,8 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
PHINode *OPN = cast<PHINode>(I);
PHINode *NPN = PHINode::Create(Ty, OPN->getNumIncomingValues());
for (unsigned i = 0, e = OPN->getNumIncomingValues(); i != e; ++i) {
- Value *V =
- EvaluateInDifferentType(OPN->getIncomingValue(i), Ty, isSigned);
+ Value *V = EvaluateInDifferentTypeImpl(OPN->getIncomingValue(i), Ty,
+ isSigned, IC, Processed);
NPN->addIncoming(V, OPN->getIncomingBlock(i));
}
Res = NPN;
@@ -90,8 +108,8 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
}
case Instruction::FPToUI:
case Instruction::FPToSI:
- Res = CastInst::Create(
- static_cast<Instruction::CastOps>(Opc), I->getOperand(0), Ty);
+ Res = CastInst::Create(static_cast<Instruction::CastOps>(Opc),
+ I->getOperand(0), Ty);
break;
case Instruction::Call:
if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
@@ -111,8 +129,10 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
auto *ScalarTy = cast<VectorType>(Ty)->getElementType();
auto *VTy = cast<VectorType>(I->getOperand(0)->getType());
auto *FixedTy = VectorType::get(ScalarTy, VTy->getElementCount());
- Value *Op0 = EvaluateInDifferentType(I->getOperand(0), FixedTy, isSigned);
- Value *Op1 = EvaluateInDifferentType(I->getOperand(1), FixedTy, isSigned);
+ Value *Op0 = EvaluateInDifferentTypeImpl(I->getOperand(0), FixedTy,
+ isSigned, IC, Processed);
+ Value *Op1 = EvaluateInDifferentTypeImpl(I->getOperand(1), FixedTy,
+ isSigned, IC, Processed);
Res = new ShuffleVectorInst(Op0, Op1,
cast<ShuffleVectorInst>(I)->getShuffleMask());
break;
@@ -123,7 +143,22 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
}
Res->takeName(I);
- return InsertNewInstWith(Res, I->getIterator());
+ Value *Result = IC.InsertNewInstWith(Res, I->getIterator());
+ // There is no need in keeping track of the old value/new value relationship
+ // when we have only one user, we came have here from that user and no-one
+ // else cares.
+ if (!V->hasOneUse()) {
+ Processed[V] = Result;
+ }
+ return Result;
+}
+
+/// Given an expression that CanEvaluateTruncated or CanEvaluateSExtd returns
+/// true for, actually insert the code to evaluate the expression.
+Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
+ bool isSigned) {
+ EvaluatedMap Processed;
+ return EvaluateInDifferentTypeImpl(V, Ty, isSigned, *this, Processed);
}
Instruction::CastOps
@@ -227,9 +262,175 @@ Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) {
return nullptr;
}
+namespace {
+
+/// Helper class for evaluating whether a value can be computed in a different
+/// type without changing its value. Used by cast simplification transforms.
+class TypeEvaluationHelper {
+public:
+ /// Return true if we can evaluate the specified expression tree as type Ty
+ /// instead of its larger type, and arrive with the same value.
+ /// This is used by code that tries to eliminate truncates.
+ [[nodiscard]] static bool canEvaluateTruncated(Value *V, Type *Ty,
+ InstCombinerImpl &IC,
+ Instruction *CxtI);
+
+ /// Determine if the specified value can be computed in the specified wider
+ /// type and produce the same low bits. If not, return false.
+ [[nodiscard]] static bool canEvaluateZExtd(Value *V, Type *Ty,
+ unsigned &BitsToClear,
+ InstCombinerImpl &IC,
+ Instruction *CxtI);
+
+ /// Return true if we can take the specified value and return it as type Ty
+ /// without inserting any new casts and without changing the value of the
+ /// common low bits.
+ [[nodiscard]] static bool canEvaluateSExtd(Value *V, Type *Ty);
+
+private:
+ /// Constants and extensions/truncates from the destination type are always
+ /// free to be evaluated in that type.
+ [[nodiscard]] static bool canAlwaysEvaluateInType(Value *V, Type *Ty);
+
+ /// Check if we traversed all the users of the multi-use values we've seen.
+ [[nodiscard]] bool allPendingVisited() const {
+ return llvm::all_of(Pending,
+ [this](Value *V) { return Visited.contains(V); });
+ }
+
+ /// A generic wrapper for canEvaluate* recursions to inject visitation
+ /// tracking and enforce correct multi-use value evaluations.
+ [[nodiscard]] bool
+ canEvaluate(Value *V, Type *Ty,
+ llvm::function_ref<bool(Value *, Type *Type)> Pred) {
+ if (canAlwaysEvaluateInType(V, Ty))
+ return true;
+
+ if (!isa<Instruction>(V))
+ return false;
+
+ auto *I = cast<Instruction>(V);
+ // We insert false by default to return false when we encounter user loops.
+ const auto [It, Inserted] = Visited.insert({V, false});
+
+ // There are three possible cases for us having information on this value
+ // in the Visited map:
+ // 1. We properly checked it and concluded that we can evaluate it (true)
+ // 2. We properly checked it and concluded that we can't (false)
+ // 3. We started to check it, but during the recursive traversal we came
+ // back to it.
+ //
+ // For cases 1 and 2, we can safely return the stored result. For case 3, we
+ // can potentially have a situation where we can evaluate recursive user
+ // chains, but that can be quite tricky to do properly and isntead, we
+ // return false.
+ //
+ // In any case, we should return whatever was there in the map to begin
+ // with.
+ if (!Inserted)
+ return It->getSecond();
+
+ // We can easily make a decision about single-user values whether they can
+ // be evaluated in a different type or not, we came from that user. This is
+ // not as simple for multi-user values.
+ //
+ // In general, we have the following case (inverted control-flow, users are
+ // at the top):
+ //
+ // Cast %A
+ // ____|
+ // /
+ // %A = Use %B, %C
+ // ________| |
+ // / |
+ // %B = Use %D |
+ // ________| |
+ // / |
+ // %D = Use %C |
+ // ________|___|
+ // /
+ // %C = ...
+ //
+ // In this case, when we check %A, %B and %C, we are confident that we can
+ // make the decision here and now, since we came from their only users.
+ //
+ // For %C, it is harder. We come there twice, and when we come the first
+ // time, it's hard to tell if we will visit the second user (technically
+ // it's not hard, but we might need a lot of repetitive checks with non-zero
+ // cost).
+ //
+ // In the case above, we are allowed to evaluate %C in different type
+ // because all of it users were part of the traversal.
+ //
+ // In the following case, however, we can't make this conclusion:
+ //
+ // Cast %A
+ // ____|
+ // /
+ // %A = Use %B, %C
+ // ________| |
+ // / |
+ // %B = Use %D |
+ // ________| |
+ // / |
+ // %D = Use %C |
+ // | |
+ // foo(%C) | | <- never traversing foo(%C)
+ // ________|___|
+ // /
+ // %C = ...
+ //
+ // In this case, we still can evaluate %C in a different type, but we'd need
+ // to create a copy of the original %C to be used in foo(%C). Such
+ // duplication might be not profitable.
+ //
+ // For this reason, we collect all users of the mult-user values and mark
+ // them as "pending" and defer this decision to the very end. When we are
+ // done and and ready to have a positive verdict, we should double-check all
+ // of the pending users and ensure that we visited them. allPendingVisited
+ // predicate checks exactly that.
+ if (!I->hasOneUse()) {
+ llvm::transform(I->uses(), std::back_inserter(Pending),
+ [](Use &U) { return U.getUser(); });
+ }
+
+ const bool Result = Pred(V, Ty);
+ // We have to set result this way and not via It because Pred is recursive
+ // and it is very likely that we grew Visited and invalidated It.
+ Visited[V] = Result;
+ return Result;
+ }
+
+ /// Filter out values that we can not evaluate in the destination type for
+ /// free.
+ [[nodiscard]] bool canNotEvaluateInType(Value *V, Type *Ty);
+
+ [[nodiscard]] bool canEvaluateTruncatedImpl(Value *V, Type *Ty,
+ InstCombinerImpl &IC,
+ Instruction *CxtI);
+ [[nodiscard]] bool canEvaluateTruncatedPred(Value *V, Type *Ty,
+ InstCombinerImpl &IC,
+ Instruction *CxtI);
+ [[nodiscard]] bool canEvaluateZExtdImpl(Value *V, Type *Ty,
+ unsigned &BitsToClear,
+ InstCombinerImpl &IC,
+ Instruction *CxtI);
+ [[nodiscard]] bool canEvaluateSExtdImpl(Value *V, Type *Ty);
+ [[nodiscard]] bool canEvaluateSExtdPred(Value *V, Type *Ty);
+
+ /// A bookkeeping map to memorize an already made decision for a traversed
+ /// value.
+ SmallDenseMap<Value *, bool, 8> Visited;
+
+ /// A list of pending values to check in the end.
+ SmallVector<Value *, 8> Pending;
+};
+
+} // anonymous namespace
+
/// Constants and extensions/truncates from the destination type are always
/// free to be evaluated in that type. This is a helper for canEvaluate*.
-static bool canAlwaysEvaluateInType(Value *V, Type *Ty) {
+bool TypeEvaluationHelper::canAlwaysEvaluateInType(Value *V, Type *Ty) {
if (isa<Constant>(V))
return match(V, m_ImmConstant());
@@ -243,7 +444,7 @@ static bool canAlwaysEvaluateInType(Value *V, Type *Ty) {
/// Filter out values that we can not evaluate in the destination type for free.
/// This is a helper for canEvaluate*.
-static bool canNotEvaluateInType(Value *V, Type *Ty) {
+bool TypeEvaluationHelper::canNotEvaluateInType(Value *V, Type *Ty) {
if (!isa<Instruction>(V))
return true;
// We don't extend or shrink something that has multiple uses -- doing so
@@ -265,13 +466,27 @@ static bool canNotEvaluateInType(Value *V, Type *Ty) {
///
/// This function works on both vectors and scalars.
///
-static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
- Instruction *CxtI) {
- if (canAlwaysEvaluateInType(V, Ty))
- return true;
- if (canNotEvaluateInType(V, Ty))
- return false;
+bool TypeEvaluationHelper::canEvaluateTruncated(Value *V, Type *Ty,
+ InstCombinerImpl &IC,
+ Instruction *CxtI) {
+ TypeEvaluationHelper TYH;
+ return TYH.canEvaluateTruncatedImpl(V, Ty, IC, CxtI) &&
+ // We need to check whether we visited all users of multi-user values,
+ // and we have to do it at the very end, outside of the recursion.
+ TYH.allPendingVisited();
+}
+bool TypeEvaluationHelper::canEvaluateTruncatedImpl(Value *V, Type *Ty,
+ InstCombinerImpl &IC,
+ Instruction *CxtI) {
+ return canEvaluate(V, Ty, [this, &IC, CxtI](Value *V, Type *Ty) {
+ return canEvaluateTruncatedPred(V, Ty, IC, CxtI);
+ });
+}
+
+bool TypeEvaluationHelper::canEvaluateTruncatedPred(Value *V, Type *Ty,
+ InstCombinerImpl &IC,
+ Instruction *CxtI) {
auto *I = cast<Instruction>(V);
Type *OrigTy = V->getType();
switch (I->getOpcode()) {
@@ -282,8 +497,8 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
case Instruction::Or:
case Instruction::Xor:
// These operators can all arbitrarily be extended or truncated.
- return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
- canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
+ return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) &&
+ canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI);
case Instruction::UDiv:
case Instruction::URem: {
@@ -296,8 +511,8 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
// based on later context may introduce a trap.
if (IC.MaskedValueIsZero(I->getOperand(0), Mask, I) &&
IC.MaskedValueIsZero(I->getOperand(1), Mask, I)) {
- return canEvaluateTruncated(I->getOperand(0), Ty, IC, I) &&
- canEvaluateTruncated(I->getOperand(1), Ty, IC, I);
+ return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) &&
+ canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI);
}
break;
}
@@ -308,8 +523,8 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
KnownBits AmtKnownBits =
llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout());
if (AmtKnownBits.getMaxValue().ult(BitWidth))
- return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
- canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
+ return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) &&
+ canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI);
break;
}
case Instruction::LShr: {
@@ -329,12 +544,12 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
if (auto *Trunc = dyn_cast<TruncInst>(V->user_back())) {
auto DemandedBits = Trunc->getType()->getScalarSizeInBits();
if ((MaxShiftAmt + DemandedBits).ule(BitWidth))
- return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
- canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
+ return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) &&
+ canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI);
}
if (IC.MaskedValueIsZero(I->getOperand(0), ShiftedBits, CxtI))
- return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
- canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
+ return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) &&
+ canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI);
}
break;
}
@@ -351,8 +566,8 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
unsigned ShiftedBits = OrigBitWidth - BitWidth;
if (AmtKnownBits.getMaxValue().ult(BitWidth) &&
ShiftedBits < IC.ComputeNumSignBits(I->getOperand(0), CxtI))
- return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
- canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
+ return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) &&
+ canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI);
break;
}
case Instruction::Trunc:
@@ -365,18 +580,18 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
return true;
case Instruction::Select: {
SelectInst *SI = cast<SelectInst>(I);
- return canEvaluateTruncated(SI->getTrueValue(), Ty, IC, CxtI) &&
- canEvaluateTruncated(SI->getFalseValue(), Ty, IC, CxtI);
+ return canEvaluateTruncatedImpl(SI->getTrueValue(), Ty, IC, CxtI) &&
+ canEvaluateTruncatedImpl(SI->getFalseValue(), Ty, IC, CxtI);
}
case Instruction::PHI: {
// We can change a phi if we can change all operands. Note that we never
- // get into trouble with cyclic PHIs here because we only consider
- // instructions with a single use.
+ // get into trouble with cyclic PHIs here because canEvaluate handles use
+ // chain loops.
PHINode *PN = cast<PHINode>(I);
- for (Value *IncValue : PN->incoming_values())
- if (!canEvaluateTruncated(IncValue, Ty, IC, CxtI))
- return false;
- return true;
+ return llvm::all_of(
+ PN->incoming_values(), [this, Ty, &IC, CxtI](Value *IncValue) {
+ return canEvaluateTruncatedImpl(IncValue, Ty, IC, CxtI);
+ });
}
case Instruction::FPToUI:
case Instruction::FPToSI: {
@@ -385,14 +600,14 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
// that did not exist in the original code.
Type *InputTy = I->g...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/165877
More information about the llvm-commits
mailing list