[llvm] [InstCombine] [X86] pblendvb intrinsics must be replaced by select when possible (PR #137322)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Apr 25 05:44:20 PDT 2025
https://github.com/vortex73 updated https://github.com/llvm/llvm-project/pull/137322
>From d5d68a1fc82b081189059eb80958ba0e4c5de8e9 Mon Sep 17 00:00:00 2001
From: Narayan Sreekumar <nsreekumar6 at gmail.com>
Date: Fri, 25 Apr 2025 18:01:18 +0530
Subject: [PATCH 1/2] [InstCombine] Pre-Commit Tests
---
llvm/test/Transforms/InstCombine/pblend.ll | 63 ++++++++++++++++++++++
1 file changed, 63 insertions(+)
create mode 100644 llvm/test/Transforms/InstCombine/pblend.ll
diff --git a/llvm/test/Transforms/InstCombine/pblend.ll b/llvm/test/Transforms/InstCombine/pblend.ll
new file mode 100644
index 0000000000000..e4a6cb9a8c856
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/pblend.ll
@@ -0,0 +1,63 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes=instcombine < %s | FileCheck %s
+
+define <2 x i64> @tricky(<2 x i64> noundef %a, <2 x i64> noundef %b, <2 x i64> noundef %c, <2 x i64> noundef %src) {
+; CHECK-LABEL: define <2 x i64> @tricky(
+; CHECK-SAME: <2 x i64> noundef [[A:%.*]], <2 x i64> noundef [[B:%.*]], <2 x i64> noundef [[C:%.*]], <2 x i64> noundef [[SRC:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = bitcast <2 x i64> [[A]] to <4 x i32>
+; CHECK-NEXT: [[CMP_I:%.*]] = icmp sgt <4 x i32> [[TMP0]], zeroinitializer
+; CHECK-NEXT: [[SEXT_I:%.*]] = sext <4 x i1> [[CMP_I]] to <4 x i32>
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i32> [[SEXT_I]] to <2 x i64>
+; CHECK-NEXT: [[TMP2:%.*]] = bitcast <2 x i64> [[B]] to <4 x i32>
+; CHECK-NEXT: [[CMP_I21:%.*]] = icmp sgt <4 x i32> [[TMP2]], zeroinitializer
+; CHECK-NEXT: [[SEXT_I22:%.*]] = sext <4 x i1> [[CMP_I21]] to <4 x i32>
+; CHECK-NEXT: [[TMP3:%.*]] = bitcast <4 x i32> [[SEXT_I22]] to <2 x i64>
+; CHECK-NEXT: [[TMP4:%.*]] = bitcast <2 x i64> [[C]] to <4 x i32>
+; CHECK-NEXT: [[CMP_I23:%.*]] = icmp sgt <4 x i32> [[TMP4]], zeroinitializer
+; CHECK-NEXT: [[SEXT_I24:%.*]] = sext <4 x i1> [[CMP_I23]] to <4 x i32>
+; CHECK-NEXT: [[TMP5:%.*]] = bitcast <4 x i32> [[SEXT_I24]] to <2 x i64>
+; CHECK-NEXT: [[AND_I:%.*]] = and <2 x i64> [[TMP3]], [[TMP1]]
+; CHECK-NEXT: [[XOR_I:%.*]] = xor <2 x i64> [[AND_I]], [[TMP5]]
+; CHECK-NEXT: [[AND_I25:%.*]] = and <2 x i64> [[XOR_I]], [[TMP1]]
+; CHECK-NEXT: [[AND_I26:%.*]] = and <2 x i64> [[XOR_I]], [[TMP3]]
+; CHECK-NEXT: [[AND_I27:%.*]] = and <2 x i64> [[AND_I]], [[SRC]]
+; CHECK-NEXT: [[TMP6:%.*]] = bitcast <2 x i64> [[AND_I27]] to <16 x i8>
+; CHECK-NEXT: [[TMP7:%.*]] = bitcast <2 x i64> [[A]] to <16 x i8>
+; CHECK-NEXT: [[TMP8:%.*]] = bitcast <2 x i64> [[AND_I25]] to <16 x i8>
+; CHECK-NEXT: [[TMP9:%.*]] = tail call <16 x i8> @llvm.x86.sse41.pblendvb(<16 x i8> [[TMP6]], <16 x i8> [[TMP7]], <16 x i8> [[TMP8]])
+; CHECK-NEXT: [[TMP10:%.*]] = bitcast <2 x i64> [[B]] to <16 x i8>
+; CHECK-NEXT: [[TMP11:%.*]] = bitcast <2 x i64> [[AND_I26]] to <16 x i8>
+; CHECK-NEXT: [[TMP12:%.*]] = tail call <16 x i8> @llvm.x86.sse41.pblendvb(<16 x i8> [[TMP9]], <16 x i8> [[TMP10]], <16 x i8> [[TMP11]])
+; CHECK-NEXT: [[TMP13:%.*]] = bitcast <16 x i8> [[TMP12]] to <2 x i64>
+; CHECK-NEXT: ret <2 x i64> [[TMP13]]
+;
+entry:
+ %0 = bitcast <2 x i64> %a to <4 x i32>
+ %cmp.i = icmp sgt <4 x i32> %0, zeroinitializer
+ %sext.i = sext <4 x i1> %cmp.i to <4 x i32>
+ %1 = bitcast <4 x i32> %sext.i to <2 x i64>
+ %2 = bitcast <2 x i64> %b to <4 x i32>
+ %cmp.i21 = icmp sgt <4 x i32> %2, zeroinitializer
+ %sext.i22 = sext <4 x i1> %cmp.i21 to <4 x i32>
+ %3 = bitcast <4 x i32> %sext.i22 to <2 x i64>
+ %4 = bitcast <2 x i64> %c to <4 x i32>
+ %cmp.i23 = icmp sgt <4 x i32> %4, zeroinitializer
+ %sext.i24 = sext <4 x i1> %cmp.i23 to <4 x i32>
+ %5 = bitcast <4 x i32> %sext.i24 to <2 x i64>
+ %and.i = and <2 x i64> %3, %1
+ %xor.i = xor <2 x i64> %and.i, %5
+ %and.i25 = and <2 x i64> %xor.i, %1
+ %and.i26 = and <2 x i64> %xor.i, %3
+ %and.i27 = and <2 x i64> %and.i, %src
+ %6 = bitcast <2 x i64> %and.i27 to <16 x i8>
+ %7 = bitcast <2 x i64> %a to <16 x i8>
+ %8 = bitcast <2 x i64> %and.i25 to <16 x i8>
+ %9 = tail call <16 x i8> @llvm.x86.sse41.pblendvb(<16 x i8> %6, <16 x i8> %7, <16 x i8> %8)
+ %10 = bitcast <2 x i64> %b to <16 x i8>
+ %11 = bitcast <2 x i64> %and.i26 to <16 x i8>
+ %12 = tail call <16 x i8> @llvm.x86.sse41.pblendvb(<16 x i8> %9, <16 x i8> %10, <16 x i8> %11)
+ %13 = bitcast <16 x i8> %12 to <2 x i64>
+ ret <2 x i64> %13
+}
+declare <16 x i8> @llvm.x86.sse41.pblendvb(<16 x i8>, <16 x i8>, <16 x i8>)
>From d13355814f70ad3d1ec740f2c32dde73536d56fd Mon Sep 17 00:00:00 2001
From: Narayan Sreekumar <nsreekumar6 at gmail.com>
Date: Fri, 25 Apr 2025 18:13:59 +0530
Subject: [PATCH 2/2] [InstCombine] Enhance pblendvb to select conversion with
complex boolean masks
---
.../Target/X86/X86InstCombineIntrinsic.cpp | 210 +++++++++++++++---
1 file changed, 174 insertions(+), 36 deletions(-)
diff --git a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
index c4d349044fe80..0eb7c43f8be14 100644
--- a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
@@ -52,6 +52,124 @@ static Value *getBoolVecFromMask(Value *Mask, const DataLayout &DL) {
return nullptr;
}
+// Helper function to decompose complex logic on sign-extended i1 vectors
+static Value *tryDecomposeVectorLogicMask(Value *Mask, IRBuilderBase &Builder) {
+ // Look through bitcasts
+ Mask = InstCombiner::peekThroughBitcast(Mask);
+
+ // Direct sign-extension case (should be caught by the main code path)
+ Value *InnerVal;
+ if (match(Mask, m_SExt(m_Value(InnerVal))) &&
+ InnerVal->getType()->isVectorTy() &&
+ InnerVal->getType()->getScalarType()->isIntegerTy(1))
+ return InnerVal;
+
+ // Handle AND of sign-extended vectors: (sext A) & (sext B) -> sext(A & B)
+ Value *LHS, *RHS;
+ Value *LHSInner, *RHSInner;
+ if (match(Mask, m_And(m_Value(LHS), m_Value(RHS)))) {
+ LHS = InstCombiner::peekThroughBitcast(LHS);
+ RHS = InstCombiner::peekThroughBitcast(RHS);
+
+ if (match(LHS, m_SExt(m_Value(LHSInner))) &&
+ LHSInner->getType()->isVectorTy() &&
+ LHSInner->getType()->getScalarType()->isIntegerTy(1) &&
+ match(RHS, m_SExt(m_Value(RHSInner))) &&
+ RHSInner->getType()->isVectorTy() &&
+ RHSInner->getType()->getScalarType()->isIntegerTy(1) &&
+ LHSInner->getType() == RHSInner->getType()) {
+ return Builder.CreateAnd(LHSInner, RHSInner);
+ }
+
+ // Try recursively on each operand
+ Value *DecomposedLHS = tryDecomposeVectorLogicMask(LHS, Builder);
+ Value *DecomposedRHS = tryDecomposeVectorLogicMask(RHS, Builder);
+ if (DecomposedLHS && DecomposedRHS &&
+ DecomposedLHS->getType() == DecomposedRHS->getType())
+ return Builder.CreateAnd(DecomposedLHS, DecomposedRHS);
+ }
+
+ // Handle XOR of sign-extended vectors: (sext A) ^ (sext B) -> sext(A ^ B)
+ if (match(Mask, m_Xor(m_Value(LHS), m_Value(RHS)))) {
+ LHS = InstCombiner::peekThroughBitcast(LHS);
+ RHS = InstCombiner::peekThroughBitcast(RHS);
+
+ if (match(LHS, m_SExt(m_Value(LHSInner))) &&
+ LHSInner->getType()->isVectorTy() &&
+ LHSInner->getType()->getScalarType()->isIntegerTy(1) &&
+ match(RHS, m_SExt(m_Value(RHSInner))) &&
+ RHSInner->getType()->isVectorTy() &&
+ RHSInner->getType()->getScalarType()->isIntegerTy(1) &&
+ LHSInner->getType() == RHSInner->getType()) {
+ return Builder.CreateXor(LHSInner, RHSInner);
+ }
+
+ // Try recursively on each operand
+ Value *DecomposedLHS = tryDecomposeVectorLogicMask(LHS, Builder);
+ Value *DecomposedRHS = tryDecomposeVectorLogicMask(RHS, Builder);
+ if (DecomposedLHS && DecomposedRHS &&
+ DecomposedLHS->getType() == DecomposedRHS->getType())
+ return Builder.CreateXor(DecomposedLHS, DecomposedRHS);
+ }
+
+ // Handle OR of sign-extended vectors: (sext A) | (sext B) -> sext(A | B)
+ if (match(Mask, m_Or(m_Value(LHS), m_Value(RHS)))) {
+ LHS = InstCombiner::peekThroughBitcast(LHS);
+ RHS = InstCombiner::peekThroughBitcast(RHS);
+
+ if (match(LHS, m_SExt(m_Value(LHSInner))) &&
+ LHSInner->getType()->isVectorTy() &&
+ LHSInner->getType()->getScalarType()->isIntegerTy(1) &&
+ match(RHS, m_SExt(m_Value(RHSInner))) &&
+ RHSInner->getType()->isVectorTy() &&
+ RHSInner->getType()->getScalarType()->isIntegerTy(1) &&
+ LHSInner->getType() == RHSInner->getType()) {
+ return Builder.CreateOr(LHSInner, RHSInner);
+ }
+
+ // Try recursively on each operand
+ Value *DecomposedLHS = tryDecomposeVectorLogicMask(LHS, Builder);
+ Value *DecomposedRHS = tryDecomposeVectorLogicMask(RHS, Builder);
+ if (DecomposedLHS && DecomposedRHS &&
+ DecomposedLHS->getType() == DecomposedRHS->getType())
+ return Builder.CreateOr(DecomposedLHS, DecomposedRHS);
+ }
+
+ // Handle AndNot: (sext A) & ~(sext B) -> sext(A & ~B)
+ Value *NotOp;
+ if (match(Mask, m_And(m_Value(LHS),
+ m_Not(m_Value(NotOp))))) {
+ LHS = InstCombiner::peekThroughBitcast(LHS);
+ NotOp = InstCombiner::peekThroughBitcast(NotOp);
+
+ if (match(LHS, m_SExt(m_Value(LHSInner))) &&
+ LHSInner->getType()->isVectorTy() &&
+ LHSInner->getType()->getScalarType()->isIntegerTy(1) &&
+ match(NotOp, m_SExt(m_Value(RHSInner))) &&
+ RHSInner->getType()->isVectorTy() &&
+ RHSInner->getType()->getScalarType()->isIntegerTy(1) &&
+ LHSInner->getType() == RHSInner->getType()) {
+ Value *NotRHSInner = Builder.CreateNot(RHSInner);
+ return Builder.CreateAnd(LHSInner, NotRHSInner);
+ }
+
+ // Try recursively on each operand
+ Value *DecomposedLHS = tryDecomposeVectorLogicMask(LHS, Builder);
+ Value *DecomposedNotOp = tryDecomposeVectorLogicMask(NotOp, Builder);
+ if (DecomposedLHS && DecomposedNotOp &&
+ DecomposedLHS->getType() == DecomposedNotOp->getType()) {
+ Value *NotRHS = Builder.CreateNot(DecomposedNotOp);
+ return Builder.CreateAnd(DecomposedLHS, NotRHS);
+ }
+ }
+
+ // No matching pattern found
+ return nullptr;
+}
+
+
+
+
// TODO: If the x86 backend knew how to convert a bool vector mask back to an
// XMM register mask efficiently, we could transform all x86 masked intrinsics
// to LLVM masked intrinsics and remove the x86 masked intrinsic defs.
@@ -2150,6 +2268,52 @@ static bool simplifyX86VPERMMask(Instruction *II, bool IsBinary,
return IC.SimplifyDemandedBits(II, /*OpNo=*/1, DemandedMask, KnownMask);
}
+
+static Instruction *createMaskSelect(InstCombiner &IC, CallInst &II,
+ Value *BoolVec, Value *Op0, Value *Op1,
+ Value *MaskSrc = nullptr,
+ ArrayRef<int> ShuffleMask = std::nullopt) {
+ auto *MaskTy = cast<FixedVectorType>(II.getArgOperand(2)->getType());
+ auto *OpTy = cast<FixedVectorType>(II.getType());
+ unsigned NumMaskElts = MaskTy->getNumElements();
+ unsigned NumOperandElts = OpTy->getNumElements();
+
+ // If we peeked through a shuffle, reapply the shuffle to the bool vector.
+ if (MaskSrc) {
+ unsigned NumMaskSrcElts =
+ cast<FixedVectorType>(MaskSrc->getType())->getNumElements();
+ NumMaskElts = (ShuffleMask.size() * NumMaskElts) / NumMaskSrcElts;
+ // Multiple mask bits maps to the same operand element - bail out.
+ if (NumMaskElts > NumOperandElts)
+ return nullptr;
+ SmallVector<int> ScaledMask;
+ if (!llvm::scaleShuffleMaskElts(NumMaskElts, ShuffleMask, ScaledMask))
+ return nullptr;
+ BoolVec = IC.Builder.CreateShuffleVector(BoolVec, ScaledMask);
+ MaskTy = FixedVectorType::get(MaskTy->getElementType(), NumMaskElts);
+ }
+
+ assert(MaskTy->getPrimitiveSizeInBits() ==
+ OpTy->getPrimitiveSizeInBits() &&
+ "Not expecting mask and operands with different sizes");
+
+ if (NumMaskElts == NumOperandElts) {
+ return SelectInst::Create(BoolVec, Op1, Op0);
+ }
+
+ // If the mask has less elements than the operands, each mask bit maps to
+ // multiple elements of the operands. Bitcast back and forth.
+ if (NumMaskElts < NumOperandElts) {
+ Value *CastOp0 = IC.Builder.CreateBitCast(Op0, MaskTy);
+ Value *CastOp1 = IC.Builder.CreateBitCast(Op1, MaskTy);
+ Value *Sel = IC.Builder.CreateSelect(BoolVec, CastOp1, CastOp0);
+ return new BitCastInst(Sel, II.getType());
+ }
+
+ return nullptr;
+}
+
+
std::optional<Instruction *>
X86TTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
auto SimplifyDemandedVectorEltsLow = [&IC](Value *Op, unsigned Width,
@@ -2914,42 +3078,16 @@ X86TTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
if (match(Mask, m_SExt(m_Value(BoolVec))) &&
BoolVec->getType()->isVectorTy() &&
BoolVec->getType()->getScalarSizeInBits() == 1) {
- auto *MaskTy = cast<FixedVectorType>(Mask->getType());
- auto *OpTy = cast<FixedVectorType>(II.getType());
- unsigned NumMaskElts = MaskTy->getNumElements();
- unsigned NumOperandElts = OpTy->getNumElements();
-
- // If we peeked through a shuffle, reapply the shuffle to the bool vector.
- if (MaskSrc) {
- unsigned NumMaskSrcElts =
- cast<FixedVectorType>(MaskSrc->getType())->getNumElements();
- NumMaskElts = (ShuffleMask.size() * NumMaskElts) / NumMaskSrcElts;
- // Multiple mask bits maps to the same operand element - bail out.
- if (NumMaskElts > NumOperandElts)
- break;
- SmallVector<int> ScaledMask;
- if (!llvm::scaleShuffleMaskElts(NumMaskElts, ShuffleMask, ScaledMask))
- break;
- BoolVec = IC.Builder.CreateShuffleVector(BoolVec, ScaledMask);
- MaskTy = FixedVectorType::get(MaskTy->getElementType(), NumMaskElts);
- }
- assert(MaskTy->getPrimitiveSizeInBits() ==
- OpTy->getPrimitiveSizeInBits() &&
- "Not expecting mask and operands with different sizes");
-
- if (NumMaskElts == NumOperandElts) {
- return SelectInst::Create(BoolVec, Op1, Op0);
- }
-
- // If the mask has less elements than the operands, each mask bit maps to
- // multiple elements of the operands. Bitcast back and forth.
- if (NumMaskElts < NumOperandElts) {
- Value *CastOp0 = IC.Builder.CreateBitCast(Op0, MaskTy);
- Value *CastOp1 = IC.Builder.CreateBitCast(Op1, MaskTy);
- Value *Sel = IC.Builder.CreateSelect(BoolVec, CastOp1, CastOp0);
- return new BitCastInst(Sel, II.getType());
- }
- }
+ Instruction *Select = createMaskSelect(IC, II, BoolVec, Op0, Op1, MaskSrc, ShuffleMask);
+ if (Select) return Select;
+ } else {
+ BoolVec = tryDecomposeVectorLogicMask(Mask,IC.Builder);
+ if (BoolVec) {
+ Instruction *Select = createMaskSelect(IC, II, BoolVec, Op0, Op1, MaskSrc, ShuffleMask);
+ if (Select)
+ return Select;
+ }
+ }
break;
}
More information about the llvm-commits
mailing list