[llvm] 9f2a068 - [DAG] Add getValid*ShiftAmountConstant wrappers without DemandedElts

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 24 05:27:01 PDT 2024


Author: Simon Pilgrim
Date: 2024-04-24T13:26:43+01:00
New Revision: 9f2a068bffad4a36db088673210f680bfd08b3d1

URL: https://github.com/llvm/llvm-project/commit/9f2a068bffad4a36db088673210f680bfd08b3d1
DIFF: https://github.com/llvm/llvm-project/commit/9f2a068bffad4a36db088673210f680bfd08b3d1.diff

LOG: [DAG] Add getValid*ShiftAmountConstant wrappers without DemandedElts

Simplify callers which don't have their own DemandedElts mask.

Noticed while reviewing #88801

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/SelectionDAG.h
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/lib/Target/X86/X86ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 95dbe74327cfc3..f353aef1f446ff 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -2146,18 +2146,32 @@ class SelectionDAG {
   const APInt *getValidShiftAmountConstant(SDValue V,
                                            const APInt &DemandedElts) const;
 
+  /// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount
+  /// that is less than the element bit-width of the shift node, return it.
+  const APInt *getValidShiftAmountConstant(SDValue V) const;
+
   /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
   /// than the element bit-width of the shift node, return the minimum value.
   const APInt *
   getValidMinimumShiftAmountConstant(SDValue V,
                                      const APInt &DemandedElts) const;
 
+  /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
+  /// than the element bit-width of the shift node, return the minimum value.
+  const APInt *
+  getValidMinimumShiftAmountConstant(SDValue V) const;
+
   /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
   /// than the element bit-width of the shift node, return the maximum value.
   const APInt *
   getValidMaximumShiftAmountConstant(SDValue V,
                                      const APInt &DemandedElts) const;
 
+  /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
+  /// than the element bit-width of the shift node, return the maximum value.
+  const APInt *
+  getValidMaximumShiftAmountConstant(SDValue V) const;
+
   /// Match a binop + shuffle pyramid that represents a horizontal reduction
   /// over the elements of a vector starting from the EXTRACT_VECTOR_ELT node /p
   /// Extract. The reduction must use one of the opcodes listed in /p

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 78395b3d249f5f..23ebfe466c7471 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2993,6 +2993,14 @@ SelectionDAG::getValidShiftAmountConstant(SDValue V,
   return nullptr;
 }
 
+const APInt *SelectionDAG::getValidShiftAmountConstant(SDValue V) const {
+  EVT VT = V.getValueType();
+  APInt DemandedElts = VT.isFixedLengthVector()
+                           ? APInt::getAllOnes(VT.getVectorNumElements())
+                           : APInt(1, 1);
+  return getValidShiftAmountConstant(V, DemandedElts);
+}
+
 const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(
     SDValue V, const APInt &DemandedElts) const {
   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
@@ -3022,6 +3030,14 @@ const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(
   return MinShAmt;
 }
 
+const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(SDValue V) const {
+  EVT VT = V.getValueType();
+  APInt DemandedElts = VT.isFixedLengthVector()
+                           ? APInt::getAllOnes(VT.getVectorNumElements())
+                           : APInt(1, 1);
+  return getValidMinimumShiftAmountConstant(V, DemandedElts);
+}
+
 const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(
     SDValue V, const APInt &DemandedElts) const {
   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
@@ -3051,6 +3067,14 @@ const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(
   return MaxShAmt;
 }
 
+const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(SDValue V) const {
+  EVT VT = V.getValueType();
+  APInt DemandedElts = VT.isFixedLengthVector()
+                           ? APInt::getAllOnes(VT.getVectorNumElements())
+                           : APInt(1, 1);
+  return getValidMaximumShiftAmountConstant(V, DemandedElts);
+}
+
 /// Determine which bits of Op are known to be either zero or one and return
 /// them in Known. For vectors, the known bits are those that are shared by
 /// every vector element.

diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index a4df05e1bd03ca..0bb737b04b7ee7 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -20459,8 +20459,7 @@ static SDValue matchTruncateWithPACK(unsigned &PackOpcode, EVT DstVT,
   // the truncation then we can use PACKSS by converting the srl to a sra.
   // SimplifyDemandedBits often relaxes sra to srl so we need to reverse it.
   if (In.getOpcode() == ISD::SRL && In->hasOneUse())
-    if (const APInt *ShAmt = DAG.getValidShiftAmountConstant(
-            In, APInt::getAllOnes(SrcVT.getVectorNumElements()))) {
+    if (const APInt *ShAmt = DAG.getValidShiftAmountConstant(In)) {
       if (*ShAmt == MinSignBits) {
         PackOpcode = X86ISD::PACKSS;
         return DAG.getNode(ISD::SRA, DL, SrcVT, In->ops());


        


More information about the llvm-commits mailing list