[llvm] [VectorCombine] Support pattern `bitop(cast(x), C) -> cast(bitop(x, InvC))` (PR #155216)

via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 25 00:22:08 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-vectorizers

Author: XChy (XChy)

<details>
<summary>Changes</summary>

Resolves #<!-- -->154797.
This patch adds the fold `bitop(cast(x), C) -> bitop(cast(x), cast(InvC)) -> cast(bitop(x, InvC))`.
The helper function `getLosslessInvCast` tries to calculate the constant `InvC`, satisfying `castop(InvC) == C`, and will try its best to keep the poison-generated flags of the cast operation.

---
Full diff: https://github.com/llvm/llvm-project/pull/155216.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+155) 
- (modified) llvm/test/Transforms/VectorCombine/X86/bitop-of-castops.ll (+160) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 1275d53a075b5..e351e9205499b 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -118,6 +118,7 @@ class VectorCombine {
   bool foldInsExtBinop(Instruction &I);
   bool foldInsExtVectorToShuffle(Instruction &I);
   bool foldBitOpOfCastops(Instruction &I);
+  bool foldBitOpOfCastConstant(Instruction &I);
   bool foldBitcastShuffle(Instruction &I);
   bool scalarizeOpOrCmp(Instruction &I);
   bool scalarizeVPIntrinsic(Instruction &I);
@@ -929,6 +930,158 @@ bool VectorCombine::foldBitOpOfCastops(Instruction &I) {
   return true;
 }
 
+struct PreservedCastFlags {
+  bool NNeg = false;
+  bool NUW = false;
+  bool NSW = false;
+};
+
+// Try to cast C to InvC losslessly, satisfying CastOp(InvC) == C.
+// Will try best to preserve the flags.
+static Constant *getLosslessInvCast(Constant *C, Type *InvCastTo,
+                                    unsigned CastOp, const DataLayout &DL,
+                                    PreservedCastFlags &Flags) {
+  switch (CastOp) {
+  case Instruction::BitCast:
+    // Bitcast is always lossless.
+    return ConstantFoldCastOperand(Instruction::BitCast, C, InvCastTo, DL);
+  case Instruction::Trunc: {
+    auto *ZExtC = ConstantFoldCastOperand(Instruction::ZExt, C, InvCastTo, DL);
+    auto *SExtC = ConstantFoldCastOperand(Instruction::SExt, C, InvCastTo, DL);
+    // Truncation back on ZExt value is always NUW.
+    Flags.NUW = true;
+    // Test positivity of C.
+    Flags.NSW = ZExtC == SExtC;
+    return ZExtC;
+  }
+  case Instruction::SExt:
+  case Instruction::ZExt: {
+    auto *InvC = ConstantExpr::getTrunc(C, InvCastTo);
+    auto *CastInvC = ConstantFoldCastOperand(CastOp, InvC, C->getType(), DL);
+    // Must satisfy CastOp(InvC) == C.
+    if (!CastInvC || CastInvC != C) 
+        return nullptr;
+    if (CastOp == Instruction::ZExt) {
+      auto *SExtInvC =
+          ConstantFoldCastOperand(Instruction::SExt, InvC, C->getType(), DL);
+      // Test positivity of InvC.
+      Flags.NNeg = CastInvC == SExtInvC;
+    }
+    return InvC;
+  }
+  default:
+    return nullptr;
+  }
+}
+
+/// Match:
+// bitop(castop(x), C) ->
+// bitop(castop(x), castop(InvC)) ->
+// castop(bitop(x, InvC))
+// Supports: bitcast, trunc, sext, zext
+bool VectorCombine::foldBitOpOfCastConstant(Instruction &I) {
+  Instruction *LHS;
+  Constant *C;
+
+  // Check if this is a bitwise logic operation
+  if (!match(&I, m_c_BitwiseLogic(m_Instruction(LHS), m_Constant(C))))
+    return false;
+
+  // Get the cast instructions
+  auto *LHSCast = dyn_cast<CastInst>(LHS);
+  if (!LHSCast)
+    return false;
+
+  Instruction::CastOps CastOpcode = LHSCast->getOpcode();
+
+  // Only handle supported cast operations
+  switch (CastOpcode) {
+  case Instruction::BitCast:
+  case Instruction::Trunc:
+  case Instruction::SExt:
+  case Instruction::ZExt:
+    break;
+  default:
+    return false;
+  }
+
+  Value *LHSSrc = LHSCast->getOperand(0);
+
+  // Only handle vector types with integer elements
+  auto *SrcVecTy = dyn_cast<FixedVectorType>(LHSSrc->getType());
+  auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
+  if (!SrcVecTy || !DstVecTy)
+    return false;
+
+  if (!SrcVecTy->getScalarType()->isIntegerTy() ||
+      !DstVecTy->getScalarType()->isIntegerTy())
+    return false;
+
+  // Find the constant InvC, such that castop(InvC) equals to C.
+  PreservedCastFlags RHSFlags;
+  Constant *InvC = getLosslessInvCast(C, SrcVecTy, CastOpcode, *DL, RHSFlags);
+  if (!InvC)
+    return false;
+
+  // Cost Check :
+  // OldCost = bitlogic + cast
+  // NewCost = bitlogic + cast
+
+  // Calculate specific costs for each cast with instruction context
+  InstructionCost LHSCastCost =
+      TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy,
+                           TTI::CastContextHint::None, CostKind, LHSCast);
+
+  InstructionCost OldCost =
+      TTI.getArithmeticInstrCost(I.getOpcode(), DstVecTy, CostKind) +
+      LHSCastCost;
+
+  // For new cost, we can't provide an instruction (it doesn't exist yet)
+  InstructionCost GenericCastCost = TTI.getCastInstrCost(
+      CastOpcode, DstVecTy, SrcVecTy, TTI::CastContextHint::None, CostKind);
+
+  InstructionCost NewCost =
+      TTI.getArithmeticInstrCost(I.getOpcode(), SrcVecTy, CostKind) +
+      GenericCastCost;
+
+  // Account for multi-use casts using specific costs
+  if (!LHSCast->hasOneUse())
+    NewCost += LHSCastCost;
+
+  LLVM_DEBUG(dbgs() << "foldBitOpOfCastConstant: OldCost=" << OldCost
+                    << " NewCost=" << NewCost << "\n");
+
+  if (NewCost > OldCost)
+    return false;
+
+  // Create the operation on the source type
+  Value *NewOp = Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(),
+                                     LHSSrc, InvC, I.getName() + ".inner");
+  if (auto *NewBinOp = dyn_cast<BinaryOperator>(NewOp))
+    NewBinOp->copyIRFlags(&I);
+
+  Worklist.pushValue(NewOp);
+
+  // Create the cast operation directly to ensure we get a new instruction
+  Instruction *NewCast = CastInst::Create(CastOpcode, NewOp, I.getType());
+
+  // Preserve cast instruction flags
+  if (RHSFlags.NNeg)
+    NewCast->setNonNeg();
+  if (RHSFlags.NSW)
+    NewCast->setHasNoSignedWrap();
+  if (RHSFlags.NUW)
+    NewCast->setHasNoUnsignedWrap();
+
+  NewCast->andIRFlags(LHSCast);
+
+  // Insert the new instruction
+  Value *Result = Builder.Insert(NewCast);
+
+  replaceValue(I, *Result);
+  return true;
+}
+
 /// If this is a bitcast of a shuffle, try to bitcast the source vector to the
 /// destination type followed by shuffle. This can enable further transforms by
 /// moving bitcasts or shuffles together.
@@ -4206,6 +4359,8 @@ bool VectorCombine::run() {
       case Instruction::Xor:
         if (foldBitOpOfCastops(I))
           return true;
+        if (foldBitOpOfCastConstant(I))
+          return true;
         break;
       case Instruction::PHI:
         if (shrinkPhiOfShuffles(I))
diff --git a/llvm/test/Transforms/VectorCombine/X86/bitop-of-castops.ll b/llvm/test/Transforms/VectorCombine/X86/bitop-of-castops.ll
index 220556c8c38c3..cd77818a2f9b6 100644
--- a/llvm/test/Transforms/VectorCombine/X86/bitop-of-castops.ll
+++ b/llvm/test/Transforms/VectorCombine/X86/bitop-of-castops.ll
@@ -260,3 +260,163 @@ define <4 x i32> @or_zext_nneg(<4 x i16> %a, <4 x i16> %b) {
   %or = or <4 x i32> %z1, %z2
   ret <4 x i32> %or
 }
+
+; Test bitwise operations with integer-to-integer bitcast with one constant
+define <2 x i32> @or_bitcast_v4i16_to_v2i32_constant(<4 x i16> %a) {
+; CHECK-LABEL: @or_bitcast_v4i16_to_v2i32_constant(
+; CHECK-NEXT:    [[A:%.*]] = or <4 x i16> [[A1:%.*]], <i16 16960, i16 15, i16 -31616, i16 30>
+; CHECK-NEXT:    [[BC1:%.*]] = bitcast <4 x i16> [[A]] to <2 x i32>
+; CHECK-NEXT:    ret <2 x i32> [[BC1]]
+;
+  %bc1 = bitcast <4 x i16> %a to <2 x i32>
+  %or = or <2 x i32> %bc1, <i32 1000000, i32 2000000>
+  ret <2 x i32> %or
+}
+
+define <2 x i32> @or_bitcast_v4i16_to_v2i32_constant_commuted(<4 x i16> %a) {
+; CHECK-LABEL: @or_bitcast_v4i16_to_v2i32_constant_commuted(
+; CHECK-NEXT:    [[A:%.*]] = or <4 x i16> [[A1:%.*]], <i16 16960, i16 15, i16 -31616, i16 30>
+; CHECK-NEXT:    [[BC1:%.*]] = bitcast <4 x i16> [[A]] to <2 x i32>
+; CHECK-NEXT:    ret <2 x i32> [[BC1]]
+;
+  %bc1 = bitcast <4 x i16> %a to <2 x i32>
+  %or = or <2 x i32> <i32 1000000, i32 2000000>, %bc1
+  ret <2 x i32> %or
+}
+
+; Test bitwise operations with truncate and one constant
+define <4 x i16> @or_trunc_v4i32_to_v4i16_constant(<4 x i32> %a) {
+; CHECK-LABEL: @or_trunc_v4i32_to_v4i16_constant(
+; CHECK-NEXT:    [[A:%.*]] = or <4 x i32> [[A1:%.*]], <i32 1, i32 2, i32 3, i32 4>
+; CHECK-NEXT:    [[T1:%.*]] = trunc <4 x i32> [[A]] to <4 x i16>
+; CHECK-NEXT:    ret <4 x i16> [[T1]]
+;
+  %t1 = trunc <4 x i32> %a to <4 x i16>
+  %or = or <4 x i16> %t1, <i16 1, i16 2, i16 3, i16 4>
+  ret <4 x i16> %or
+}
+
+; Test bitwise operations with zero extend and one constant
+define <4 x i32> @or_zext_v4i16_to_v4i32_constant(<4 x i16> %a) {
+; CHECK-LABEL: @or_zext_v4i16_to_v4i32_constant(
+; CHECK-NEXT:    [[A:%.*]] = or <4 x i16> [[A1:%.*]], <i16 1, i16 2, i16 3, i16 4>
+; CHECK-NEXT:    [[Z1:%.*]] = zext <4 x i16> [[A]] to <4 x i32>
+; CHECK-NEXT:    ret <4 x i32> [[Z1]]
+;
+  %z1 = zext <4 x i16> %a to <4 x i32>
+  %or = or <4 x i32> %z1, <i32 1, i32 2, i32 3, i32 4>
+  ret <4 x i32> %or
+}
+
+define <4 x i32> @or_zext_v4i8_to_v4i32_constant_with_loss(<4 x i8> %a) {
+; CHECK-LABEL: @or_zext_v4i8_to_v4i32_constant_with_loss(
+; CHECK-NEXT:    [[Z1:%.*]] = zext <4 x i8> [[A:%.*]] to <4 x i32>
+; CHECK-NEXT:    [[OR:%.*]] = or <4 x i32> [[Z1]], <i32 1024, i32 129, i32 3, i32 4>
+; CHECK-NEXT:    ret <4 x i32> [[OR]]
+;
+  %z1 = zext <4 x i8> %a to <4 x i32>
+  %or = or <4 x i32> %z1, <i32 1024, i32 129, i32 3, i32 4>
+  ret <4 x i32> %or
+}
+
+; Test bitwise operations with sign extend and one constant
+define <4 x i32> @or_sext_v4i8_to_v4i32_positive_constant(<4 x i8> %a) {
+; CHECK-LABEL: @or_sext_v4i8_to_v4i32_positive_constant(
+; CHECK-NEXT:    [[A:%.*]] = or <4 x i8> [[A1:%.*]], <i8 1, i8 2, i8 3, i8 4>
+; CHECK-NEXT:    [[S1:%.*]] = sext <4 x i8> [[A]] to <4 x i32>
+; CHECK-NEXT:    ret <4 x i32> [[S1]]
+;
+  %s1 = sext <4 x i8> %a to <4 x i32>
+  %or = or <4 x i32> %s1, <i32 1, i32 2, i32 3, i32 4>
+  ret <4 x i32> %or
+}
+
+define <4 x i32> @or_sext_v4i8_to_v4i32_minus_constant(<4 x i8> %a) {
+; CHECK-LABEL: @or_sext_v4i8_to_v4i32_minus_constant(
+; CHECK-NEXT:    [[A:%.*]] = or <4 x i8> [[A1:%.*]], <i8 -1, i8 -2, i8 -3, i8 -4>
+; CHECK-NEXT:    [[S1:%.*]] = sext <4 x i8> [[A]] to <4 x i32>
+; CHECK-NEXT:    ret <4 x i32> [[S1]]
+;
+  %s1 = sext <4 x i8> %a to <4 x i32>
+  %or = or <4 x i32> %s1, <i32 -1, i32 -2, i32 -3, i32 -4>
+  ret <4 x i32> %or
+}
+
+define <4 x i32> @or_sext_v4i8_to_v4i32_constant_with_loss(<4 x i8> %a) {
+; CHECK-LABEL: @or_sext_v4i8_to_v4i32_constant_with_loss(
+; CHECK-NEXT:    [[Z1:%.*]] = sext <4 x i8> [[A:%.*]] to <4 x i32>
+; CHECK-NEXT:    [[OR:%.*]] = or <4 x i32> [[Z1]], <i32 -10000, i32 2, i32 3, i32 4>
+; CHECK-NEXT:    ret <4 x i32> [[OR]]
+;
+  %z1 = sext <4 x i8> %a to <4 x i32>
+  %or = or <4 x i32> %z1, <i32 -10000, i32 2, i32 3, i32 4>
+  ret <4 x i32> %or
+}
+
+; Test truncate with flag preservation and one constant
+define <4 x i16> @and_trunc_nuw_nsw_constant(<4 x i32> %a) {
+; CHECK-LABEL: @and_trunc_nuw_nsw_constant(
+; CHECK-NEXT:    [[A:%.*]] = and <4 x i32> [[A1:%.*]], <i32 1, i32 2, i32 3, i32 4>
+; CHECK-NEXT:    [[T1:%.*]] = trunc nuw nsw <4 x i32> [[A]] to <4 x i16>
+; CHECK-NEXT:    ret <4 x i16> [[T1]]
+;
+  %t1 = trunc nuw nsw <4 x i32> %a to <4 x i16>
+  %and = and <4 x i16> %t1, <i16 1, i16 2, i16 3, i16 4>
+  ret <4 x i16> %and
+}
+
+define <4 x i8> @and_trunc_nuw_nsw_minus_constant(<4 x i32> %a) {
+; CHECK-LABEL: @and_trunc_nuw_nsw_minus_constant(
+; CHECK-NEXT:    [[AND_INNER:%.*]] = and <4 x i32> [[A:%.*]], <i32 240, i32 241, i32 242, i32 243>
+; CHECK-NEXT:    [[AND:%.*]] = trunc nuw <4 x i32> [[AND_INNER]] to <4 x i8>
+; CHECK-NEXT:    ret <4 x i8> [[AND]]
+;
+  %t1 = trunc nuw nsw <4 x i32> %a to <4 x i8>
+  %and = and <4 x i8> %t1, <i8 240, i8 241, i8 242, i8 243>
+  ret <4 x i8> %and
+}
+
+define <4 x i8> @and_trunc_nuw_nsw_multiconstant(<4 x i32> %a) {
+; CHECK-LABEL: @and_trunc_nuw_nsw_multiconstant(
+; CHECK-NEXT:    [[AND_INNER:%.*]] = and <4 x i32> [[A:%.*]], <i32 240, i32 1, i32 242, i32 3>
+; CHECK-NEXT:    [[AND:%.*]] = trunc nuw <4 x i32> [[AND_INNER]] to <4 x i8>
+; CHECK-NEXT:    ret <4 x i8> [[AND]]
+;
+  %t1 = trunc nuw nsw <4 x i32> %a to <4 x i8>
+  %and = and <4 x i8> %t1, <i8 240, i8 1, i8 242, i8 3>
+  ret <4 x i8> %and
+}
+
+; Test sign extend with nneg flag and one constant
+define <4 x i32> @or_zext_nneg_constant(<4 x i16> %a) {
+; CHECK-LABEL: @or_zext_nneg_constant(
+; CHECK-NEXT:    [[A:%.*]] = or <4 x i16> [[A1:%.*]], <i16 1, i16 2, i16 3, i16 4>
+; CHECK-NEXT:    [[Z1:%.*]] = zext nneg <4 x i16> [[A]] to <4 x i32>
+; CHECK-NEXT:    ret <4 x i32> [[Z1]]
+;
+  %z1 = zext nneg <4 x i16> %a to <4 x i32>
+  %or = or <4 x i32> %z1, <i32 1, i32 2, i32 3, i32 4>
+  ret <4 x i32> %or
+}
+
+define <4 x i32> @or_zext_nneg_minus_constant(<4 x i8> %a) {
+; CHECK-LABEL: @or_zext_nneg_minus_constant(
+; CHECK-NEXT:    [[OR_INNER:%.*]] = or <4 x i8> [[A:%.*]], <i8 -16, i8 -15, i8 -14, i8 -13>
+; CHECK-NEXT:    [[OR:%.*]] = zext <4 x i8> [[OR_INNER]] to <4 x i32>
+; CHECK-NEXT:    ret <4 x i32> [[OR]]
+;
+  %z1 = zext nneg <4 x i8> %a to <4 x i32>
+  %or = or <4 x i32> %z1, <i32 240, i32 241, i32 242, i32 243>
+  ret <4 x i32> %or
+}
+
+define <4 x i32> @or_zext_nneg_multiconstant(<4 x i8> %a) {
+; CHECK-LABEL: @or_zext_nneg_multiconstant(
+; CHECK-NEXT:    [[OR_INNER:%.*]] = or <4 x i8> [[A:%.*]], <i8 -16, i8 1, i8 -14, i8 3>
+; CHECK-NEXT:    [[OR:%.*]] = zext <4 x i8> [[OR_INNER]] to <4 x i32>
+; CHECK-NEXT:    ret <4 x i32> [[OR]]
+;
+  %z1 = zext nneg <4 x i8> %a to <4 x i32>
+  %or = or <4 x i32> %z1, <i32 240, i32 1, i32 242, i32 3>
+  ret <4 x i32> %or
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/155216


More information about the llvm-commits mailing list