[llvm] [X86] SimplifyDemandedBitsForTargetNode: add X86ISD::BZHI handling (PR #177881)
Saina Daneshmand via llvm-commits
llvm-commits at lists.llvm.org
Sun Jan 25 13:34:01 PST 2026
https://github.com/SainaDaneshmandjahromi created https://github.com/llvm/llvm-project/pull/177881
This patch adds SimplifyDemandedBitsForTargetNode support for X86ISD::BZHI.
For BZHI, only the low 8 bits of the mask operand are semantically relevant.
This change teaches SimplifyDemandedBitsForTargetNode to:
* Limit demanded bits of the mask operand to mask[7:0]
* Use KnownBits information to compute an upper bound on mask[7:0]
* When the maximum possible value of mask[7:0] is less than the operand
bitwidth, treat result bits at and above that bound as always zero and
mark the corresponding source bits as not demanded
Tests are added to combine-bzhi.ll.
Closes #177369
>From 2d9109491c863c1c0554ebedf3b99155cc926f0a Mon Sep 17 00:00:00 2001
From: SainaDaneshmandjahromi <daneshmand.saina at gmail.com>
Date: Sun, 25 Jan 2026 14:18:55 -0700
Subject: [PATCH 1/2] [X86] Improve BZHI SimplifyDemandedBits
---
llvm/lib/Target/X86/X86ISelLowering.cpp | 58 +++++++++++++++++++++++++
llvm/test/CodeGen/X86/combine-bzhi.ll | 47 ++++++++++++++++++++
2 files changed, 105 insertions(+)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index e54f4ed2fb26c..ca33b70f1a61b 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -45440,6 +45440,64 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
break;
}
+case X86ISD::BZHI: {
+ SDValue Op0 = Op.getOperand(0); // src
+ SDValue Op1 = Op.getOperand(1); // mask
+
+ // Rule 1: Only the bottom 8 bits of the mask are required.
+ // Track an upper bound on mask[7:0] so we can apply Rule 2.
+ uint64_t MaxMask8 = 255;
+
+ if (auto *Cst1 = dyn_cast<ConstantSDNode>(Op1)) {
+ // NOTE: SimplifyDemandedBits won't do this for constants.
+ uint64_t Val1 = Cst1->getZExtValue();
+ uint64_t MaskedVal1 = Val1 & 0xFF;
+
+ if (MaskedVal1 != Val1) {
+ SDLoc DL(Op);
+ EVT MaskVT = Op1.getValueType();
+ SDValue NewMask = TLO.DAG.getConstant(MaskedVal1, DL, MaskVT);
+ return TLO.CombineTo(Op, TLO.DAG.getNode(X86ISD::BZHI, DL, VT, Op0, NewMask));
+ }
+
+ MaxMask8 = MaskedVal1;
+ } else {
+ unsigned MaskBW = Op1.getValueType().getSizeInBits();
+ APInt MaskDemand = APInt::getLowBitsSet(MaskBW, 8);
+
+ KnownBits Known1;
+ if (SimplifyDemandedBits(Op1, MaskDemand, Known1, TLO, Depth + 1))
+ return true;
+
+ // Compute an upper bound on mask[7:0].
+ KnownBits MaskBits = Known1.extractBits(8, 0);
+ MaxMask8 = MaskBits.getMaxValue().getZExtValue();
+ }
+
+ // Rule 2: If mask[7:0] is known to be < BitWidth, then bits at/above
+ // getMaxValue are always zero and thus not demanded; likewise src bits.
+ APInt SrcDemanded = OriginalDemandedBits;
+ if (MaxMask8 < BitWidth) {
+ unsigned Cut = (unsigned)MaxMask8;
+ SrcDemanded.clearBits(Cut, BitWidth);
+ }
+
+
+ KnownBits KnownSrc;
+ if (SimplifyDemandedBits(Op0, SrcDemanded, KnownSrc, TLO, Depth + 1))
+ return true;
+
+
+ Known.One.clearAllBits();
+ Known.Zero.clearAllBits();
+ if (MaxMask8 < BitWidth) {
+ unsigned Cut = (unsigned)MaxMask8;
+ Known.Zero.setBits(Cut, BitWidth);
+ }
+
+ break;
+ }
+
case X86ISD::PDEP: {
SDValue Op0 = Op.getOperand(0);
SDValue Op1 = Op.getOperand(1);
diff --git a/llvm/test/CodeGen/X86/combine-bzhi.ll b/llvm/test/CodeGen/X86/combine-bzhi.ll
index 54e76469dbb82..281c513c72c0c 100644
--- a/llvm/test/CodeGen/X86/combine-bzhi.ll
+++ b/llvm/test/CodeGen/X86/combine-bzhi.ll
@@ -40,3 +40,50 @@ define i64 @test_bzhi64_constfold() nounwind readnone {
ret i64 %1
}
+define i32 @test_bzhi32_mask_const_highbits(i32 %a) nounwind {
+; CHECK-LABEL: test_bzhi32_mask_const_highbits:
+; CHECK-NOT: $257
+; CHECK: # %bb.0:
+; CHECK-NEXT: movl $1, %eax
+; CHECK-NEXT: bzhil
+; CHECK-NEXT: retq
+ %1 = tail call i32 @llvm.x86.bmi.bzhi.32(i32 %a, i32 257)
+ ret i32 %1
+}
+
+define i64 @test_bzhi64_mask_const_highbits(i64 %a) nounwind {
+; CHECK-LABEL: test_bzhi64_mask_const_highbits:
+; CHECK-NOT: $257
+; CHECK: # %bb.0:
+; CHECK-NEXT: movl $1, %eax
+; CHECK-NEXT: bzhiq
+; CHECK-NEXT: retq
+ %1 = tail call i64 @llvm.x86.bmi.bzhi.64(i64 %a, i64 257)
+ ret i64 %1
+}
+
+define i32 @test_bzhi32_rule2_mask_max31_kills_topbit(i32 %a, i32 %m) nounwind {
+; CHECK-LABEL: test_bzhi32_rule2_mask_max31_kills_topbit:
+; CHECK: # %bb.0:
+; CHECK-NEXT: andl $31, %esi
+; CHECK-NEXT: bzhil
+; CHECK-NEXT: retq
+ %mask = and i32 %m, 31
+ %hi = shl i32 %a, 31
+ %src = or i32 %a, %hi
+ %r = tail call i32 @llvm.x86.bmi.bzhi.32(i32 %src, i32 %mask)
+ ret i32 %r
+}
+
+define i64 @test_bzhi64_rule2_mask_max63_kills_topbit(i64 %a, i64 %m) nounwind {
+; CHECK-LABEL: test_bzhi64_rule2_mask_max63_kills_topbit:
+; CHECK: # %bb.0:
+; CHECK-NEXT: andl $63, %esi
+; CHECK-NEXT: bzhiq
+; CHECK-NEXT: retq
+ %mask = and i64 %m, 63
+ %hi = shl i64 %a, 63
+ %src = or i64 %a, %hi
+ %r = tail call i64 @llvm.x86.bmi.bzhi.64(i64 %src, i64 %mask)
+ ret i64 %r
+}
>From 8d35e962a55b2047323f18eb2498cd2a9aec6a49 Mon Sep 17 00:00:00 2001
From: SainaDaneshmandjahromi <daneshmand.saina at gmail.com>
Date: Sun, 25 Jan 2026 14:19:46 -0700
Subject: [PATCH 2/2] [X86] clang-format
---
llvm/lib/Target/X86/X86ISelLowering.cpp | 89 ++++++++++++-------------
1 file changed, 44 insertions(+), 45 deletions(-)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index ca33b70f1a61b..7d7158615a3ff 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -45440,62 +45440,61 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
break;
}
-case X86ISD::BZHI: {
- SDValue Op0 = Op.getOperand(0); // src
- SDValue Op1 = Op.getOperand(1); // mask
+ case X86ISD::BZHI: {
+ SDValue Op0 = Op.getOperand(0); // src
+ SDValue Op1 = Op.getOperand(1); // mask
- // Rule 1: Only the bottom 8 bits of the mask are required.
- // Track an upper bound on mask[7:0] so we can apply Rule 2.
- uint64_t MaxMask8 = 255;
+ // Rule 1: Only the bottom 8 bits of the mask are required.
+ // Track an upper bound on mask[7:0] so we can apply Rule 2.
+ uint64_t MaxMask8 = 255;
- if (auto *Cst1 = dyn_cast<ConstantSDNode>(Op1)) {
- // NOTE: SimplifyDemandedBits won't do this for constants.
- uint64_t Val1 = Cst1->getZExtValue();
- uint64_t MaskedVal1 = Val1 & 0xFF;
+ if (auto *Cst1 = dyn_cast<ConstantSDNode>(Op1)) {
+ // NOTE: SimplifyDemandedBits won't do this for constants.
+ uint64_t Val1 = Cst1->getZExtValue();
+ uint64_t MaskedVal1 = Val1 & 0xFF;
- if (MaskedVal1 != Val1) {
- SDLoc DL(Op);
- EVT MaskVT = Op1.getValueType();
- SDValue NewMask = TLO.DAG.getConstant(MaskedVal1, DL, MaskVT);
- return TLO.CombineTo(Op, TLO.DAG.getNode(X86ISD::BZHI, DL, VT, Op0, NewMask));
- }
+ if (MaskedVal1 != Val1) {
+ SDLoc DL(Op);
+ EVT MaskVT = Op1.getValueType();
+ SDValue NewMask = TLO.DAG.getConstant(MaskedVal1, DL, MaskVT);
+ return TLO.CombineTo(
+ Op, TLO.DAG.getNode(X86ISD::BZHI, DL, VT, Op0, NewMask));
+ }
- MaxMask8 = MaskedVal1;
- } else {
- unsigned MaskBW = Op1.getValueType().getSizeInBits();
- APInt MaskDemand = APInt::getLowBitsSet(MaskBW, 8);
+ MaxMask8 = MaskedVal1;
+ } else {
+ unsigned MaskBW = Op1.getValueType().getSizeInBits();
+ APInt MaskDemand = APInt::getLowBitsSet(MaskBW, 8);
- KnownBits Known1;
- if (SimplifyDemandedBits(Op1, MaskDemand, Known1, TLO, Depth + 1))
- return true;
+ KnownBits Known1;
+ if (SimplifyDemandedBits(Op1, MaskDemand, Known1, TLO, Depth + 1))
+ return true;
- // Compute an upper bound on mask[7:0].
- KnownBits MaskBits = Known1.extractBits(8, 0);
- MaxMask8 = MaskBits.getMaxValue().getZExtValue();
- }
+ // Compute an upper bound on mask[7:0].
+ KnownBits MaskBits = Known1.extractBits(8, 0);
+ MaxMask8 = MaskBits.getMaxValue().getZExtValue();
+ }
- // Rule 2: If mask[7:0] is known to be < BitWidth, then bits at/above
- // getMaxValue are always zero and thus not demanded; likewise src bits.
- APInt SrcDemanded = OriginalDemandedBits;
- if (MaxMask8 < BitWidth) {
- unsigned Cut = (unsigned)MaxMask8;
- SrcDemanded.clearBits(Cut, BitWidth);
- }
+ // Rule 2: If mask[7:0] is known to be < BitWidth, then bits at/above
+ // getMaxValue are always zero and thus not demanded; likewise src bits.
+ APInt SrcDemanded = OriginalDemandedBits;
+ if (MaxMask8 < BitWidth) {
+ unsigned Cut = (unsigned)MaxMask8;
+ SrcDemanded.clearBits(Cut, BitWidth);
+ }
-
- KnownBits KnownSrc;
- if (SimplifyDemandedBits(Op0, SrcDemanded, KnownSrc, TLO, Depth + 1))
- return true;
+ KnownBits KnownSrc;
+ if (SimplifyDemandedBits(Op0, SrcDemanded, KnownSrc, TLO, Depth + 1))
+ return true;
-
- Known.One.clearAllBits();
- Known.Zero.clearAllBits();
- if (MaxMask8 < BitWidth) {
- unsigned Cut = (unsigned)MaxMask8;
- Known.Zero.setBits(Cut, BitWidth);
+ Known.One.clearAllBits();
+ Known.Zero.clearAllBits();
+ if (MaxMask8 < BitWidth) {
+ unsigned Cut = (unsigned)MaxMask8;
+ Known.Zero.setBits(Cut, BitWidth);
}
- break;
+ break;
}
case X86ISD::PDEP: {
More information about the llvm-commits
mailing list