[flang-commits] [flang] [llvm] [GlobalISel] Port over `simplifyDemandedBits` to GlobalISel (PR #198808)

Alan Li via flang-commits flang-commits at lists.llvm.org
Thu Jun 11 14:51:00 PDT 2026


https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/198808

>From 1283843095854a22e7eec868cc081aa3d9cbad4f Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 26 May 2026 08:09:32 -0700
Subject: [PATCH 1/4] [GlobalISel] Port SimplifyDemandedBits with a fused,
 idiomatic simplifier

Introduce demanded-bits analysis and simplification for GlobalISel,
mirroring SelectionDAG's SimplifyDemandedBits in behaviour while staying
GISel-idiomatic.

Core query struct:
* GISelDemandedMask bundles a demanded-bits mask and a demanded-elements
  mask. No public positional constructor; named factories (getAllBits,
  forBits, forElts) plus withBits/withElts/forScalarElement builders make
  a (Bits, Elts) swap a compile error and assert width parity.

Value-tracking (GISelValueTracking):
* computeKnownBitsImpl now takes a GISelDemandedMask; the public getKnownBits
  gains a struct overload, and the legacy elements-only getKnownBits /
  computeKnownBitsImpl forms are kept as public forwarders. AMDGPU target
  call sites migrate to the public getKnownBits API. maskedValueIsZero
  threads its check mask as the demanded-bits component.
* Shift opcodes (G_SHL/G_LSHR/G_ASHR) narrow the demanded mask of their
  shifted source over the feasible shift-amount range, with an all-ones
  fallback when the range is unprovable.
* getDemandedBitsForUse: per-opcode reverse transfer (G_TRUNC, G_AND,
  G_OR, G_XOR, G_ZEXT, G_SEXT) returning the bits an operand is demanded
  in; over-approximates (all-ones fallback) so it can never under-demand.
* computeDemandedBits: union, over all users of a register, of each
  user's demand, recursing upward (depth-capped).

Combiner (CombinerHelper):
* getDemandedBitsSimplifiedReg: pure decision returning the operand a
  redundant G_AND / disjoint G_OR can be replaced by under a mask.
* trySimplifyDemandedBits(MI, OperandNo, Mask): mutator wrapper used by
  the trunc-of-shift path.
* matchSimplifyDemandedBits + the post-legalize simplify_demanded_bits
  combine: per-instruction driver that computes a value's demand from its
  users and eliminates the masked op when redundant. Deadloop-safe via
  the build_fn pattern (decide in match, run in apply). Demand is the
  union over all users, so replacing every use is sound and multi-use
  needs no separate look-through.
* narrow_trunc_shr_const and the trunc_shift demand hook for trunc of
  shifts.

Tests: GISelValueTracking unit tests for the reverse transfer and union,
an exhaustive i8 soundness oracle for the over-approximation, and AArch64
post-legalizer MIR tests including a multi-use soundness gate.
---
 .../llvm/CodeGen/GlobalISel/CombinerHelper.h  |  31 ++
 .../CodeGen/GlobalISel/GISelDemandedMask.h    |  75 +++
 .../CodeGen/GlobalISel/GISelValueTracking.h   |  43 +-
 .../include/llvm/Target/GlobalISel/Combine.td |  34 +-
 .../lib/CodeGen/GlobalISel/CombinerHelper.cpp | 161 ++++++
 .../CodeGen/GlobalISel/GISelValueTracking.cpp | 417 +++++++++-----
 llvm/lib/Target/AArch64/AArch64Combine.td     |   3 +-
 llvm/lib/Target/AMDGPU/SIISelLowering.cpp     |  17 +-
 .../GlobalISel/combine-narrow-trunc-shr.mir   | 193 +++++++
 .../combine-simplify-demanded-bits.mir        |  54 ++
 llvm/test/CodeGen/AArch64/arm64-vhadd.ll      |   6 +-
 .../CodeGen/AArch64/hadd-combine-scalar.ll    |  21 +-
 llvm/test/CodeGen/AArch64/hadd-combine.ll     |   6 +-
 .../CodeGen/AMDGPU/GlobalISel/sext_inreg.ll   |  24 +-
 .../CodeGen/GlobalISel/KnownBitsTest.cpp      | 527 +++++++++++++++++-
 15 files changed, 1395 insertions(+), 217 deletions(-)
 create mode 100644 llvm/include/llvm/CodeGen/GlobalISel/GISelDemandedMask.h
 create mode 100644 llvm/test/CodeGen/AArch64/GlobalISel/combine-narrow-trunc-shr.mir
 create mode 100644 llvm/test/CodeGen/AArch64/GlobalISel/combine-simplify-demanded-bits.mir

diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index aa61310994a67..fa43650fd202b 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -39,6 +39,7 @@ class MachineRegisterInfo;
 class MachineInstr;
 class MachineOperand;
 class GISelValueTracking;
+struct GISelDemandedMask;
 class MachineDominatorTree;
 class LegalizerInfo;
 struct LegalityQuery;
@@ -1177,6 +1178,36 @@ class CombinerHelper {
   LLVM_ABI bool matchAVG(MachineInstr &MI, MachineRegisterInfo &MRI, Register X,
                          Register Y, unsigned TargetOpc) const;
 
+  /// If the value defined by operand \p OperandNo of \p MI can be replaced by
+  /// one of its existing operand registers when only \p Mask is demanded,
+  /// return that register. Pure (no mutation), generic opcodes only (no target
+  /// dispatch). Currently handles redundant G_AND / disjoint G_OR by a
+  /// constant.
+  std::optional<Register>
+  getDemandedBitsSimplifiedReg(MachineInstr &MI, unsigned OperandNo,
+                               const GISelDemandedMask &Mask) const;
+
+  /// Demanded-bits back-propagation driver. Attempts to rewrite the value
+  /// defined by operand \p OperandNo of \p MI under the assumption that only
+  /// the bits set in \p Mask.Bits and the elements set in \p Mask.Elts of that
+  /// def are consumed. \p OperandNo selects which def is described, so the
+  /// driver can be used on instructions with multiple defs. Returns true if
+  /// \p MI (or one of its operand-producing instructions) was mutated or
+  /// replaced.
+  bool trySimplifyDemandedBits(MachineInstr &MI, unsigned OperandNo,
+                               const GISelDemandedMask &Mask) const;
+
+  /// Per-instruction demanded-bits simplification. Computes the instruction's
+  /// demand from its users; if its def can be replaced under that demand,
+  /// captures the rewrite. Deadloop-safe: returns true only when apply will
+  /// make progress.
+  bool matchSimplifyDemandedBits(MachineInstr &MI, BuildFnTy &MatchInfo) const;
+
+  // Match (G_TRUNC (G_LSHR/G_ASHR X, K-const)) when X's bits beyond DstBW
+  // (zero for LSHR, sign-bit replicated for ASHR) are provably idle. Rewrites
+  // to (G_LSHR/G_ASHR (G_TRUNC X), K-trunc), eliminating the outer trunc.
+  bool matchNarrowTruncShrConst(MachineInstr &MI, BuildFnTy &MatchInfo) const;
+
 private:
   /// Checks for legality of an indexed variant of \p LdSt.
   bool isIndexedLoadStoreLegal(GLoadStore &LdSt) const;
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GISelDemandedMask.h b/llvm/include/llvm/CodeGen/GlobalISel/GISelDemandedMask.h
new file mode 100644
index 0000000000000..a2d1d8216355e
--- /dev/null
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GISelDemandedMask.h
@@ -0,0 +1,75 @@
+//===- llvm/CodeGen/GlobalISel/GISelDemandedMask.h --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+/// Defines the GISelDemandedMask query struct used by GISelValueTracking and
+/// CombinerHelper to express demanded-bit and demanded-element masks as a
+/// single, ergonomic value with named factories.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CODEGEN_GLOBALISEL_GISELDEMANDEDMASK_H
+#define LLVM_CODEGEN_GLOBALISEL_GISELDEMANDEDMASK_H
+
+#include "llvm/ADT/APInt.h"
+#include "llvm/CodeGen/Register.h"
+#include "llvm/Support/Compiler.h"
+
+#include <utility>
+
+namespace llvm {
+
+class MachineRegisterInfo;
+
+/// Bundles a demanded-bits mask and a demanded-elements mask used to
+/// constrain GISel value-tracking queries. The struct has no public
+/// positional constructor; callers must use named factories so that
+/// `(Bits, Elts)` cannot be swapped at a call site.
+struct GISelDemandedMask {
+  APInt Bits;
+  APInt Elts;
+
+  GISelDemandedMask() = delete;
+
+  /// Construct a mask demanding every bit and every element of \p R.
+  LLVM_ABI static GISelDemandedMask getAllBits(const MachineRegisterInfo &MRI,
+                                               Register R);
+
+  /// Construct a mask demanding the given bit mask of every element of \p R.
+  LLVM_ABI static GISelDemandedMask forBits(const MachineRegisterInfo &MRI,
+                                            Register R, APInt Bits);
+
+  /// Construct a mask demanding every bit of the given elements of \p R.
+  LLVM_ABI static GISelDemandedMask forElts(const MachineRegisterInfo &MRI,
+                                            Register R, APInt Elts);
+
+  /// Return a copy with the bit mask replaced.
+  GISelDemandedMask withBits(APInt NewBits) const {
+    GISelDemandedMask M = *this;
+    M.Bits = std::move(NewBits);
+    return M;
+  }
+
+  /// Return a copy with the element mask replaced.
+  GISelDemandedMask withElts(APInt NewElts) const {
+    GISelDemandedMask M = *this;
+    M.Elts = std::move(NewElts);
+    return M;
+  }
+
+  /// Return a copy demanding a single scalar element. Used when recursing
+  /// from a vector opcode into a per-lane producer.
+  GISelDemandedMask forScalarElement() const { return withElts(APInt(1, 1)); }
+
+private:
+  GISelDemandedMask(APInt B, APInt E)
+      : Bits(std::move(B)), Elts(std::move(E)) {}
+};
+
+} // namespace llvm
+
+#endif // LLVM_CODEGEN_GLOBALISEL_GISELDEMANDEDMASK_H
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GISelValueTracking.h b/llvm/include/llvm/CodeGen/GlobalISel/GISelValueTracking.h
index 722a45eeb07e5..85819bd32101c 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GISelValueTracking.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GISelValueTracking.h
@@ -17,6 +17,7 @@
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
+#include "llvm/CodeGen/GlobalISel/GISelDemandedMask.h"
 #include "llvm/CodeGen/MachineFunctionPass.h"
 #include "llvm/CodeGen/Register.h"
 #include "llvm/IR/InstrTypes.h"
@@ -39,11 +40,14 @@ class LLVM_ABI GISelValueTracking : public GISelChangeObserver {
   unsigned MaxDepth;
 
   void computeKnownBitsMin(Register Src0, Register Src1, KnownBits &Known,
-                           const APInt &DemandedElts, unsigned Depth = 0);
+                           const GISelDemandedMask &Mask, unsigned Depth = 0);
 
   unsigned computeNumSignBitsMin(Register Src0, Register Src1,
                                  const APInt &DemandedElts, unsigned Depth = 0);
 
+  void computeKnownBitsImpl(Register R, KnownBits &Known,
+                            const GISelDemandedMask &Mask, unsigned Depth = 0);
+
   void computeKnownFPClass(Register R, KnownFPClass &Known,
                            FPClassTest InterestedClasses, unsigned Depth);
 
@@ -64,28 +68,37 @@ class LLVM_ABI GISelValueTracking : public GISelChangeObserver {
 
   const DataLayout &getDataLayout() const { return DL; }
 
-  void computeKnownBitsImpl(Register R, KnownBits &Known,
-                            const APInt &DemandedElts, unsigned Depth = 0);
-
   unsigned computeNumSignBits(Register R, const APInt &DemandedElts,
                               unsigned Depth = 0);
   unsigned computeNumSignBits(Register R, unsigned Depth = 0);
 
   // KnownBitsAPI
   KnownBits getKnownBits(Register R);
+  KnownBits getKnownBits(Register R, const GISelDemandedMask &Mask,
+                         unsigned Depth = 0);
+
+  /// Convenience wrapper for callers that have only a demanded-elements mask
+  /// in hand. Forwards to the GISelDemandedMask form with all bits demanded.
   KnownBits getKnownBits(Register R, const APInt &DemandedElts,
                          unsigned Depth = 0);
 
+  /// Compute known bits for \p R restricted to \p DemandedElts. Forwards to the
+  /// GISelDemandedMask form (all bits demanded); kept for target/out-of-tree
+  /// callers that pass an elements mask directly.
+  void computeKnownBitsImpl(Register R, KnownBits &Known,
+                            const APInt &DemandedElts, unsigned Depth = 0);
+
   // Calls getKnownBits for first operand def of MI.
   KnownBits getKnownBits(MachineInstr &MI);
   APInt getKnownZeroes(Register R);
   APInt getKnownOnes(Register R);
 
-  /// \return true if 'V & Mask' is known to be zero in DemandedElts. We use
-  /// this predicate to simplify operations downstream.
-  /// Mask is known to be zero for bits that V cannot have.
+  /// \return true if 'V & Mask' is known to be zero. We use this predicate
+  /// to simplify operations downstream. Mask is known to be zero for bits
+  /// that V cannot have.
   bool maskedValueIsZero(Register Val, const APInt &Mask) {
-    return Mask.isSubsetOf(getKnownBits(Val).Zero);
+    return Mask.isSubsetOf(
+        getKnownBits(Val, GISelDemandedMask::forBits(MRI, Val, Mask)).Zero);
   }
 
   /// \return true if the sign bit of Op is known to be zero.  We use this
@@ -97,6 +110,20 @@ class LLVM_ABI GISelValueTracking : public GISelChangeObserver {
     Known.Zero.setLowBits(Log2(Alignment));
   }
 
+  /// Bits of operand \p OpIdx of \p UseMI that \p UseMI demands, given that
+  /// only \p DemandOfUser bits of UseMI's def are demanded. The returned mask
+  /// is at the operand's scalar bit width. Over-approximates: returns all-ones
+  /// for opcodes/operands whose reverse transfer is not modeled.
+  APInt getDemandedBitsForUse(const MachineInstr &UseMI, unsigned OpIdx,
+                              const APInt &DemandOfUser);
+
+  /// Union, over all non-debug users of \p R, of the bits each user demands of
+  /// \p R (via getDemandedBitsForUse, recursing upward for each user's own
+  /// demand). Returns all-ones at/past the recursion cap, for vector/invalid
+  /// types, or when a user has multiple defs. Returns an empty (zero) mask when
+  /// \p R has no users (a dead value demands nothing).
+  APInt computeDemandedBits(Register R, unsigned Depth = 0);
+
   /// \return The known alignment for the pointer-like value \p R.
   Align computeKnownAlignment(Register R, unsigned Depth = 0);
 
diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index f3474eb95c436..a090eeecde14b 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -1102,6 +1102,38 @@ def trunc_shift: GICombineRule <
   (apply [{ Helper.applyCombineTruncOfShift(*${root}, ${matchinfo}); }])
 >;
 
+// Demanded-bits-driven narrowing of trunc-of-right-shift with a constant
+// shift amount. Fires when X's bits beyond the destination width are
+// provably idle (known-zero for LSHR, sign-bit-replicated for ASHR), in
+// which case the outer truncate can be dropped:
+//
+//   (trunc (lshr X, K))  -> (lshr (trunc X), K)  when X[DstBW..K+DstBW) == 0
+//   (trunc (ashr X, K))  -> (ashr (trunc X), K)  when X has SrcBW-DstBW+1
+//                                                sign bits.
+//
+// K must be a constant less than DstBW so the rewrite never recurses into
+// a re-truncated outer pattern -- guarantees fixpoint termination and
+// avoids the deadloop that a generic trunc-of-shift demand combine has
+// against existing rules like ``trunc_shift``.
+def narrow_trunc_shr_const : GICombineRule<
+  (defs root:$root, build_fn_matchinfo:$matchinfo),
+  (match (wip_match_opcode G_TRUNC):$root,
+    [{ return Helper.matchNarrowTruncShrConst(*${root}, ${matchinfo}); }]),
+  (apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
+
+// Demand-driven elimination of redundant G_AND / disjoint G_OR by a constant.
+// Computes the union of bit-demands from all users of the def; if the constant
+// mask (for AND) covers every demanded bit, or the constant or-bits (for OR)
+// are entirely undemanded, the instruction is replaced by its LHS operand.
+// Deadloop-safe: match returns true only when the rewrite will make progress.
+// NOTE: intentionally left unwired from all combiner groups -- wire-up is a
+// separate step once the rule is validated end-to-end.
+def simplify_demanded_bits : GICombineRule<
+  (defs root:$root, build_fn_matchinfo:$matchinfo),
+  (match (wip_match_opcode G_AND, G_OR):$root,
+         [{ return Helper.matchSimplifyDemandedBits(*${root}, ${matchinfo}); }]),
+  (apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
+
 // Transform (mul x, -1) -> (sub 0, x)
 def mul_by_neg_one: GICombineRule <
   (defs root:$dst),
@@ -2506,7 +2538,7 @@ def all_combines : GICombineGroup<[integer_reassoc_combines, trivial_combines,
     reassocs, ptr_add_immed_chain, cmp_combines,
     shl_ashr_to_sext_inreg, neg_and_one_to_sext_inreg, sext_inreg_of_load,
     width_reduction_combines, select_combines, select_zero_false, select_not,
-    known_bits_simplifications, trunc_shift,
+    known_bits_simplifications, trunc_shift, narrow_trunc_shr_const,
     not_cmp_fold, opt_brcond_by_inverting_cond,
     const_combines, xor_of_and_with_same_reg, ptr_add_with_zero,
     shift_immed_chain, shift_of_shifted_logic_chain, load_or_combine,
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index 88b68d7685c63..f5346fe3d06df 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -12,6 +12,7 @@
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/Analysis/CmpInstAnalysis.h"
 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
+#include "llvm/CodeGen/GlobalISel/GISelDemandedMask.h"
 #include "llvm/CodeGen/GlobalISel/GISelValueTracking.h"
 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
@@ -2852,6 +2853,21 @@ void CombinerHelper::applyCombineTruncOfShift(
 
   Register ShiftAmt = ShiftMI->getOperand(2).getReg();
   Register ShiftSrc = ShiftMI->getOperand(1).getReg();
+
+  // Stage-3 demanded-bits hook: the upcoming inner G_TRUNC consumes only the
+  // low NewShiftTy bits of ShiftSrc. If its producer is a redundant
+  // mask/disjoint-or, drive simplification before truncating so we don't
+  // synthesise dead high-bit work that DCE has to clean up later.
+  if (auto *ShiftSrcMI = getDefIgnoringCopies(ShiftSrc, MRI)) {
+    APInt LowMask =
+        APInt::getLowBitsSet(MRI.getType(ShiftSrc).getScalarSizeInBits(),
+                             NewShiftTy.getScalarSizeInBits());
+    trySimplifyDemandedBits(*ShiftSrcMI, /*OperandNo=*/0,
+                            GISelDemandedMask::forBits(MRI, ShiftSrc, LowMask));
+    // Re-read the operand: driver may have RAUW-rerouted it.
+    ShiftSrc = ShiftMI->getOperand(1).getReg();
+  }
+
   ShiftSrc = Builder.buildTrunc(NewShiftTy, ShiftSrc).getReg(0);
 
   Register NewShift =
@@ -8816,3 +8832,148 @@ bool CombinerHelper::matchAVG(MachineInstr &MI, MachineRegisterInfo &MRI,
   LLT XTy = MRI.getType(X);
   return XTy == MRI.getType(Y) && isLegal({TargetOpc, {XTy}});
 }
+
+std::optional<Register> CombinerHelper::getDemandedBitsSimplifiedReg(
+    MachineInstr &MI, unsigned OperandNo, const GISelDemandedMask &Mask) const {
+  if (OperandNo >= MI.getNumOperands() || !MI.getOperand(OperandNo).isReg() ||
+      !MI.getOperand(OperandNo).isDef())
+    return std::nullopt;
+  Register Dst = MI.getOperand(OperandNo).getReg();
+  LLT DstTy = MRI.getType(Dst);
+  if (!DstTy.isValid())
+    return std::nullopt;
+  unsigned BW = DstTy.getScalarSizeInBits();
+  if (Mask.Bits.getBitWidth() != BW)
+    return std::nullopt;
+
+  switch (MI.getOpcode()) {
+  case TargetOpcode::G_AND: {
+    // (and X, C) with (C & Mask) == Mask -> X is sufficient.
+    Register LHS = MI.getOperand(1).getReg();
+    std::optional<APInt> C =
+        getConstantOrConstantSplatVector(MI.getOperand(2).getReg());
+    if (!C || C->getBitWidth() != BW)
+      return std::nullopt;
+    if (!Mask.Bits.isSubsetOf(*C))
+      return std::nullopt;
+    if (MRI.getType(LHS) != DstTy)
+      return std::nullopt;
+    return LHS;
+  }
+  case TargetOpcode::G_OR: {
+    // (or X, C) with (C & Mask) == 0 -> X is sufficient.
+    Register LHS = MI.getOperand(1).getReg();
+    std::optional<APInt> C =
+        getConstantOrConstantSplatVector(MI.getOperand(2).getReg());
+    if (!C || C->getBitWidth() != BW)
+      return std::nullopt;
+    if ((*C & Mask.Bits).getBoolValue())
+      return std::nullopt;
+    if (MRI.getType(LHS) != DstTy)
+      return std::nullopt;
+    return LHS;
+  }
+  default:
+    return std::nullopt;
+  }
+}
+
+bool CombinerHelper::trySimplifyDemandedBits(
+    MachineInstr &MI, unsigned OperandNo, const GISelDemandedMask &Mask) const {
+  if (std::optional<Register> R =
+          getDemandedBitsSimplifiedReg(MI, OperandNo, Mask)) {
+    replaceRegWith(MRI, MI.getOperand(OperandNo).getReg(), *R);
+    eraseInst(MI);
+    return true;
+  }
+  return false;
+}
+
+bool CombinerHelper::matchSimplifyDemandedBits(MachineInstr &MI,
+                                               BuildFnTy &MatchInfo) const {
+  if (MI.getNumExplicitDefs() != 1 || !MI.getOperand(0).isReg())
+    return false;
+  Register Dst = MI.getOperand(0).getReg();
+  LLT Ty = MRI.getType(Dst);
+  if (!Ty.isValid() || Ty.isVector())
+    return false;
+  unsigned BW = Ty.getScalarSizeInBits();
+  APInt Demand = VT->computeDemandedBits(Dst);
+  if (Demand.getBitWidth() != BW || Demand.isAllOnes())
+    return false; // nothing un-demanded -> no opportunity, and avoids re-firing
+  std::optional<Register> R = getDemandedBitsSimplifiedReg(
+      MI, /*OperandNo=*/0, GISelDemandedMask::forBits(MRI, Dst, Demand));
+  if (!R)
+    return false;
+  Register Repl = *R;
+  MatchInfo = [=](MachineIRBuilder &B) { replaceRegWith(MRI, Dst, Repl); };
+  return true;
+}
+
+// (trunc (lshr X, K)) with bits [DstBW, DstBW+K) of X known-zero
+//   -> (lshr (trunc X), K)
+// (trunc (ashr X, K)) when X has at least (BW - DstBW + 1) sign bits
+//   -> (ashr (trunc X), K)
+bool CombinerHelper::matchNarrowTruncShrConst(MachineInstr &MI,
+                                              BuildFnTy &MatchInfo) const {
+  assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected G_TRUNC");
+  Register Dst = MI.getOperand(0).getReg();
+  Register Src = MI.getOperand(1).getReg();
+  if (!MRI.hasOneNonDBGUse(Src))
+    return false;
+
+  MachineInstr *ShrMI = getDefIgnoringCopies(Src, MRI);
+  if (!ShrMI)
+    return false;
+  unsigned ShrOpc = ShrMI->getOpcode();
+  if (ShrOpc != TargetOpcode::G_LSHR && ShrOpc != TargetOpcode::G_ASHR)
+    return false;
+
+  Register X = ShrMI->getOperand(1).getReg();
+  Register AmtReg = ShrMI->getOperand(2).getReg();
+  LLT SrcTy = MRI.getType(X);
+  LLT DstTy = MRI.getType(Dst);
+  if (SrcTy.isVector() != DstTy.isVector())
+    return false;
+  if (SrcTy.isVector() && SrcTy.getElementCount() != DstTy.getElementCount())
+    return false;
+
+  unsigned SrcBW = SrcTy.getScalarSizeInBits();
+  unsigned DstBW = DstTy.getScalarSizeInBits();
+  if (DstBW >= SrcBW)
+    return false;
+
+  std::optional<APInt> K = getConstantOrConstantSplatVector(AmtReg);
+  if (!K)
+    return false;
+  if (K->uge(DstBW))
+    return false;
+  unsigned KVal = K->getZExtValue();
+  if (KVal + DstBW > SrcBW)
+    return false;
+
+  if (!VT)
+    return false;
+
+  if (ShrOpc == TargetOpcode::G_LSHR) {
+    KnownBits Known = VT->getKnownBits(X);
+    APInt HiZeroes = Known.Zero.extractBits(KVal, DstBW);
+    if (!HiZeroes.isAllOnes())
+      return false;
+  } else {
+    unsigned SignBits = VT->computeNumSignBits(X);
+    if (SignBits < SrcBW - DstBW + 1)
+      return false;
+  }
+
+  LLT AmtTy = getTargetLowering().getPreferredShiftAmountTy(DstTy);
+  if (!isLegalOrBeforeLegalizer({ShrOpc, {DstTy, AmtTy}}))
+    return false;
+
+  MatchInfo = [=](MachineIRBuilder &B) {
+    Register NarrowX = B.buildTrunc(DstTy, X).getReg(0);
+    Register NarrowAmt = B.buildConstant(AmtTy, KVal).getReg(0);
+    B.buildInstr(ShrOpc, {Dst}, {NarrowX, NarrowAmt});
+  };
+  return true;
+}
diff --git a/llvm/lib/CodeGen/GlobalISel/GISelValueTracking.cpp b/llvm/lib/CodeGen/GlobalISel/GISelValueTracking.cpp
index ea4ab73700cf0..a0047f990f288 100644
--- a/llvm/lib/CodeGen/GlobalISel/GISelValueTracking.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/GISelValueTracking.cpp
@@ -53,6 +53,32 @@ GISelValueTracking::GISelValueTracking(MachineFunction &MF, unsigned MaxDepth)
     : MF(MF), MRI(MF.getRegInfo()), TL(*MF.getSubtarget().getTargetLowering()),
       DL(MF.getFunction().getDataLayout()), MaxDepth(MaxDepth) {}
 
+GISelDemandedMask GISelDemandedMask::getAllBits(const MachineRegisterInfo &MRI,
+                                                Register R) {
+  LLT Ty = MRI.getType(R);
+  unsigned BW = Ty.isValid() ? Ty.getScalarSizeInBits() : 1;
+  APInt Elts =
+      Ty.isFixedVector() ? APInt::getAllOnes(Ty.getNumElements()) : APInt(1, 1);
+  return GISelDemandedMask(APInt::getAllOnes(BW), std::move(Elts));
+}
+
+GISelDemandedMask GISelDemandedMask::forBits(const MachineRegisterInfo &MRI,
+                                             Register R, APInt Bits) {
+  LLT Ty = MRI.getType(R);
+  // Bits must match the register's scalar width. Mismatches would silently
+  // disable downstream demand narrowing (which falls back to all-ones when
+  // the width doesn't line up), so refuse them at the construction boundary.
+  assert((!Ty.isValid() || Bits.getBitWidth() == Ty.getScalarSizeInBits()) &&
+         "GISelDemandedMask::forBits: Bits width must match register's "
+         "scalar bit width");
+  return getAllBits(MRI, R).withBits(std::move(Bits));
+}
+
+GISelDemandedMask GISelDemandedMask::forElts(const MachineRegisterInfo &MRI,
+                                             Register R, APInt Elts) {
+  return getAllBits(MRI, R).withElts(std::move(Elts));
+}
+
 Align GISelValueTracking::computeKnownAlignment(Register R, unsigned Depth) {
   const MachineInstr *MI = MRI.getVRegDef(R);
   switch (MI->getOpcode()) {
@@ -82,23 +108,31 @@ KnownBits GISelValueTracking::getKnownBits(MachineInstr &MI) {
 }
 
 KnownBits GISelValueTracking::getKnownBits(Register R) {
-  const LLT Ty = MRI.getType(R);
-  // Since the number of lanes in a scalable vector is unknown at compile time,
-  // we track one bit which is implicitly broadcast to all lanes.  This means
-  // that all lanes in a scalable vector are considered demanded.
-  APInt DemandedElts =
-      Ty.isFixedVector() ? APInt::getAllOnes(Ty.getNumElements()) : APInt(1, 1);
-  return getKnownBits(R, DemandedElts);
+  return getKnownBits(R, GISelDemandedMask::getAllBits(MRI, R));
 }
 
 KnownBits GISelValueTracking::getKnownBits(Register R,
-                                           const APInt &DemandedElts,
+                                           const GISelDemandedMask &Mask,
                                            unsigned Depth) {
   KnownBits Known;
-  computeKnownBitsImpl(R, Known, DemandedElts, Depth);
+  computeKnownBitsImpl(R, Known, Mask, Depth);
   return Known;
 }
 
+KnownBits GISelValueTracking::getKnownBits(Register R,
+                                           const APInt &DemandedElts,
+                                           unsigned Depth) {
+  return getKnownBits(R, GISelDemandedMask::forElts(MRI, R, DemandedElts),
+                      Depth);
+}
+
+void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
+                                              const APInt &DemandedElts,
+                                              unsigned Depth) {
+  computeKnownBitsImpl(R, Known,
+                       GISelDemandedMask::forElts(MRI, R, DemandedElts), Depth);
+}
+
 bool GISelValueTracking::signBitIsZero(Register R) {
   LLT Ty = MRI.getType(R);
   unsigned BitWidth = Ty.getScalarSizeInBits();
@@ -127,22 +161,153 @@ dumpResult(const MachineInstr &MI, const KnownBits &Known, unsigned Depth) {
 /// Compute known bits for the intersection of \p Src0 and \p Src1
 void GISelValueTracking::computeKnownBitsMin(Register Src0, Register Src1,
                                              KnownBits &Known,
-                                             const APInt &DemandedElts,
+                                             const GISelDemandedMask &Mask,
                                              unsigned Depth) {
   // Test src1 first, since we canonicalize simpler expressions to the RHS.
-  computeKnownBitsImpl(Src1, Known, DemandedElts, Depth);
+  computeKnownBitsImpl(Src1, Known, Mask, Depth);
 
   // If we don't know any bits, early out.
   if (Known.isUnknown())
     return;
 
   KnownBits Known2;
-  computeKnownBitsImpl(Src0, Known2, DemandedElts, Depth);
+  computeKnownBitsImpl(Src0, Known2, Mask, Depth);
 
   // Only known if known in both the LHS and RHS.
   Known = Known.intersectWith(Known2);
 }
 
+// Back-propagate the demanded-bits mask of a SHL/LSHR/ASHR result into its
+// shifted-value source operand, taking the union of routed source bits over
+// the feasible shift-amount range [Lo, Hi]. For SHL, result bit p is sourced
+// from bit (p - c); for LSHR/ASHR direct route, from bit (p + c). For ASHR
+// additionally, any demanded bit at position >= BW - Hi reaches the source
+// sign bit via sign-fill. Returns the safe all-ones source demand when the
+// range is not provable, empty, wrapped, or extends past BitWidth, or when
+// the local and outer bitwidths disagree (cross-bitwidth recursion through
+// COPY/PHI/EXTRACT/CONCAT/SHUFFLE).
+static APInt getShiftSrcDemandedBits(GISelValueTracking &VT, Register AmtReg,
+                                     unsigned BitWidth,
+                                     const GISelDemandedMask &Mask,
+                                     unsigned ShiftOpcode, unsigned Depth) {
+  if (Mask.Bits.getBitWidth() != BitWidth)
+    return APInt::getAllOnes(BitWidth);
+  std::optional<ConstantRange> ShAmtRange =
+      VT.getValidShiftAmountRange(AmtReg, Mask.Elts, Depth);
+  if (!ShAmtRange || ShAmtRange->isEmptySet() || ShAmtRange->isWrappedSet())
+    return APInt::getAllOnes(BitWidth);
+  const APInt &Hi = ShAmtRange->getUnsignedMax();
+  if (Hi.uge(BitWidth))
+    return APInt::getAllOnes(BitWidth);
+  uint64_t LoVal = ShAmtRange->getUnsignedMin().getZExtValue();
+  uint64_t HiVal = Hi.getZExtValue();
+
+  APInt SrcBits = APInt::getZero(BitWidth);
+  if (ShiftOpcode == TargetOpcode::G_SHL) {
+    for (uint64_t C = LoVal; C <= HiVal; ++C)
+      SrcBits |= Mask.Bits.lshr(C);
+  } else {
+    for (uint64_t C = LoVal; C <= HiVal; ++C)
+      SrcBits |= Mask.Bits.shl(C);
+    if (ShiftOpcode == TargetOpcode::G_ASHR) {
+      // Sign-fill broadcast: any demanded bit p with p + c >= BW for some
+      // feasible c maps to the source sign bit; binding case is c = Hi.
+      APInt SignFillReach = APInt::getBitsSetFrom(BitWidth, BitWidth - HiVal);
+      if (Mask.Bits.intersects(SignFillReach))
+        SrcBits.setBit(BitWidth - 1);
+    }
+  }
+  return SrcBits;
+}
+
+APInt GISelValueTracking::getDemandedBitsForUse(const MachineInstr &UseMI,
+                                                unsigned OpIdx,
+                                                const APInt &DemandOfUser) {
+  Register OpReg = UseMI.getOperand(OpIdx).getReg();
+  LLT OpTy = MRI.getType(OpReg);
+  if (!OpTy.isValid() || OpTy.isVector())
+    return APInt::getAllOnes(OpTy.isValid() ? OpTy.getScalarSizeInBits() : 1);
+  unsigned OpBW = OpTy.getScalarSizeInBits();
+
+  switch (UseMI.getOpcode()) {
+  case TargetOpcode::G_TRUNC:
+    // def (narrow) bit p == operand bit p; high operand bits not demanded.
+    return DemandOfUser.zext(OpBW);
+  case TargetOpcode::G_AND: {
+    // result = X & C. If the OTHER operand is a constant C, X is demanded only
+    // where the result is demanded AND C is set. Otherwise demand all of D.
+    unsigned Other = OpIdx == 1 ? 2 : 1;
+    if (auto C = getIConstantVRegValWithLookThrough(
+            UseMI.getOperand(Other).getReg(), MRI);
+        C && C->Value.getBitWidth() == OpBW &&
+        DemandOfUser.getBitWidth() == OpBW)
+      return DemandOfUser & C->Value;
+    return DemandOfUser.getBitWidth() == OpBW ? DemandOfUser
+                                              : APInt::getAllOnes(OpBW);
+  }
+  case TargetOpcode::G_OR: {
+    // result = X | C. Where C is set, result is 1 regardless of X, so X is
+    // not demanded there. Only a constant other-operand is modeled.
+    unsigned Other = OpIdx == 1 ? 2 : 1;
+    if (auto C = getIConstantVRegValWithLookThrough(
+            UseMI.getOperand(Other).getReg(), MRI);
+        C && C->Value.getBitWidth() == OpBW &&
+        DemandOfUser.getBitWidth() == OpBW)
+      return DemandOfUser & ~C->Value;
+    return DemandOfUser.getBitWidth() == OpBW ? DemandOfUser
+                                              : APInt::getAllOnes(OpBW);
+  }
+  case TargetOpcode::G_XOR:
+    // Bit-local: each demanded result bit needs the matching operand bit.
+    return DemandOfUser.getBitWidth() == OpBW ? DemandOfUser
+                                              : APInt::getAllOnes(OpBW);
+  case TargetOpcode::G_ZEXT:
+    // def low OpBW bits == operand; higher def bits are zero (independent of
+    // X).
+    return DemandOfUser.getBitWidth() >= OpBW ? DemandOfUser.trunc(OpBW)
+                                              : APInt::getAllOnes(OpBW);
+  case TargetOpcode::G_SEXT: {
+    // def low OpBW bits == operand; every def bit at index >= OpBW-1 comes from
+    // the operand sign bit (OpBW-1).
+    if (OpBW == 0 || DemandOfUser.getBitWidth() < OpBW)
+      return APInt::getAllOnes(OpBW);
+    APInt Low = DemandOfUser.trunc(OpBW);
+    if (!DemandOfUser.lshr(OpBW - 1).isZero())
+      Low.setBit(OpBW - 1);
+    return Low;
+  }
+  default:
+    return APInt::getAllOnes(OpBW);
+  }
+}
+
+APInt GISelValueTracking::computeDemandedBits(Register R, unsigned Depth) {
+  LLT Ty = MRI.getType(R);
+  unsigned BW = Ty.isValid() ? Ty.getScalarSizeInBits() : 1;
+  if (!Ty.isValid() || Ty.isVector())
+    return APInt::getAllOnes(BW);
+  if (Depth >= MaxAnalysisRecursionDepth)
+    return APInt::getAllOnes(BW);
+
+  APInt Demand = APInt::getZero(BW);
+  for (const MachineInstr &UseMI : MRI.use_nodbg_instructions(R)) {
+    if (UseMI.getNumExplicitDefs() != 1) {
+      Demand = APInt::getAllOnes(BW);
+      break;
+    }
+    APInt UserDemand =
+        computeDemandedBits(UseMI.getOperand(0).getReg(), Depth + 1);
+    for (unsigned I = 0, E = UseMI.getNumOperands(); I != E; ++I) {
+      const MachineOperand &MO = UseMI.getOperand(I);
+      if (MO.isReg() && !MO.isDef() && MO.getReg() == R)
+        Demand |= getDemandedBitsForUse(UseMI, I, UserDemand);
+    }
+    if (Demand.isAllOnes())
+      break;
+  }
+  return Demand;
+}
+
 // Bitfield extract is computed as (Src >> Offset) & Mask, where Mask is
 // created using Width. Use this function when the inputs are KnownBits
 // objects. TODO: Move this KnownBits.h if this is usable in more cases.
@@ -158,8 +323,9 @@ static KnownBits extractBits(unsigned BitWidth, const KnownBits &SrcOpKnown,
 }
 
 void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
-                                              const APInt &DemandedElts,
+                                              const GISelDemandedMask &Mask,
                                               unsigned Depth) {
+  const APInt &DemandedElts = Mask.Elts;
   MachineInstr &MI = *MRI.getVRegDef(R);
   unsigned Opcode = MI.getOpcode();
   LLT DstTy = MRI.getType(R);
@@ -215,7 +381,8 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
       if (!DemandedElts[I])
         continue;
 
-      computeKnownBitsImpl(MO.getReg(), Known2, APInt(1, 1), Depth + 1);
+      computeKnownBitsImpl(MO.getReg(), Known2, Mask.forScalarElement(),
+                           Depth + 1);
 
       // Known bits are the values that are shared by every demanded element.
       Known = Known.intersectWith(Known2);
@@ -227,8 +394,8 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
     break;
   }
   case TargetOpcode::G_SPLAT_VECTOR: {
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, APInt(1, 1),
-                         Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known,
+                         Mask.forScalarElement(), Depth + 1);
     // Implicitly truncate the bits to match the official semantics of
     // G_SPLAT_VECTOR.
     Known = Known.trunc(BitWidth);
@@ -269,7 +436,7 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
         }
 
         // For COPYs we don't do anything, don't increase the depth.
-        computeKnownBitsImpl(SrcReg, Known2, NowDemandedElts,
+        computeKnownBitsImpl(SrcReg, Known2, Mask.withElts(NowDemandedElts),
                              Depth + (Opcode != TargetOpcode::COPY));
         Known2 = Known2.anyextOrTrunc(BitWidth);
         Known = Known.intersectWith(Known2);
@@ -319,19 +486,15 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
     break;
   }
   case TargetOpcode::G_SUB: {
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
-                         Depth + 1);
-    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, DemandedElts,
-                         Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, Mask, Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, Mask, Depth + 1);
     Known = KnownBits::sub(Known, Known2, MI.getFlag(MachineInstr::NoSWrap),
                            MI.getFlag(MachineInstr::NoUWrap));
     break;
   }
   case TargetOpcode::G_XOR: {
-    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, DemandedElts,
-                         Depth + 1);
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
-                         Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, Mask, Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, Mask, Depth + 1);
 
     Known ^= Known2;
     break;
@@ -346,70 +509,54 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
     [[fallthrough]];
   }
   case TargetOpcode::G_ADD: {
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
-                         Depth + 1);
-    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, DemandedElts,
-                         Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, Mask, Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, Mask, Depth + 1);
     Known = KnownBits::add(Known, Known2);
     break;
   }
   case TargetOpcode::G_AND: {
     // If either the LHS or the RHS are Zero, the result is zero.
-    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, DemandedElts,
-                         Depth + 1);
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
-                         Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, Mask, Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, Mask, Depth + 1);
 
     Known &= Known2;
     break;
   }
   case TargetOpcode::G_OR: {
     // If either the LHS or the RHS are Zero, the result is zero.
-    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, DemandedElts,
-                         Depth + 1);
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
-                         Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, Mask, Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, Mask, Depth + 1);
 
     Known |= Known2;
     break;
   }
   case TargetOpcode::G_MUL: {
-    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, DemandedElts,
-                         Depth + 1);
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
-                         Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, Mask, Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, Mask, Depth + 1);
     Known = KnownBits::mul(Known, Known2);
     break;
   }
   case TargetOpcode::G_UMULH: {
-    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, DemandedElts,
-                         Depth + 1);
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
-                         Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, Mask, Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, Mask, Depth + 1);
     Known = KnownBits::mulhu(Known, Known2);
     break;
   }
   case TargetOpcode::G_SMULH: {
-    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, DemandedElts,
-                         Depth + 1);
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
-                         Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, Mask, Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, Mask, Depth + 1);
     Known = KnownBits::mulhs(Known, Known2);
     break;
   }
   case TargetOpcode::G_ABDU: {
-    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, DemandedElts,
-                         Depth + 1);
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
-                         Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, Mask, Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, Mask, Depth + 1);
     Known = KnownBits::abdu(Known, Known2);
     break;
   }
   case TargetOpcode::G_ABDS: {
-    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, DemandedElts,
-                         Depth + 1);
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
-                         Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, Mask, Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, Mask, Depth + 1);
     Known = KnownBits::abds(Known, Known2);
 
     unsigned SignBits1 =
@@ -424,19 +571,15 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
     break;
   }
   case TargetOpcode::G_UDIV: {
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
-                         Depth + 1);
-    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, DemandedElts,
-                         Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, Mask, Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, Mask, Depth + 1);
     Known = KnownBits::udiv(Known, Known2,
                             MI.getFlag(MachineInstr::MIFlag::IsExact));
     break;
   }
   case TargetOpcode::G_SDIV: {
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
-                         Depth + 1);
-    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, DemandedElts,
-                         Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, Mask, Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, Mask, Depth + 1);
     Known = KnownBits::sdiv(Known, Known2,
                             MI.getFlag(MachineInstr::MIFlag::IsExact));
     break;
@@ -445,10 +588,8 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
     KnownBits LHSKnown(Known.getBitWidth());
     KnownBits RHSKnown(Known.getBitWidth());
 
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), LHSKnown, DemandedElts,
-                         Depth + 1);
-    computeKnownBitsImpl(MI.getOperand(2).getReg(), RHSKnown, DemandedElts,
-                         Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), LHSKnown, Mask, Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(2).getReg(), RHSKnown, Mask, Depth + 1);
 
     Known = KnownBits::urem(LHSKnown, RHSKnown);
     break;
@@ -457,54 +598,44 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
     KnownBits LHSKnown(Known.getBitWidth());
     KnownBits RHSKnown(Known.getBitWidth());
 
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), LHSKnown, DemandedElts,
-                         Depth + 1);
-    computeKnownBitsImpl(MI.getOperand(2).getReg(), RHSKnown, DemandedElts,
-                         Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), LHSKnown, Mask, Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(2).getReg(), RHSKnown, Mask, Depth + 1);
 
     Known = KnownBits::srem(LHSKnown, RHSKnown);
     break;
   }
   case TargetOpcode::G_SELECT: {
     computeKnownBitsMin(MI.getOperand(2).getReg(), MI.getOperand(3).getReg(),
-                        Known, DemandedElts, Depth + 1);
+                        Known, Mask, Depth + 1);
     break;
   }
   case TargetOpcode::G_SMIN: {
     // TODO: Handle clamp pattern with number of sign bits
     KnownBits KnownRHS;
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
-                         Depth + 1);
-    computeKnownBitsImpl(MI.getOperand(2).getReg(), KnownRHS, DemandedElts,
-                         Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, Mask, Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(2).getReg(), KnownRHS, Mask, Depth + 1);
     Known = KnownBits::smin(Known, KnownRHS);
     break;
   }
   case TargetOpcode::G_SMAX: {
     // TODO: Handle clamp pattern with number of sign bits
     KnownBits KnownRHS;
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
-                         Depth + 1);
-    computeKnownBitsImpl(MI.getOperand(2).getReg(), KnownRHS, DemandedElts,
-                         Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, Mask, Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(2).getReg(), KnownRHS, Mask, Depth + 1);
     Known = KnownBits::smax(Known, KnownRHS);
     break;
   }
   case TargetOpcode::G_UMIN: {
     KnownBits KnownRHS;
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
-                         Depth + 1);
-    computeKnownBitsImpl(MI.getOperand(2).getReg(), KnownRHS, DemandedElts,
-                         Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, Mask, Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(2).getReg(), KnownRHS, Mask, Depth + 1);
     Known = KnownBits::umin(Known, KnownRHS);
     break;
   }
   case TargetOpcode::G_UMAX: {
     KnownBits KnownRHS;
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
-                         Depth + 1);
-    computeKnownBitsImpl(MI.getOperand(2).getReg(), KnownRHS, DemandedElts,
-                         Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, Mask, Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(2).getReg(), KnownRHS, Mask, Depth + 1);
     Known = KnownBits::umax(Known, KnownRHS);
     break;
   }
@@ -520,8 +651,7 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
     break;
   }
   case TargetOpcode::G_SEXT: {
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
-                         Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, Mask, Depth + 1);
     // If the sign bit is known to be zero or one, then sext will extend
     // it to the top bits, else it will just zext.
     Known = Known.sext(BitWidth);
@@ -529,14 +659,12 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
   }
   case TargetOpcode::G_ASSERT_SEXT:
   case TargetOpcode::G_SEXT_INREG: {
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
-                         Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, Mask, Depth + 1);
     Known = Known.sextInReg(MI.getOperand(2).getImm());
     break;
   }
   case TargetOpcode::G_ANYEXT: {
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
-                         Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, Mask, Depth + 1);
     Known = Known.anyext(BitWidth);
     break;
   }
@@ -562,29 +690,35 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
     break;
   }
   case TargetOpcode::G_ASHR: {
+    Register SrcReg = MI.getOperand(1).getReg();
+    Register AmtReg = MI.getOperand(2).getReg();
+    APInt SrcBits = getShiftSrcDemandedBits(*this, AmtReg, BitWidth, Mask,
+                                            Opcode, Depth + 1);
     KnownBits LHSKnown, RHSKnown;
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), LHSKnown, DemandedElts,
-                         Depth + 1);
-    computeKnownBitsImpl(MI.getOperand(2).getReg(), RHSKnown, DemandedElts,
-                         Depth + 1);
+    computeKnownBitsImpl(SrcReg, LHSKnown, Mask.withBits(SrcBits), Depth + 1);
+    computeKnownBitsImpl(AmtReg, RHSKnown, Mask, Depth + 1);
     Known = KnownBits::ashr(LHSKnown, RHSKnown);
     break;
   }
   case TargetOpcode::G_LSHR: {
+    Register SrcReg = MI.getOperand(1).getReg();
+    Register AmtReg = MI.getOperand(2).getReg();
+    APInt SrcBits = getShiftSrcDemandedBits(*this, AmtReg, BitWidth, Mask,
+                                            Opcode, Depth + 1);
     KnownBits LHSKnown, RHSKnown;
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), LHSKnown, DemandedElts,
-                         Depth + 1);
-    computeKnownBitsImpl(MI.getOperand(2).getReg(), RHSKnown, DemandedElts,
-                         Depth + 1);
+    computeKnownBitsImpl(SrcReg, LHSKnown, Mask.withBits(SrcBits), Depth + 1);
+    computeKnownBitsImpl(AmtReg, RHSKnown, Mask, Depth + 1);
     Known = KnownBits::lshr(LHSKnown, RHSKnown);
     break;
   }
   case TargetOpcode::G_SHL: {
+    Register SrcReg = MI.getOperand(1).getReg();
+    Register AmtReg = MI.getOperand(2).getReg();
+    APInt SrcBits = getShiftSrcDemandedBits(*this, AmtReg, BitWidth, Mask,
+                                            Opcode, Depth + 1);
     KnownBits LHSKnown, RHSKnown;
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), LHSKnown, DemandedElts,
-                         Depth + 1);
-    computeKnownBitsImpl(MI.getOperand(2).getReg(), RHSKnown, DemandedElts,
-                         Depth + 1);
+    computeKnownBitsImpl(SrcReg, LHSKnown, Mask.withBits(SrcBits), Depth + 1);
+    computeKnownBitsImpl(AmtReg, RHSKnown, Mask, Depth + 1);
     Known = KnownBits::shl(LHSKnown, RHSKnown);
     break;
   }
@@ -596,7 +730,7 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
       break;
 
     Register SrcReg = MI.getOperand(1).getReg();
-    computeKnownBitsImpl(SrcReg, Known, DemandedElts, Depth + 1);
+    computeKnownBitsImpl(SrcReg, Known, Mask, Depth + 1);
 
     unsigned Amt = MaybeAmtOp->urem(BitWidth);
 
@@ -616,10 +750,8 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
       break;
 
     const APInt Amt = *MaybeAmtOp;
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
-                         Depth + 1);
-    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, DemandedElts,
-                         Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, Mask, Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, Mask, Depth + 1);
     Known = Opcode == TargetOpcode::G_FSHL
                 ? KnownBits::fshl(Known, Known2, Amt)
                 : KnownBits::fshr(Known, Known2, Amt);
@@ -634,13 +766,13 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
   case TargetOpcode::G_ZEXT:
   case TargetOpcode::G_TRUNC: {
     Register SrcReg = MI.getOperand(1).getReg();
-    computeKnownBitsImpl(SrcReg, Known, DemandedElts, Depth + 1);
+    computeKnownBitsImpl(SrcReg, Known, Mask, Depth + 1);
     Known = Known.zextOrTrunc(BitWidth);
     break;
   }
   case TargetOpcode::G_ASSERT_ZEXT: {
     Register SrcReg = MI.getOperand(1).getReg();
-    computeKnownBitsImpl(SrcReg, Known, DemandedElts, Depth + 1);
+    computeKnownBitsImpl(SrcReg, Known, Mask, Depth + 1);
 
     unsigned SrcBitWidth = MI.getOperand(2).getImm();
     assert(SrcBitWidth && "SrcBitWidth can't be zero");
@@ -665,8 +797,8 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
 
     for (unsigned I = 0; I != NumOps - 1; ++I) {
       KnownBits SrcOpKnown;
-      computeKnownBitsImpl(MI.getOperand(I + 1).getReg(), SrcOpKnown,
-                           DemandedElts, Depth + 1);
+      computeKnownBitsImpl(MI.getOperand(I + 1).getReg(), SrcOpKnown, Mask,
+                           Depth + 1);
       Known.insertBits(SrcOpKnown, I * OpSize);
     }
     break;
@@ -693,7 +825,8 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
     }
 
     KnownBits SrcOpKnown;
-    computeKnownBitsImpl(SrcReg, SrcOpKnown, SubDemandedElts, Depth + 1);
+    computeKnownBitsImpl(SrcReg, SrcOpKnown, Mask.withElts(SubDemandedElts),
+                         Depth + 1);
 
     if (SrcTy.isVector())
       Known = std::move(SrcOpKnown);
@@ -703,19 +836,18 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
   }
   case TargetOpcode::G_BSWAP: {
     Register SrcReg = MI.getOperand(1).getReg();
-    computeKnownBitsImpl(SrcReg, Known, DemandedElts, Depth + 1);
+    computeKnownBitsImpl(SrcReg, Known, Mask, Depth + 1);
     Known = Known.byteSwap();
     break;
   }
   case TargetOpcode::G_BITREVERSE: {
     Register SrcReg = MI.getOperand(1).getReg();
-    computeKnownBitsImpl(SrcReg, Known, DemandedElts, Depth + 1);
+    computeKnownBitsImpl(SrcReg, Known, Mask, Depth + 1);
     Known = Known.reverseBits();
     break;
   }
   case TargetOpcode::G_CTPOP: {
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
-                         Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, Mask, Depth + 1);
     // We can bound the space the count needs.  Also, bits known to be zero
     // can't contribute to the population.
     unsigned BitsPossiblySet = Known2.countMaxPopulation();
@@ -727,22 +859,22 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
   }
   case TargetOpcode::G_UBFX: {
     KnownBits SrcOpKnown, OffsetKnown, WidthKnown;
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), SrcOpKnown, DemandedElts,
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), SrcOpKnown, Mask,
                          Depth + 1);
-    computeKnownBitsImpl(MI.getOperand(2).getReg(), OffsetKnown, DemandedElts,
+    computeKnownBitsImpl(MI.getOperand(2).getReg(), OffsetKnown, Mask,
                          Depth + 1);
-    computeKnownBitsImpl(MI.getOperand(3).getReg(), WidthKnown, DemandedElts,
+    computeKnownBitsImpl(MI.getOperand(3).getReg(), WidthKnown, Mask,
                          Depth + 1);
     Known = extractBits(BitWidth, SrcOpKnown, OffsetKnown, WidthKnown);
     break;
   }
   case TargetOpcode::G_SBFX: {
     KnownBits SrcOpKnown, OffsetKnown, WidthKnown;
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), SrcOpKnown, DemandedElts,
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), SrcOpKnown, Mask,
                          Depth + 1);
-    computeKnownBitsImpl(MI.getOperand(2).getReg(), OffsetKnown, DemandedElts,
+    computeKnownBitsImpl(MI.getOperand(2).getReg(), OffsetKnown, Mask,
                          Depth + 1);
-    computeKnownBitsImpl(MI.getOperand(3).getReg(), WidthKnown, DemandedElts,
+    computeKnownBitsImpl(MI.getOperand(3).getReg(), WidthKnown, Mask,
                          Depth + 1);
     OffsetKnown = OffsetKnown.sext(BitWidth);
     WidthKnown = WidthKnown.sext(BitWidth);
@@ -773,18 +905,15 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
     // With [US]ADDE, a carry bit may be added in.
     KnownBits Carry(1);
     if (Opcode == TargetOpcode::G_UADDE || Opcode == TargetOpcode::G_SADDE) {
-      computeKnownBitsImpl(MI.getOperand(4).getReg(), Carry, DemandedElts,
-                           Depth + 1);
+      computeKnownBitsImpl(MI.getOperand(4).getReg(), Carry, Mask, Depth + 1);
       // Carry has bit width 1
       Carry = Carry.trunc(1);
     } else {
       Carry.setAllZero();
     }
 
-    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, DemandedElts,
-                         Depth + 1);
-    computeKnownBitsImpl(MI.getOperand(3).getReg(), Known2, DemandedElts,
-                         Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, Mask, Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(3).getReg(), Known2, Mask, Depth + 1);
     Known = KnownBits::computeForAddCarry(Known, Known2, Carry);
     break;
   }
@@ -807,7 +936,7 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
   case TargetOpcode::G_CTTZ:
   case TargetOpcode::G_CTTZ_ZERO_POISON: {
     KnownBits SrcOpKnown;
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), SrcOpKnown, DemandedElts,
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), SrcOpKnown, Mask,
                          Depth + 1);
     // If we have a known 1, its position is our upper bound
     unsigned PossibleTZ = SrcOpKnown.countMaxTrailingZeros();
@@ -818,7 +947,7 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
   case TargetOpcode::G_CTLZ:
   case TargetOpcode::G_CTLZ_ZERO_POISON: {
     KnownBits SrcOpKnown;
-    computeKnownBitsImpl(MI.getOperand(1).getReg(), SrcOpKnown, DemandedElts,
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), SrcOpKnown, Mask,
                          Depth + 1);
     // If we have a known 1, its position is our upper bound.
     unsigned PossibleLZ = SrcOpKnown.countMaxLeadingZeros();
@@ -867,7 +996,8 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
       DemandedSrcElts =
           APInt::getOneBitSet(NumSrcElts, ConstEltNo->getZExtValue());
 
-    computeKnownBitsImpl(InVec, Known, DemandedSrcElts, Depth + 1);
+    computeKnownBitsImpl(InVec, Known, Mask.withElts(DemandedSrcElts),
+                         Depth + 1);
     break;
   }
   case TargetOpcode::G_SHUFFLE_VECTOR: {
@@ -883,16 +1013,16 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
     Known.Zero.setAllBits();
     Known.One.setAllBits();
     if (!!DemandedLHS) {
-      computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedLHS,
-                           Depth + 1);
+      computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2,
+                           Mask.withElts(DemandedLHS), Depth + 1);
       Known = Known.intersectWith(Known2);
     }
     // If we don't know any bits, early out.
     if (Known.isUnknown())
       break;
     if (!!DemandedRHS) {
-      computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, DemandedRHS,
-                           Depth + 1);
+      computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2,
+                           Mask.withElts(DemandedRHS), Depth + 1);
       Known = Known.intersectWith(Known2);
     }
     break;
@@ -910,7 +1040,8 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
       APInt DemandedSub =
           DemandedElts.extractBits(NumSubVectorElts, I * NumSubVectorElts);
       if (!!DemandedSub) {
-        computeKnownBitsImpl(MO.getReg(), Known2, DemandedSub, Depth + 1);
+        computeKnownBitsImpl(MO.getReg(), Known2, Mask.withElts(DemandedSub),
+                             Depth + 1);
 
         Known = Known.intersectWith(Known2);
       }
@@ -922,7 +1053,7 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
   }
   case TargetOpcode::G_ABS: {
     Register SrcReg = MI.getOperand(1).getReg();
-    computeKnownBitsImpl(SrcReg, Known, DemandedElts, Depth + 1);
+    computeKnownBitsImpl(SrcReg, Known, Mask, Depth + 1);
     Known = Known.abs();
     Known.Zero.setHighBits(computeNumSignBits(SrcReg, DemandedElts, Depth + 1) -
                            1);
diff --git a/llvm/lib/Target/AArch64/AArch64Combine.td b/llvm/lib/Target/AArch64/AArch64Combine.td
index a9c447336cd5e..f2d12f46e0fca 100644
--- a/llvm/lib/Target/AArch64/AArch64Combine.td
+++ b/llvm/lib/Target/AArch64/AArch64Combine.td
@@ -397,5 +397,6 @@ def AArch64PostLegalizerCombiner
                         combine_mul_cmlt, combine_use_vector_truncate,
                         extmultomull, subaddmulreassoc, truncsat_combines,
                         lshr_of_trunc_of_lshr,
-                        funnel_shift_from_or_shift_constants_are_legal]> {
+                        funnel_shift_from_or_shift_constants_are_legal,
+                        simplify_demanded_bits]> {
 }
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index fb053edfc7a9e..07c64be902289 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -19684,8 +19684,7 @@ static void knownBitsForSBFE(const MachineInstr &MI, GISelValueTracking &VT,
   if (Width >= BFEWidth) // Ill-formed.
     return;
 
-  VT.computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
-                          Depth + 1);
+  Known = VT.getKnownBits(MI.getOperand(1).getReg(), DemandedElts, Depth + 1);
 
   Known = Known.extractBits(Width, Offset);
 
@@ -19734,9 +19733,8 @@ void SITargetLowering::computeKnownBitsForTargetInstr(
       Known.Zero.setBitsFrom(IID == Intrinsic::amdgcn_mbcnt_lo
                                  ? getSubtarget()->getWavefrontSizeLog2()
                                  : 5);
-      KnownBits Known2;
-      VT.computeKnownBitsImpl(MI->getOperand(3).getReg(), Known2, DemandedElts,
-                              Depth + 1);
+      KnownBits Known2 =
+          VT.getKnownBits(MI->getOperand(3).getReg(), DemandedElts, Depth + 1);
       Known = KnownBits::add(Known, Known2);
       break;
     }
@@ -19766,18 +19764,15 @@ void SITargetLowering::computeKnownBitsForTargetInstr(
   case AMDGPU::G_AMDGPU_UMED3: {
     auto [Dst, Src0, Src1, Src2] = MI->getFirst4Regs();
 
-    KnownBits Known2;
-    VT.computeKnownBitsImpl(Src2, Known2, DemandedElts, Depth + 1);
+    KnownBits Known2 = VT.getKnownBits(Src2, DemandedElts, Depth + 1);
     if (Known2.isUnknown())
       break;
 
-    KnownBits Known1;
-    VT.computeKnownBitsImpl(Src1, Known1, DemandedElts, Depth + 1);
+    KnownBits Known1 = VT.getKnownBits(Src1, DemandedElts, Depth + 1);
     if (Known1.isUnknown())
       break;
 
-    KnownBits Known0;
-    VT.computeKnownBitsImpl(Src0, Known0, DemandedElts, Depth + 1);
+    KnownBits Known0 = VT.getKnownBits(Src0, DemandedElts, Depth + 1);
     if (Known0.isUnknown())
       break;
 
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/combine-narrow-trunc-shr.mir b/llvm/test/CodeGen/AArch64/GlobalISel/combine-narrow-trunc-shr.mir
new file mode 100644
index 0000000000000..fda97e9cfc4e4
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/combine-narrow-trunc-shr.mir
@@ -0,0 +1,193 @@
+# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py UTC_ARGS: --version 6
+# RUN: llc -o - -mtriple=aarch64-unknown-unknown -run-pass=aarch64-prelegalizer-combiner -verify-machineinstrs %s | FileCheck %s
+
+# narrow_trunc_shr_const: (trunc (lshr (and X, low-32-mask), K-const)) -> (lshr (trunc X), K-const)
+# AND proves bits [32..) zero; K=5; K+DstBW=37 <= SrcBW=64.
+---
+name:            narrow_trunc_lshr_const_lowmask
+legalized: false
+tracksRegLiveness: true
+body:             |
+  bb.0:
+    liveins: $x0
+    ; CHECK-LABEL: name: narrow_trunc_lshr_const_lowmask
+    ; CHECK: liveins: $x0
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 5
+    ; CHECK-NEXT: [[C1:%[0-9]+]]:_(s64) = G_CONSTANT i64 27
+    ; CHECK-NEXT: [[UBFX:%[0-9]+]]:_(s64) = G_UBFX [[COPY]], [[C]](s64), [[C1]]
+    ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s32) = G_TRUNC [[UBFX]](s64)
+    ; CHECK-NEXT: $w0 = COPY [[TRUNC]](s32)
+    ; CHECK-NEXT: RET_ReallyLR implicit $w0
+    %0:_(s64) = COPY $x0
+    %1:_(s64) = G_CONSTANT i64 4294967295
+    %2:_(s64) = G_AND %0, %1
+    %3:_(s64) = G_CONSTANT i64 5
+    %4:_(s64) = G_LSHR %2, %3
+    %5:_(s32) = G_TRUNC %4(s64)
+    $w0 = COPY %5(s32)
+    RET_ReallyLR implicit $w0
+...
+
+# Same as above but with zext-derived source: high bits provably zero.
+---
+name:            narrow_trunc_lshr_const_zext
+legalized: false
+tracksRegLiveness: true
+body:             |
+  bb.0:
+    liveins: $w0
+    ; CHECK-LABEL: name: narrow_trunc_lshr_const_zext
+    ; CHECK: liveins: $w0
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s32) = COPY $w0
+    ; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 7
+    ; CHECK-NEXT: [[LSHR:%[0-9]+]]:_(s32) = G_LSHR [[COPY]], [[C]](s32)
+    ; CHECK-NEXT: $w0 = COPY [[LSHR]](s32)
+    ; CHECK-NEXT: RET_ReallyLR implicit $w0
+    %0:_(s32) = COPY $w0
+    %1:_(s64) = G_ZEXT %0(s32)
+    %2:_(s64) = G_CONSTANT i64 7
+    %3:_(s64) = G_LSHR %1, %2
+    %4:_(s32) = G_TRUNC %3(s64)
+    $w0 = COPY %4(s32)
+    RET_ReallyLR implicit $w0
+...
+
+# ASHR positive: sext-derived source has 33 sign bits >= 64-32+1.
+---
+name:            narrow_trunc_ashr_const_sext
+legalized: false
+tracksRegLiveness: true
+body:             |
+  bb.0:
+    liveins: $w0
+    ; CHECK-LABEL: name: narrow_trunc_ashr_const_sext
+    ; CHECK: liveins: $w0
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s32) = COPY $w0
+    ; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 11
+    ; CHECK-NEXT: [[ASHR:%[0-9]+]]:_(s32) = G_ASHR [[COPY]], [[C]](s32)
+    ; CHECK-NEXT: $w0 = COPY [[ASHR]](s32)
+    ; CHECK-NEXT: RET_ReallyLR implicit $w0
+    %0:_(s32) = COPY $w0
+    %1:_(s64) = G_SEXT %0(s32)
+    %2:_(s64) = G_CONSTANT i64 11
+    %3:_(s64) = G_ASHR %1, %2
+    %4:_(s32) = G_TRUNC %3(s64)
+    $w0 = COPY %4(s32)
+    RET_ReallyLR implicit $w0
+...
+
+# Negative: unknown high bits. Must NOT strip the outer trunc to lshr(trunc).
+# (trunc_shift may still introduce its own narrowing via the mid-VT path; ensure
+# the narrow_trunc_shr_const rewrite does not fire on its own.)
+---
+name:            negative_unknown_high_bits
+legalized: false
+tracksRegLiveness: true
+body:             |
+  bb.0:
+    liveins: $x0
+    ; CHECK-LABEL: name: negative_unknown_high_bits
+    ; CHECK: liveins: $x0
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 5
+    ; CHECK-NEXT: [[LSHR:%[0-9]+]]:_(s64) = G_LSHR [[COPY]], [[C]](s64)
+    ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s32) = G_TRUNC [[LSHR]](s64)
+    ; CHECK-NEXT: $w0 = COPY [[TRUNC]](s32)
+    ; CHECK-NEXT: RET_ReallyLR implicit $w0
+    %0:_(s64) = COPY $x0
+    %1:_(s64) = G_CONSTANT i64 5
+    %2:_(s64) = G_LSHR %0, %1
+    %3:_(s32) = G_TRUNC %2(s64)
+    $w0 = COPY %3(s32)
+    RET_ReallyLR implicit $w0
+...
+
+# Negative: K >= DstBW. Combine must not fire (the rewrite would be unsound:
+# the K+DstBW high X bits would be sourced from a position past SrcBW.)
+---
+name:            negative_shift_geq_dstbw
+legalized: false
+tracksRegLiveness: true
+body:             |
+  bb.0:
+    liveins: $w0
+    ; CHECK-LABEL: name: negative_shift_geq_dstbw
+    ; CHECK: liveins: $w0
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 0
+    ; CHECK-NEXT: $w0 = COPY [[C]](s32)
+    ; CHECK-NEXT: RET_ReallyLR implicit $w0
+    %0:_(s32) = COPY $w0
+    %1:_(s64) = G_ZEXT %0(s32)
+    %2:_(s64) = G_CONSTANT i64 32
+    %3:_(s64) = G_LSHR %1, %2
+    %4:_(s32) = G_TRUNC %3(s64)
+    $w0 = COPY %4(s32)
+    RET_ReallyLR implicit $w0
+...
+
+# Negative: shift has multiple uses. The single-use guard blocks the rewrite.
+---
+name:            negative_multi_use_shift
+legalized: false
+tracksRegLiveness: true
+body:             |
+  bb.0:
+    liveins: $w0
+    ; CHECK-LABEL: name: negative_multi_use_shift
+    ; CHECK: liveins: $w0
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s32) = COPY $w0
+    ; CHECK-NEXT: [[ZEXT:%[0-9]+]]:_(s64) = G_ZEXT [[COPY]](s32)
+    ; CHECK-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 3
+    ; CHECK-NEXT: [[LSHR:%[0-9]+]]:_(s64) = G_LSHR [[ZEXT]], [[C]](s64)
+    ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s32) = G_TRUNC [[LSHR]](s64)
+    ; CHECK-NEXT: $x0 = COPY [[LSHR]](s64)
+    ; CHECK-NEXT: $w1 = COPY [[TRUNC]](s32)
+    ; CHECK-NEXT: RET_ReallyLR implicit $x0, implicit $w1
+    %0:_(s32) = COPY $w0
+    %1:_(s64) = G_ZEXT %0(s32)
+    %2:_(s64) = G_CONSTANT i64 3
+    %3:_(s64) = G_LSHR %1, %2
+    %4:_(s32) = G_TRUNC %3(s64)
+    $x0 = COPY %3(s64)
+    $w1 = COPY %4(s32)
+    RET_ReallyLR implicit $x0, implicit $w1
+...
+
+# Option 1: extended trunc_shift drives demanded-bits AND elimination.
+# (trunc:s32 (shl:s64 (and:s64 X, low-32-mask), K))
+# After trunc_shift narrows shift to s32 (NewShiftTy == DstTy), the driver
+# is invoked on the inner AND with demand mask = low 32 bits of s64. The
+# mask 0xFFFFFFFF covers the demand entirely, so the AND is dropped before
+# the trunc is built -- the resulting MIR has no surviving G_AND.
+---
+name:            option1_driver_drops_redundant_and
+legalized: false
+tracksRegLiveness: true
+body:             |
+  bb.0:
+    liveins: $x0
+    ; CHECK-LABEL: name: option1_driver_drops_redundant_and
+    ; CHECK: liveins: $x0
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 5
+    ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s32) = G_TRUNC [[COPY]](s64)
+    ; CHECK-NEXT: [[SHL:%[0-9]+]]:_(s32) = G_SHL [[TRUNC]], [[C]](s64)
+    ; CHECK-NEXT: $w0 = COPY [[SHL]](s32)
+    ; CHECK-NEXT: RET_ReallyLR implicit $w0
+    %0:_(s64) = COPY $x0
+    %1:_(s64) = G_CONSTANT i64 4294967295
+    %2:_(s64) = G_AND %0, %1
+    %3:_(s64) = G_CONSTANT i64 5
+    %4:_(s64) = G_SHL %2, %3
+    %5:_(s32) = G_TRUNC %4(s64)
+    $w0 = COPY %5(s32)
+    RET_ReallyLR implicit $w0
+...
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/combine-simplify-demanded-bits.mir b/llvm/test/CodeGen/AArch64/GlobalISel/combine-simplify-demanded-bits.mir
new file mode 100644
index 0000000000000..2a31a934808c7
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/combine-simplify-demanded-bits.mir
@@ -0,0 +1,54 @@
+# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py UTC_ARGS: --version 5
+# RUN: llc -mtriple=aarch64 -run-pass=aarch64-postlegalizer-combiner -verify-machineinstrs %s -o - | FileCheck %s
+---
+name:            drop_redundant_outer_mask
+legalized:       true
+tracksRegLiveness: true
+body:             |
+  bb.0:
+    liveins: $w0
+    ; CHECK-LABEL: name: drop_redundant_outer_mask
+    ; CHECK: liveins: $w0
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s32) = COPY $w0
+    ; CHECK-NEXT: %lowmask:_(s32) = G_CONSTANT i32 15
+    ; CHECK-NEXT: %and2:_(s32) = G_AND [[COPY]], %lowmask
+    ; CHECK-NEXT: $w0 = COPY %and2(s32)
+    ; CHECK-NEXT: RET_ReallyLR implicit $w0
+    %0:_(s32) = COPY $w0
+    %mask:_(s32) = G_CONSTANT i32 255
+    %and:_(s32) = G_AND %0, %mask
+    %lowmask:_(s32) = G_CONSTANT i32 15
+    %and2:_(s32) = G_AND %and, %lowmask
+    $w0 = COPY %and2(s32)
+    RET_ReallyLR implicit $w0
+...
+---
+name:            keep_and_multiuse_high_bits
+legalized:       true
+tracksRegLiveness: true
+body:             |
+  bb.0:
+    liveins: $x0
+    ; CHECK-LABEL: name: keep_and_multiuse_high_bits
+    ; CHECK: liveins: $x0
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: %c:_(s64) = G_CONSTANT i64 4294967295
+    ; CHECK-NEXT: %and:_(s64) = G_AND [[COPY]], %c
+    ; CHECK-NEXT: %lo:_(s32) = G_TRUNC %and(s64)
+    ; CHECK-NEXT: %amt:_(s64) = G_CONSTANT i64 40
+    ; CHECK-NEXT: %hi:_(s64) = G_LSHR %and, %amt(s64)
+    ; CHECK-NEXT: $w0 = COPY %lo(s32)
+    ; CHECK-NEXT: $x1 = COPY %hi(s64)
+    ; CHECK-NEXT: RET_ReallyLR implicit $w0, implicit $x1
+    %0:_(s64) = COPY $x0
+    %c:_(s64) = G_CONSTANT i64 4294967295
+    %and:_(s64) = G_AND %0, %c
+    %lo:_(s32) = G_TRUNC %and(s64)
+    %amt:_(s64) = G_CONSTANT i64 40
+    %hi:_(s64) = G_LSHR %and, %amt(s64)
+    $w0 = COPY %lo(s32)
+    $x1 = COPY %hi(s64)
+    RET_ReallyLR implicit $w0, implicit $x1
+...
diff --git a/llvm/test/CodeGen/AArch64/arm64-vhadd.ll b/llvm/test/CodeGen/AArch64/arm64-vhadd.ll
index 9034a39b0ac51..2913bde375dc3 100644
--- a/llvm/test/CodeGen/AArch64/arm64-vhadd.ll
+++ b/llvm/test/CodeGen/AArch64/arm64-vhadd.ll
@@ -1751,13 +1751,13 @@ define <16 x i8> @andmask2v16i8(<16 x i16> %src1, <16 x i16> %src2) {
 ; CHECK-GI-NEXT:    movi.8h v4, #7
 ; CHECK-GI-NEXT:    movi.8h v5, #3
 ; CHECK-GI-NEXT:    and.16b v0, v0, v4
-; CHECK-GI-NEXT:    and.16b v2, v2, v5
 ; CHECK-GI-NEXT:    and.16b v1, v1, v4
+; CHECK-GI-NEXT:    and.16b v2, v2, v5
 ; CHECK-GI-NEXT:    and.16b v3, v3, v5
 ; CHECK-GI-NEXT:    add.8h v0, v0, v2
 ; CHECK-GI-NEXT:    add.8h v1, v1, v3
-; CHECK-GI-NEXT:    shrn.8b v0, v0, #1
-; CHECK-GI-NEXT:    shrn2.16b v0, v1, #1
+; CHECK-GI-NEXT:    uzp1.16b v0, v0, v1
+; CHECK-GI-NEXT:    ushr.16b v0, v0, #1
 ; CHECK-GI-NEXT:    ret
   %zextsrc1 = and <16 x i16> %src1, <i16 7, i16 7, i16 7, i16 7, i16 7, i16 7, i16 7, i16 7, i16 7, i16 7, i16 7, i16 7, i16 7, i16 7, i16 7, i16 7>
   %zextsrc2 = and <16 x i16> %src2, <i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3>
diff --git a/llvm/test/CodeGen/AArch64/hadd-combine-scalar.ll b/llvm/test/CodeGen/AArch64/hadd-combine-scalar.ll
index 2d54bb737ce9a..e8e2bdb5b42b9 100644
--- a/llvm/test/CodeGen/AArch64/hadd-combine-scalar.ll
+++ b/llvm/test/CodeGen/AArch64/hadd-combine-scalar.ll
@@ -57,12 +57,17 @@ define i32 @haddu_const_lhs(i32 %src1) {
 }
 
 define i32 @haddu_const_zero(i32 %src1) {
-; CHECK-LABEL: haddu_const_zero:
-; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov w8, w0
-; CHECK-NEXT:    lsr x0, x8, #1
-; CHECK-NEXT:    // kill: def $w0 killed $w0 killed $x0
-; CHECK-NEXT:    ret
+; CHECK-SD-LABEL: haddu_const_zero:
+; CHECK-SD:       // %bb.0:
+; CHECK-SD-NEXT:    mov w8, w0
+; CHECK-SD-NEXT:    lsr x0, x8, #1
+; CHECK-SD-NEXT:    // kill: def $w0 killed $w0 killed $x0
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: haddu_const_zero:
+; CHECK-GI:       // %bb.0:
+; CHECK-GI-NEXT:    lsr w0, w0, #1
+; CHECK-GI-NEXT:    ret
   %zextsrc1 = zext i32 %src1 to i64
   %add = add i64 0, %zextsrc1
   %resulti32 = lshr i64 %add, 1
@@ -179,9 +184,7 @@ define i32 @hadds_const_zero(i32 %src1) {
 ;
 ; CHECK-GI-LABEL: hadds_const_zero:
 ; CHECK-GI:       // %bb.0:
-; CHECK-GI-NEXT:    // kill: def $w0 killed $w0 def $x0
-; CHECK-GI-NEXT:    sbfx x0, x0, #1, #31
-; CHECK-GI-NEXT:    // kill: def $w0 killed $w0 killed $x0
+; CHECK-GI-NEXT:    asr w0, w0, #1
 ; CHECK-GI-NEXT:    ret
   %zextsrc1 = sext i32 %src1 to i64
   %add = add i64 0, %zextsrc1
diff --git a/llvm/test/CodeGen/AArch64/hadd-combine.ll b/llvm/test/CodeGen/AArch64/hadd-combine.ll
index 450069cd27428..b99437b149651 100644
--- a/llvm/test/CodeGen/AArch64/hadd-combine.ll
+++ b/llvm/test/CodeGen/AArch64/hadd-combine.ll
@@ -72,9 +72,9 @@ define <8 x i16> @haddu_const_zero(<8 x i16> %src1) {
 ; CHECK-GI:       // %bb.0:
 ; CHECK-GI-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-GI-NEXT:    uaddw v2.4s, v1.4s, v0.4h
-; CHECK-GI-NEXT:    uaddw2 v1.4s, v1.4s, v0.8h
-; CHECK-GI-NEXT:    shrn v0.4h, v2.4s, #1
-; CHECK-GI-NEXT:    shrn2 v0.8h, v1.4s, #1
+; CHECK-GI-NEXT:    uaddw2 v0.4s, v1.4s, v0.8h
+; CHECK-GI-NEXT:    uzp1 v0.8h, v2.8h, v0.8h
+; CHECK-GI-NEXT:    ushr v0.8h, v0.8h, #1
 ; CHECK-GI-NEXT:    ret
   %zextsrc1 = zext <8 x i16> %src1 to <8 x i32>
   %add = add <8 x i32> <i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0>, %zextsrc1
diff --git a/llvm/test/CodeGen/AMDGPU/GlobalISel/sext_inreg.ll b/llvm/test/CodeGen/AMDGPU/GlobalISel/sext_inreg.ll
index ab7e11a78ed57..659193288ff6f 100644
--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/sext_inreg.ll
+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/sext_inreg.ll
@@ -1329,11 +1329,10 @@ define i65 @v_sext_inreg_i65_22(i65 %value) {
 ; GFX6-NEXT:    v_lshrrev_b32_e32 v3, 10, v1
 ; GFX6-NEXT:    v_or_b32_e32 v2, v2, v3
 ; GFX6-NEXT:    v_bfe_i32 v2, v2, 0, 1
-; GFX6-NEXT:    v_ashrrev_i32_e32 v3, 31, v2
 ; GFX6-NEXT:    v_bfe_u32 v1, v1, 0, 10
-; GFX6-NEXT:    v_lshlrev_b32_e32 v4, 10, v2
-; GFX6-NEXT:    v_ashr_i64 v[2:3], v[2:3], 22
-; GFX6-NEXT:    v_or_b32_e32 v1, v1, v4
+; GFX6-NEXT:    v_lshlrev_b32_e32 v3, 10, v2
+; GFX6-NEXT:    v_or_b32_e32 v1, v1, v3
+; GFX6-NEXT:    v_ashrrev_i32_e32 v2, 22, v2
 ; GFX6-NEXT:    s_setpc_b64 s[30:31]
 ;
 ; GFX8-LABEL: v_sext_inreg_i65_22:
@@ -1343,11 +1342,10 @@ define i65 @v_sext_inreg_i65_22(i65 %value) {
 ; GFX8-NEXT:    v_lshrrev_b32_e32 v3, 10, v1
 ; GFX8-NEXT:    v_or_b32_e32 v2, v2, v3
 ; GFX8-NEXT:    v_bfe_i32 v2, v2, 0, 1
-; GFX8-NEXT:    v_ashrrev_i32_e32 v3, 31, v2
 ; GFX8-NEXT:    v_bfe_u32 v1, v1, 0, 10
-; GFX8-NEXT:    v_lshlrev_b32_e32 v4, 10, v2
-; GFX8-NEXT:    v_ashrrev_i64 v[2:3], 22, v[2:3]
-; GFX8-NEXT:    v_or_b32_e32 v1, v1, v4
+; GFX8-NEXT:    v_lshlrev_b32_e32 v3, 10, v2
+; GFX8-NEXT:    v_or_b32_e32 v1, v1, v3
+; GFX8-NEXT:    v_ashrrev_i32_e32 v2, 22, v2
 ; GFX8-NEXT:    s_setpc_b64 s[30:31]
 ;
 ; GFX9-LABEL: v_sext_inreg_i65_22:
@@ -1357,10 +1355,9 @@ define i65 @v_sext_inreg_i65_22(i65 %value) {
 ; GFX9-NEXT:    v_lshrrev_b32_e32 v3, 10, v1
 ; GFX9-NEXT:    v_or_b32_e32 v2, v2, v3
 ; GFX9-NEXT:    v_bfe_i32 v2, v2, 0, 1
-; GFX9-NEXT:    v_ashrrev_i32_e32 v3, 31, v2
 ; GFX9-NEXT:    v_bfe_u32 v1, v1, 0, 10
 ; GFX9-NEXT:    v_lshl_or_b32 v1, v2, 10, v1
-; GFX9-NEXT:    v_ashrrev_i64 v[2:3], 22, v[2:3]
+; GFX9-NEXT:    v_ashrrev_i32_e32 v2, 22, v2
 ; GFX9-NEXT:    s_setpc_b64 s[30:31]
 ;
 ; GFX10PLUS-LABEL: v_sext_inreg_i65_22:
@@ -1371,9 +1368,8 @@ define i65 @v_sext_inreg_i65_22(i65 %value) {
 ; GFX10PLUS-NEXT:    v_bfe_u32 v1, v1, 0, 10
 ; GFX10PLUS-NEXT:    v_or_b32_e32 v2, v2, v3
 ; GFX10PLUS-NEXT:    v_bfe_i32 v2, v2, 0, 1
-; GFX10PLUS-NEXT:    v_ashrrev_i32_e32 v3, 31, v2
 ; GFX10PLUS-NEXT:    v_lshl_or_b32 v1, v2, 10, v1
-; GFX10PLUS-NEXT:    v_ashrrev_i64 v[2:3], 22, v[2:3]
+; GFX10PLUS-NEXT:    v_ashrrev_i32_e32 v2, 22, v2
 ; GFX10PLUS-NEXT:    s_setpc_b64 s[30:31]
   %shl = shl i65 %value, 22
   %ashr = ashr i65 %shl, 22
@@ -1444,7 +1440,7 @@ define amdgpu_ps i65 @s_sext_inreg_i65_18(i65 inreg %value) {
 ; GCN-NEXT:    s_lshl_b32 s5, s2, 14
 ; GCN-NEXT:    s_mov_b32 s4, 0
 ; GCN-NEXT:    s_or_b64 s[0:1], s[0:1], s[4:5]
-; GCN-NEXT:    s_ashr_i64 s[2:3], s[2:3], 18
+; GCN-NEXT:    s_ashr_i32 s2, s2, 18
 ; GCN-NEXT:    ; return to shader part epilog
 ;
 ; GFX10PLUS-LABEL: s_sext_inreg_i65_18:
@@ -1456,7 +1452,7 @@ define amdgpu_ps i65 @s_sext_inreg_i65_18(i65 inreg %value) {
 ; GFX10PLUS-NEXT:    s_mov_b32 s4, 0
 ; GFX10PLUS-NEXT:    s_bfe_i64 s[2:3], s[2:3], 0x10000
 ; GFX10PLUS-NEXT:    s_lshl_b32 s5, s2, 14
-; GFX10PLUS-NEXT:    s_ashr_i64 s[2:3], s[2:3], 18
+; GFX10PLUS-NEXT:    s_ashr_i32 s2, s2, 18
 ; GFX10PLUS-NEXT:    s_or_b64 s[0:1], s[0:1], s[4:5]
 ; GFX10PLUS-NEXT:    ; return to shader part epilog
   %shl = shl i65 %value, 18
diff --git a/llvm/unittests/CodeGen/GlobalISel/KnownBitsTest.cpp b/llvm/unittests/CodeGen/GlobalISel/KnownBitsTest.cpp
index 8563d7f1f15c9..cd99da2cf56d7 100644
--- a/llvm/unittests/CodeGen/GlobalISel/KnownBitsTest.cpp
+++ b/llvm/unittests/CodeGen/GlobalISel/KnownBitsTest.cpp
@@ -1023,20 +1023,23 @@ TEST_F(AArch64GISelMITest, TestNumSignBitsCmp) {
 }
 
 TEST_F(AMDGPUGISelMITest, TestNumSignBitsTrunc) {
-  StringRef MIRString =
-    "  %3:_(<4 x s32>) = G_IMPLICIT_DEF\n"
-    "  %4:_(s32) = G_IMPLICIT_DEF\n"
-    "  %5:_(s32) = G_AMDGPU_BUFFER_LOAD_UBYTE %3, %4, %4, %4, 0, 0, 0 :: (load (s8))\n"
-    "  %6:_(s32) = COPY %5\n"
+  StringRef MIRString = "  %3:_(<4 x s32>) = G_IMPLICIT_DEF\n"
+                        "  %4:_(s32) = G_IMPLICIT_DEF\n"
+                        "  %5:_(s32) = G_AMDGPU_BUFFER_LOAD_UBYTE %3, %4, %4, "
+                        "%4, 0, 0, 0 :: (load (s8))\n"
+                        "  %6:_(s32) = COPY %5\n"
 
-    "  %7:_(s32) = G_AMDGPU_BUFFER_LOAD_SBYTE %3, %4, %4, %4, 0, 0, 0 :: (load (s8))\n"
-    "  %8:_(s32) = COPY %7\n"
+                        "  %7:_(s32) = G_AMDGPU_BUFFER_LOAD_SBYTE %3, %4, %4, "
+                        "%4, 0, 0, 0 :: (load (s8))\n"
+                        "  %8:_(s32) = COPY %7\n"
 
-    "  %9:_(s32) = G_AMDGPU_BUFFER_LOAD_USHORT %3, %4, %4, %4, 0, 0, 0 :: (load (s16))\n"
-    "  %10:_(s32) = COPY %9\n"
+                        "  %9:_(s32) = G_AMDGPU_BUFFER_LOAD_USHORT %3, %4, %4, "
+                        "%4, 0, 0, 0 :: (load (s16))\n"
+                        "  %10:_(s32) = COPY %9\n"
 
-    "  %11:_(s32) = G_AMDGPU_BUFFER_LOAD_SSHORT %3, %4, %4, %4, 0, 0, 0 :: (load (s16))\n"
-    "  %12:_(s32) = COPY %11\n";
+                        "  %11:_(s32) = G_AMDGPU_BUFFER_LOAD_SSHORT %3, %4, "
+                        "%4, %4, 0, 0, 0 :: (load (s16))\n"
+                        "  %12:_(s32) = COPY %11\n";
 
   setUp(MIRString);
   if (!TM)
@@ -1057,16 +1060,16 @@ TEST_F(AMDGPUGISelMITest, TestNumSignBitsTrunc) {
 
 TEST_F(AMDGPUGISelMITest, TestTargetKnownAlign) {
   StringRef MIRString =
-    "  %5:_(p4) = G_INTRINSIC intrinsic(@llvm.amdgcn.dispatch.ptr)\n"
-    "  %6:_(p4) = COPY %5\n"
-    "  %7:_(p4) = G_INTRINSIC intrinsic(@llvm.amdgcn.queue.ptr)\n"
-    "  %8:_(p4) = COPY %7\n"
-    "  %9:_(p4) = G_INTRINSIC intrinsic(@llvm.amdgcn.kernarg.segment.ptr)\n"
-    "  %10:_(p4) = COPY %9\n"
-    "  %11:_(p4) = G_INTRINSIC intrinsic(@llvm.amdgcn.implicitarg.ptr)\n"
-    "  %12:_(p4) = COPY %11\n"
-    "  %13:_(p4) = G_INTRINSIC intrinsic(@llvm.amdgcn.implicit.buffer.ptr)\n"
-    "  %14:_(p4) = COPY %13\n";
+      "  %5:_(p4) = G_INTRINSIC intrinsic(@llvm.amdgcn.dispatch.ptr)\n"
+      "  %6:_(p4) = COPY %5\n"
+      "  %7:_(p4) = G_INTRINSIC intrinsic(@llvm.amdgcn.queue.ptr)\n"
+      "  %8:_(p4) = COPY %7\n"
+      "  %9:_(p4) = G_INTRINSIC intrinsic(@llvm.amdgcn.kernarg.segment.ptr)\n"
+      "  %10:_(p4) = COPY %9\n"
+      "  %11:_(p4) = G_INTRINSIC intrinsic(@llvm.amdgcn.implicitarg.ptr)\n"
+      "  %12:_(p4) = COPY %11\n"
+      "  %13:_(p4) = G_INTRINSIC intrinsic(@llvm.amdgcn.implicit.buffer.ptr)\n"
+      "  %14:_(p4) = COPY %13\n";
 
   setUp(MIRString);
   if (!TM)
@@ -1517,7 +1520,8 @@ TEST_F(AArch64GISelMITest, TestKnownBitsUnmergeValues) {
 
     uint16_t PartTestVal = static_cast<uint16_t>(TestVal >> BitOffset);
     EXPECT_EQ(PartTestVal, PartKnown.One.getZExtValue());
-    EXPECT_EQ(static_cast<uint16_t>(~PartTestVal), PartKnown.Zero.getZExtValue());
+    EXPECT_EQ(static_cast<uint16_t>(~PartTestVal),
+              PartKnown.Zero.getZExtValue());
   }
 }
 
@@ -1763,7 +1767,6 @@ TEST_F(AArch64GISelMITest, TestInvalidQueries) {
   KnownBits EqSizeRes = Info.getKnownBits(EqSizedShl);
   KnownBits BiggerSizeRes = Info.getKnownBits(BiggerSizedShl);
 
-
   // Result can be anything, but we should not crash.
   EXPECT_TRUE(EqSizeRes.One.isZero());
   EXPECT_TRUE(EqSizeRes.Zero.isAllOnes());
@@ -2119,7 +2122,8 @@ TEST_F(AMDGPUGISelMITest, TestKnownBitsAssertAlign) {
     EXPECT_EQ(64u, Res.getBitWidth());
     EXPECT_EQ(NumBits - 1, Res.Zero.countr_one());
     EXPECT_EQ(64u, Res.One.countr_zero());
-    EXPECT_EQ(Align(1ull << (NumBits - 1)), Info.computeKnownAlignment(Copies[Idx]));
+    EXPECT_EQ(Align(1ull << (NumBits - 1)),
+              Info.computeKnownAlignment(Copies[Idx]));
   };
 
   const unsigned NumSetupCopies = 5;
@@ -2150,3 +2154,478 @@ TEST_F(AArch64GISelMITest, TestKnownBitsUADDO) {
   EXPECT_EQ(0u, Res.One.getZExtValue());
   EXPECT_EQ(31u, Res.Zero.countl_one());
 }
+
+// ---------------------------------------------------------------------------
+// GISelDemandedMask shift demand-narrowing — stage 2 contracts.
+//
+// Each opcode (G_SHL / G_LSHR / G_ASHR) has three tests:
+//   * DemandEquivalence — getKnownBits(R) == getKnownBits(R, getAllBits).
+//   * DemandNonRegression — partial-Mask result is no looser than the
+//     all-ones-restricted result within the demanded region.
+//   * DemandSoundnessOracle — exhaustive iN=4 enumeration of (X, Sh, Mask)
+//     proves the analysis never claims a zero/one bit that the concrete
+//     shift semantics contradict.
+// ---------------------------------------------------------------------------
+
+namespace {
+
+KnownBits runShiftAnalysis(MachineRegisterInfo &MRI, GISelValueTracking &Info,
+                           Register R, APInt Bits) {
+  return Info.getKnownBits(R, GISelDemandedMask::forBits(MRI, R, Bits));
+}
+
+} // namespace
+
+TEST_F(AArch64GISelMITest, KnownBitsSHLDemandEquivalence) {
+  StringRef MIRString = R"(
+    %x:_(s32) = G_IMPLICIT_DEF
+    %sh:_(s32) = G_CONSTANT i32 3
+    %r:_(s32) = G_SHL %x, %sh
+    %copy_r:_(s32) = COPY %r
+)";
+  setUp(MIRString);
+  if (!TM)
+    GTEST_SKIP();
+  Register CopyR = Copies.back();
+  Register R = MRI->getVRegDef(CopyR)->getOperand(1).getReg();
+  GISelValueTracking Info(*MF);
+
+  KnownBits Baseline = Info.getKnownBits(R);
+  KnownBits Echoed =
+      Info.getKnownBits(R, GISelDemandedMask::getAllBits(*MRI, R));
+  EXPECT_EQ(Baseline.Zero, Echoed.Zero);
+  EXPECT_EQ(Baseline.One, Echoed.One);
+}
+
+TEST_F(AArch64GISelMITest, KnownBitsLSHRDemandEquivalence) {
+  StringRef MIRString = R"(
+    %x:_(s32) = G_IMPLICIT_DEF
+    %sh:_(s32) = G_CONSTANT i32 5
+    %r:_(s32) = G_LSHR %x, %sh
+    %copy_r:_(s32) = COPY %r
+)";
+  setUp(MIRString);
+  if (!TM)
+    GTEST_SKIP();
+  Register CopyR = Copies.back();
+  Register R = MRI->getVRegDef(CopyR)->getOperand(1).getReg();
+  GISelValueTracking Info(*MF);
+
+  KnownBits Baseline = Info.getKnownBits(R);
+  KnownBits Echoed =
+      Info.getKnownBits(R, GISelDemandedMask::getAllBits(*MRI, R));
+  EXPECT_EQ(Baseline.Zero, Echoed.Zero);
+  EXPECT_EQ(Baseline.One, Echoed.One);
+}
+
+TEST_F(AArch64GISelMITest, KnownBitsASHRDemandEquivalence) {
+  StringRef MIRString = R"(
+    %x:_(s32) = G_IMPLICIT_DEF
+    %sh:_(s32) = G_CONSTANT i32 7
+    %r:_(s32) = G_ASHR %x, %sh
+    %copy_r:_(s32) = COPY %r
+)";
+  setUp(MIRString);
+  if (!TM)
+    GTEST_SKIP();
+  Register CopyR = Copies.back();
+  Register R = MRI->getVRegDef(CopyR)->getOperand(1).getReg();
+  GISelValueTracking Info(*MF);
+
+  KnownBits Baseline = Info.getKnownBits(R);
+  KnownBits Echoed =
+      Info.getKnownBits(R, GISelDemandedMask::getAllBits(*MRI, R));
+  EXPECT_EQ(Baseline.Zero, Echoed.Zero);
+  EXPECT_EQ(Baseline.One, Echoed.One);
+}
+
+TEST_F(AArch64GISelMITest, KnownBitsSHLDemandNonRegression) {
+  // Constant SHL by 3: result low 3 bits are zero. Partial-mask path must
+  // not regress this fact within the demanded window.
+  StringRef MIRString = R"(
+    %x:_(s32) = G_IMPLICIT_DEF
+    %sh:_(s32) = G_CONSTANT i32 3
+    %r:_(s32) = G_SHL %x, %sh
+    %copy_r:_(s32) = COPY %r
+)";
+  setUp(MIRString);
+  if (!TM)
+    GTEST_SKIP();
+  Register CopyR = Copies.back();
+  Register R = MRI->getVRegDef(CopyR)->getOperand(1).getReg();
+  GISelValueTracking Info(*MF);
+
+  KnownBits Baseline = Info.getKnownBits(R);
+  for (uint32_t MaskBits : {0xFFFFFFFFu, 0xFFu, 0xF0u, 0x07u, 0x00u}) {
+    APInt Demand(32, MaskBits);
+    KnownBits Demanded = runShiftAnalysis(*MRI, Info, R, Demand);
+    EXPECT_TRUE((Baseline.Zero & Demand).isSubsetOf(Demanded.Zero & Demand))
+        << "MaskBits=" << MaskBits;
+    EXPECT_TRUE((Baseline.One & Demand).isSubsetOf(Demanded.One & Demand))
+        << "MaskBits=" << MaskBits;
+  }
+}
+
+TEST_F(AArch64GISelMITest, KnownBitsLSHRDemandNonRegression) {
+  StringRef MIRString = R"(
+    %x:_(s32) = G_IMPLICIT_DEF
+    %sh:_(s32) = G_CONSTANT i32 28
+    %r:_(s32) = G_LSHR %x, %sh
+    %copy_r:_(s32) = COPY %r
+)";
+  setUp(MIRString);
+  if (!TM)
+    GTEST_SKIP();
+  Register CopyR = Copies.back();
+  Register R = MRI->getVRegDef(CopyR)->getOperand(1).getReg();
+  GISelValueTracking Info(*MF);
+
+  KnownBits Baseline = Info.getKnownBits(R);
+  for (uint32_t MaskBits : {0xFFFFFFFFu, 0xFu, 0xF0u, 0x0Fu, 0x00u}) {
+    APInt Demand(32, MaskBits);
+    KnownBits Demanded = runShiftAnalysis(*MRI, Info, R, Demand);
+    EXPECT_TRUE((Baseline.Zero & Demand).isSubsetOf(Demanded.Zero & Demand))
+        << "MaskBits=" << MaskBits;
+    EXPECT_TRUE((Baseline.One & Demand).isSubsetOf(Demanded.One & Demand))
+        << "MaskBits=" << MaskBits;
+  }
+}
+
+TEST_F(AArch64GISelMITest, KnownBitsASHRDemandNonRegression) {
+  StringRef MIRString = R"(
+    %x:_(s32) = G_IMPLICIT_DEF
+    %sh:_(s32) = G_CONSTANT i32 16
+    %r:_(s32) = G_ASHR %x, %sh
+    %copy_r:_(s32) = COPY %r
+)";
+  setUp(MIRString);
+  if (!TM)
+    GTEST_SKIP();
+  Register CopyR = Copies.back();
+  Register R = MRI->getVRegDef(CopyR)->getOperand(1).getReg();
+  GISelValueTracking Info(*MF);
+
+  KnownBits Baseline = Info.getKnownBits(R);
+  for (uint32_t MaskBits :
+       {0xFFFFFFFFu, 0xFFu, 0xFF000000u, 0x80000000u, 0x01u, 0x00u}) {
+    APInt Demand(32, MaskBits);
+    KnownBits Demanded = runShiftAnalysis(*MRI, Info, R, Demand);
+    EXPECT_TRUE((Baseline.Zero & Demand).isSubsetOf(Demanded.Zero & Demand))
+        << "MaskBits=" << MaskBits;
+    EXPECT_TRUE((Baseline.One & Demand).isSubsetOf(Demanded.One & Demand))
+        << "MaskBits=" << MaskBits;
+  }
+}
+
+// Soundness oracle: exhaustive iN=4 enumeration. For every (X, Sh, Mask)
+// in [0..16), assert that the demand-aware analysis never claims a
+// zero/one bit inside the demanded region that the concrete shift
+// semantics disagree with. The shift amount is a G_CONSTANT in [0, 3]
+// so the analysis sees a known constant amount (no G_AND masking is
+// required; the helper iterates ShVal over [0, 3] directly).
+namespace {
+
+int signExtendI4(unsigned X) { return SignExtend64<4>(X); }
+
+std::string buildShiftMIR(const char *Opcode, unsigned XVal, unsigned ShVal) {
+  std::string Out;
+  raw_string_ostream OS(Out);
+  OS << "  %x:_(s4) = G_CONSTANT i4 " << signExtendI4(XVal) << "\n";
+  OS << "  %sh:_(s4) = G_CONSTANT i4 " << signExtendI4(ShVal) << "\n";
+  OS << "  %r:_(s4) = " << Opcode << " %x, %sh\n";
+  OS << "  %copy_r:_(s4) = COPY %r\n";
+  return Out;
+}
+
+} // namespace
+
+TEST_F(AArch64GISelMITest, KnownBitsSHLDemandSoundness_i4) {
+  for (unsigned XVal = 0; XVal < 16; ++XVal) {
+    for (unsigned ShVal = 0; ShVal < 4; ++ShVal) {
+      // Reset per-iteration state by reinvoking setUp with fresh MIR.
+      Copies.clear();
+      setUp(buildShiftMIR("G_SHL", XVal, ShVal));
+      if (!TM)
+        GTEST_SKIP();
+      Register CopyR = Copies.back();
+      Register R = MRI->getVRegDef(CopyR)->getOperand(1).getReg();
+      GISelValueTracking Info(*MF);
+      unsigned Concrete = (XVal << ShVal) & 0xFu;
+      for (unsigned MaskBits = 0; MaskBits < 16; ++MaskBits) {
+        APInt Demand(4, MaskBits);
+        KnownBits Demanded =
+            Info.getKnownBits(R, GISelDemandedMask::forBits(*MRI, R, Demand));
+        APInt GTZeroInMask = APInt(4, ~Concrete & 0xFu) & Demand;
+        APInt GTOneInMask = APInt(4, Concrete & 0xFu) & Demand;
+        EXPECT_TRUE((Demanded.Zero & Demand).isSubsetOf(GTZeroInMask))
+            << "SHL X=" << XVal << " Sh=" << ShVal << " Mask=" << MaskBits;
+        EXPECT_TRUE((Demanded.One & Demand).isSubsetOf(GTOneInMask))
+            << "SHL X=" << XVal << " Sh=" << ShVal << " Mask=" << MaskBits;
+      }
+    }
+  }
+}
+
+TEST_F(AArch64GISelMITest, KnownBitsLSHRDemandSoundness_i4) {
+  for (unsigned XVal = 0; XVal < 16; ++XVal) {
+    for (unsigned ShVal = 0; ShVal < 4; ++ShVal) {
+      Copies.clear();
+      setUp(buildShiftMIR("G_LSHR", XVal, ShVal));
+      if (!TM)
+        GTEST_SKIP();
+      Register CopyR = Copies.back();
+      Register R = MRI->getVRegDef(CopyR)->getOperand(1).getReg();
+      GISelValueTracking Info(*MF);
+      unsigned Concrete = ((XVal & 0xFu) >> ShVal) & 0xFu;
+      for (unsigned MaskBits = 0; MaskBits < 16; ++MaskBits) {
+        APInt Demand(4, MaskBits);
+        KnownBits Demanded =
+            Info.getKnownBits(R, GISelDemandedMask::forBits(*MRI, R, Demand));
+        APInt GTZeroInMask = APInt(4, ~Concrete & 0xFu) & Demand;
+        APInt GTOneInMask = APInt(4, Concrete & 0xFu) & Demand;
+        EXPECT_TRUE((Demanded.Zero & Demand).isSubsetOf(GTZeroInMask))
+            << "LSHR X=" << XVal << " Sh=" << ShVal << " Mask=" << MaskBits;
+        EXPECT_TRUE((Demanded.One & Demand).isSubsetOf(GTOneInMask))
+            << "LSHR X=" << XVal << " Sh=" << ShVal << " Mask=" << MaskBits;
+      }
+    }
+  }
+}
+
+TEST_F(AArch64GISelMITest, DemandedBitsForUseTruncAnd) {
+  StringRef MIRString = R"(
+    %0:_(s64) = COPY $x0
+    %trunc:_(s32) = G_TRUNC %0
+    %c:_(s32) = G_CONSTANT i32 255
+    %and:_(s32) = G_AND %trunc, %c
+    %4:_(s32) = COPY %and
+)";
+  setUp(MIRString);
+  if (!TM)
+    GTEST_SKIP();
+  GISelValueTracking VT(*MF);
+
+  // Final COPY's source is %and.
+  MachineInstr *FinalCopy = MRI->getVRegDef(Copies.back());
+  Register AndReg = FinalCopy->getOperand(1).getReg();
+  MachineInstr *AndMI = MRI->getVRegDef(AndReg);
+  Register TruncReg = AndMI->getOperand(1).getReg();
+  MachineInstr *TruncMI = MRI->getVRegDef(TruncReg);
+
+  // AND with C=0xFF, def fully demanded -> variable operand demanded in low 8.
+  EXPECT_EQ(
+      VT.getDemandedBitsForUse(*AndMI, /*OpIdx=*/1, APInt::getAllOnes(32)),
+      APInt(32, 0xFF));
+  // TRUNC s64<-...: operand (s64) demanded in low 32 bits.
+  EXPECT_EQ(
+      VT.getDemandedBitsForUse(*TruncMI, /*OpIdx=*/1, APInt::getAllOnes(32)),
+      APInt(64, 0xFFFFFFFFULL));
+}
+
+TEST_F(AArch64GISelMITest, ComputeDemandedBitsAndMask) {
+  // %0 is s64 (pre-populated by the framework as COPY $x0). Trunc to s32,
+  // then AND with 255. computeDemandedBits on the s32 trunc register should
+  // report only the low 8 bits demanded (the COPY %out demands all 32, the
+  // AND narrows that to 0xFF via the constant mask).
+  StringRef MIRString = R"(
+    %trunc:_(s32) = G_TRUNC %0
+    %c:_(s32) = G_CONSTANT i32 255
+    %and:_(s32) = G_AND %trunc, %c
+    %out:_(s32) = COPY %and
+)";
+  setUp(MIRString);
+  if (!TM)
+    GTEST_SKIP();
+  GISelValueTracking VT(*MF);
+
+  // Walk: final COPY (%out) -> %and -> its variable operand %trunc.
+  Register AndReg = MRI->getVRegDef(Copies.back())->getOperand(1).getReg();
+  MachineInstr *AndMI = MRI->getVRegDef(AndReg);
+  Register TruncReg = AndMI->getOperand(1).getReg();
+
+  // %trunc is only used by (and %trunc, 255) -> demanded just in the low 8
+  // bits.
+  EXPECT_EQ(VT.computeDemandedBits(TruncReg, /*Depth=*/0), APInt(32, 0xFF));
+}
+
+TEST_F(AArch64GISelMITest, KnownBitsASHRDemandSoundness_i4) {
+  for (unsigned XVal = 0; XVal < 16; ++XVal) {
+    for (unsigned ShVal = 0; ShVal < 4; ++ShVal) {
+      Copies.clear();
+      setUp(buildShiftMIR("G_ASHR", XVal, ShVal));
+      if (!TM)
+        GTEST_SKIP();
+      Register CopyR = Copies.back();
+      Register R = MRI->getVRegDef(CopyR)->getOperand(1).getReg();
+      GISelValueTracking Info(*MF);
+      unsigned Concrete =
+          static_cast<unsigned>(signExtendI4(XVal) >> ShVal) & 0xFu;
+      for (unsigned MaskBits = 0; MaskBits < 16; ++MaskBits) {
+        APInt Demand(4, MaskBits);
+        KnownBits Demanded =
+            Info.getKnownBits(R, GISelDemandedMask::forBits(*MRI, R, Demand));
+        APInt GTZeroInMask = APInt(4, ~Concrete & 0xFu) & Demand;
+        APInt GTOneInMask = APInt(4, Concrete & 0xFu) & Demand;
+        EXPECT_TRUE((Demanded.Zero & Demand).isSubsetOf(GTZeroInMask))
+            << "ASHR X=" << XVal << " Sh=" << ShVal << " Mask=" << MaskBits;
+        EXPECT_TRUE((Demanded.One & Demand).isSubsetOf(GTOneInMask))
+            << "ASHR X=" << XVal << " Sh=" << ShVal << " Mask=" << MaskBits;
+      }
+    }
+  }
+}
+
+// Part B: spot-check getDemandedBitsForUse for G_OR, G_XOR, G_ZEXT, G_SEXT.
+TEST_F(AArch64GISelMITest, DemandedBitsForUseOrXorZextSext) {
+  // MIR layout (all regs named for clarity):
+  //   %x8  = G_TRUNC %0        ; s32 narrow input
+  //   %c0f = G_CONSTANT i32 15 ; 0x0F
+  //   %or  = G_OR  %x8, %c0f  ; G_OR case: constant other-operand = 0x0F
+  //   %c1  = G_CONSTANT i32 42 ; arbitrary second XOR operand
+  //   %xr  = G_XOR %x8, %c1   ; G_XOR case
+  //   %t8  = G_TRUNC %0        ; s8 input for ext tests
+  //   %zx  = G_ZEXT %t8        ; s8->s32 zext
+  //   %sx  = G_SEXT %t8        ; s8->s32 sext
+  //   %out = COPY %sx          ; anchor to Copies.back()
+  StringRef MIRString = R"(
+    %x8:_(s32) = G_TRUNC %0
+    %c0f:_(s32) = G_CONSTANT i32 15
+    %or:_(s32) = G_OR %x8, %c0f
+    %c1:_(s32) = G_CONSTANT i32 42
+    %xr:_(s32) = G_XOR %x8, %c1
+    %t8:_(s8) = G_TRUNC %0
+    %zx:_(s32) = G_ZEXT %t8
+    %sx:_(s32) = G_SEXT %t8
+    %out:_(s32) = COPY %sx
+)";
+  setUp(MIRString);
+  if (!TM)
+    GTEST_SKIP();
+  GISelValueTracking VT(*MF);
+
+  // Scan MBBs to find the instruction objects by opcode.
+  MachineInstr *OrMI = nullptr;
+  MachineInstr *XorMI = nullptr;
+  MachineInstr *ZextMI = nullptr;
+  MachineInstr *SextMI = nullptr;
+  for (auto &MBB : *MF) {
+    for (auto &MI : MBB) {
+      switch (MI.getOpcode()) {
+      case TargetOpcode::G_OR:
+        OrMI = &MI;
+        break;
+      case TargetOpcode::G_XOR:
+        XorMI = &MI;
+        break;
+      case TargetOpcode::G_ZEXT:
+        ZextMI = &MI;
+        break;
+      case TargetOpcode::G_SEXT:
+        SextMI = &MI;
+        break;
+      default:
+        break;
+      }
+    }
+  }
+  ASSERT_NE(OrMI, nullptr);
+  ASSERT_NE(XorMI, nullptr);
+  ASSERT_NE(ZextMI, nullptr);
+  ASSERT_NE(SextMI, nullptr);
+
+  // G_OR %x8, 0x0F -- full s32 demand -> operand demand = ~0x0F = 0xFFFFFFF0.
+  // Bits where C=1 are always 1 regardless of X, so X is not demanded there.
+  EXPECT_EQ(VT.getDemandedBitsForUse(*OrMI, /*OpIdx=*/1, APInt::getAllOnes(32)),
+            APInt(32, 0xFFFFFFF0U));
+
+  // G_XOR %x8, %c1 -- def demand 0x0F -> each operand demand 0x0F (bit-local).
+  EXPECT_EQ(VT.getDemandedBitsForUse(*XorMI, /*OpIdx=*/1, APInt(32, 0x0F)),
+            APInt(32, 0x0F));
+
+  // G_ZEXT s8->s32, def demand 0x1FF (9 bits set) -> operand demand trunc to
+  // s8 = 0xFF (high def bits are always 0, not from operand).
+  EXPECT_EQ(VT.getDemandedBitsForUse(*ZextMI, /*OpIdx=*/1, APInt(32, 0x1FFU)),
+            APInt(8, 0xFF));
+
+  // G_SEXT s8->s32, def demand 0x80000000 (only sign bit of s32) -> operand
+  // demand must include the sign bit of s8 (bit 7 = 0x80).
+  EXPECT_EQ(
+      VT.getDemandedBitsForUse(*SextMI, /*OpIdx=*/1, APInt(32, 0x80000000U)),
+      APInt(8, 0x80));
+}
+
+// Part C: exhaustive soundness oracle for G_OR-with-constant and G_XOR.
+// For all demand masks D and all operand values x, verifies that
+// getDemandedBitsForUse over-approximates: flipping only outside-demand bits
+// cannot change demanded result bits.
+TEST_F(AArch64GISelMITest, DemandedBitsForUseOrXorSoundnessOracle) {
+  // Width=8, C=0x3C (OR constant), Y=0x5A (XOR second operand).
+  constexpr unsigned W = 8;
+  constexpr uint8_t C = 0x3C;
+  constexpr uint8_t Y = 0x5A;
+
+  StringRef MIRString = R"(
+    %base:_(s8) = G_TRUNC %0
+    %cor:_(s8) = G_CONSTANT i8 60
+    %orr:_(s8) = G_OR %base, %cor
+    %cxor:_(s8) = G_CONSTANT i8 90
+    %xrr:_(s8) = G_XOR %base, %cxor
+    %sentinel:_(s8) = COPY %xrr
+)";
+  setUp(MIRString);
+  if (!TM)
+    GTEST_SKIP();
+  GISelValueTracking VT(*MF);
+
+  MachineInstr *OrMI = nullptr;
+  MachineInstr *XorMI = nullptr;
+  for (auto &MBB : *MF)
+    for (auto &MI : MBB) {
+      if (MI.getOpcode() == TargetOpcode::G_OR)
+        OrMI = &MI;
+      else if (MI.getOpcode() == TargetOpcode::G_XOR)
+        XorMI = &MI;
+    }
+  ASSERT_NE(OrMI, nullptr);
+  ASSERT_NE(XorMI, nullptr);
+
+  for (unsigned D = 0; D < (1u << W); ++D) {
+    APInt Demand(W, D);
+
+    // --- G_OR oracle ---
+    APInt OrDemand = VT.getDemandedBitsForUse(*OrMI, /*OpIdx=*/1, Demand);
+    APInt OutsideOr = ~OrDemand & APInt::getAllOnes(W);
+    for (unsigned x = 0; x < (1u << W); ++x) {
+      uint8_t x2 = static_cast<uint8_t>(x ^ OutsideOr.getZExtValue());
+      uint8_t r1 = static_cast<uint8_t>(x | C);
+      uint8_t r2 = static_cast<uint8_t>(x2 | C);
+      uint8_t diff_in_demanded =
+          static_cast<uint8_t>((r1 ^ r2) & static_cast<uint8_t>(D));
+      EXPECT_EQ(diff_in_demanded, 0)
+          << "G_OR soundness FAILED: D=" << D << " x=" << x
+          << " x2=" << static_cast<unsigned>(x2)
+          << " r1=" << static_cast<unsigned>(r1)
+          << " r2=" << static_cast<unsigned>(r2);
+      if (diff_in_demanded != 0)
+        return; // Stop on first failure as instructed.
+    }
+
+    // --- G_XOR oracle ---
+    APInt XorDemand = VT.getDemandedBitsForUse(*XorMI, /*OpIdx=*/1, Demand);
+    APInt OutsideXor = ~XorDemand & APInt::getAllOnes(W);
+    for (unsigned x = 0; x < (1u << W); ++x) {
+      uint8_t x2 = static_cast<uint8_t>(x ^ OutsideXor.getZExtValue());
+      uint8_t r1 = static_cast<uint8_t>(x ^ Y);
+      uint8_t r2 = static_cast<uint8_t>(x2 ^ Y);
+      uint8_t diff_in_demanded =
+          static_cast<uint8_t>((r1 ^ r2) & static_cast<uint8_t>(D));
+      EXPECT_EQ(diff_in_demanded, 0)
+          << "G_XOR soundness FAILED: D=" << D << " x=" << x
+          << " x2=" << static_cast<unsigned>(x2)
+          << " r1=" << static_cast<unsigned>(r1)
+          << " r2=" << static_cast<unsigned>(r2);
+      if (diff_in_demanded != 0)
+        return;
+    }
+  }
+}

>From cc10954c2420cb68793f69f9b3d8454e481ae8df Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 26 May 2026 10:53:41 -0700
Subject: [PATCH 2/4] [GlobalISel] Address SimplifyDemandedBits review feedback

Follow-up to the fused demanded-bits simplifier. Resolves review findings
without changing the analysis result of any in-tree query.

Correctness / robustness:
* Shift KnownBits: recurse the shift-amount operand with an all-bits demand
  (GISelDemandedMask::getAllBits) rather than threading the result's
  bit-demand into it. The amount is always fully demanded; this is inert
  today (no leaf consumes Mask.Bits) but removes a latent miscompile trap.
* GISelDemandedMask::forElts asserts element-count width parity, mirroring
  forBits, so a (Bits, Elts) swap fails loudly at the construction boundary.
* getDemandedBitsForUse G_TRUNC guards its zext like the sibling cases.
* Fix the stale "intentionally left unwired" comment on simplify_demanded_bits
  (it is wired into the AArch64 post-legalizer combiner).

Tests:
* combine-simplify-demanded-bits.mir: isolate the rule via only-enable-rule
  and add positive/negative disjoint-G_OR coverage.
* combine-narrow-trunc-shr.mir: two runs in one file -- an isolated run
  (only-enable-rule) so the lowmask case shows the real rewrite instead of a
  UBFX, plus a full-pipeline run that doubles as a deadloop guard
  (narrow_trunc_shr_const coexisting with trunc_shift must reach a fixpoint,
  not ping-pong on the outer trunc). negative_shift_geq_dstbw uses an unknown
  source so it actually exercises the K>=DstBW guard rather than folding to 0.
* combine-trunc-shift-demanded-and.mir (new): isolates trunc_shift so the
  redundant-AND drop is attributable to applyCombineTruncOfShift's hook.
* KnownBitsTest.cpp: add a ranged-shift-amount soundness oracle that drives
  getShiftSrcDemandedBits's multi-amount union loop and ASHR sign-fill, and
  correct the constant-amount oracle comment.
---
 .../llvm/CodeGen/GlobalISel/CombinerHelper.h  |   8 +-
 .../include/llvm/Target/GlobalISel/Combine.td |   4 +-
 .../CodeGen/GlobalISel/GISelValueTracking.cpp |  30 +++-
 .../GlobalISel/combine-narrow-trunc-shr.mir   | 160 ++++++++++--------
 .../combine-simplify-demanded-bits.mir        |  55 +++++-
 .../combine-trunc-shift-demanded-and.mir      |  33 ++++
 .../CodeGen/GlobalISel/KnownBitsTest.cpp      |  84 ++++++++-
 7 files changed, 283 insertions(+), 91 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/GlobalISel/combine-trunc-shift-demanded-and.mir

diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index fa43650fd202b..bbbef0c45380e 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -1197,10 +1197,10 @@ class CombinerHelper {
   bool trySimplifyDemandedBits(MachineInstr &MI, unsigned OperandNo,
                                const GISelDemandedMask &Mask) const;
 
-  /// Per-instruction demanded-bits simplification. Computes the instruction's
-  /// demand from its users; if its def can be replaced under that demand,
-  /// captures the rewrite. Deadloop-safe: returns true only when apply will
-  /// make progress.
+  /// Match per-instruction demanded-bits simplification. Computes the
+  /// instruction's demand from its users; if its def can be replaced under that
+  /// demand, captures the apply callback. Deadloop-safe: returns true only when
+  /// the captured callback will make progress.
   bool matchSimplifyDemandedBits(MachineInstr &MI, BuildFnTy &MatchInfo) const;
 
   // Match (G_TRUNC (G_LSHR/G_ASHR X, K-const)) when X's bits beyond DstBW
diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index a090eeecde14b..46595ee7295a3 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -1126,8 +1126,8 @@ def narrow_trunc_shr_const : GICombineRule<
 // mask (for AND) covers every demanded bit, or the constant or-bits (for OR)
 // are entirely undemanded, the instruction is replaced by its LHS operand.
 // Deadloop-safe: match returns true only when the rewrite will make progress.
-// NOTE: intentionally left unwired from all combiner groups -- wire-up is a
-// separate step once the rule is validated end-to-end.
+// Wired into the AArch64 post-legalizer combiner; not in the generic groups so
+// other targets opt in explicitly.
 def simplify_demanded_bits : GICombineRule<
   (defs root:$root, build_fn_matchinfo:$matchinfo),
   (match (wip_match_opcode G_AND, G_OR):$root,
diff --git a/llvm/lib/CodeGen/GlobalISel/GISelValueTracking.cpp b/llvm/lib/CodeGen/GlobalISel/GISelValueTracking.cpp
index a0047f990f288..14753f00027dd 100644
--- a/llvm/lib/CodeGen/GlobalISel/GISelValueTracking.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/GISelValueTracking.cpp
@@ -76,6 +76,16 @@ GISelDemandedMask GISelDemandedMask::forBits(const MachineRegisterInfo &MRI,
 
 GISelDemandedMask GISelDemandedMask::forElts(const MachineRegisterInfo &MRI,
                                              Register R, APInt Elts) {
+  LLT Ty = MRI.getType(R);
+  // Elts must match the register's element count (1 for scalars and scalable
+  // vectors). Mirror forBits and refuse a mismatch at the construction
+  // boundary so a (Bits, Elts) swap fails loudly instead of being silently
+  // disabled downstream.
+  assert(
+      (!Ty.isValid() ||
+       Elts.getBitWidth() == (Ty.isFixedVector() ? Ty.getNumElements() : 1)) &&
+      "GISelDemandedMask::forElts: Elts width must match register's "
+      "element count");
   return getAllBits(MRI, R).withElts(std::move(Elts));
 }
 
@@ -232,7 +242,10 @@ APInt GISelValueTracking::getDemandedBitsForUse(const MachineInstr &UseMI,
   switch (UseMI.getOpcode()) {
   case TargetOpcode::G_TRUNC:
     // def (narrow) bit p == operand bit p; high operand bits not demanded.
-    return DemandOfUser.zext(OpBW);
+    // The def is narrower than the operand, so DemandOfUser should fit; guard
+    // symmetrically with the other cases against a wider-than-expected demand.
+    return DemandOfUser.getBitWidth() <= OpBW ? DemandOfUser.zext(OpBW)
+                                              : APInt::getAllOnes(OpBW);
   case TargetOpcode::G_AND: {
     // result = X & C. If the OTHER operand is a constant C, X is demanded only
     // where the result is demanded AND C is set. Otherwise demand all of D.
@@ -696,7 +709,10 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
                                             Opcode, Depth + 1);
     KnownBits LHSKnown, RHSKnown;
     computeKnownBitsImpl(SrcReg, LHSKnown, Mask.withBits(SrcBits), Depth + 1);
-    computeKnownBitsImpl(AmtReg, RHSKnown, Mask, Depth + 1);
+    // The shift amount is fully demanded (every amount bit selects the route);
+    // never thread the result's bit-demand into it.
+    computeKnownBitsImpl(AmtReg, RHSKnown,
+                         GISelDemandedMask::getAllBits(MRI, AmtReg), Depth + 1);
     Known = KnownBits::ashr(LHSKnown, RHSKnown);
     break;
   }
@@ -707,7 +723,10 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
                                             Opcode, Depth + 1);
     KnownBits LHSKnown, RHSKnown;
     computeKnownBitsImpl(SrcReg, LHSKnown, Mask.withBits(SrcBits), Depth + 1);
-    computeKnownBitsImpl(AmtReg, RHSKnown, Mask, Depth + 1);
+    // The shift amount is fully demanded (every amount bit selects the route);
+    // never thread the result's bit-demand into it.
+    computeKnownBitsImpl(AmtReg, RHSKnown,
+                         GISelDemandedMask::getAllBits(MRI, AmtReg), Depth + 1);
     Known = KnownBits::lshr(LHSKnown, RHSKnown);
     break;
   }
@@ -718,7 +737,10 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
                                             Opcode, Depth + 1);
     KnownBits LHSKnown, RHSKnown;
     computeKnownBitsImpl(SrcReg, LHSKnown, Mask.withBits(SrcBits), Depth + 1);
-    computeKnownBitsImpl(AmtReg, RHSKnown, Mask, Depth + 1);
+    // The shift amount is fully demanded (every amount bit selects the route);
+    // never thread the result's bit-demand into it.
+    computeKnownBitsImpl(AmtReg, RHSKnown,
+                         GISelDemandedMask::getAllBits(MRI, AmtReg), Depth + 1);
     Known = KnownBits::shl(LHSKnown, RHSKnown);
     break;
   }
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/combine-narrow-trunc-shr.mir b/llvm/test/CodeGen/AArch64/GlobalISel/combine-narrow-trunc-shr.mir
index fda97e9cfc4e4..256e10d9b054c 100644
--- a/llvm/test/CodeGen/AArch64/GlobalISel/combine-narrow-trunc-shr.mir
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/combine-narrow-trunc-shr.mir
@@ -1,8 +1,18 @@
 # NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py UTC_ARGS: --version 6
-# RUN: llc -o - -mtriple=aarch64-unknown-unknown -run-pass=aarch64-prelegalizer-combiner -verify-machineinstrs %s | FileCheck %s
+# RUN: llc -o - -mtriple=aarch64-unknown-unknown -run-pass=aarch64-prelegalizer-combiner -aarch64prelegalizercombiner-only-enable-rule=narrow_trunc_shr_const -verify-machineinstrs %s | FileCheck %s --check-prefixes=CHECK,ISO
+# RUN: llc -o - -mtriple=aarch64-unknown-unknown -run-pass=aarch64-prelegalizer-combiner -verify-machineinstrs %s | FileCheck %s --check-prefixes=CHECK,FULL
+#
+# Two runs share this file:
+#  - the isolated run enables only narrow_trunc_shr_const, so each test pins
+#    this rule (no trunc_shift / bitfield-extract / narrow_binop can confound
+#    the output) -- the ISO-prefixed checks.
+#  - the full run lets the whole prelegalizer combiner go, so
+#    narrow_trunc_shr_const coexists with trunc_shift et al -- the FULL-prefixed
+#    checks. It doubles as a deadloop guard: a reintroduced trunc-of-shift
+#    ping-pong trips the combiner iteration limit / hangs rather than diffing.
 
 # narrow_trunc_shr_const: (trunc (lshr (and X, low-32-mask), K-const)) -> (lshr (trunc X), K-const)
-# AND proves bits [32..) zero; K=5; K+DstBW=37 <= SrcBW=64.
+# AND proves bits [32..) zero; K=5; K+DstBW=37 <= SrcBW=64. Outer trunc dropped.
 ---
 name:            narrow_trunc_lshr_const_lowmask
 legalized: false
@@ -10,16 +20,28 @@ tracksRegLiveness: true
 body:             |
   bb.0:
     liveins: $x0
-    ; CHECK-LABEL: name: narrow_trunc_lshr_const_lowmask
-    ; CHECK: liveins: $x0
-    ; CHECK-NEXT: {{  $}}
-    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
-    ; CHECK-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 5
-    ; CHECK-NEXT: [[C1:%[0-9]+]]:_(s64) = G_CONSTANT i64 27
-    ; CHECK-NEXT: [[UBFX:%[0-9]+]]:_(s64) = G_UBFX [[COPY]], [[C]](s64), [[C1]]
-    ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s32) = G_TRUNC [[UBFX]](s64)
-    ; CHECK-NEXT: $w0 = COPY [[TRUNC]](s32)
-    ; CHECK-NEXT: RET_ReallyLR implicit $w0
+    ; ISO-LABEL: name: narrow_trunc_lshr_const_lowmask
+    ; ISO: liveins: $x0
+    ; ISO-NEXT: {{  $}}
+    ; ISO-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; ISO-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 4294967295
+    ; ISO-NEXT: [[AND:%[0-9]+]]:_(s64) = G_AND [[COPY]], [[C]]
+    ; ISO-NEXT: [[TRUNC:%[0-9]+]]:_(s32) = G_TRUNC [[AND]](s64)
+    ; ISO-NEXT: [[C1:%[0-9]+]]:_(s32) = G_CONSTANT i32 5
+    ; ISO-NEXT: [[LSHR:%[0-9]+]]:_(s32) = G_LSHR [[TRUNC]], [[C1]](s32)
+    ; ISO-NEXT: $w0 = COPY [[LSHR]](s32)
+    ; ISO-NEXT: RET_ReallyLR implicit $w0
+    ;
+    ; FULL-LABEL: name: narrow_trunc_lshr_const_lowmask
+    ; FULL: liveins: $x0
+    ; FULL-NEXT: {{  $}}
+    ; FULL-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; FULL-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 5
+    ; FULL-NEXT: [[C1:%[0-9]+]]:_(s64) = G_CONSTANT i64 27
+    ; FULL-NEXT: [[UBFX:%[0-9]+]]:_(s64) = G_UBFX [[COPY]], [[C]](s64), [[C1]]
+    ; FULL-NEXT: [[TRUNC:%[0-9]+]]:_(s32) = G_TRUNC [[UBFX]](s64)
+    ; FULL-NEXT: $w0 = COPY [[TRUNC]](s32)
+    ; FULL-NEXT: RET_ReallyLR implicit $w0
     %0:_(s64) = COPY $x0
     %1:_(s64) = G_CONSTANT i64 4294967295
     %2:_(s64) = G_AND %0, %1
@@ -38,14 +60,25 @@ tracksRegLiveness: true
 body:             |
   bb.0:
     liveins: $w0
-    ; CHECK-LABEL: name: narrow_trunc_lshr_const_zext
-    ; CHECK: liveins: $w0
-    ; CHECK-NEXT: {{  $}}
-    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s32) = COPY $w0
-    ; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 7
-    ; CHECK-NEXT: [[LSHR:%[0-9]+]]:_(s32) = G_LSHR [[COPY]], [[C]](s32)
-    ; CHECK-NEXT: $w0 = COPY [[LSHR]](s32)
-    ; CHECK-NEXT: RET_ReallyLR implicit $w0
+    ; ISO-LABEL: name: narrow_trunc_lshr_const_zext
+    ; ISO: liveins: $w0
+    ; ISO-NEXT: {{  $}}
+    ; ISO-NEXT: [[COPY:%[0-9]+]]:_(s32) = COPY $w0
+    ; ISO-NEXT: [[ZEXT:%[0-9]+]]:_(s64) = G_ZEXT [[COPY]](s32)
+    ; ISO-NEXT: [[TRUNC:%[0-9]+]]:_(s32) = G_TRUNC [[ZEXT]](s64)
+    ; ISO-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 7
+    ; ISO-NEXT: [[LSHR:%[0-9]+]]:_(s32) = G_LSHR [[TRUNC]], [[C]](s32)
+    ; ISO-NEXT: $w0 = COPY [[LSHR]](s32)
+    ; ISO-NEXT: RET_ReallyLR implicit $w0
+    ;
+    ; FULL-LABEL: name: narrow_trunc_lshr_const_zext
+    ; FULL: liveins: $w0
+    ; FULL-NEXT: {{  $}}
+    ; FULL-NEXT: [[COPY:%[0-9]+]]:_(s32) = COPY $w0
+    ; FULL-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 7
+    ; FULL-NEXT: [[LSHR:%[0-9]+]]:_(s32) = G_LSHR [[COPY]], [[C]](s32)
+    ; FULL-NEXT: $w0 = COPY [[LSHR]](s32)
+    ; FULL-NEXT: RET_ReallyLR implicit $w0
     %0:_(s32) = COPY $w0
     %1:_(s64) = G_ZEXT %0(s32)
     %2:_(s64) = G_CONSTANT i64 7
@@ -63,14 +96,25 @@ tracksRegLiveness: true
 body:             |
   bb.0:
     liveins: $w0
-    ; CHECK-LABEL: name: narrow_trunc_ashr_const_sext
-    ; CHECK: liveins: $w0
-    ; CHECK-NEXT: {{  $}}
-    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s32) = COPY $w0
-    ; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 11
-    ; CHECK-NEXT: [[ASHR:%[0-9]+]]:_(s32) = G_ASHR [[COPY]], [[C]](s32)
-    ; CHECK-NEXT: $w0 = COPY [[ASHR]](s32)
-    ; CHECK-NEXT: RET_ReallyLR implicit $w0
+    ; ISO-LABEL: name: narrow_trunc_ashr_const_sext
+    ; ISO: liveins: $w0
+    ; ISO-NEXT: {{  $}}
+    ; ISO-NEXT: [[COPY:%[0-9]+]]:_(s32) = COPY $w0
+    ; ISO-NEXT: [[SEXT:%[0-9]+]]:_(s64) = G_SEXT [[COPY]](s32)
+    ; ISO-NEXT: [[TRUNC:%[0-9]+]]:_(s32) = G_TRUNC [[SEXT]](s64)
+    ; ISO-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 11
+    ; ISO-NEXT: [[ASHR:%[0-9]+]]:_(s32) = G_ASHR [[TRUNC]], [[C]](s32)
+    ; ISO-NEXT: $w0 = COPY [[ASHR]](s32)
+    ; ISO-NEXT: RET_ReallyLR implicit $w0
+    ;
+    ; FULL-LABEL: name: narrow_trunc_ashr_const_sext
+    ; FULL: liveins: $w0
+    ; FULL-NEXT: {{  $}}
+    ; FULL-NEXT: [[COPY:%[0-9]+]]:_(s32) = COPY $w0
+    ; FULL-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 11
+    ; FULL-NEXT: [[ASHR:%[0-9]+]]:_(s32) = G_ASHR [[COPY]], [[C]](s32)
+    ; FULL-NEXT: $w0 = COPY [[ASHR]](s32)
+    ; FULL-NEXT: RET_ReallyLR implicit $w0
     %0:_(s32) = COPY $w0
     %1:_(s64) = G_SEXT %0(s32)
     %2:_(s64) = G_CONSTANT i64 11
@@ -81,8 +125,6 @@ body:             |
 ...
 
 # Negative: unknown high bits. Must NOT strip the outer trunc to lshr(trunc).
-# (trunc_shift may still introduce its own narrowing via the mid-VT path; ensure
-# the narrow_trunc_shr_const rewrite does not fire on its own.)
 ---
 name:            negative_unknown_high_bits
 legalized: false
@@ -107,27 +149,29 @@ body:             |
     RET_ReallyLR implicit $w0
 ...
 
-# Negative: K >= DstBW. Combine must not fire (the rewrite would be unsound:
-# the K+DstBW high X bits would be sourced from a position past SrcBW.)
+# Negative: K >= DstBW. Source is an unknown s64 (no constant fold can hide the
+# decision), so the unchanged lshr+trunc proves the K>=DstBW guard fired.
 ---
 name:            negative_shift_geq_dstbw
 legalized: false
 tracksRegLiveness: true
 body:             |
   bb.0:
-    liveins: $w0
+    liveins: $x0
     ; CHECK-LABEL: name: negative_shift_geq_dstbw
-    ; CHECK: liveins: $w0
+    ; CHECK: liveins: $x0
     ; CHECK-NEXT: {{  $}}
-    ; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 0
-    ; CHECK-NEXT: $w0 = COPY [[C]](s32)
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 32
+    ; CHECK-NEXT: [[LSHR:%[0-9]+]]:_(s64) = G_LSHR [[COPY]], [[C]](s64)
+    ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s32) = G_TRUNC [[LSHR]](s64)
+    ; CHECK-NEXT: $w0 = COPY [[TRUNC]](s32)
     ; CHECK-NEXT: RET_ReallyLR implicit $w0
-    %0:_(s32) = COPY $w0
-    %1:_(s64) = G_ZEXT %0(s32)
-    %2:_(s64) = G_CONSTANT i64 32
-    %3:_(s64) = G_LSHR %1, %2
-    %4:_(s32) = G_TRUNC %3(s64)
-    $w0 = COPY %4(s32)
+    %0:_(s64) = COPY $x0
+    %1:_(s64) = G_CONSTANT i64 32
+    %2:_(s64) = G_LSHR %0, %1
+    %3:_(s32) = G_TRUNC %2(s64)
+    $w0 = COPY %3(s32)
     RET_ReallyLR implicit $w0
 ...
 
@@ -159,35 +203,3 @@ body:             |
     $w1 = COPY %4(s32)
     RET_ReallyLR implicit $x0, implicit $w1
 ...
-
-# Option 1: extended trunc_shift drives demanded-bits AND elimination.
-# (trunc:s32 (shl:s64 (and:s64 X, low-32-mask), K))
-# After trunc_shift narrows shift to s32 (NewShiftTy == DstTy), the driver
-# is invoked on the inner AND with demand mask = low 32 bits of s64. The
-# mask 0xFFFFFFFF covers the demand entirely, so the AND is dropped before
-# the trunc is built -- the resulting MIR has no surviving G_AND.
----
-name:            option1_driver_drops_redundant_and
-legalized: false
-tracksRegLiveness: true
-body:             |
-  bb.0:
-    liveins: $x0
-    ; CHECK-LABEL: name: option1_driver_drops_redundant_and
-    ; CHECK: liveins: $x0
-    ; CHECK-NEXT: {{  $}}
-    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
-    ; CHECK-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 5
-    ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s32) = G_TRUNC [[COPY]](s64)
-    ; CHECK-NEXT: [[SHL:%[0-9]+]]:_(s32) = G_SHL [[TRUNC]], [[C]](s64)
-    ; CHECK-NEXT: $w0 = COPY [[SHL]](s32)
-    ; CHECK-NEXT: RET_ReallyLR implicit $w0
-    %0:_(s64) = COPY $x0
-    %1:_(s64) = G_CONSTANT i64 4294967295
-    %2:_(s64) = G_AND %0, %1
-    %3:_(s64) = G_CONSTANT i64 5
-    %4:_(s64) = G_SHL %2, %3
-    %5:_(s32) = G_TRUNC %4(s64)
-    $w0 = COPY %5(s32)
-    RET_ReallyLR implicit $w0
-...
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/combine-simplify-demanded-bits.mir b/llvm/test/CodeGen/AArch64/GlobalISel/combine-simplify-demanded-bits.mir
index 2a31a934808c7..90fdddd1ca004 100644
--- a/llvm/test/CodeGen/AArch64/GlobalISel/combine-simplify-demanded-bits.mir
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/combine-simplify-demanded-bits.mir
@@ -1,5 +1,5 @@
 # NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py UTC_ARGS: --version 5
-# RUN: llc -mtriple=aarch64 -run-pass=aarch64-postlegalizer-combiner -verify-machineinstrs %s -o - | FileCheck %s
+# RUN: llc -mtriple=aarch64 -run-pass=aarch64-postlegalizer-combiner -aarch64postlegalizercombiner-only-enable-rule=simplify_demanded_bits -verify-machineinstrs %s -o - | FileCheck %s
 ---
 name:            drop_redundant_outer_mask
 legalized:       true
@@ -52,3 +52,56 @@ body:             |
     $x1 = COPY %hi(s64)
     RET_ReallyLR implicit $w0, implicit $x1
 ...
+---
+# (or X, C) feeding a use that only demands bits where C is zero: the constant
+# or-bits are entirely undemanded, so the disjoint G_OR is dropped and the use
+# reads X directly.
+name:            drop_disjoint_or_high_const
+legalized:       true
+tracksRegLiveness: true
+body:             |
+  bb.0:
+    liveins: $w0
+    ; CHECK-LABEL: name: drop_disjoint_or_high_const
+    ; CHECK: liveins: $w0
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s32) = COPY $w0
+    ; CHECK-NEXT: %lowmask:_(s32) = G_CONSTANT i32 255
+    ; CHECK-NEXT: %and:_(s32) = G_AND [[COPY]], %lowmask
+    ; CHECK-NEXT: $w0 = COPY %and(s32)
+    ; CHECK-NEXT: RET_ReallyLR implicit $w0
+    %0:_(s32) = COPY $w0
+    %hc:_(s32) = G_CONSTANT i32 65280
+    %or:_(s32) = G_OR %0, %hc
+    %lowmask:_(s32) = G_CONSTANT i32 255
+    %and:_(s32) = G_AND %or, %lowmask
+    $w0 = COPY %and(s32)
+    RET_ReallyLR implicit $w0
+...
+---
+# Negative: the demanded window overlaps the OR constant's set bits, so the OR
+# is observable and must be kept.
+name:            keep_or_demanded_const
+legalized:       true
+tracksRegLiveness: true
+body:             |
+  bb.0:
+    liveins: $w0
+    ; CHECK-LABEL: name: keep_or_demanded_const
+    ; CHECK: liveins: $w0
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s32) = COPY $w0
+    ; CHECK-NEXT: %hc:_(s32) = G_CONSTANT i32 255
+    ; CHECK-NEXT: %or:_(s32) = G_OR [[COPY]], %hc
+    ; CHECK-NEXT: %lowmask:_(s32) = G_CONSTANT i32 4095
+    ; CHECK-NEXT: %and:_(s32) = G_AND %or, %lowmask
+    ; CHECK-NEXT: $w0 = COPY %and(s32)
+    ; CHECK-NEXT: RET_ReallyLR implicit $w0
+    %0:_(s32) = COPY $w0
+    %hc:_(s32) = G_CONSTANT i32 255
+    %or:_(s32) = G_OR %0, %hc
+    %lowmask:_(s32) = G_CONSTANT i32 4095
+    %and:_(s32) = G_AND %or, %lowmask
+    $w0 = COPY %and(s32)
+    RET_ReallyLR implicit $w0
+...
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/combine-trunc-shift-demanded-and.mir b/llvm/test/CodeGen/AArch64/GlobalISel/combine-trunc-shift-demanded-and.mir
new file mode 100644
index 0000000000000..6ce75897d0e77
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/combine-trunc-shift-demanded-and.mir
@@ -0,0 +1,33 @@
+# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py UTC_ARGS: --version 6
+# RUN: llc -o - -mtriple=aarch64-unknown-unknown -run-pass=aarch64-prelegalizer-combiner -aarch64prelegalizercombiner-only-enable-rule=trunc_shift -verify-machineinstrs %s | FileCheck %s
+#
+# Isolates trunc_shift: only this rule is enabled, so narrow_binop / redundant-and
+# folds cannot drop the inner G_AND. applyCombineTruncOfShift's stage-3 demanded-
+# bits hook is the ONLY thing that can eliminate it. With the hook, the inner AND
+# (mask = low 32 bits, which the post-trunc demand covers entirely) is dropped
+# before the trunc is built; without it the AND would survive as (and (trunc X), -1).
+---
+name:            hook_drops_redundant_and_under_trunc_shl
+legalized: false
+tracksRegLiveness: true
+body:             |
+  bb.0:
+    liveins: $x0
+    ; CHECK-LABEL: name: hook_drops_redundant_and_under_trunc_shl
+    ; CHECK: liveins: $x0
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 5
+    ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s32) = G_TRUNC [[COPY]](s64)
+    ; CHECK-NEXT: [[SHL:%[0-9]+]]:_(s32) = G_SHL [[TRUNC]], [[C]](s64)
+    ; CHECK-NEXT: $w0 = COPY [[SHL]](s32)
+    ; CHECK-NEXT: RET_ReallyLR implicit $w0
+    %0:_(s64) = COPY $x0
+    %1:_(s64) = G_CONSTANT i64 4294967295
+    %2:_(s64) = G_AND %0, %1
+    %3:_(s64) = G_CONSTANT i64 5
+    %4:_(s64) = G_SHL %2, %3
+    %5:_(s32) = G_TRUNC %4(s64)
+    $w0 = COPY %5(s32)
+    RET_ReallyLR implicit $w0
+...
diff --git a/llvm/unittests/CodeGen/GlobalISel/KnownBitsTest.cpp b/llvm/unittests/CodeGen/GlobalISel/KnownBitsTest.cpp
index cd99da2cf56d7..b93058246e8e6 100644
--- a/llvm/unittests/CodeGen/GlobalISel/KnownBitsTest.cpp
+++ b/llvm/unittests/CodeGen/GlobalISel/KnownBitsTest.cpp
@@ -2317,12 +2317,14 @@ TEST_F(AArch64GISelMITest, KnownBitsASHRDemandNonRegression) {
   }
 }
 
-// Soundness oracle: exhaustive iN=4 enumeration. For every (X, Sh, Mask)
-// in [0..16), assert that the demand-aware analysis never claims a
-// zero/one bit inside the demanded region that the concrete shift
-// semantics disagree with. The shift amount is a G_CONSTANT in [0, 3]
-// so the analysis sees a known constant amount (no G_AND masking is
-// required; the helper iterates ShVal over [0, 3] directly).
+// Soundness oracle: exhaustive iN=4 enumeration with a CONSTANT shift amount.
+// For every (X, Sh, Mask) in [0..16), assert the demand-aware analysis never
+// claims a zero/one bit inside the demanded region that the concrete shift
+// semantics contradict. With a constant amount the shift-source demand reduces
+// to a singleton range, and because no leaf opcode consumes Mask.Bits today the
+// demand mask does not change the KnownBits result here -- so this verifies
+// soundness of the constant-amount path. The RangedAmount oracle further down
+// drives getShiftSrcDemandedBits's multi-amount union loop and ASHR sign-fill.
 namespace {
 
 int signExtendI4(unsigned X) { return SignExtend64<4>(X); }
@@ -2337,6 +2339,23 @@ std::string buildShiftMIR(const char *Opcode, unsigned XVal, unsigned ShVal) {
   return Out;
 }
 
+// Like buildShiftMIR but the amount is a *non-constant* value masked to [0, 3]
+// (G_AND of an undef with 0b0011). The analysis then sees a multi-element
+// amount range, so getShiftSrcDemandedBits iterates C over [0, 3] (and for
+// ASHR evaluates the sign-fill broadcast) rather than the trivial singleton
+// path.
+std::string buildRangedShiftMIR(const char *Opcode, unsigned XVal) {
+  std::string Out;
+  raw_string_ostream OS(Out);
+  OS << "  %x:_(s4) = G_CONSTANT i4 " << signExtendI4(XVal) << "\n";
+  OS << "  %rawsh:_(s4) = G_IMPLICIT_DEF\n";
+  OS << "  %m:_(s4) = G_CONSTANT i4 3\n";
+  OS << "  %sh:_(s4) = G_AND %rawsh, %m\n";
+  OS << "  %r:_(s4) = " << Opcode << " %x, %sh\n";
+  OS << "  %copy_r:_(s4) = COPY %r\n";
+  return Out;
+}
+
 } // namespace
 
 TEST_F(AArch64GISelMITest, KnownBitsSHLDemandSoundness_i4) {
@@ -2475,6 +2494,59 @@ TEST_F(AArch64GISelMITest, KnownBitsASHRDemandSoundness_i4) {
   }
 }
 
+// Ranged-amount soundness oracle: the shift amount is non-constant in [0, 3],
+// so getShiftSrcDemandedBits unions the routed source bits over C in [0, 3]
+// (and, for ASHR, evaluates the sign-fill broadcast). For each X and demand
+// mask, the analysis must not claim a demanded zero/one bit that some feasible
+// amount contradicts -- it must agree with the intersection over all amounts.
+TEST_F(AArch64GISelMITest, KnownBitsRangedShiftAmountDemandSoundness_i4) {
+  struct OpInfo {
+    const char *Name;
+    unsigned (*Eval)(unsigned X, unsigned C); // -> concrete result & 0xF
+  };
+  const OpInfo Ops[] = {
+      {"G_SHL",
+       [](unsigned X, unsigned C) -> unsigned { return (X << C) & 0xFu; }},
+      {"G_LSHR",
+       [](unsigned X, unsigned C) -> unsigned {
+         return ((X & 0xFu) >> C) & 0xFu;
+       }},
+      {"G_ASHR",
+       [](unsigned X, unsigned C) -> unsigned {
+         return static_cast<unsigned>(signExtendI4(X) >> C) & 0xFu;
+       }},
+  };
+  for (const OpInfo &Op : Ops) {
+    for (unsigned XVal = 0; XVal < 16; ++XVal) {
+      Copies.clear();
+      setUp(buildRangedShiftMIR(Op.Name, XVal));
+      if (!TM)
+        GTEST_SKIP();
+      Register CopyR = Copies.back();
+      Register R = MRI->getVRegDef(CopyR)->getOperand(1).getReg();
+      GISelValueTracking Info(*MF);
+      // Bits that are zero (resp. one) for every feasible amount c in [0, 3].
+      unsigned ZeroInAll = 0xFu, OneInAll = 0xFu;
+      for (unsigned C = 0; C < 4; ++C) {
+        unsigned Concrete = Op.Eval(XVal, C);
+        ZeroInAll &= ~Concrete & 0xFu;
+        OneInAll &= Concrete;
+      }
+      for (unsigned MaskBits = 0; MaskBits < 16; ++MaskBits) {
+        APInt Demand(4, MaskBits);
+        KnownBits Demanded =
+            Info.getKnownBits(R, GISelDemandedMask::forBits(*MRI, R, Demand));
+        EXPECT_TRUE(
+            (Demanded.Zero & Demand).isSubsetOf(APInt(4, ZeroInAll) & Demand))
+            << Op.Name << " X=" << XVal << " Mask=" << MaskBits;
+        EXPECT_TRUE(
+            (Demanded.One & Demand).isSubsetOf(APInt(4, OneInAll) & Demand))
+            << Op.Name << " X=" << XVal << " Mask=" << MaskBits;
+      }
+    }
+  }
+}
+
 // Part B: spot-check getDemandedBitsForUse for G_OR, G_XOR, G_ZEXT, G_SEXT.
 TEST_F(AArch64GISelMITest, DemandedBitsForUseOrXorZextSext) {
   // MIR layout (all regs named for clarity):

>From a7d5483763c19f1686e98d492e98efe40c7108d6 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 9 Jun 2026 08:35:47 -0700
Subject: [PATCH 3/4] refactor(gisel): demand-drive bit simplifier

---
 .../llvm/CodeGen/GlobalISel/CombinerHelper.h  |  41 +-
 .../CodeGen/GlobalISel/GISelDemandedMask.h    |  13 +-
 .../CodeGen/GlobalISel/GISelValueTracking.h   |  14 -
 .../include/llvm/Target/GlobalISel/Combine.td |   7 +-
 .../lib/CodeGen/GlobalISel/CombinerHelper.cpp | 221 +++++++----
 .../CodeGen/GlobalISel/GISelValueTracking.cpp |  91 -----
 .../CodeGen/GlobalISel/KnownBitsTest.cpp      | 349 ++++++++----------
 7 files changed, 338 insertions(+), 398 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index bbbef0c45380e..950d27f81c8ed 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -39,6 +39,7 @@ class MachineRegisterInfo;
 class MachineInstr;
 class MachineOperand;
 class GISelValueTracking;
+struct KnownBits;
 struct GISelDemandedMask;
 class MachineDominatorTree;
 class LegalizerInfo;
@@ -1178,29 +1179,18 @@ class CombinerHelper {
   LLVM_ABI bool matchAVG(MachineInstr &MI, MachineRegisterInfo &MRI, Register X,
                          Register Y, unsigned TargetOpc) const;
 
-  /// If the value defined by operand \p OperandNo of \p MI can be replaced by
-  /// one of its existing operand registers when only \p Mask is demanded,
-  /// return that register. Pure (no mutation), generic opcodes only (no target
-  /// dispatch). Currently handles redundant G_AND / disjoint G_OR by a
-  /// constant.
-  std::optional<Register>
-  getDemandedBitsSimplifiedReg(MachineInstr &MI, unsigned OperandNo,
-                               const GISelDemandedMask &Mask) const;
-
-  /// Demanded-bits back-propagation driver. Attempts to rewrite the value
-  /// defined by operand \p OperandNo of \p MI under the assumption that only
-  /// the bits set in \p Mask.Bits and the elements set in \p Mask.Elts of that
-  /// def are consumed. \p OperandNo selects which def is described, so the
-  /// driver can be used on instructions with multiple defs. Returns true if
-  /// \p MI (or one of its operand-producing instructions) was mutated or
-  /// replaced.
-  bool trySimplifyDemandedBits(MachineInstr &MI, unsigned OperandNo,
-                               const GISelDemandedMask &Mask) const;
-
-  /// Match per-instruction demanded-bits simplification. Computes the
-  /// instruction's demand from its users; if its def can be replaced under that
-  /// demand, captures the apply callback. Deadloop-safe: returns true only when
-  /// the captured callback will make progress.
+  /// Simplify operand \p OpNo of \p MI when only \p Mask is demanded, and
+  /// return the known bits discovered during the same recursive walk in
+  /// \p Known. If the operand's defining value has one non-debug use, the
+  /// defining instruction may be rewritten or erased. If the defining value has
+  /// multiple uses, only this operand use may be replaced.
+  bool simplifyDemandedBits(MachineInstr &MI, unsigned OpNo,
+                            const GISelDemandedMask &Mask, KnownBits &Known,
+                            unsigned Depth = 0) const;
+
+  /// Match per-instruction demanded-bits simplification in the current combine
+  /// context. The matcher may simplify a use operand of \p MI; it does not
+  /// compute a union of demands from all users of \p MI's def.
   bool matchSimplifyDemandedBits(MachineInstr &MI, BuildFnTy &MatchInfo) const;
 
   // Match (G_TRUNC (G_LSHR/G_ASHR X, K-const)) when X's bits beyond DstBW
@@ -1209,6 +1199,11 @@ class CombinerHelper {
   bool matchNarrowTruncShrConst(MachineInstr &MI, BuildFnTy &MatchInfo) const;
 
 private:
+  std::optional<Register>
+  simplifyDemandedBitsForUse(MachineInstr &MI, unsigned OpNo,
+                             const GISelDemandedMask &Mask, KnownBits &Known,
+                             unsigned Depth) const;
+
   /// Checks for legality of an indexed variant of \p LdSt.
   bool isIndexedLoadStoreLegal(GLoadStore &LdSt) const;
 
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GISelDemandedMask.h b/llvm/include/llvm/CodeGen/GlobalISel/GISelDemandedMask.h
index a2d1d8216355e..dd240df0eed00 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GISelDemandedMask.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GISelDemandedMask.h
@@ -7,8 +7,8 @@
 //===----------------------------------------------------------------------===//
 /// \file
 /// Defines the GISelDemandedMask query struct used by GISelValueTracking and
-/// CombinerHelper to express demanded-bit and demanded-element masks as a
-/// single, ergonomic value with named factories.
+/// CombinerHelper to carry GlobalISel's equivalent of SelectionDAG's
+/// (DemandedBits, DemandedElts) pair.
 //
 //===----------------------------------------------------------------------===//
 
@@ -25,10 +25,11 @@ namespace llvm {
 
 class MachineRegisterInfo;
 
-/// Bundles a demanded-bits mask and a demanded-elements mask used to
-/// constrain GISel value-tracking queries. The struct has no public
-/// positional constructor; callers must use named factories so that
-/// `(Bits, Elts)` cannot be swapped at a call site.
+/// Bundles a demanded-bits mask and a demanded-elements mask. This is a carrier
+/// for GlobalISel's equivalent of SelectionDAG's `(DemandedBits,
+/// DemandedElts)` pair, not a separate policy layer. The struct has no public
+/// positional constructor; callers must use named factories so that `(Bits,
+/// Elts)` cannot be swapped at a call site.
 struct GISelDemandedMask {
   APInt Bits;
   APInt Elts;
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GISelValueTracking.h b/llvm/include/llvm/CodeGen/GlobalISel/GISelValueTracking.h
index 85819bd32101c..c7dd1d3133ba0 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GISelValueTracking.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GISelValueTracking.h
@@ -110,20 +110,6 @@ class LLVM_ABI GISelValueTracking : public GISelChangeObserver {
     Known.Zero.setLowBits(Log2(Alignment));
   }
 
-  /// Bits of operand \p OpIdx of \p UseMI that \p UseMI demands, given that
-  /// only \p DemandOfUser bits of UseMI's def are demanded. The returned mask
-  /// is at the operand's scalar bit width. Over-approximates: returns all-ones
-  /// for opcodes/operands whose reverse transfer is not modeled.
-  APInt getDemandedBitsForUse(const MachineInstr &UseMI, unsigned OpIdx,
-                              const APInt &DemandOfUser);
-
-  /// Union, over all non-debug users of \p R, of the bits each user demands of
-  /// \p R (via getDemandedBitsForUse, recursing upward for each user's own
-  /// demand). Returns all-ones at/past the recursion cap, for vector/invalid
-  /// types, or when a user has multiple defs. Returns an empty (zero) mask when
-  /// \p R has no users (a dead value demands nothing).
-  APInt computeDemandedBits(Register R, unsigned Depth = 0);
-
   /// \return The known alignment for the pointer-like value \p R.
   Align computeKnownAlignment(Register R, unsigned Depth = 0);
 
diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index 46595ee7295a3..e6300c3552ac6 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -1122,9 +1122,8 @@ def narrow_trunc_shr_const : GICombineRule<
   (apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
 
 // Demand-driven elimination of redundant G_AND / disjoint G_OR by a constant.
-// Computes the union of bit-demands from all users of the def; if the constant
-// mask (for AND) covers every demanded bit, or the constant or-bits (for OR)
-// are entirely undemanded, the instruction is replaced by its LHS operand.
+// Invokes a demanded-bits operand simplifier in the current combine context;
+// it does not compute a union of demands from all users of the root def.
 // Deadloop-safe: match returns true only when the rewrite will make progress.
 // Wired into the AArch64 post-legalizer combiner; not in the generic groups so
 // other targets opt in explicitly.
@@ -1132,7 +1131,7 @@ def simplify_demanded_bits : GICombineRule<
   (defs root:$root, build_fn_matchinfo:$matchinfo),
   (match (wip_match_opcode G_AND, G_OR):$root,
          [{ return Helper.matchSimplifyDemandedBits(*${root}, ${matchinfo}); }]),
-  (apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
+  (apply [{ Helper.applyBuildFnNoErase(*${root}, ${matchinfo}); }])>;
 
 // Transform (mul x, -1) -> (sub 0, x)
 def mul_by_neg_one: GICombineRule <
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index f5346fe3d06df..bda88ff378faf 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -11,6 +11,7 @@
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/Analysis/CmpInstAnalysis.h"
+#include "llvm/Analysis/ValueTracking.h"
 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
 #include "llvm/CodeGen/GlobalISel/GISelDemandedMask.h"
 #include "llvm/CodeGen/GlobalISel/GISelValueTracking.h"
@@ -2855,18 +2856,18 @@ void CombinerHelper::applyCombineTruncOfShift(
   Register ShiftSrc = ShiftMI->getOperand(1).getReg();
 
   // Stage-3 demanded-bits hook: the upcoming inner G_TRUNC consumes only the
-  // low NewShiftTy bits of ShiftSrc. If its producer is a redundant
-  // mask/disjoint-or, drive simplification before truncating so we don't
+  // low NewShiftTy bits of ShiftSrc. If this use can see through a redundant
+  // mask/disjoint-or, simplify the shift operand before truncating so we don't
   // synthesise dead high-bit work that DCE has to clean up later.
-  if (auto *ShiftSrcMI = getDefIgnoringCopies(ShiftSrc, MRI)) {
-    APInt LowMask =
-        APInt::getLowBitsSet(MRI.getType(ShiftSrc).getScalarSizeInBits(),
-                             NewShiftTy.getScalarSizeInBits());
-    trySimplifyDemandedBits(*ShiftSrcMI, /*OperandNo=*/0,
-                            GISelDemandedMask::forBits(MRI, ShiftSrc, LowMask));
-    // Re-read the operand: driver may have RAUW-rerouted it.
-    ShiftSrc = ShiftMI->getOperand(1).getReg();
-  }
+  APInt LowMask =
+      APInt::getLowBitsSet(MRI.getType(ShiftSrc).getScalarSizeInBits(),
+                           NewShiftTy.getScalarSizeInBits());
+  KnownBits Known(MRI.getType(ShiftSrc).getScalarSizeInBits());
+  simplifyDemandedBits(*ShiftMI, /*OpNo=*/1,
+                       GISelDemandedMask::forBits(MRI, ShiftSrc, LowMask),
+                       Known);
+  // Re-read the operand: the demanded-bits simplifier may have rerouted it.
+  ShiftSrc = ShiftMI->getOperand(1).getReg();
 
   ShiftSrc = Builder.buildTrunc(NewShiftTy, ShiftSrc).getReg(0);
 
@@ -8833,81 +8834,173 @@ bool CombinerHelper::matchAVG(MachineInstr &MI, MachineRegisterInfo &MRI,
   return XTy == MRI.getType(Y) && isLegal({TargetOpc, {XTy}});
 }
 
-std::optional<Register> CombinerHelper::getDemandedBitsSimplifiedReg(
-    MachineInstr &MI, unsigned OperandNo, const GISelDemandedMask &Mask) const {
-  if (OperandNo >= MI.getNumOperands() || !MI.getOperand(OperandNo).isReg() ||
-      !MI.getOperand(OperandNo).isDef())
+static bool isRegUseOperand(const MachineInstr &MI, unsigned OpNo) {
+  return OpNo < MI.getNumOperands() && MI.getOperand(OpNo).isReg() &&
+         !MI.getOperand(OpNo).isDef();
+}
+
+std::optional<Register> CombinerHelper::simplifyDemandedBitsForUse(
+    MachineInstr &MI, unsigned OpNo, const GISelDemandedMask &Mask,
+    KnownBits &Known, unsigned Depth) const {
+  Known = KnownBits(Mask.Bits.getBitWidth());
+  if (!isRegUseOperand(MI, OpNo) || Mask.Bits.isZero())
     return std::nullopt;
-  Register Dst = MI.getOperand(OperandNo).getReg();
-  LLT DstTy = MRI.getType(Dst);
-  if (!DstTy.isValid())
+
+  Register OpReg = MI.getOperand(OpNo).getReg();
+  LLT OpTy = MRI.getType(OpReg);
+  if (!VT)
+    return std::nullopt;
+
+  if (!OpTy.isValid() ||
+      Mask.Bits.getBitWidth() != OpTy.getScalarSizeInBits()) {
+    Known = VT->getKnownBits(OpReg);
     return std::nullopt;
-  unsigned BW = DstTy.getScalarSizeInBits();
-  if (Mask.Bits.getBitWidth() != BW)
+  }
+
+  auto GiveUp = [&]() -> std::optional<Register> {
+    Known = VT->getKnownBits(OpReg, Mask, Depth);
     return std::nullopt;
+  };
 
-  switch (MI.getOpcode()) {
-  case TargetOpcode::G_AND: {
-    // (and X, C) with (C & Mask) == Mask -> X is sufficient.
-    Register LHS = MI.getOperand(1).getReg();
-    std::optional<APInt> C =
-        getConstantOrConstantSplatVector(MI.getOperand(2).getReg());
-    if (!C || C->getBitWidth() != BW)
-      return std::nullopt;
-    if (!Mask.Bits.isSubsetOf(*C))
-      return std::nullopt;
-    if (MRI.getType(LHS) != DstTy)
+  if (Depth >= MaxAnalysisRecursionDepth)
+    return GiveUp();
+
+  MachineInstr *DefMI = OpReg.isVirtual() ? MRI.getVRegDef(OpReg) : nullptr;
+  if (!DefMI)
+    return GiveUp();
+
+  unsigned Opcode = DefMI->getOpcode();
+  if (Opcode != TargetOpcode::G_AND && Opcode != TargetOpcode::G_OR)
+    return GiveUp();
+
+  if (DefMI->getNumExplicitDefs() != 1 || !DefMI->getOperand(0).isReg() ||
+      !isRegUseOperand(*DefMI, 1) || !isRegUseOperand(*DefMI, 2))
+    return GiveUp();
+
+  Register Dst = DefMI->getOperand(0).getReg();
+  LLT DstTy = MRI.getType(Dst);
+  if (!DstTy.isValid() ||
+      Mask.Bits.getBitWidth() != DstTy.getScalarSizeInBits())
+    return GiveUp();
+
+  unsigned BW = Mask.Bits.getBitWidth();
+  Register LHS = DefMI->getOperand(1).getReg();
+  Register RHS = DefMI->getOperand(2).getReg();
+  auto SimplifyWithConst = [&](Register X,
+                               Register CReg) -> std::optional<Register> {
+    if (MRI.getType(X) != DstTy || MRI.getType(CReg) != DstTy)
       return std::nullopt;
-    return LHS;
-  }
-  case TargetOpcode::G_OR: {
-    // (or X, C) with (C & Mask) == 0 -> X is sufficient.
-    Register LHS = MI.getOperand(1).getReg();
-    std::optional<APInt> C =
-        getConstantOrConstantSplatVector(MI.getOperand(2).getReg());
+
+    std::optional<APInt> C = getConstantOrConstantSplatVector(CReg);
     if (!C || C->getBitWidth() != BW)
       return std::nullopt;
-    if ((*C & Mask.Bits).getBoolValue())
-      return std::nullopt;
-    if (MRI.getType(LHS) != DstTy)
+
+    if (Opcode == TargetOpcode::G_AND) {
+      if (Mask.Bits.isSubsetOf(*C))
+        return X;
+      if (Mask.Bits.isSubsetOf(~*C))
+        return CReg;
       return std::nullopt;
-    return LHS;
-  }
-  default:
+    }
+
+    if (Mask.Bits.isSubsetOf(~*C))
+      return X;
+    if (Mask.Bits.isSubsetOf(*C))
+      return CReg;
     return std::nullopt;
-  }
+  };
+
+  KnownBits RHSKnown(BW);
+  simplifyDemandedBitsForUse(*DefMI, /*OpNo=*/2, Mask, RHSKnown, Depth + 1);
+
+  APInt LHSDemand = Mask.Bits;
+  if (Opcode == TargetOpcode::G_AND)
+    LHSDemand &= ~RHSKnown.Zero;
+  else
+    LHSDemand &= ~RHSKnown.One;
+
+  KnownBits LHSKnown(BW);
+  simplifyDemandedBitsForUse(*DefMI, /*OpNo=*/1, Mask.withBits(LHSDemand),
+                             LHSKnown, Depth + 1);
+
+  Known = LHSKnown;
+  if (Opcode == TargetOpcode::G_AND)
+    Known &= RHSKnown;
+  else
+    Known |= RHSKnown;
+
+  if (std::optional<Register> Repl = SimplifyWithConst(LHS, RHS))
+    return Repl;
+  return SimplifyWithConst(RHS, LHS);
 }
 
-bool CombinerHelper::trySimplifyDemandedBits(
-    MachineInstr &MI, unsigned OperandNo, const GISelDemandedMask &Mask) const {
-  if (std::optional<Register> R =
-          getDemandedBitsSimplifiedReg(MI, OperandNo, Mask)) {
-    replaceRegWith(MRI, MI.getOperand(OperandNo).getReg(), *R);
-    eraseInst(MI);
-    return true;
+bool CombinerHelper::simplifyDemandedBits(MachineInstr &MI, unsigned OpNo,
+                                          const GISelDemandedMask &Mask,
+                                          KnownBits &Known,
+                                          unsigned Depth) const {
+  std::optional<Register> Repl =
+      simplifyDemandedBitsForUse(MI, OpNo, Mask, Known, Depth);
+  if (!Repl)
+    return false;
+
+  MachineOperand &UseMO = MI.getOperand(OpNo);
+  Register OldReg = UseMO.getReg();
+  if (OldReg == *Repl)
+    return false;
+
+  if (OldReg.isVirtual() && MRI.hasOneNonDBGUse(OldReg)) {
+    if (MachineInstr *DefMI = MRI.getVRegDef(OldReg)) {
+      replaceRegWith(MRI, OldReg, *Repl);
+      eraseInst(*DefMI);
+      return true;
+    }
   }
-  return false;
+
+  replaceRegOpWith(MRI, UseMO, *Repl);
+  return true;
 }
 
 bool CombinerHelper::matchSimplifyDemandedBits(MachineInstr &MI,
                                                BuildFnTy &MatchInfo) const {
-  if (MI.getNumExplicitDefs() != 1 || !MI.getOperand(0).isReg())
+  if (MI.getNumExplicitDefs() != 1 || !MI.getOperand(0).isReg() ||
+      !isRegUseOperand(MI, 1) || !isRegUseOperand(MI, 2))
     return false;
+
   Register Dst = MI.getOperand(0).getReg();
   LLT Ty = MRI.getType(Dst);
   if (!Ty.isValid() || Ty.isVector())
     return false;
-  unsigned BW = Ty.getScalarSizeInBits();
-  APInt Demand = VT->computeDemandedBits(Dst);
-  if (Demand.getBitWidth() != BW || Demand.isAllOnes())
-    return false; // nothing un-demanded -> no opportunity, and avoids re-firing
-  std::optional<Register> R = getDemandedBitsSimplifiedReg(
-      MI, /*OperandNo=*/0, GISelDemandedMask::forBits(MRI, Dst, Demand));
-  if (!R)
+
+  GISelDemandedMask RootMask = GISelDemandedMask::getAllBits(MRI, Dst);
+  auto MatchOperand = [&](unsigned OpNo, const GISelDemandedMask &OpMask,
+                          KnownBits &Known) {
+    std::optional<Register> Repl =
+        simplifyDemandedBitsForUse(MI, OpNo, OpMask, Known, /*Depth=*/0);
+    if (!Repl)
+      return false;
+
+    MatchInfo = [this, &MI, OpNo, OpMask](MachineIRBuilder &B) {
+      KnownBits Known(OpMask.Bits.getBitWidth());
+      simplifyDemandedBits(MI, OpNo, OpMask, Known);
+    };
+    return true;
+  };
+
+  unsigned Opcode = MI.getOpcode();
+  KnownBits RHSKnown(RootMask.Bits.getBitWidth());
+  if (MatchOperand(/*OpNo=*/2, RootMask, RHSKnown))
+    return true;
+
+  APInt LHSDemand = RootMask.Bits;
+  if (Opcode == TargetOpcode::G_AND)
+    LHSDemand &= ~RHSKnown.Zero;
+  else if (Opcode == TargetOpcode::G_OR)
+    LHSDemand &= ~RHSKnown.One;
+  else
     return false;
-  Register Repl = *R;
-  MatchInfo = [=](MachineIRBuilder &B) { replaceRegWith(MRI, Dst, Repl); };
-  return true;
+
+  KnownBits LHSKnown(RootMask.Bits.getBitWidth());
+  return MatchOperand(/*OpNo=*/1, RootMask.withBits(LHSDemand), LHSKnown);
 }
 
 // (trunc (lshr X, K)) with bits [DstBW, DstBW+K) of X known-zero
diff --git a/llvm/lib/CodeGen/GlobalISel/GISelValueTracking.cpp b/llvm/lib/CodeGen/GlobalISel/GISelValueTracking.cpp
index 14753f00027dd..71d018861d71c 100644
--- a/llvm/lib/CodeGen/GlobalISel/GISelValueTracking.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/GISelValueTracking.cpp
@@ -230,97 +230,6 @@ static APInt getShiftSrcDemandedBits(GISelValueTracking &VT, Register AmtReg,
   return SrcBits;
 }
 
-APInt GISelValueTracking::getDemandedBitsForUse(const MachineInstr &UseMI,
-                                                unsigned OpIdx,
-                                                const APInt &DemandOfUser) {
-  Register OpReg = UseMI.getOperand(OpIdx).getReg();
-  LLT OpTy = MRI.getType(OpReg);
-  if (!OpTy.isValid() || OpTy.isVector())
-    return APInt::getAllOnes(OpTy.isValid() ? OpTy.getScalarSizeInBits() : 1);
-  unsigned OpBW = OpTy.getScalarSizeInBits();
-
-  switch (UseMI.getOpcode()) {
-  case TargetOpcode::G_TRUNC:
-    // def (narrow) bit p == operand bit p; high operand bits not demanded.
-    // The def is narrower than the operand, so DemandOfUser should fit; guard
-    // symmetrically with the other cases against a wider-than-expected demand.
-    return DemandOfUser.getBitWidth() <= OpBW ? DemandOfUser.zext(OpBW)
-                                              : APInt::getAllOnes(OpBW);
-  case TargetOpcode::G_AND: {
-    // result = X & C. If the OTHER operand is a constant C, X is demanded only
-    // where the result is demanded AND C is set. Otherwise demand all of D.
-    unsigned Other = OpIdx == 1 ? 2 : 1;
-    if (auto C = getIConstantVRegValWithLookThrough(
-            UseMI.getOperand(Other).getReg(), MRI);
-        C && C->Value.getBitWidth() == OpBW &&
-        DemandOfUser.getBitWidth() == OpBW)
-      return DemandOfUser & C->Value;
-    return DemandOfUser.getBitWidth() == OpBW ? DemandOfUser
-                                              : APInt::getAllOnes(OpBW);
-  }
-  case TargetOpcode::G_OR: {
-    // result = X | C. Where C is set, result is 1 regardless of X, so X is
-    // not demanded there. Only a constant other-operand is modeled.
-    unsigned Other = OpIdx == 1 ? 2 : 1;
-    if (auto C = getIConstantVRegValWithLookThrough(
-            UseMI.getOperand(Other).getReg(), MRI);
-        C && C->Value.getBitWidth() == OpBW &&
-        DemandOfUser.getBitWidth() == OpBW)
-      return DemandOfUser & ~C->Value;
-    return DemandOfUser.getBitWidth() == OpBW ? DemandOfUser
-                                              : APInt::getAllOnes(OpBW);
-  }
-  case TargetOpcode::G_XOR:
-    // Bit-local: each demanded result bit needs the matching operand bit.
-    return DemandOfUser.getBitWidth() == OpBW ? DemandOfUser
-                                              : APInt::getAllOnes(OpBW);
-  case TargetOpcode::G_ZEXT:
-    // def low OpBW bits == operand; higher def bits are zero (independent of
-    // X).
-    return DemandOfUser.getBitWidth() >= OpBW ? DemandOfUser.trunc(OpBW)
-                                              : APInt::getAllOnes(OpBW);
-  case TargetOpcode::G_SEXT: {
-    // def low OpBW bits == operand; every def bit at index >= OpBW-1 comes from
-    // the operand sign bit (OpBW-1).
-    if (OpBW == 0 || DemandOfUser.getBitWidth() < OpBW)
-      return APInt::getAllOnes(OpBW);
-    APInt Low = DemandOfUser.trunc(OpBW);
-    if (!DemandOfUser.lshr(OpBW - 1).isZero())
-      Low.setBit(OpBW - 1);
-    return Low;
-  }
-  default:
-    return APInt::getAllOnes(OpBW);
-  }
-}
-
-APInt GISelValueTracking::computeDemandedBits(Register R, unsigned Depth) {
-  LLT Ty = MRI.getType(R);
-  unsigned BW = Ty.isValid() ? Ty.getScalarSizeInBits() : 1;
-  if (!Ty.isValid() || Ty.isVector())
-    return APInt::getAllOnes(BW);
-  if (Depth >= MaxAnalysisRecursionDepth)
-    return APInt::getAllOnes(BW);
-
-  APInt Demand = APInt::getZero(BW);
-  for (const MachineInstr &UseMI : MRI.use_nodbg_instructions(R)) {
-    if (UseMI.getNumExplicitDefs() != 1) {
-      Demand = APInt::getAllOnes(BW);
-      break;
-    }
-    APInt UserDemand =
-        computeDemandedBits(UseMI.getOperand(0).getReg(), Depth + 1);
-    for (unsigned I = 0, E = UseMI.getNumOperands(); I != E; ++I) {
-      const MachineOperand &MO = UseMI.getOperand(I);
-      if (MO.isReg() && !MO.isDef() && MO.getReg() == R)
-        Demand |= getDemandedBitsForUse(UseMI, I, UserDemand);
-    }
-    if (Demand.isAllOnes())
-      break;
-  }
-  return Demand;
-}
-
 // Bitfield extract is computed as (Src >> Offset) & Mask, where Mask is
 // created using Width. Use this function when the inputs are KnownBits
 // objects. TODO: Move this KnownBits.h if this is usable in more cases.
diff --git a/llvm/unittests/CodeGen/GlobalISel/KnownBitsTest.cpp b/llvm/unittests/CodeGen/GlobalISel/KnownBitsTest.cpp
index b93058246e8e6..958007c2509a6 100644
--- a/llvm/unittests/CodeGen/GlobalISel/KnownBitsTest.cpp
+++ b/llvm/unittests/CodeGen/GlobalISel/KnownBitsTest.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "GISelMITest.h"
+#include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
 #include "llvm/CodeGen/GlobalISel/GISelValueTracking.h"
 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
 
@@ -2174,6 +2175,34 @@ KnownBits runShiftAnalysis(MachineRegisterInfo &MRI, GISelValueTracking &Info,
   return Info.getKnownBits(R, GISelDemandedMask::forBits(MRI, R, Bits));
 }
 
+MachineInstr *findOpcode(MachineFunction &MF, unsigned Opcode,
+                         unsigned Index = 0) {
+  for (MachineBasicBlock &MBB : MF) {
+    for (MachineInstr &MI : MBB) {
+      if (MI.getOpcode() == Opcode) {
+        if (Index == 0)
+          return &MI;
+        --Index;
+      }
+    }
+  }
+  return nullptr;
+}
+
+KnownBits simplifyDemandedBitsOperand(MachineFunction &MF,
+                                      MachineRegisterInfo &MRI,
+                                      MachineIRBuilder &B, MachineInstr &Use,
+                                      const APInt &Demand) {
+  GISelValueTracking VT(MF);
+  CombinerHelper Helper(VT, B, /*IsPreLegalize=*/false, &VT);
+  KnownBits Known(Demand.getBitWidth());
+  EXPECT_TRUE(Helper.simplifyDemandedBits(
+      Use, /*OpNo=*/1,
+      GISelDemandedMask::forBits(MRI, Use.getOperand(1).getReg(), Demand),
+      Known));
+  return Known;
+}
+
 } // namespace
 
 TEST_F(AArch64GISelMITest, KnownBitsSHLDemandEquivalence) {
@@ -2411,60 +2440,143 @@ TEST_F(AArch64GISelMITest, KnownBitsLSHRDemandSoundness_i4) {
   }
 }
 
-TEST_F(AArch64GISelMITest, DemandedBitsForUseTruncAnd) {
+TEST_F(AArch64GISelMITest, SimplifyDemandedBitsAndSingleUse) {
   StringRef MIRString = R"(
-    %0:_(s64) = COPY $x0
-    %trunc:_(s32) = G_TRUNC %0
-    %c:_(s32) = G_CONSTANT i32 255
-    %and:_(s32) = G_AND %trunc, %c
-    %4:_(s32) = COPY %and
+    %x:_(s32) = G_TRUNC %0
+    %mask:_(s32) = G_CONSTANT i32 255
+    %and:_(s32) = G_AND %x, %mask
+    %lowmask:_(s32) = G_CONSTANT i32 15
+    %use:_(s32) = G_AND %and, %lowmask
+    %out:_(s32) = COPY %use
 )";
   setUp(MIRString);
   if (!TM)
     GTEST_SKIP();
-  GISelValueTracking VT(*MF);
 
-  // Final COPY's source is %and.
-  MachineInstr *FinalCopy = MRI->getVRegDef(Copies.back());
-  Register AndReg = FinalCopy->getOperand(1).getReg();
-  MachineInstr *AndMI = MRI->getVRegDef(AndReg);
-  Register TruncReg = AndMI->getOperand(1).getReg();
-  MachineInstr *TruncMI = MRI->getVRegDef(TruncReg);
-
-  // AND with C=0xFF, def fully demanded -> variable operand demanded in low 8.
-  EXPECT_EQ(
-      VT.getDemandedBitsForUse(*AndMI, /*OpIdx=*/1, APInt::getAllOnes(32)),
-      APInt(32, 0xFF));
-  // TRUNC s64<-...: operand (s64) demanded in low 32 bits.
-  EXPECT_EQ(
-      VT.getDemandedBitsForUse(*TruncMI, /*OpIdx=*/1, APInt::getAllOnes(32)),
-      APInt(64, 0xFFFFFFFFULL));
-}
-
-TEST_F(AArch64GISelMITest, ComputeDemandedBitsAndMask) {
-  // %0 is s64 (pre-populated by the framework as COPY $x0). Trunc to s32,
-  // then AND with 255. computeDemandedBits on the s32 trunc register should
-  // report only the low 8 bits demanded (the COPY %out demands all 32, the
-  // AND narrows that to 0xFF via the constant mask).
+  MachineInstr *Producer = findOpcode(*MF, TargetOpcode::G_AND);
+  MachineInstr *Use = findOpcode(*MF, TargetOpcode::G_AND, /*Index=*/1);
+  ASSERT_NE(Producer, nullptr);
+  ASSERT_NE(Use, nullptr);
+
+  Register ProducerReg = Producer->getOperand(0).getReg();
+  Register XReg = Producer->getOperand(1).getReg();
+  simplifyDemandedBitsOperand(*MF, *MRI, B, *Use, APInt(32, 0x0F));
+  EXPECT_EQ(Use->getOperand(1).getReg(), XReg);
+  EXPECT_TRUE(MRI->use_nodbg_empty(ProducerReg));
+}
+
+TEST_F(AArch64GISelMITest, SimplifyDemandedBitsAndMultiUse) {
   StringRef MIRString = R"(
-    %trunc:_(s32) = G_TRUNC %0
-    %c:_(s32) = G_CONSTANT i32 255
-    %and:_(s32) = G_AND %trunc, %c
-    %out:_(s32) = COPY %and
+    %x:_(s32) = G_TRUNC %0
+    %mask:_(s32) = G_CONSTANT i32 255
+    %and:_(s32) = G_AND %x, %mask
+    %lowmask:_(s32) = G_CONSTANT i32 15
+    %use:_(s32) = G_AND %and, %lowmask
+    %amt:_(s32) = G_CONSTANT i32 8
+    %side:_(s32) = G_LSHR %and, %amt
+    %out:_(s32) = COPY %use
+    %side_out:_(s32) = COPY %side
+)";
+  setUp(MIRString);
+  if (!TM)
+    GTEST_SKIP();
+
+  MachineInstr *Producer = findOpcode(*MF, TargetOpcode::G_AND);
+  MachineInstr *Use = findOpcode(*MF, TargetOpcode::G_AND, /*Index=*/1);
+  MachineInstr *Side = findOpcode(*MF, TargetOpcode::G_LSHR);
+  ASSERT_NE(Producer, nullptr);
+  ASSERT_NE(Use, nullptr);
+  ASSERT_NE(Side, nullptr);
+
+  Register ProducerReg = Producer->getOperand(0).getReg();
+  Register XReg = Producer->getOperand(1).getReg();
+  simplifyDemandedBitsOperand(*MF, *MRI, B, *Use, APInt(32, 0x0F));
+  EXPECT_EQ(Use->getOperand(1).getReg(), XReg);
+  EXPECT_EQ(Side->getOperand(1).getReg(), ProducerReg);
+  EXPECT_EQ(MRI->getVRegDef(ProducerReg)->getOpcode(), TargetOpcode::G_AND);
+}
+
+TEST_F(AArch64GISelMITest, SimplifyDemandedBitsOrSingleUse) {
+  StringRef MIRString = R"(
+    %x:_(s32) = G_TRUNC %0
+    %high:_(s32) = G_CONSTANT i32 65280
+    %or:_(s32) = G_OR %x, %high
+    %lowmask:_(s32) = G_CONSTANT i32 255
+    %use:_(s32) = G_AND %or, %lowmask
+    %out:_(s32) = COPY %use
+)";
+  setUp(MIRString);
+  if (!TM)
+    GTEST_SKIP();
+
+  MachineInstr *Producer = findOpcode(*MF, TargetOpcode::G_OR);
+  MachineInstr *Use = findOpcode(*MF, TargetOpcode::G_AND);
+  ASSERT_NE(Producer, nullptr);
+  ASSERT_NE(Use, nullptr);
+
+  Register ProducerReg = Producer->getOperand(0).getReg();
+  Register XReg = Producer->getOperand(1).getReg();
+  simplifyDemandedBitsOperand(*MF, *MRI, B, *Use, APInt(32, 0xFF));
+  EXPECT_EQ(Use->getOperand(1).getReg(), XReg);
+  EXPECT_TRUE(MRI->use_nodbg_empty(ProducerReg));
+}
+
+TEST_F(AArch64GISelMITest, SimplifyDemandedBitsOrMultiUse) {
+  StringRef MIRString = R"(
+    %x:_(s32) = G_TRUNC %0
+    %high:_(s32) = G_CONSTANT i32 65280
+    %or:_(s32) = G_OR %x, %high
+    %lowmask:_(s32) = G_CONSTANT i32 255
+    %use:_(s32) = G_AND %or, %lowmask
+    %amt:_(s32) = G_CONSTANT i32 8
+    %side:_(s32) = G_LSHR %or, %amt
+    %out:_(s32) = COPY %use
+    %side_out:_(s32) = COPY %side
+)";
+  setUp(MIRString);
+  if (!TM)
+    GTEST_SKIP();
+
+  MachineInstr *Producer = findOpcode(*MF, TargetOpcode::G_OR);
+  MachineInstr *Use = findOpcode(*MF, TargetOpcode::G_AND);
+  MachineInstr *Side = findOpcode(*MF, TargetOpcode::G_LSHR);
+  ASSERT_NE(Producer, nullptr);
+  ASSERT_NE(Use, nullptr);
+  ASSERT_NE(Side, nullptr);
+
+  Register ProducerReg = Producer->getOperand(0).getReg();
+  Register XReg = Producer->getOperand(1).getReg();
+  simplifyDemandedBitsOperand(*MF, *MRI, B, *Use, APInt(32, 0xFF));
+  EXPECT_EQ(Use->getOperand(1).getReg(), XReg);
+  EXPECT_EQ(Side->getOperand(1).getReg(), ProducerReg);
+  EXPECT_EQ(MRI->getVRegDef(ProducerReg)->getOpcode(), TargetOpcode::G_OR);
+}
+
+TEST_F(AArch64GISelMITest, SimplifyDemandedBitsOrConstantExplainsDemand) {
+  StringRef MIRString = R"(
+    %x:_(s32) = G_TRUNC %0
+    %low:_(s32) = G_CONSTANT i32 255
+    %or:_(s32) = G_OR %x, %low
+    %usemask:_(s32) = G_CONSTANT i32 15
+    %use:_(s32) = G_AND %or, %usemask
+    %out:_(s32) = COPY %use
 )";
   setUp(MIRString);
   if (!TM)
     GTEST_SKIP();
-  GISelValueTracking VT(*MF);
 
-  // Walk: final COPY (%out) -> %and -> its variable operand %trunc.
-  Register AndReg = MRI->getVRegDef(Copies.back())->getOperand(1).getReg();
-  MachineInstr *AndMI = MRI->getVRegDef(AndReg);
-  Register TruncReg = AndMI->getOperand(1).getReg();
+  MachineInstr *Producer = findOpcode(*MF, TargetOpcode::G_OR);
+  MachineInstr *Use = findOpcode(*MF, TargetOpcode::G_AND);
+  ASSERT_NE(Producer, nullptr);
+  ASSERT_NE(Use, nullptr);
 
-  // %trunc is only used by (and %trunc, 255) -> demanded just in the low 8
-  // bits.
-  EXPECT_EQ(VT.computeDemandedBits(TruncReg, /*Depth=*/0), APInt(32, 0xFF));
+  Register ProducerReg = Producer->getOperand(0).getReg();
+  Register LowCstReg = Producer->getOperand(2).getReg();
+  KnownBits Known =
+      simplifyDemandedBitsOperand(*MF, *MRI, B, *Use, APInt(32, 0x0F));
+  EXPECT_EQ(Use->getOperand(1).getReg(), LowCstReg);
+  EXPECT_TRUE(MRI->use_nodbg_empty(ProducerReg));
+  EXPECT_TRUE(APInt(32, 0x0F).isSubsetOf(Known.One));
 }
 
 TEST_F(AArch64GISelMITest, KnownBitsASHRDemandSoundness_i4) {
@@ -2546,158 +2658,3 @@ TEST_F(AArch64GISelMITest, KnownBitsRangedShiftAmountDemandSoundness_i4) {
     }
   }
 }
-
-// Part B: spot-check getDemandedBitsForUse for G_OR, G_XOR, G_ZEXT, G_SEXT.
-TEST_F(AArch64GISelMITest, DemandedBitsForUseOrXorZextSext) {
-  // MIR layout (all regs named for clarity):
-  //   %x8  = G_TRUNC %0        ; s32 narrow input
-  //   %c0f = G_CONSTANT i32 15 ; 0x0F
-  //   %or  = G_OR  %x8, %c0f  ; G_OR case: constant other-operand = 0x0F
-  //   %c1  = G_CONSTANT i32 42 ; arbitrary second XOR operand
-  //   %xr  = G_XOR %x8, %c1   ; G_XOR case
-  //   %t8  = G_TRUNC %0        ; s8 input for ext tests
-  //   %zx  = G_ZEXT %t8        ; s8->s32 zext
-  //   %sx  = G_SEXT %t8        ; s8->s32 sext
-  //   %out = COPY %sx          ; anchor to Copies.back()
-  StringRef MIRString = R"(
-    %x8:_(s32) = G_TRUNC %0
-    %c0f:_(s32) = G_CONSTANT i32 15
-    %or:_(s32) = G_OR %x8, %c0f
-    %c1:_(s32) = G_CONSTANT i32 42
-    %xr:_(s32) = G_XOR %x8, %c1
-    %t8:_(s8) = G_TRUNC %0
-    %zx:_(s32) = G_ZEXT %t8
-    %sx:_(s32) = G_SEXT %t8
-    %out:_(s32) = COPY %sx
-)";
-  setUp(MIRString);
-  if (!TM)
-    GTEST_SKIP();
-  GISelValueTracking VT(*MF);
-
-  // Scan MBBs to find the instruction objects by opcode.
-  MachineInstr *OrMI = nullptr;
-  MachineInstr *XorMI = nullptr;
-  MachineInstr *ZextMI = nullptr;
-  MachineInstr *SextMI = nullptr;
-  for (auto &MBB : *MF) {
-    for (auto &MI : MBB) {
-      switch (MI.getOpcode()) {
-      case TargetOpcode::G_OR:
-        OrMI = &MI;
-        break;
-      case TargetOpcode::G_XOR:
-        XorMI = &MI;
-        break;
-      case TargetOpcode::G_ZEXT:
-        ZextMI = &MI;
-        break;
-      case TargetOpcode::G_SEXT:
-        SextMI = &MI;
-        break;
-      default:
-        break;
-      }
-    }
-  }
-  ASSERT_NE(OrMI, nullptr);
-  ASSERT_NE(XorMI, nullptr);
-  ASSERT_NE(ZextMI, nullptr);
-  ASSERT_NE(SextMI, nullptr);
-
-  // G_OR %x8, 0x0F -- full s32 demand -> operand demand = ~0x0F = 0xFFFFFFF0.
-  // Bits where C=1 are always 1 regardless of X, so X is not demanded there.
-  EXPECT_EQ(VT.getDemandedBitsForUse(*OrMI, /*OpIdx=*/1, APInt::getAllOnes(32)),
-            APInt(32, 0xFFFFFFF0U));
-
-  // G_XOR %x8, %c1 -- def demand 0x0F -> each operand demand 0x0F (bit-local).
-  EXPECT_EQ(VT.getDemandedBitsForUse(*XorMI, /*OpIdx=*/1, APInt(32, 0x0F)),
-            APInt(32, 0x0F));
-
-  // G_ZEXT s8->s32, def demand 0x1FF (9 bits set) -> operand demand trunc to
-  // s8 = 0xFF (high def bits are always 0, not from operand).
-  EXPECT_EQ(VT.getDemandedBitsForUse(*ZextMI, /*OpIdx=*/1, APInt(32, 0x1FFU)),
-            APInt(8, 0xFF));
-
-  // G_SEXT s8->s32, def demand 0x80000000 (only sign bit of s32) -> operand
-  // demand must include the sign bit of s8 (bit 7 = 0x80).
-  EXPECT_EQ(
-      VT.getDemandedBitsForUse(*SextMI, /*OpIdx=*/1, APInt(32, 0x80000000U)),
-      APInt(8, 0x80));
-}
-
-// Part C: exhaustive soundness oracle for G_OR-with-constant and G_XOR.
-// For all demand masks D and all operand values x, verifies that
-// getDemandedBitsForUse over-approximates: flipping only outside-demand bits
-// cannot change demanded result bits.
-TEST_F(AArch64GISelMITest, DemandedBitsForUseOrXorSoundnessOracle) {
-  // Width=8, C=0x3C (OR constant), Y=0x5A (XOR second operand).
-  constexpr unsigned W = 8;
-  constexpr uint8_t C = 0x3C;
-  constexpr uint8_t Y = 0x5A;
-
-  StringRef MIRString = R"(
-    %base:_(s8) = G_TRUNC %0
-    %cor:_(s8) = G_CONSTANT i8 60
-    %orr:_(s8) = G_OR %base, %cor
-    %cxor:_(s8) = G_CONSTANT i8 90
-    %xrr:_(s8) = G_XOR %base, %cxor
-    %sentinel:_(s8) = COPY %xrr
-)";
-  setUp(MIRString);
-  if (!TM)
-    GTEST_SKIP();
-  GISelValueTracking VT(*MF);
-
-  MachineInstr *OrMI = nullptr;
-  MachineInstr *XorMI = nullptr;
-  for (auto &MBB : *MF)
-    for (auto &MI : MBB) {
-      if (MI.getOpcode() == TargetOpcode::G_OR)
-        OrMI = &MI;
-      else if (MI.getOpcode() == TargetOpcode::G_XOR)
-        XorMI = &MI;
-    }
-  ASSERT_NE(OrMI, nullptr);
-  ASSERT_NE(XorMI, nullptr);
-
-  for (unsigned D = 0; D < (1u << W); ++D) {
-    APInt Demand(W, D);
-
-    // --- G_OR oracle ---
-    APInt OrDemand = VT.getDemandedBitsForUse(*OrMI, /*OpIdx=*/1, Demand);
-    APInt OutsideOr = ~OrDemand & APInt::getAllOnes(W);
-    for (unsigned x = 0; x < (1u << W); ++x) {
-      uint8_t x2 = static_cast<uint8_t>(x ^ OutsideOr.getZExtValue());
-      uint8_t r1 = static_cast<uint8_t>(x | C);
-      uint8_t r2 = static_cast<uint8_t>(x2 | C);
-      uint8_t diff_in_demanded =
-          static_cast<uint8_t>((r1 ^ r2) & static_cast<uint8_t>(D));
-      EXPECT_EQ(diff_in_demanded, 0)
-          << "G_OR soundness FAILED: D=" << D << " x=" << x
-          << " x2=" << static_cast<unsigned>(x2)
-          << " r1=" << static_cast<unsigned>(r1)
-          << " r2=" << static_cast<unsigned>(r2);
-      if (diff_in_demanded != 0)
-        return; // Stop on first failure as instructed.
-    }
-
-    // --- G_XOR oracle ---
-    APInt XorDemand = VT.getDemandedBitsForUse(*XorMI, /*OpIdx=*/1, Demand);
-    APInt OutsideXor = ~XorDemand & APInt::getAllOnes(W);
-    for (unsigned x = 0; x < (1u << W); ++x) {
-      uint8_t x2 = static_cast<uint8_t>(x ^ OutsideXor.getZExtValue());
-      uint8_t r1 = static_cast<uint8_t>(x ^ Y);
-      uint8_t r2 = static_cast<uint8_t>(x2 ^ Y);
-      uint8_t diff_in_demanded =
-          static_cast<uint8_t>((r1 ^ r2) & static_cast<uint8_t>(D));
-      EXPECT_EQ(diff_in_demanded, 0)
-          << "G_XOR soundness FAILED: D=" << D << " x=" << x
-          << " x2=" << static_cast<unsigned>(x2)
-          << " r1=" << static_cast<unsigned>(r1)
-          << " r2=" << static_cast<unsigned>(r2);
-      if (diff_in_demanded != 0)
-        return;
-    }
-  }
-}

>From daf039eb73e355d10e34ac033ad3598000d67c8c Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Thu, 11 Jun 2026 14:50:41 -0700
Subject: [PATCH 4/4] [flang] Relax offload image suffix check in
 omp-driver-offload.f90

The driver now emits a .s file for the AMDGPU offload image in this
path, while the test still expected .bc. Drop the suffix from the
FileCheck pattern; the test's intent is to verify that
llvm-offload-binary is invoked with the correct triple/arch/kind,
not the image file type.
---
 flang/test/Driver/omp-driver-offload.f90 | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/flang/test/Driver/omp-driver-offload.f90 b/flang/test/Driver/omp-driver-offload.f90
index a4184f11bf3b0..217ced988ceac 100644
--- a/flang/test/Driver/omp-driver-offload.f90
+++ b/flang/test/Driver/omp-driver-offload.f90
@@ -61,7 +61,7 @@
 ! OPENMP-OFFLOAD-ARGS-SAME:  "-fopenmp"
 ! OPENMP-OFFLOAD-ARGS-SAME:  "-fopenmp-host-ir-file-path" "{{.*}}.bc" "-fopenmp-is-target-device"
 ! OPENMP-OFFLOAD-ARGS-SAME:  {{.*}}.f90"
-! OPENMP-OFFLOAD-ARGS: "{{[^"]*}}llvm-offload-binary{{.*}}" {{.*}} "--image=file={{.*}}.bc,triple=amdgcn-amd-amdhsa,arch=gfx90a,kind=openmp"
+! OPENMP-OFFLOAD-ARGS: "{{[^"]*}}llvm-offload-binary{{.*}}" {{.*}} "--image=file={{.*}},triple=amdgcn-amd-amdhsa,arch=gfx90a,kind=openmp"
 ! OPENMP-OFFLOAD-ARGS-NEXT: "{{[^"]*}}flang" "-fc1" "-triple" "aarch64-unknown-linux-gnu"
 ! OPENMP-OFFLOAD-ARGS-SAME:  "-fopenmp"
 ! OPENMP-OFFLOAD-ARGS-SAME:  "-fembed-offload-object={{.*}}.out" {{.*}}.bc"



More information about the flang-commits mailing list