[llvm] [InstCombine] [X86] pblendvb intrinsics must be replaced by select when possible (PR #137322)

via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 25 05:44:20 PDT 2025


https://github.com/vortex73 updated https://github.com/llvm/llvm-project/pull/137322

>From d5d68a1fc82b081189059eb80958ba0e4c5de8e9 Mon Sep 17 00:00:00 2001
From: Narayan Sreekumar <nsreekumar6 at gmail.com>
Date: Fri, 25 Apr 2025 18:01:18 +0530
Subject: [PATCH 1/2] [InstCombine] Pre-Commit Tests

---
 llvm/test/Transforms/InstCombine/pblend.ll | 63 ++++++++++++++++++++++
 1 file changed, 63 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/pblend.ll

diff --git a/llvm/test/Transforms/InstCombine/pblend.ll b/llvm/test/Transforms/InstCombine/pblend.ll
new file mode 100644
index 0000000000000..e4a6cb9a8c856
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/pblend.ll
@@ -0,0 +1,63 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes=instcombine < %s | FileCheck %s
+
+define <2 x i64> @tricky(<2 x i64> noundef %a, <2 x i64> noundef %b, <2 x i64> noundef %c, <2 x i64> noundef %src) {
+; CHECK-LABEL: define <2 x i64> @tricky(
+; CHECK-SAME: <2 x i64> noundef [[A:%.*]], <2 x i64> noundef [[B:%.*]], <2 x i64> noundef [[C:%.*]], <2 x i64> noundef [[SRC:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = bitcast <2 x i64> [[A]] to <4 x i32>
+; CHECK-NEXT:    [[CMP_I:%.*]] = icmp sgt <4 x i32> [[TMP0]], zeroinitializer
+; CHECK-NEXT:    [[SEXT_I:%.*]] = sext <4 x i1> [[CMP_I]] to <4 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i32> [[SEXT_I]] to <2 x i64>
+; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <2 x i64> [[B]] to <4 x i32>
+; CHECK-NEXT:    [[CMP_I21:%.*]] = icmp sgt <4 x i32> [[TMP2]], zeroinitializer
+; CHECK-NEXT:    [[SEXT_I22:%.*]] = sext <4 x i1> [[CMP_I21]] to <4 x i32>
+; CHECK-NEXT:    [[TMP3:%.*]] = bitcast <4 x i32> [[SEXT_I22]] to <2 x i64>
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <2 x i64> [[C]] to <4 x i32>
+; CHECK-NEXT:    [[CMP_I23:%.*]] = icmp sgt <4 x i32> [[TMP4]], zeroinitializer
+; CHECK-NEXT:    [[SEXT_I24:%.*]] = sext <4 x i1> [[CMP_I23]] to <4 x i32>
+; CHECK-NEXT:    [[TMP5:%.*]] = bitcast <4 x i32> [[SEXT_I24]] to <2 x i64>
+; CHECK-NEXT:    [[AND_I:%.*]] = and <2 x i64> [[TMP3]], [[TMP1]]
+; CHECK-NEXT:    [[XOR_I:%.*]] = xor <2 x i64> [[AND_I]], [[TMP5]]
+; CHECK-NEXT:    [[AND_I25:%.*]] = and <2 x i64> [[XOR_I]], [[TMP1]]
+; CHECK-NEXT:    [[AND_I26:%.*]] = and <2 x i64> [[XOR_I]], [[TMP3]]
+; CHECK-NEXT:    [[AND_I27:%.*]] = and <2 x i64> [[AND_I]], [[SRC]]
+; CHECK-NEXT:    [[TMP6:%.*]] = bitcast <2 x i64> [[AND_I27]] to <16 x i8>
+; CHECK-NEXT:    [[TMP7:%.*]] = bitcast <2 x i64> [[A]] to <16 x i8>
+; CHECK-NEXT:    [[TMP8:%.*]] = bitcast <2 x i64> [[AND_I25]] to <16 x i8>
+; CHECK-NEXT:    [[TMP9:%.*]] = tail call <16 x i8> @llvm.x86.sse41.pblendvb(<16 x i8> [[TMP6]], <16 x i8> [[TMP7]], <16 x i8> [[TMP8]])
+; CHECK-NEXT:    [[TMP10:%.*]] = bitcast <2 x i64> [[B]] to <16 x i8>
+; CHECK-NEXT:    [[TMP11:%.*]] = bitcast <2 x i64> [[AND_I26]] to <16 x i8>
+; CHECK-NEXT:    [[TMP12:%.*]] = tail call <16 x i8> @llvm.x86.sse41.pblendvb(<16 x i8> [[TMP9]], <16 x i8> [[TMP10]], <16 x i8> [[TMP11]])
+; CHECK-NEXT:    [[TMP13:%.*]] = bitcast <16 x i8> [[TMP12]] to <2 x i64>
+; CHECK-NEXT:    ret <2 x i64> [[TMP13]]
+;
+entry:
+  %0 = bitcast <2 x i64> %a to <4 x i32>
+  %cmp.i = icmp sgt <4 x i32> %0, zeroinitializer
+  %sext.i = sext <4 x i1> %cmp.i to <4 x i32>
+  %1 = bitcast <4 x i32> %sext.i to <2 x i64>
+  %2 = bitcast <2 x i64> %b to <4 x i32>
+  %cmp.i21 = icmp sgt <4 x i32> %2, zeroinitializer
+  %sext.i22 = sext <4 x i1> %cmp.i21 to <4 x i32>
+  %3 = bitcast <4 x i32> %sext.i22 to <2 x i64>
+  %4 = bitcast <2 x i64> %c to <4 x i32>
+  %cmp.i23 = icmp sgt <4 x i32> %4, zeroinitializer
+  %sext.i24 = sext <4 x i1> %cmp.i23 to <4 x i32>
+  %5 = bitcast <4 x i32> %sext.i24 to <2 x i64>
+  %and.i = and <2 x i64> %3, %1
+  %xor.i = xor <2 x i64> %and.i, %5
+  %and.i25 = and <2 x i64> %xor.i, %1
+  %and.i26 = and <2 x i64> %xor.i, %3
+  %and.i27 = and <2 x i64> %and.i, %src
+  %6 = bitcast <2 x i64> %and.i27 to <16 x i8>
+  %7 = bitcast <2 x i64> %a to <16 x i8>
+  %8 = bitcast <2 x i64> %and.i25 to <16 x i8>
+  %9 = tail call <16 x i8> @llvm.x86.sse41.pblendvb(<16 x i8> %6, <16 x i8> %7, <16 x i8> %8)
+  %10 = bitcast <2 x i64> %b to <16 x i8>
+  %11 = bitcast <2 x i64> %and.i26 to <16 x i8>
+  %12 = tail call <16 x i8> @llvm.x86.sse41.pblendvb(<16 x i8> %9, <16 x i8> %10, <16 x i8> %11)
+  %13 = bitcast <16 x i8> %12 to <2 x i64>
+  ret <2 x i64> %13
+}
+declare <16 x i8> @llvm.x86.sse41.pblendvb(<16 x i8>, <16 x i8>, <16 x i8>)

>From d13355814f70ad3d1ec740f2c32dde73536d56fd Mon Sep 17 00:00:00 2001
From: Narayan Sreekumar <nsreekumar6 at gmail.com>
Date: Fri, 25 Apr 2025 18:13:59 +0530
Subject: [PATCH 2/2] [InstCombine] Enhance pblendvb to select conversion with
 complex boolean masks

---
 .../Target/X86/X86InstCombineIntrinsic.cpp    | 210 +++++++++++++++---
 1 file changed, 174 insertions(+), 36 deletions(-)

diff --git a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
index c4d349044fe80..0eb7c43f8be14 100644
--- a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
@@ -52,6 +52,124 @@ static Value *getBoolVecFromMask(Value *Mask, const DataLayout &DL) {
   return nullptr;
 }
 
+// Helper function to decompose complex logic on sign-extended i1 vectors
+static Value *tryDecomposeVectorLogicMask(Value *Mask, IRBuilderBase &Builder) {
+  // Look through bitcasts
+  Mask = InstCombiner::peekThroughBitcast(Mask);
+  
+  // Direct sign-extension case (should be caught by the main code path)
+  Value *InnerVal;
+  if (match(Mask, m_SExt(m_Value(InnerVal))) &&
+      InnerVal->getType()->isVectorTy() &&
+      InnerVal->getType()->getScalarType()->isIntegerTy(1))
+    return InnerVal;
+  
+  // Handle AND of sign-extended vectors: (sext A) & (sext B) -> sext(A & B)
+  Value *LHS, *RHS;
+  Value *LHSInner, *RHSInner;
+  if (match(Mask, m_And(m_Value(LHS), m_Value(RHS)))) {
+    LHS = InstCombiner::peekThroughBitcast(LHS);
+    RHS = InstCombiner::peekThroughBitcast(RHS);
+    
+    if (match(LHS, m_SExt(m_Value(LHSInner))) && 
+        LHSInner->getType()->isVectorTy() &&
+        LHSInner->getType()->getScalarType()->isIntegerTy(1) &&
+        match(RHS, m_SExt(m_Value(RHSInner))) &&
+        RHSInner->getType()->isVectorTy() &&
+        RHSInner->getType()->getScalarType()->isIntegerTy(1) &&
+        LHSInner->getType() == RHSInner->getType()) {
+      return Builder.CreateAnd(LHSInner, RHSInner);
+    }
+    
+    // Try recursively on each operand
+    Value *DecomposedLHS = tryDecomposeVectorLogicMask(LHS, Builder);
+    Value *DecomposedRHS = tryDecomposeVectorLogicMask(RHS, Builder);
+    if (DecomposedLHS && DecomposedRHS && 
+        DecomposedLHS->getType() == DecomposedRHS->getType())
+      return Builder.CreateAnd(DecomposedLHS, DecomposedRHS);
+  }
+  
+  // Handle XOR of sign-extended vectors: (sext A) ^ (sext B) -> sext(A ^ B)
+  if (match(Mask, m_Xor(m_Value(LHS), m_Value(RHS)))) {
+    LHS = InstCombiner::peekThroughBitcast(LHS);
+    RHS = InstCombiner::peekThroughBitcast(RHS);
+    
+    if (match(LHS, m_SExt(m_Value(LHSInner))) && 
+        LHSInner->getType()->isVectorTy() &&
+        LHSInner->getType()->getScalarType()->isIntegerTy(1) &&
+        match(RHS, m_SExt(m_Value(RHSInner))) &&
+        RHSInner->getType()->isVectorTy() &&
+        RHSInner->getType()->getScalarType()->isIntegerTy(1) &&
+        LHSInner->getType() == RHSInner->getType()) {
+      return Builder.CreateXor(LHSInner, RHSInner);
+    }
+    
+    // Try recursively on each operand
+    Value *DecomposedLHS = tryDecomposeVectorLogicMask(LHS, Builder);
+    Value *DecomposedRHS = tryDecomposeVectorLogicMask(RHS, Builder);
+    if (DecomposedLHS && DecomposedRHS && 
+        DecomposedLHS->getType() == DecomposedRHS->getType())
+      return Builder.CreateXor(DecomposedLHS, DecomposedRHS);
+  }
+  
+  // Handle OR of sign-extended vectors: (sext A) | (sext B) -> sext(A | B)
+  if (match(Mask, m_Or(m_Value(LHS), m_Value(RHS)))) {
+    LHS = InstCombiner::peekThroughBitcast(LHS);
+    RHS = InstCombiner::peekThroughBitcast(RHS);
+    
+    if (match(LHS, m_SExt(m_Value(LHSInner))) && 
+        LHSInner->getType()->isVectorTy() &&
+        LHSInner->getType()->getScalarType()->isIntegerTy(1) &&
+        match(RHS, m_SExt(m_Value(RHSInner))) &&
+        RHSInner->getType()->isVectorTy() &&
+        RHSInner->getType()->getScalarType()->isIntegerTy(1) &&
+        LHSInner->getType() == RHSInner->getType()) {
+      return Builder.CreateOr(LHSInner, RHSInner);
+    }
+    
+    // Try recursively on each operand
+    Value *DecomposedLHS = tryDecomposeVectorLogicMask(LHS, Builder);
+    Value *DecomposedRHS = tryDecomposeVectorLogicMask(RHS, Builder);
+    if (DecomposedLHS && DecomposedRHS && 
+        DecomposedLHS->getType() == DecomposedRHS->getType())
+      return Builder.CreateOr(DecomposedLHS, DecomposedRHS);
+  }
+  
+  // Handle AndNot: (sext A) & ~(sext B) -> sext(A & ~B)
+  Value *NotOp;
+  if (match(Mask, m_And(m_Value(LHS), 
+                        m_Not(m_Value(NotOp))))) {
+    LHS = InstCombiner::peekThroughBitcast(LHS);
+    NotOp = InstCombiner::peekThroughBitcast(NotOp);
+    
+    if (match(LHS, m_SExt(m_Value(LHSInner))) && 
+        LHSInner->getType()->isVectorTy() &&
+        LHSInner->getType()->getScalarType()->isIntegerTy(1) &&
+        match(NotOp, m_SExt(m_Value(RHSInner))) &&
+        RHSInner->getType()->isVectorTy() &&
+        RHSInner->getType()->getScalarType()->isIntegerTy(1) &&
+        LHSInner->getType() == RHSInner->getType()) {
+      Value *NotRHSInner = Builder.CreateNot(RHSInner);
+      return Builder.CreateAnd(LHSInner, NotRHSInner);
+    }
+    
+    // Try recursively on each operand
+    Value *DecomposedLHS = tryDecomposeVectorLogicMask(LHS, Builder);
+    Value *DecomposedNotOp = tryDecomposeVectorLogicMask(NotOp, Builder);
+    if (DecomposedLHS && DecomposedNotOp && 
+        DecomposedLHS->getType() == DecomposedNotOp->getType()) {
+      Value *NotRHS = Builder.CreateNot(DecomposedNotOp);
+      return Builder.CreateAnd(DecomposedLHS, NotRHS);
+    }
+  }
+  
+  // No matching pattern found
+  return nullptr;
+}
+
+
+
+
 // TODO: If the x86 backend knew how to convert a bool vector mask back to an
 // XMM register mask efficiently, we could transform all x86 masked intrinsics
 // to LLVM masked intrinsics and remove the x86 masked intrinsic defs.
@@ -2150,6 +2268,52 @@ static bool simplifyX86VPERMMask(Instruction *II, bool IsBinary,
   return IC.SimplifyDemandedBits(II, /*OpNo=*/1, DemandedMask, KnownMask);
 }
 
+
+static Instruction *createMaskSelect(InstCombiner &IC, CallInst &II,
+                                    Value *BoolVec, Value *Op0, Value *Op1,
+                                    Value *MaskSrc = nullptr,
+                                    ArrayRef<int> ShuffleMask = std::nullopt) {
+  auto *MaskTy = cast<FixedVectorType>(II.getArgOperand(2)->getType());
+  auto *OpTy = cast<FixedVectorType>(II.getType());
+  unsigned NumMaskElts = MaskTy->getNumElements();
+  unsigned NumOperandElts = OpTy->getNumElements();
+  
+  // If we peeked through a shuffle, reapply the shuffle to the bool vector.
+  if (MaskSrc) {
+    unsigned NumMaskSrcElts =
+        cast<FixedVectorType>(MaskSrc->getType())->getNumElements();
+    NumMaskElts = (ShuffleMask.size() * NumMaskElts) / NumMaskSrcElts;
+    // Multiple mask bits maps to the same operand element - bail out.
+    if (NumMaskElts > NumOperandElts)
+      return nullptr;
+    SmallVector<int> ScaledMask;
+    if (!llvm::scaleShuffleMaskElts(NumMaskElts, ShuffleMask, ScaledMask))
+      return nullptr;
+    BoolVec = IC.Builder.CreateShuffleVector(BoolVec, ScaledMask);
+    MaskTy = FixedVectorType::get(MaskTy->getElementType(), NumMaskElts);
+  }
+  
+  assert(MaskTy->getPrimitiveSizeInBits() ==
+             OpTy->getPrimitiveSizeInBits() &&
+         "Not expecting mask and operands with different sizes");
+  
+  if (NumMaskElts == NumOperandElts) {
+    return SelectInst::Create(BoolVec, Op1, Op0);
+  }
+  
+  // If the mask has less elements than the operands, each mask bit maps to
+  // multiple elements of the operands. Bitcast back and forth.
+  if (NumMaskElts < NumOperandElts) {
+    Value *CastOp0 = IC.Builder.CreateBitCast(Op0, MaskTy);
+    Value *CastOp1 = IC.Builder.CreateBitCast(Op1, MaskTy);
+    Value *Sel = IC.Builder.CreateSelect(BoolVec, CastOp1, CastOp0);
+    return new BitCastInst(Sel, II.getType());
+  }
+  
+  return nullptr;
+}
+
+
 std::optional<Instruction *>
 X86TTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
   auto SimplifyDemandedVectorEltsLow = [&IC](Value *Op, unsigned Width,
@@ -2914,42 +3078,16 @@ X86TTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
     if (match(Mask, m_SExt(m_Value(BoolVec))) &&
         BoolVec->getType()->isVectorTy() &&
         BoolVec->getType()->getScalarSizeInBits() == 1) {
-      auto *MaskTy = cast<FixedVectorType>(Mask->getType());
-      auto *OpTy = cast<FixedVectorType>(II.getType());
-      unsigned NumMaskElts = MaskTy->getNumElements();
-      unsigned NumOperandElts = OpTy->getNumElements();
-
-      // If we peeked through a shuffle, reapply the shuffle to the bool vector.
-      if (MaskSrc) {
-        unsigned NumMaskSrcElts =
-            cast<FixedVectorType>(MaskSrc->getType())->getNumElements();
-        NumMaskElts = (ShuffleMask.size() * NumMaskElts) / NumMaskSrcElts;
-        // Multiple mask bits maps to the same operand element - bail out.
-        if (NumMaskElts > NumOperandElts)
-          break;
-        SmallVector<int> ScaledMask;
-        if (!llvm::scaleShuffleMaskElts(NumMaskElts, ShuffleMask, ScaledMask))
-          break;
-        BoolVec = IC.Builder.CreateShuffleVector(BoolVec, ScaledMask);
-        MaskTy = FixedVectorType::get(MaskTy->getElementType(), NumMaskElts);
-      }
-      assert(MaskTy->getPrimitiveSizeInBits() ==
-                 OpTy->getPrimitiveSizeInBits() &&
-             "Not expecting mask and operands with different sizes");
-
-      if (NumMaskElts == NumOperandElts) {
-        return SelectInst::Create(BoolVec, Op1, Op0);
-      }
-
-      // If the mask has less elements than the operands, each mask bit maps to
-      // multiple elements of the operands. Bitcast back and forth.
-      if (NumMaskElts < NumOperandElts) {
-        Value *CastOp0 = IC.Builder.CreateBitCast(Op0, MaskTy);
-        Value *CastOp1 = IC.Builder.CreateBitCast(Op1, MaskTy);
-        Value *Sel = IC.Builder.CreateSelect(BoolVec, CastOp1, CastOp0);
-        return new BitCastInst(Sel, II.getType());
-      }
-    }
+		Instruction *Select = createMaskSelect(IC, II, BoolVec, Op0, Op1, MaskSrc, ShuffleMask);
+		if (Select) return Select;
+	} else {
+		BoolVec = tryDecomposeVectorLogicMask(Mask,IC.Builder);
+		if (BoolVec) {
+			Instruction *Select = createMaskSelect(IC, II, BoolVec, Op0, Op1, MaskSrc, ShuffleMask);
+			if (Select)
+				return Select;
+		}
+	}
 
     break;
   }



More information about the llvm-commits mailing list