[llvm] [VectorCombine] Generalize foldBitOpOfBitcasts to support more cast operations (PR #148350)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 14 01:06:45 PDT 2025


================
@@ -808,46 +808,105 @@ bool VectorCombine::foldInsExtBinop(Instruction &I) {
   return true;
 }
 
-bool VectorCombine::foldBitOpOfBitcasts(Instruction &I) {
-  // Match: bitop(bitcast(x), bitcast(y)) -> bitcast(bitop(x, y))
-  Value *LHSSrc, *RHSSrc;
-  if (!match(&I, m_BitwiseLogic(m_BitCast(m_Value(LHSSrc)),
-                                m_BitCast(m_Value(RHSSrc)))))
+bool VectorCombine::foldBitOpOfCastops(Instruction &I) {
+  // Match: bitop(castop(x), castop(y)) -> castop(bitop(x, y))
+  // Supports: bitcast, trunc, sext, zext
+
+  // Check if this is a bitwise logic operation
+  auto *BinOp = dyn_cast<BinaryOperator>(&I);
+  if (!BinOp || !BinOp->isBitwiseLogicOp())
+    return false;
+
+  LLVM_DEBUG(dbgs() << "Found bitwise logic op: " << I << "\n");
+
+  // Get the cast instructions
+  auto *LHSCast = dyn_cast<CastInst>(BinOp->getOperand(0));
+  auto *RHSCast = dyn_cast<CastInst>(BinOp->getOperand(1));
+  if (!LHSCast || !RHSCast) {
+    LLVM_DEBUG(dbgs() << "  One or both operands are not cast instructions\n");
+    return false;
+  }
+
+  LLVM_DEBUG(dbgs() << "  LHS cast: " << *LHSCast << "\n");
+  LLVM_DEBUG(dbgs() << "  RHS cast: " << *RHSCast << "\n");
+
+  // Both casts must be the same type
+  Instruction::CastOps CastOpcode = LHSCast->getOpcode();
+  if (CastOpcode != RHSCast->getOpcode())
     return false;
 
+  // Only handle supported cast operations
+  switch (CastOpcode) {
+  case Instruction::BitCast:
+  case Instruction::Trunc:
+  case Instruction::SExt:
+  case Instruction::ZExt:
+    break;
+  default:
+    return false;
+  }
+
+  Value *LHSSrc = LHSCast->getOperand(0);
+  Value *RHSSrc = RHSCast->getOperand(0);
+
   // Source types must match
   if (LHSSrc->getType() != RHSSrc->getType())
     return false;
-  if (!LHSSrc->getType()->getScalarType()->isIntegerTy())
-    return false;
 
-  // Only handle vector types
+  // Only handle vector types with integer elements
   auto *SrcVecTy = dyn_cast<FixedVectorType>(LHSSrc->getType());
   auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
   if (!SrcVecTy || !DstVecTy)
     return false;
 
-  // Same total bit width
-  assert(SrcVecTy->getPrimitiveSizeInBits() ==
-             DstVecTy->getPrimitiveSizeInBits() &&
-         "Bitcast should preserve total bit width");
+  if (!SrcVecTy->getScalarType()->isIntegerTy() ||
+      !DstVecTy->getScalarType()->isIntegerTy())
+    return false;
+
+  // Validate cast operation constraints
+  switch (CastOpcode) {
+  case Instruction::BitCast:
+    // Total bit width must be preserved
+    if (SrcVecTy->getPrimitiveSizeInBits() !=
+        DstVecTy->getPrimitiveSizeInBits())
+      return false;
+    break;
+  case Instruction::Trunc:
+    // Source elements must be wider
+    if (SrcVecTy->getScalarSizeInBits() <= DstVecTy->getScalarSizeInBits())
+      return false;
+    break;
+  case Instruction::SExt:
+  case Instruction::ZExt:
+    // Source elements must be narrower
+    if (SrcVecTy->getScalarSizeInBits() >= DstVecTy->getScalarSizeInBits())
+      return false;
+    break;
+  }
 
   // Cost Check :
-  // OldCost = bitlogic + 2*bitcasts
-  // NewCost = bitlogic + bitcast
-  auto *BinOp = cast<BinaryOperator>(&I);
+  // OldCost = bitlogic + 2*casts
+  // NewCost = bitlogic + cast
   InstructionCost OldCost =
       TTI.getArithmeticInstrCost(BinOp->getOpcode(), DstVecTy) +
-      TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, LHSSrc->getType(),
-                           TTI::CastContextHint::None) +
-      TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, RHSSrc->getType(),
-                           TTI::CastContextHint::None);
+      TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy,
+                           TTI::CastContextHint::None) *
+          2;
+
   InstructionCost NewCost =
       TTI.getArithmeticInstrCost(BinOp->getOpcode(), SrcVecTy) +
-      TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, SrcVecTy,
+      TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy,
                            TTI::CastContextHint::None);
----------------
RKSimon wrote:

Missing CostKind.

https://github.com/llvm/llvm-project/pull/148350


More information about the llvm-commits mailing list