[llvm] [X86] SimplifyDemandedBitsForTargetNode: add X86ISD::BZHI handling (PR #177881)

Saina Daneshmand via llvm-commits llvm-commits at lists.llvm.org
Sun Jan 25 13:34:01 PST 2026


https://github.com/SainaDaneshmandjahromi created https://github.com/llvm/llvm-project/pull/177881

This patch adds SimplifyDemandedBitsForTargetNode support for X86ISD::BZHI.

For BZHI, only the low 8 bits of the mask operand are semantically relevant.
This change teaches SimplifyDemandedBitsForTargetNode to:

* Limit demanded bits of the mask operand to mask[7:0]
* Use KnownBits information to compute an upper bound on mask[7:0]
* When the maximum possible value of mask[7:0] is less than the operand
  bitwidth, treat result bits at and above that bound as always zero and
  mark the corresponding source bits as not demanded

Tests are added to combine-bzhi.ll.

Closes #177369


>From 2d9109491c863c1c0554ebedf3b99155cc926f0a Mon Sep 17 00:00:00 2001
From: SainaDaneshmandjahromi <daneshmand.saina at gmail.com>
Date: Sun, 25 Jan 2026 14:18:55 -0700
Subject: [PATCH 1/2] [X86] Improve BZHI SimplifyDemandedBits

---
 llvm/lib/Target/X86/X86ISelLowering.cpp | 58 +++++++++++++++++++++++++
 llvm/test/CodeGen/X86/combine-bzhi.ll   | 47 ++++++++++++++++++++
 2 files changed, 105 insertions(+)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index e54f4ed2fb26c..ca33b70f1a61b 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -45440,6 +45440,64 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
 
     break;
   }
+case X86ISD::BZHI: {
+  SDValue Op0 = Op.getOperand(0); // src
+  SDValue Op1 = Op.getOperand(1); // mask
+
+  // Rule 1: Only the bottom 8 bits of the mask are required.
+  // Track an upper bound on mask[7:0] so we can apply Rule 2.
+  uint64_t MaxMask8 = 255;
+
+  if (auto *Cst1 = dyn_cast<ConstantSDNode>(Op1)) {
+    // NOTE: SimplifyDemandedBits won't do this for constants.
+    uint64_t Val1 = Cst1->getZExtValue();
+    uint64_t MaskedVal1 = Val1 & 0xFF;
+
+    if (MaskedVal1 != Val1) {
+      SDLoc DL(Op);
+      EVT MaskVT = Op1.getValueType();
+      SDValue NewMask = TLO.DAG.getConstant(MaskedVal1, DL, MaskVT);
+      return TLO.CombineTo(Op, TLO.DAG.getNode(X86ISD::BZHI, DL, VT, Op0, NewMask));
+    }
+
+    MaxMask8 = MaskedVal1;
+  } else {
+    unsigned MaskBW = Op1.getValueType().getSizeInBits();
+    APInt MaskDemand = APInt::getLowBitsSet(MaskBW, 8);
+
+    KnownBits Known1;
+    if (SimplifyDemandedBits(Op1, MaskDemand, Known1, TLO, Depth + 1))
+      return true;
+
+    // Compute an upper bound on mask[7:0].
+    KnownBits MaskBits = Known1.extractBits(8, 0);
+    MaxMask8 = MaskBits.getMaxValue().getZExtValue();
+  }
+
+  // Rule 2: If mask[7:0] is known to be < BitWidth, then bits at/above
+  // getMaxValue are always zero and thus not demanded; likewise src bits.
+  APInt SrcDemanded = OriginalDemandedBits;
+  if (MaxMask8 < BitWidth) {
+    unsigned Cut = (unsigned)MaxMask8;
+    SrcDemanded.clearBits(Cut, BitWidth);
+  }
+
+  
+  KnownBits KnownSrc;
+  if (SimplifyDemandedBits(Op0, SrcDemanded, KnownSrc, TLO, Depth + 1))
+    return true;
+
+  
+  Known.One.clearAllBits();
+  Known.Zero.clearAllBits();
+  if (MaxMask8 < BitWidth) {
+    unsigned Cut = (unsigned)MaxMask8;
+    Known.Zero.setBits(Cut, BitWidth);
+    }
+
+  break;
+  }
+
   case X86ISD::PDEP: {
     SDValue Op0 = Op.getOperand(0);
     SDValue Op1 = Op.getOperand(1);
diff --git a/llvm/test/CodeGen/X86/combine-bzhi.ll b/llvm/test/CodeGen/X86/combine-bzhi.ll
index 54e76469dbb82..281c513c72c0c 100644
--- a/llvm/test/CodeGen/X86/combine-bzhi.ll
+++ b/llvm/test/CodeGen/X86/combine-bzhi.ll
@@ -40,3 +40,50 @@ define i64 @test_bzhi64_constfold() nounwind readnone {
   ret i64 %1
 }
 
+define i32 @test_bzhi32_mask_const_highbits(i32 %a) nounwind {
+; CHECK-LABEL: test_bzhi32_mask_const_highbits:
+; CHECK-NOT:   $257
+; CHECK:       # %bb.0:
+; CHECK-NEXT:  movl $1, %eax
+; CHECK-NEXT:  bzhil
+; CHECK-NEXT:  retq
+  %1 = tail call i32 @llvm.x86.bmi.bzhi.32(i32 %a, i32 257)
+  ret i32 %1
+}
+
+define i64 @test_bzhi64_mask_const_highbits(i64 %a) nounwind {
+; CHECK-LABEL: test_bzhi64_mask_const_highbits:
+; CHECK-NOT:   $257
+; CHECK:       # %bb.0:
+; CHECK-NEXT:  movl $1, %eax
+; CHECK-NEXT:  bzhiq
+; CHECK-NEXT:  retq
+  %1 = tail call i64 @llvm.x86.bmi.bzhi.64(i64 %a, i64 257)
+  ret i64 %1
+}
+
+define i32 @test_bzhi32_rule2_mask_max31_kills_topbit(i32 %a, i32 %m) nounwind {
+; CHECK-LABEL: test_bzhi32_rule2_mask_max31_kills_topbit:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:  andl $31, %esi
+; CHECK-NEXT:  bzhil
+; CHECK-NEXT:  retq
+  %mask = and i32 %m, 31
+  %hi  = shl i32 %a, 31
+  %src = or i32 %a, %hi
+  %r = tail call i32 @llvm.x86.bmi.bzhi.32(i32 %src, i32 %mask)
+  ret i32 %r
+}
+
+define i64 @test_bzhi64_rule2_mask_max63_kills_topbit(i64 %a, i64 %m) nounwind {
+; CHECK-LABEL: test_bzhi64_rule2_mask_max63_kills_topbit:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:  andl $63, %esi
+; CHECK-NEXT:  bzhiq
+; CHECK-NEXT:  retq
+  %mask = and i64 %m, 63
+  %hi  = shl i64 %a, 63
+  %src = or i64 %a, %hi
+  %r = tail call i64 @llvm.x86.bmi.bzhi.64(i64 %src, i64 %mask)
+  ret i64 %r
+}

>From 8d35e962a55b2047323f18eb2498cd2a9aec6a49 Mon Sep 17 00:00:00 2001
From: SainaDaneshmandjahromi <daneshmand.saina at gmail.com>
Date: Sun, 25 Jan 2026 14:19:46 -0700
Subject: [PATCH 2/2] [X86] clang-format

---
 llvm/lib/Target/X86/X86ISelLowering.cpp | 89 ++++++++++++-------------
 1 file changed, 44 insertions(+), 45 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index ca33b70f1a61b..7d7158615a3ff 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -45440,62 +45440,61 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
 
     break;
   }
-case X86ISD::BZHI: {
-  SDValue Op0 = Op.getOperand(0); // src
-  SDValue Op1 = Op.getOperand(1); // mask
+  case X86ISD::BZHI: {
+    SDValue Op0 = Op.getOperand(0); // src
+    SDValue Op1 = Op.getOperand(1); // mask
 
-  // Rule 1: Only the bottom 8 bits of the mask are required.
-  // Track an upper bound on mask[7:0] so we can apply Rule 2.
-  uint64_t MaxMask8 = 255;
+    // Rule 1: Only the bottom 8 bits of the mask are required.
+    // Track an upper bound on mask[7:0] so we can apply Rule 2.
+    uint64_t MaxMask8 = 255;
 
-  if (auto *Cst1 = dyn_cast<ConstantSDNode>(Op1)) {
-    // NOTE: SimplifyDemandedBits won't do this for constants.
-    uint64_t Val1 = Cst1->getZExtValue();
-    uint64_t MaskedVal1 = Val1 & 0xFF;
+    if (auto *Cst1 = dyn_cast<ConstantSDNode>(Op1)) {
+      // NOTE: SimplifyDemandedBits won't do this for constants.
+      uint64_t Val1 = Cst1->getZExtValue();
+      uint64_t MaskedVal1 = Val1 & 0xFF;
 
-    if (MaskedVal1 != Val1) {
-      SDLoc DL(Op);
-      EVT MaskVT = Op1.getValueType();
-      SDValue NewMask = TLO.DAG.getConstant(MaskedVal1, DL, MaskVT);
-      return TLO.CombineTo(Op, TLO.DAG.getNode(X86ISD::BZHI, DL, VT, Op0, NewMask));
-    }
+      if (MaskedVal1 != Val1) {
+        SDLoc DL(Op);
+        EVT MaskVT = Op1.getValueType();
+        SDValue NewMask = TLO.DAG.getConstant(MaskedVal1, DL, MaskVT);
+        return TLO.CombineTo(
+            Op, TLO.DAG.getNode(X86ISD::BZHI, DL, VT, Op0, NewMask));
+      }
 
-    MaxMask8 = MaskedVal1;
-  } else {
-    unsigned MaskBW = Op1.getValueType().getSizeInBits();
-    APInt MaskDemand = APInt::getLowBitsSet(MaskBW, 8);
+      MaxMask8 = MaskedVal1;
+    } else {
+      unsigned MaskBW = Op1.getValueType().getSizeInBits();
+      APInt MaskDemand = APInt::getLowBitsSet(MaskBW, 8);
 
-    KnownBits Known1;
-    if (SimplifyDemandedBits(Op1, MaskDemand, Known1, TLO, Depth + 1))
-      return true;
+      KnownBits Known1;
+      if (SimplifyDemandedBits(Op1, MaskDemand, Known1, TLO, Depth + 1))
+        return true;
 
-    // Compute an upper bound on mask[7:0].
-    KnownBits MaskBits = Known1.extractBits(8, 0);
-    MaxMask8 = MaskBits.getMaxValue().getZExtValue();
-  }
+      // Compute an upper bound on mask[7:0].
+      KnownBits MaskBits = Known1.extractBits(8, 0);
+      MaxMask8 = MaskBits.getMaxValue().getZExtValue();
+    }
 
-  // Rule 2: If mask[7:0] is known to be < BitWidth, then bits at/above
-  // getMaxValue are always zero and thus not demanded; likewise src bits.
-  APInt SrcDemanded = OriginalDemandedBits;
-  if (MaxMask8 < BitWidth) {
-    unsigned Cut = (unsigned)MaxMask8;
-    SrcDemanded.clearBits(Cut, BitWidth);
-  }
+    // Rule 2: If mask[7:0] is known to be < BitWidth, then bits at/above
+    // getMaxValue are always zero and thus not demanded; likewise src bits.
+    APInt SrcDemanded = OriginalDemandedBits;
+    if (MaxMask8 < BitWidth) {
+      unsigned Cut = (unsigned)MaxMask8;
+      SrcDemanded.clearBits(Cut, BitWidth);
+    }
 
-  
-  KnownBits KnownSrc;
-  if (SimplifyDemandedBits(Op0, SrcDemanded, KnownSrc, TLO, Depth + 1))
-    return true;
+    KnownBits KnownSrc;
+    if (SimplifyDemandedBits(Op0, SrcDemanded, KnownSrc, TLO, Depth + 1))
+      return true;
 
-  
-  Known.One.clearAllBits();
-  Known.Zero.clearAllBits();
-  if (MaxMask8 < BitWidth) {
-    unsigned Cut = (unsigned)MaxMask8;
-    Known.Zero.setBits(Cut, BitWidth);
+    Known.One.clearAllBits();
+    Known.Zero.clearAllBits();
+    if (MaxMask8 < BitWidth) {
+      unsigned Cut = (unsigned)MaxMask8;
+      Known.Zero.setBits(Cut, BitWidth);
     }
 
-  break;
+    break;
   }
 
   case X86ISD::PDEP: {



More information about the llvm-commits mailing list