[llvm] [VectorCombine] Generalize foldBitOpOfBitcasts to support more cast operations (PR #148350)
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 14 01:06:46 PDT 2025
================
@@ -808,46 +808,105 @@ bool VectorCombine::foldInsExtBinop(Instruction &I) {
return true;
}
-bool VectorCombine::foldBitOpOfBitcasts(Instruction &I) {
- // Match: bitop(bitcast(x), bitcast(y)) -> bitcast(bitop(x, y))
- Value *LHSSrc, *RHSSrc;
- if (!match(&I, m_BitwiseLogic(m_BitCast(m_Value(LHSSrc)),
- m_BitCast(m_Value(RHSSrc)))))
+bool VectorCombine::foldBitOpOfCastops(Instruction &I) {
+ // Match: bitop(castop(x), castop(y)) -> castop(bitop(x, y))
+ // Supports: bitcast, trunc, sext, zext
+
+ // Check if this is a bitwise logic operation
+ auto *BinOp = dyn_cast<BinaryOperator>(&I);
+ if (!BinOp || !BinOp->isBitwiseLogicOp())
+ return false;
+
+ LLVM_DEBUG(dbgs() << "Found bitwise logic op: " << I << "\n");
+
+ // Get the cast instructions
+ auto *LHSCast = dyn_cast<CastInst>(BinOp->getOperand(0));
+ auto *RHSCast = dyn_cast<CastInst>(BinOp->getOperand(1));
+ if (!LHSCast || !RHSCast) {
+ LLVM_DEBUG(dbgs() << " One or both operands are not cast instructions\n");
+ return false;
+ }
+
+ LLVM_DEBUG(dbgs() << " LHS cast: " << *LHSCast << "\n");
+ LLVM_DEBUG(dbgs() << " RHS cast: " << *RHSCast << "\n");
+
+ // Both casts must be the same type
+ Instruction::CastOps CastOpcode = LHSCast->getOpcode();
+ if (CastOpcode != RHSCast->getOpcode())
return false;
+ // Only handle supported cast operations
+ switch (CastOpcode) {
+ case Instruction::BitCast:
+ case Instruction::Trunc:
+ case Instruction::SExt:
+ case Instruction::ZExt:
+ break;
+ default:
+ return false;
+ }
+
+ Value *LHSSrc = LHSCast->getOperand(0);
+ Value *RHSSrc = RHSCast->getOperand(0);
+
// Source types must match
if (LHSSrc->getType() != RHSSrc->getType())
return false;
- if (!LHSSrc->getType()->getScalarType()->isIntegerTy())
- return false;
- // Only handle vector types
+ // Only handle vector types with integer elements
auto *SrcVecTy = dyn_cast<FixedVectorType>(LHSSrc->getType());
auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
if (!SrcVecTy || !DstVecTy)
return false;
- // Same total bit width
- assert(SrcVecTy->getPrimitiveSizeInBits() ==
- DstVecTy->getPrimitiveSizeInBits() &&
- "Bitcast should preserve total bit width");
+ if (!SrcVecTy->getScalarType()->isIntegerTy() ||
+ !DstVecTy->getScalarType()->isIntegerTy())
+ return false;
+
+ // Validate cast operation constraints
+ switch (CastOpcode) {
+ case Instruction::BitCast:
+ // Total bit width must be preserved
+ if (SrcVecTy->getPrimitiveSizeInBits() !=
+ DstVecTy->getPrimitiveSizeInBits())
+ return false;
+ break;
+ case Instruction::Trunc:
+ // Source elements must be wider
+ if (SrcVecTy->getScalarSizeInBits() <= DstVecTy->getScalarSizeInBits())
+ return false;
+ break;
+ case Instruction::SExt:
+ case Instruction::ZExt:
+ // Source elements must be narrower
+ if (SrcVecTy->getScalarSizeInBits() >= DstVecTy->getScalarSizeInBits())
+ return false;
+ break;
+ }
// Cost Check :
- // OldCost = bitlogic + 2*bitcasts
- // NewCost = bitlogic + bitcast
- auto *BinOp = cast<BinaryOperator>(&I);
+ // OldCost = bitlogic + 2*casts
+ // NewCost = bitlogic + cast
InstructionCost OldCost =
TTI.getArithmeticInstrCost(BinOp->getOpcode(), DstVecTy) +
- TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, LHSSrc->getType(),
- TTI::CastContextHint::None) +
- TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, RHSSrc->getType(),
- TTI::CastContextHint::None);
+ TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy,
+ TTI::CastContextHint::None) *
+ 2;
----------------
RKSimon wrote:
Can we hoist the the separate getCastInstrCost calls here to avoid calling it again for the !hasOneUse cases below, Add the Instruction* args as well to help improve costs - we can't do it for new cost calc but its still useful for old costs. We're missing the CostKind as well
https://github.com/llvm/llvm-project/pull/148350
More information about the llvm-commits
mailing list