[llvm] [InstCombine] [X86] pblendvb intrinsics must be replaced by select when possible (PR #137322)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 16 10:00:06 PDT 2025


https://github.com/vortex73 updated https://github.com/llvm/llvm-project/pull/137322

>From 96ed7cbe1d757c49948ff416364b5a20677f3f73 Mon Sep 17 00:00:00 2001
From: Narayan Sreekumar <nsreekumar6 at gmail.com>
Date: Mon, 16 Jun 2025 01:25:01 +0530
Subject: [PATCH 1/2] [InstCombine] VectorCombine Pass

---
 .../Transforms/Vectorize/VectorCombine.cpp    | 58 +++++++++++++++++++
 .../PhaseOrdering/X86/blendv-select.ll        | 30 ++++------
 2 files changed, 69 insertions(+), 19 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 04c084ffdda97..480b4f27e963f 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -111,6 +111,7 @@ class VectorCombine {
   bool foldInsExtFNeg(Instruction &I);
   bool foldInsExtBinop(Instruction &I);
   bool foldInsExtVectorToShuffle(Instruction &I);
+  bool foldBitOpOfBitcasts(Instruction &I);
   bool foldBitcastShuffle(Instruction &I);
   bool scalarizeBinopOrCmp(Instruction &I);
   bool scalarizeVPIntrinsic(Instruction &I);
@@ -801,6 +802,58 @@ bool VectorCombine::foldInsExtBinop(Instruction &I) {
   return true;
 }
 
+bool VectorCombine::foldBitOpOfBitcasts(Instruction &I) {
+  // Match: bitop(bitcast(x), bitcast(y)) -> bitcast(bitop(x, y))
+  auto *BinOp = dyn_cast<BinaryOperator>(&I);
+  if (!BinOp || !BinOp->isBitwiseLogicOp())
+    return false;
+
+  Value *LHS = BinOp->getOperand(0);
+  Value *RHS = BinOp->getOperand(1);
+
+  // Both operands must be bitcasts
+  auto *LHSCast = dyn_cast<BitCastInst>(LHS);
+  auto *RHSCast = dyn_cast<BitCastInst>(RHS);
+  if (!LHSCast || !RHSCast)
+    return false;
+
+  Value *LHSSrc = LHSCast->getOperand(0);
+  Value *RHSSrc = RHSCast->getOperand(0);
+
+  // Source types must match
+  if (LHSSrc->getType() != RHSSrc->getType())
+    return false;
+
+  // Only handle vector types
+  auto *SrcVecTy = dyn_cast<FixedVectorType>(LHSSrc->getType());
+  auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
+  if (!SrcVecTy || !DstVecTy)
+    return false;
+
+  // Same total bit width
+  if (SrcVecTy->getPrimitiveSizeInBits() != DstVecTy->getPrimitiveSizeInBits())
+    return false;
+
+  // Cost check: prefer operations on narrower element types
+  unsigned SrcEltBits = SrcVecTy->getScalarSizeInBits();
+  unsigned DstEltBits = DstVecTy->getScalarSizeInBits();
+
+  // Prefer smaller element sizes (more elements, finer granularity)
+  if (SrcEltBits > DstEltBits)
+    return false;
+
+  // Create the operation on the source type
+  Value *NewOp = Builder.CreateBinOp(BinOp->getOpcode(), LHSSrc, RHSSrc,
+                                     BinOp->getName() + ".inner");
+  if (auto *NewBinOp = dyn_cast<BinaryOperator>(NewOp))
+    NewBinOp->copyIRFlags(BinOp);
+
+  // Bitcast the result back
+  Value *Result = Builder.CreateBitCast(NewOp, I.getType());
+  replaceValue(I, *Result);
+  return true;
+}
+
 /// If this is a bitcast of a shuffle, try to bitcast the source vector to the
 /// destination type followed by shuffle. This can enable further transforms by
 /// moving bitcasts or shuffles together.
@@ -3562,6 +3615,11 @@ bool VectorCombine::run() {
       case Instruction::BitCast:
         MadeChange |= foldBitcastShuffle(I);
         break;
+      case Instruction::And:
+      case Instruction::Or:
+      case Instruction::Xor:
+        MadeChange |= foldBitOpOfBitcasts(I);
+        break;
       default:
         MadeChange |= shrinkType(I);
         break;
diff --git a/llvm/test/Transforms/PhaseOrdering/X86/blendv-select.ll b/llvm/test/Transforms/PhaseOrdering/X86/blendv-select.ll
index 22e4239009dd2..daf4a7b799dd4 100644
--- a/llvm/test/Transforms/PhaseOrdering/X86/blendv-select.ll
+++ b/llvm/test/Transforms/PhaseOrdering/X86/blendv-select.ll
@@ -477,30 +477,22 @@ define <2 x i64> @PR66513(<2 x i64> %a, <2 x i64> %b, <2 x i64> %c, <2 x i64> %s
 ; CHECK-LABEL: @PR66513(
 ; CHECK-NEXT:    [[I:%.*]] = bitcast <2 x i64> [[A:%.*]] to <4 x i32>
 ; CHECK-NEXT:    [[CMP_I23:%.*]] = icmp sgt <4 x i32> [[I]], zeroinitializer
-; CHECK-NEXT:    [[SEXT_I24:%.*]] = sext <4 x i1> [[CMP_I23]] to <4 x i32>
-; CHECK-NEXT:    [[I1:%.*]] = bitcast <4 x i32> [[SEXT_I24]] to <2 x i64>
 ; CHECK-NEXT:    [[I2:%.*]] = bitcast <2 x i64> [[B:%.*]] to <4 x i32>
 ; CHECK-NEXT:    [[CMP_I21:%.*]] = icmp sgt <4 x i32> [[I2]], zeroinitializer
-; CHECK-NEXT:    [[SEXT_I22:%.*]] = sext <4 x i1> [[CMP_I21]] to <4 x i32>
-; CHECK-NEXT:    [[I3:%.*]] = bitcast <4 x i32> [[SEXT_I22]] to <2 x i64>
 ; CHECK-NEXT:    [[I4:%.*]] = bitcast <2 x i64> [[C:%.*]] to <4 x i32>
 ; CHECK-NEXT:    [[CMP_I:%.*]] = icmp sgt <4 x i32> [[I4]], zeroinitializer
-; CHECK-NEXT:    [[SEXT_I:%.*]] = sext <4 x i1> [[CMP_I]] to <4 x i32>
+; CHECK-NEXT:    [[NARROW:%.*]] = select <4 x i1> [[CMP_I21]], <4 x i1> [[CMP_I23]], <4 x i1> zeroinitializer
+; CHECK-NEXT:    [[XOR_I_INNER1:%.*]] = xor <4 x i1> [[NARROW]], [[CMP_I]]
+; CHECK-NEXT:    [[NARROW3:%.*]] = select <4 x i1> [[CMP_I23]], <4 x i1> [[XOR_I_INNER1]], <4 x i1> zeroinitializer
+; CHECK-NEXT:    [[AND_I25_INNER2:%.*]] = and <4 x i1> [[XOR_I_INNER1]], [[CMP_I21]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <2 x i64> [[SRC:%.*]] to <4 x i32>
+; CHECK-NEXT:    [[TMP2:%.*]] = select <4 x i1> [[NARROW]], <4 x i32> [[TMP1]], <4 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = bitcast <2 x i64> [[A]] to <4 x i32>
+; CHECK-NEXT:    [[TMP4:%.*]] = select <4 x i1> [[NARROW3]], <4 x i32> [[TMP3]], <4 x i32> [[TMP2]]
+; CHECK-NEXT:    [[TMP5:%.*]] = bitcast <2 x i64> [[B]] to <4 x i32>
+; CHECK-NEXT:    [[SEXT_I:%.*]] = select <4 x i1> [[AND_I25_INNER2]], <4 x i32> [[TMP5]], <4 x i32> [[TMP4]]
 ; CHECK-NEXT:    [[I5:%.*]] = bitcast <4 x i32> [[SEXT_I]] to <2 x i64>
-; CHECK-NEXT:    [[AND_I27:%.*]] = and <2 x i64> [[I3]], [[I1]]
-; CHECK-NEXT:    [[XOR_I:%.*]] = xor <2 x i64> [[AND_I27]], [[I5]]
-; CHECK-NEXT:    [[AND_I26:%.*]] = and <2 x i64> [[XOR_I]], [[I1]]
-; CHECK-NEXT:    [[AND_I25:%.*]] = and <2 x i64> [[XOR_I]], [[I3]]
-; CHECK-NEXT:    [[AND_I:%.*]] = and <2 x i64> [[AND_I27]], [[SRC:%.*]]
-; CHECK-NEXT:    [[I6:%.*]] = bitcast <2 x i64> [[AND_I]] to <16 x i8>
-; CHECK-NEXT:    [[I7:%.*]] = bitcast <2 x i64> [[A]] to <16 x i8>
-; CHECK-NEXT:    [[I8:%.*]] = bitcast <2 x i64> [[AND_I26]] to <16 x i8>
-; CHECK-NEXT:    [[I9:%.*]] = tail call <16 x i8> @llvm.x86.sse41.pblendvb(<16 x i8> [[I6]], <16 x i8> [[I7]], <16 x i8> [[I8]])
-; CHECK-NEXT:    [[I12:%.*]] = bitcast <2 x i64> [[B]] to <16 x i8>
-; CHECK-NEXT:    [[I13:%.*]] = bitcast <2 x i64> [[AND_I25]] to <16 x i8>
-; CHECK-NEXT:    [[I14:%.*]] = tail call <16 x i8> @llvm.x86.sse41.pblendvb(<16 x i8> [[I9]], <16 x i8> [[I12]], <16 x i8> [[I13]])
-; CHECK-NEXT:    [[I15:%.*]] = bitcast <16 x i8> [[I14]] to <2 x i64>
-; CHECK-NEXT:    ret <2 x i64> [[I15]]
+; CHECK-NEXT:    ret <2 x i64> [[I5]]
 ;
   %i = bitcast <2 x i64> %a to <4 x i32>
   %cmp.i23 = icmp sgt <4 x i32> %i, zeroinitializer

>From a78434737ceb241d3c3b9afa126c4370faa2915f Mon Sep 17 00:00:00 2001
From: Narayan Sreekumar <nsreekumar6 at gmail.com>
Date: Mon, 16 Jun 2025 22:28:16 +0530
Subject: [PATCH 2/2] [InstCombine] Minor Tweaks

---
 .../Transforms/Vectorize/VectorCombine.cpp    | 29 ++++++++++++++-----
 .../VectorCombine/AArch64/shrink-types.ll     | 29 ++++++++-----------
 2 files changed, 33 insertions(+), 25 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 480b4f27e963f..6b1419210f363 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -824,6 +824,9 @@ bool VectorCombine::foldBitOpOfBitcasts(Instruction &I) {
   if (LHSSrc->getType() != RHSSrc->getType())
     return false;
 
+  if (!LHSSrc->getType()->getScalarType()->isIntegerTy())
+    return false;
+
   // Only handle vector types
   auto *SrcVecTy = dyn_cast<FixedVectorType>(LHSSrc->getType());
   auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
@@ -831,15 +834,25 @@ bool VectorCombine::foldBitOpOfBitcasts(Instruction &I) {
     return false;
 
   // Same total bit width
-  if (SrcVecTy->getPrimitiveSizeInBits() != DstVecTy->getPrimitiveSizeInBits())
-    return false;
+  assert(SrcVecTy->getPrimitiveSizeInBits() ==
+             DstVecTy->getPrimitiveSizeInBits() &&
+         "Bitcast should preserve total bit width");
+
+  // Cost Check :
+  // OldCost = bitlogic + 2*bitcasts
+  // NewCost = bitlogic + bitcast
+  auto OldCost =
+      TTI.getArithmeticInstrCost(BinOp->getOpcode(), DstVecTy) +
+      TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, LHSSrc->getType(),
+                           TTI::CastContextHint::None) +
+      TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, RHSSrc->getType(),
+                           TTI::CastContextHint::None);
+
+  auto NewCost = TTI.getArithmeticInstrCost(BinOp->getOpcode(), SrcVecTy) +
+                 TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, SrcVecTy,
+                                      TTI::CastContextHint::None);
 
-  // Cost check: prefer operations on narrower element types
-  unsigned SrcEltBits = SrcVecTy->getScalarSizeInBits();
-  unsigned DstEltBits = DstVecTy->getScalarSizeInBits();
-
-  // Prefer smaller element sizes (more elements, finer granularity)
-  if (SrcEltBits > DstEltBits)
+  if (NewCost > OldCost)
     return false;
 
   // Create the operation on the source type
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll b/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll
index 3c672efbb5a07..761ad80d560e8 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll
@@ -7,9 +7,8 @@ define i32 @test_and(<16 x i32> %a, ptr %b) {
 ; CHECK-LABEL: @test_and(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
-; CHECK-NEXT:    [[TMP0:%.*]] = trunc <16 x i32> [[A:%.*]] to <16 x i8>
-; CHECK-NEXT:    [[TMP1:%.*]] = and <16 x i8> [[WIDE_LOAD]], [[TMP0]]
-; CHECK-NEXT:    [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP0:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
+; CHECK-NEXT:    [[TMP2:%.*]] = and <16 x i32> [[TMP0]], [[A:%.*]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]])
 ; CHECK-NEXT:    ret i32 [[TMP3]]
 ;
@@ -26,9 +25,8 @@ define i32 @test_mask_or(<16 x i32> %a, ptr %b) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
 ; CHECK-NEXT:    [[A_MASKED:%.*]] = and <16 x i32> [[A:%.*]], splat (i32 16)
-; CHECK-NEXT:    [[TMP0:%.*]] = trunc <16 x i32> [[A_MASKED]] to <16 x i8>
-; CHECK-NEXT:    [[TMP1:%.*]] = or <16 x i8> [[WIDE_LOAD]], [[TMP0]]
-; CHECK-NEXT:    [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP0:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
+; CHECK-NEXT:    [[TMP2:%.*]] = or <16 x i32> [[TMP0]], [[A_MASKED]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]])
 ; CHECK-NEXT:    ret i32 [[TMP3]]
 ;
@@ -47,15 +45,13 @@ define i32 @multiuse(<16 x i32> %u, <16 x i32> %v, ptr %b) {
 ; CHECK-NEXT:    [[U_MASKED:%.*]] = and <16 x i32> [[U:%.*]], splat (i32 255)
 ; CHECK-NEXT:    [[V_MASKED:%.*]] = and <16 x i32> [[V:%.*]], splat (i32 255)
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
-; CHECK-NEXT:    [[TMP0:%.*]] = lshr <16 x i8> [[WIDE_LOAD]], splat (i8 4)
-; CHECK-NEXT:    [[TMP1:%.*]] = trunc <16 x i32> [[V_MASKED]] to <16 x i8>
-; CHECK-NEXT:    [[TMP2:%.*]] = or <16 x i8> [[TMP0]], [[TMP1]]
-; CHECK-NEXT:    [[TMP3:%.*]] = zext <16 x i8> [[TMP2]] to <16 x i32>
-; CHECK-NEXT:    [[TMP4:%.*]] = and <16 x i8> [[WIDE_LOAD]], splat (i8 15)
-; CHECK-NEXT:    [[TMP5:%.*]] = trunc <16 x i32> [[U_MASKED]] to <16 x i8>
-; CHECK-NEXT:    [[TMP6:%.*]] = or <16 x i8> [[TMP4]], [[TMP5]]
+; CHECK-NEXT:    [[TMP0:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
+; CHECK-NEXT:    [[TMP6:%.*]] = lshr <16 x i8> [[WIDE_LOAD]], splat (i8 4)
 ; CHECK-NEXT:    [[TMP7:%.*]] = zext <16 x i8> [[TMP6]] to <16 x i32>
-; CHECK-NEXT:    [[TMP8:%.*]] = add nuw nsw <16 x i32> [[TMP3]], [[TMP7]]
+; CHECK-NEXT:    [[TMP3:%.*]] = or <16 x i32> [[TMP7]], [[V_MASKED]]
+; CHECK-NEXT:    [[TMP4:%.*]] = and <16 x i32> [[TMP0]], splat (i32 15)
+; CHECK-NEXT:    [[TMP5:%.*]] = or <16 x i32> [[TMP4]], [[U_MASKED]]
+; CHECK-NEXT:    [[TMP8:%.*]] = add nuw nsw <16 x i32> [[TMP3]], [[TMP5]]
 ; CHECK-NEXT:    [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP8]])
 ; CHECK-NEXT:    ret i32 [[TMP9]]
 ;
@@ -81,9 +77,8 @@ define i32 @phi_bug(<16 x i32> %a, ptr %b) {
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[A_PHI:%.*]] = phi <16 x i32> [ [[A:%.*]], [[ENTRY:%.*]] ]
 ; CHECK-NEXT:    [[WIDE_LOAD_PHI:%.*]] = phi <16 x i8> [ [[WIDE_LOAD]], [[ENTRY]] ]
-; CHECK-NEXT:    [[TMP0:%.*]] = trunc <16 x i32> [[A_PHI]] to <16 x i8>
-; CHECK-NEXT:    [[TMP1:%.*]] = and <16 x i8> [[WIDE_LOAD_PHI]], [[TMP0]]
-; CHECK-NEXT:    [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP0:%.*]] = zext <16 x i8> [[WIDE_LOAD_PHI]] to <16 x i32>
+; CHECK-NEXT:    [[TMP2:%.*]] = and <16 x i32> [[TMP0]], [[A_PHI]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]])
 ; CHECK-NEXT:    ret i32 [[TMP3]]
 ;



More information about the llvm-commits mailing list