[llvm] 75b0c89 - [InstCombine][VectorCombine][NFC] Unify uses of lossless inverse cast (#156597)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 8 06:30:11 PDT 2025
Author: Hongyu Chen
Date: 2025-09-08T13:30:06Z
New Revision: 75b0c89e626f21e9ba9c920c878cc9e81471f4cf
URL: https://github.com/llvm/llvm-project/commit/75b0c89e626f21e9ba9c920c878cc9e81471f4cf
DIFF: https://github.com/llvm/llvm-project/commit/75b0c89e626f21e9ba9c920c878cc9e81471f4cf.diff
LOG: [InstCombine][VectorCombine][NFC] Unify uses of lossless inverse cast (#156597)
This patch addresses
https://github.com/llvm/llvm-project/pull/155216#discussion_r2297724663.
This patch adds a helper function to put the inverse cast on constants,
with cast flags preserved(optional).
Follow-up patches will add trunc/ext handling on VectorCombine and flags
preservation on InstCombine.
Added:
Modified:
llvm/include/llvm/Analysis/ConstantFolding.h
llvm/lib/Analysis/ConstantFolding.cpp
llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
llvm/lib/Transforms/InstCombine/InstCombineInternal.h
llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/ConstantFolding.h b/llvm/include/llvm/Analysis/ConstantFolding.h
index dcbac8a301025..5f91f9747bb97 100644
--- a/llvm/include/llvm/Analysis/ConstantFolding.h
+++ b/llvm/include/llvm/Analysis/ConstantFolding.h
@@ -226,6 +226,27 @@ LLVM_ABI bool isMathLibCallNoop(const CallBase *Call,
LLVM_ABI Constant *ReadByteArrayFromGlobal(const GlobalVariable *GV,
uint64_t Offset);
-}
+
+struct PreservedCastFlags {
+ bool NNeg = false;
+ bool NUW = false;
+ bool NSW = false;
+};
+
+/// Try to cast C to InvC losslessly, satisfying CastOp(InvC) equals C, or
+/// CastOp(InvC) is a refined value of undefined C. Will try best to
+/// preserve the flags.
+LLVM_ABI Constant *getLosslessInvCast(Constant *C, Type *InvCastTo,
+ unsigned CastOp, const DataLayout &DL,
+ PreservedCastFlags *Flags = nullptr);
+
+LLVM_ABI Constant *
+getLosslessUnsignedTrunc(Constant *C, Type *DestTy, const DataLayout &DL,
+ PreservedCastFlags *Flags = nullptr);
+
+LLVM_ABI Constant *getLosslessSignedTrunc(Constant *C, Type *DestTy,
+ const DataLayout &DL,
+ PreservedCastFlags *Flags = nullptr);
+} // namespace llvm
#endif
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 2148431c1acce..40e176c2ab5ce 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -4608,4 +4608,55 @@ bool llvm::isMathLibCallNoop(const CallBase *Call,
return false;
}
+Constant *llvm::getLosslessInvCast(Constant *C, Type *InvCastTo,
+ unsigned CastOp, const DataLayout &DL,
+ PreservedCastFlags *Flags) {
+ switch (CastOp) {
+ case Instruction::BitCast:
+ // Bitcast is always lossless.
+ return ConstantFoldCastOperand(Instruction::BitCast, C, InvCastTo, DL);
+ case Instruction::Trunc: {
+ auto *ZExtC = ConstantFoldCastOperand(Instruction::ZExt, C, InvCastTo, DL);
+ if (Flags) {
+ // Truncation back on ZExt value is always NUW.
+ Flags->NUW = true;
+ // Test positivity of C.
+ auto *SExtC =
+ ConstantFoldCastOperand(Instruction::SExt, C, InvCastTo, DL);
+ Flags->NSW = ZExtC == SExtC;
+ }
+ return ZExtC;
+ }
+ case Instruction::SExt:
+ case Instruction::ZExt: {
+ auto *InvC = ConstantExpr::getTrunc(C, InvCastTo);
+ auto *CastInvC = ConstantFoldCastOperand(CastOp, InvC, C->getType(), DL);
+ // Must satisfy CastOp(InvC) == C.
+ if (!CastInvC || CastInvC != C)
+ return nullptr;
+ if (Flags && CastOp == Instruction::ZExt) {
+ auto *SExtInvC =
+ ConstantFoldCastOperand(Instruction::SExt, InvC, C->getType(), DL);
+ // Test positivity of InvC.
+ Flags->NNeg = CastInvC == SExtInvC;
+ }
+ return InvC;
+ }
+ default:
+ return nullptr;
+ }
+}
+
+Constant *llvm::getLosslessUnsignedTrunc(Constant *C, Type *DestTy,
+ const DataLayout &DL,
+ PreservedCastFlags *Flags) {
+ return getLosslessInvCast(C, DestTy, Instruction::ZExt, DL, Flags);
+}
+
+Constant *llvm::getLosslessSignedTrunc(Constant *C, Type *DestTy,
+ const DataLayout &DL,
+ PreservedCastFlags *Flags) {
+ return getLosslessInvCast(C, DestTy, Instruction::SExt, DL, Flags);
+}
+
void TargetFolder::anchor() {}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index a13d3ceb61320..8b9df62d7c652 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -1799,8 +1799,9 @@ static Instruction *foldLogicCastConstant(BinaryOperator &Logic, CastInst *Cast,
// type may provide more information to later folds, and the smaller logic
// instruction may be cheaper (particularly in the case of vectors).
Value *X;
+ auto &DL = IC.getDataLayout();
if (match(Cast, m_OneUse(m_ZExt(m_Value(X))))) {
- if (Constant *TruncC = IC.getLosslessUnsignedTrunc(C, SrcTy)) {
+ if (Constant *TruncC = getLosslessUnsignedTrunc(C, SrcTy, DL)) {
// LogicOpc (zext X), C --> zext (LogicOpc X, C)
Value *NewOp = IC.Builder.CreateBinOp(LogicOpc, X, TruncC);
return new ZExtInst(NewOp, DestTy);
@@ -1808,7 +1809,7 @@ static Instruction *foldLogicCastConstant(BinaryOperator &Logic, CastInst *Cast,
}
if (match(Cast, m_OneUse(m_SExtLike(m_Value(X))))) {
- if (Constant *TruncC = IC.getLosslessSignedTrunc(C, SrcTy)) {
+ if (Constant *TruncC = getLosslessSignedTrunc(C, SrcTy, DL)) {
// LogicOpc (sext X), C --> sext (LogicOpc X, C)
Value *NewOp = IC.Builder.CreateBinOp(LogicOpc, X, TruncC);
return new SExtInst(NewOp, DestTy);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 42b65dde67255..33b66aeaffe60 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1956,7 +1956,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
Constant *C;
if (match(I0, m_ZExt(m_Value(X))) && match(I1, m_Constant(C)) &&
I0->hasOneUse()) {
- if (Constant *NarrowC = getLosslessUnsignedTrunc(C, X->getType())) {
+ if (Constant *NarrowC = getLosslessUnsignedTrunc(C, X->getType(), DL)) {
Value *NarrowMaxMin = Builder.CreateBinaryIntrinsic(IID, X, NarrowC);
return CastInst::Create(Instruction::ZExt, NarrowMaxMin, II->getType());
}
@@ -2006,7 +2006,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
Constant *C;
if (match(I0, m_SExt(m_Value(X))) && match(I1, m_Constant(C)) &&
I0->hasOneUse()) {
- if (Constant *NarrowC = getLosslessSignedTrunc(C, X->getType())) {
+ if (Constant *NarrowC = getLosslessSignedTrunc(C, X->getType(), DL)) {
Value *NarrowMaxMin = Builder.CreateBinaryIntrinsic(IID, X, NarrowC);
return CastInst::Create(Instruction::SExt, NarrowMaxMin, II->getType());
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 01b0da3469c18..07da12a3ab2a4 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -6336,7 +6336,7 @@ Instruction *InstCombinerImpl::foldICmpWithZextOrSext(ICmpInst &ICmp) {
// If a lossless truncate is possible...
Type *SrcTy = CastOp0->getSrcTy();
- Constant *Res = getLosslessTrunc(C, SrcTy, CastOp0->getOpcode());
+ Constant *Res = getLosslessInvCast(C, SrcTy, CastOp0->getOpcode(), DL);
if (Res) {
if (ICmp.isEquality())
return new ICmpInst(ICmp.getPredicate(), X, Res);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 2340028ce93dc..d3d23130b6fc4 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -222,23 +222,6 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
bool fmulByZeroIsZero(Value *MulVal, FastMathFlags FMF,
const Instruction *CtxI) const;
- Constant *getLosslessTrunc(Constant *C, Type *TruncTy, unsigned ExtOp) {
- Constant *TruncC = ConstantExpr::getTrunc(C, TruncTy);
- Constant *ExtTruncC =
- ConstantFoldCastOperand(ExtOp, TruncC, C->getType(), DL);
- if (ExtTruncC && ExtTruncC == C)
- return TruncC;
- return nullptr;
- }
-
- Constant *getLosslessUnsignedTrunc(Constant *C, Type *TruncTy) {
- return getLosslessTrunc(C, TruncTy, Instruction::ZExt);
- }
-
- Constant *getLosslessSignedTrunc(Constant *C, Type *TruncTy) {
- return getLosslessTrunc(C, TruncTy, Instruction::SExt);
- }
-
std::optional<std::pair<Intrinsic::ID, SmallVector<Value *, 3>>>
convertOrOfShiftsToFunnelShift(Instruction &Or);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index d7310b1c741c0..a9aacc707cc20 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1642,10 +1642,11 @@ static Instruction *narrowUDivURem(BinaryOperator &I,
}
Constant *C;
+ auto &DL = IC.getDataLayout();
if (isa<Instruction>(N) && match(N, m_OneUse(m_ZExt(m_Value(X)))) &&
match(D, m_Constant(C))) {
// If the constant is the same in the smaller type, use the narrow version.
- Constant *TruncC = IC.getLosslessUnsignedTrunc(C, X->getType());
+ Constant *TruncC = getLosslessUnsignedTrunc(C, X->getType(), DL);
if (!TruncC)
return nullptr;
@@ -1656,7 +1657,7 @@ static Instruction *narrowUDivURem(BinaryOperator &I,
if (isa<Instruction>(D) && match(D, m_OneUse(m_ZExt(m_Value(X)))) &&
match(N, m_Constant(C))) {
// If the constant is the same in the smaller type, use the narrow version.
- Constant *TruncC = IC.getLosslessUnsignedTrunc(C, X->getType());
+ Constant *TruncC = getLosslessUnsignedTrunc(C, X->getType(), DL);
if (!TruncC)
return nullptr;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
index 6477141ab095f..ed9a0be6981fa 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
@@ -841,7 +841,7 @@ Instruction *InstCombinerImpl::foldPHIArgZextsIntoPHI(PHINode &Phi) {
NumZexts++;
} else if (auto *C = dyn_cast<Constant>(V)) {
// Make sure that constants can fit in the new type.
- Constant *Trunc = getLosslessUnsignedTrunc(C, NarrowType);
+ Constant *Trunc = getLosslessUnsignedTrunc(C, NarrowType, DL);
if (!Trunc)
return nullptr;
NewIncoming.push_back(Trunc);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index ba8b4c47e8f88..9467463d39c0e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -2375,7 +2375,7 @@ Instruction *InstCombinerImpl::foldSelectExtConst(SelectInst &Sel) {
// If the constant is the same after truncation to the smaller type and
// extension to the original type, we can narrow the select.
Type *SelType = Sel.getType();
- Constant *TruncC = getLosslessTrunc(C, SmallType, ExtOpcode);
+ Constant *TruncC = getLosslessInvCast(C, SmallType, ExtOpcode, DL);
if (TruncC && ExtInst->hasOneUse()) {
Value *TruncCVal = cast<Value>(TruncC);
if (ExtInst == Sel.getFalseValue())
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index abb802bab265c..a74f292524b4d 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -2568,7 +2568,7 @@ Instruction *InstCombinerImpl::narrowMathIfNoOverflow(BinaryOperator &BO) {
Constant *WideC;
if (!Op0->hasOneUse() || !match(Op1, m_Constant(WideC)))
return nullptr;
- Constant *NarrowC = getLosslessTrunc(WideC, X->getType(), CastOpc);
+ Constant *NarrowC = getLosslessInvCast(WideC, X->getType(), CastOpc, DL);
if (!NarrowC)
return nullptr;
Y = NarrowC;
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 9dd1532d1b230..17cb18a22336a 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -939,51 +939,6 @@ bool VectorCombine::foldBitOpOfCastops(Instruction &I) {
return true;
}
-struct PreservedCastFlags {
- bool NNeg = false;
- bool NUW = false;
- bool NSW = false;
-};
-
-// Try to cast C to InvC losslessly, satisfying CastOp(InvC) == C.
-// Will try best to preserve the flags.
-static Constant *getLosslessInvCast(Constant *C, Type *InvCastTo,
- Instruction::CastOps CastOp,
- const DataLayout &DL,
- PreservedCastFlags &Flags) {
- switch (CastOp) {
- case Instruction::BitCast:
- // Bitcast is always lossless.
- return ConstantFoldCastOperand(Instruction::BitCast, C, InvCastTo, DL);
- case Instruction::Trunc: {
- auto *ZExtC = ConstantFoldCastOperand(Instruction::ZExt, C, InvCastTo, DL);
- auto *SExtC = ConstantFoldCastOperand(Instruction::SExt, C, InvCastTo, DL);
- // Truncation back on ZExt value is always NUW.
- Flags.NUW = true;
- // Test positivity of C.
- Flags.NSW = ZExtC == SExtC;
- return ZExtC;
- }
- case Instruction::SExt:
- case Instruction::ZExt: {
- auto *InvC = ConstantExpr::getTrunc(C, InvCastTo);
- auto *CastInvC = ConstantFoldCastOperand(CastOp, InvC, C->getType(), DL);
- // Must satisfy CastOp(InvC) == C.
- if (!CastInvC || CastInvC != C)
- return nullptr;
- if (CastOp == Instruction::ZExt) {
- auto *SExtInvC =
- ConstantFoldCastOperand(Instruction::SExt, InvC, C->getType(), DL);
- // Test positivity of InvC.
- Flags.NNeg = CastInvC == SExtInvC;
- }
- return InvC;
- }
- default:
- return nullptr;
- }
-}
-
/// Match:
// bitop(castop(x), C) ->
// bitop(castop(x), castop(InvC)) ->
@@ -1029,7 +984,7 @@ bool VectorCombine::foldBitOpOfCastConstant(Instruction &I) {
// Find the constant InvC, such that castop(InvC) equals to C.
PreservedCastFlags RHSFlags;
- Constant *InvC = getLosslessInvCast(C, SrcTy, CastOpcode, *DL, RHSFlags);
+ Constant *InvC = getLosslessInvCast(C, SrcTy, CastOpcode, *DL, &RHSFlags);
if (!InvC)
return false;
More information about the llvm-commits
mailing list