[llvm] 186c192 - [SDAG] Allow scalable vectors in SimplifyDemanded routines

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 5 12:42:31 PST 2022


Author: Philip Reames
Date: 2022-12-05T12:42:16-08:00
New Revision: 186c1922611501b701128987a5c938287d048cd7

URL: https://github.com/llvm/llvm-project/commit/186c1922611501b701128987a5c938287d048cd7
DIFF: https://github.com/llvm/llvm-project/commit/186c1922611501b701128987a5c938287d048cd7.diff

LOG: [SDAG] Allow scalable vectors in SimplifyDemanded routines

This is a continuation of the series of patches adding lane wise support for scalable vectors in various knownbit-esq routines.

The basic idea here is that we track a single lane for scalable vectors which corresponds to an unknown number of lanes at runtime. This is enough for us to perform lane wise reasoning on many arithmetic operations.

Differential Revision: https://reviews.llvm.org/D137190

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
    llvm/test/CodeGen/AArch64/active_lane_mask.ll
    llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 466a2edbd3efd..553facc40d233 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -634,16 +634,10 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
                                           bool AssumeSingleUse) const {
   EVT VT = Op.getValueType();
 
-  // TODO: We can probably do more work on calculating the known bits and
-  // simplifying the operations for scalable vectors, but for now we just
-  // bail out.
-  if (VT.isScalableVector()) {
-    // Pretend we don't know anything for now.
-    Known = KnownBits(DemandedBits.getBitWidth());
-    return false;
-  }
-
-  APInt DemandedElts = VT.isVector()
+  // Since the number of lanes in a scalable vector is unknown at compile time,
+  // we track one bit which is implicitly broadcast to all lanes.  This means
+  // that all lanes in a scalable vector are considered demanded.
+  APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
   return SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, Depth,
@@ -656,12 +650,6 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     SelectionDAG &DAG, unsigned Depth) const {
   EVT VT = Op.getValueType();
 
-  // Pretend we don't know anything about scalable vectors for now.
-  // TODO: We can probably do more work on simplifying the operations for
-  // scalable vectors, but for now we just bail out.
-  if (VT.isScalableVector())
-    return SDValue();
-
   // Limit search depth.
   if (Depth >= SelectionDAG::MaxRecursionDepth)
     return SDValue();
@@ -680,6 +668,9 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   KnownBits LHSKnown, RHSKnown;
   switch (Op.getOpcode()) {
   case ISD::BITCAST: {
+    if (VT.isScalableVector())
+      return SDValue();
+
     SDValue Src = peekThroughBitcasts(Op.getOperand(0));
     EVT SrcVT = Src.getValueType();
     EVT DstVT = Op.getValueType();
@@ -825,6 +816,9 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   case ISD::ANY_EXTEND_VECTOR_INREG:
   case ISD::SIGN_EXTEND_VECTOR_INREG:
   case ISD::ZERO_EXTEND_VECTOR_INREG: {
+    if (VT.isScalableVector())
+      return SDValue();
+
     // If we only want the lowest element and none of extended bits, then we can
     // return the bitcasted source vector.
     SDValue Src = Op.getOperand(0);
@@ -838,6 +832,9 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     break;
   }
   case ISD::INSERT_VECTOR_ELT: {
+    if (VT.isScalableVector())
+      return SDValue();
+
     // If we don't demand the inserted element, return the base vector.
     SDValue Vec = Op.getOperand(0);
     auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2));
@@ -848,6 +845,9 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     break;
   }
   case ISD::INSERT_SUBVECTOR: {
+    if (VT.isScalableVector())
+      return SDValue();
+
     SDValue Vec = Op.getOperand(0);
     SDValue Sub = Op.getOperand(1);
     uint64_t Idx = Op.getConstantOperandVal(2);
@@ -868,6 +868,7 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     break;
   }
   case ISD::VECTOR_SHUFFLE: {
+    assert(!VT.isScalableVector());
     ArrayRef<int> ShuffleMask = cast<ShuffleVectorSDNode>(Op)->getMask();
 
     // If all the demanded elts are from one operand and are inline,
@@ -891,6 +892,11 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     break;
   }
   default:
+    // TODO: Probably okay to remove after audit; here to reduce change size
+    // in initial enablement patch for scalable vectors
+    if (VT.isScalableVector())
+      return SDValue();
+
     if (Op.getOpcode() >= ISD::BUILTIN_OP_END)
       if (SDValue V = SimplifyMultipleUseDemandedBitsForTargetNode(
               Op, DemandedBits, DemandedElts, DAG, Depth))
@@ -904,14 +910,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     SDValue Op, const APInt &DemandedBits, SelectionDAG &DAG,
     unsigned Depth) const {
   EVT VT = Op.getValueType();
-
-  // Pretend we don't know anything about scalable vectors for now.
-  // TODO: We can probably do more work on simplifying the operations for
-  // scalable vectors, but for now we just bail out.
-  if (VT.isScalableVector())
-    return SDValue();
-
-  APInt DemandedElts = VT.isVector()
+  // Since the number of lanes in a scalable vector is unknown at compile time,
+  // we track one bit which is implicitly broadcast to all lanes.  This means
+  // that all lanes in a scalable vector are considered demanded.
+  APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
   return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts, DAG,
@@ -1070,16 +1072,10 @@ bool TargetLowering::SimplifyDemandedBits(
   // Don't know anything.
   Known = KnownBits(BitWidth);
 
-  // TODO: We can probably do more work on calculating the known bits and
-  // simplifying the operations for scalable vectors, but for now we just
-  // bail out.
   EVT VT = Op.getValueType();
-  if (VT.isScalableVector())
-    return false;
-
   bool IsLE = TLO.DAG.getDataLayout().isLittleEndian();
   unsigned NumElts = OriginalDemandedElts.getBitWidth();
-  assert((!VT.isVector() || NumElts == VT.getVectorNumElements()) &&
+  assert((!VT.isFixedLengthVector() || NumElts == VT.getVectorNumElements()) &&
          "Unexpected vector size");
 
   APInt DemandedBits = OriginalDemandedBits;
@@ -1130,6 +1126,8 @@ bool TargetLowering::SimplifyDemandedBits(
   KnownBits Known2;
   switch (Op.getOpcode()) {
   case ISD::SCALAR_TO_VECTOR: {
+    if (VT.isScalableVector())
+      return false;
     if (!DemandedElts[0])
       return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
 
@@ -1167,6 +1165,8 @@ bool TargetLowering::SimplifyDemandedBits(
     break;
   }
   case ISD::INSERT_VECTOR_ELT: {
+    if (VT.isScalableVector())
+      return false;
     SDValue Vec = Op.getOperand(0);
     SDValue Scl = Op.getOperand(1);
     auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2));
@@ -1203,6 +1203,8 @@ bool TargetLowering::SimplifyDemandedBits(
     return false;
   }
   case ISD::INSERT_SUBVECTOR: {
+    if (VT.isScalableVector())
+      return false;
     // Demand any elements from the subvector and the remainder from the src its
     // inserted into.
     SDValue Src = Op.getOperand(0);
@@ -1246,6 +1248,8 @@ bool TargetLowering::SimplifyDemandedBits(
     break;
   }
   case ISD::EXTRACT_SUBVECTOR: {
+    if (VT.isScalableVector())
+      return false;
     // Offset the demanded elts by the subvector index.
     SDValue Src = Op.getOperand(0);
     if (Src.getValueType().isScalableVector())
@@ -1271,6 +1275,8 @@ bool TargetLowering::SimplifyDemandedBits(
     break;
   }
   case ISD::CONCAT_VECTORS: {
+    if (VT.isScalableVector())
+      return false;
     Known.Zero.setAllBits();
     Known.One.setAllBits();
     EVT SubVT = Op.getOperand(0).getValueType();
@@ -1289,6 +1295,7 @@ bool TargetLowering::SimplifyDemandedBits(
     break;
   }
   case ISD::VECTOR_SHUFFLE: {
+    assert(!VT.isScalableVector());
     ArrayRef<int> ShuffleMask = cast<ShuffleVectorSDNode>(Op)->getMask();
 
     // Collect demanded elements from shuffle operands..
@@ -1366,7 +1373,7 @@ bool TargetLowering::SimplifyDemandedBits(
 
     // AND(INSERT_SUBVECTOR(C,X,I),M) -> INSERT_SUBVECTOR(AND(C,M),X,I)
     // iff 'C' is Undef/Constant and AND(X,M) == X (for DemandedBits).
-    if (Op0.getOpcode() == ISD::INSERT_SUBVECTOR &&
+    if (Op0.getOpcode() == ISD::INSERT_SUBVECTOR && !VT.isScalableVector() &&
         (Op0.getOperand(0).isUndef() ||
          ISD::isBuildVectorOfConstantSDNodes(Op0.getOperand(0).getNode())) &&
         Op0->hasOneUse()) {
@@ -2226,12 +2233,15 @@ bool TargetLowering::SimplifyDemandedBits(
     Known = KnownHi.concat(KnownLo);
     break;
   }
-  case ISD::ZERO_EXTEND:
-  case ISD::ZERO_EXTEND_VECTOR_INREG: {
+  case ISD::ZERO_EXTEND_VECTOR_INREG:
+    if (VT.isScalableVector())
+      return false;
+    [[fallthrough]];
+  case ISD::ZERO_EXTEND: {
     SDValue Src = Op.getOperand(0);
     EVT SrcVT = Src.getValueType();
     unsigned InBits = SrcVT.getScalarSizeInBits();
-    unsigned InElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1;
+    unsigned InElts = SrcVT.isFixedLengthVector() ? SrcVT.getVectorNumElements() : 1;
     bool IsVecInReg = Op.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG;
 
     // If none of the top bits are demanded, convert this into an any_extend.
@@ -2263,12 +2273,15 @@ bool TargetLowering::SimplifyDemandedBits(
       return TLO.CombineTo(Op, TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc));
     break;
   }
-  case ISD::SIGN_EXTEND:
-  case ISD::SIGN_EXTEND_VECTOR_INREG: {
+  case ISD::SIGN_EXTEND_VECTOR_INREG:
+    if (VT.isScalableVector())
+      return false;
+    [[fallthrough]];
+  case ISD::SIGN_EXTEND: {
     SDValue Src = Op.getOperand(0);
     EVT SrcVT = Src.getValueType();
     unsigned InBits = SrcVT.getScalarSizeInBits();
-    unsigned InElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1;
+    unsigned InElts = SrcVT.isFixedLengthVector() ? SrcVT.getVectorNumElements() : 1;
     bool IsVecInReg = Op.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG;
 
     // If none of the top bits are demanded, convert this into an any_extend.
@@ -2315,12 +2328,15 @@ bool TargetLowering::SimplifyDemandedBits(
       return TLO.CombineTo(Op, TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc));
     break;
   }
-  case ISD::ANY_EXTEND:
-  case ISD::ANY_EXTEND_VECTOR_INREG: {
+  case ISD::ANY_EXTEND_VECTOR_INREG:
+    if (VT.isScalableVector())
+      return false;
+    [[fallthrough]];
+  case ISD::ANY_EXTEND: {
     SDValue Src = Op.getOperand(0);
     EVT SrcVT = Src.getValueType();
     unsigned InBits = SrcVT.getScalarSizeInBits();
-    unsigned InElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1;
+    unsigned InElts = SrcVT.isFixedLengthVector() ? SrcVT.getVectorNumElements() : 1;
     bool IsVecInReg = Op.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG;
 
     // If we only need the bottom element then we can just bitcast.
@@ -2459,6 +2475,8 @@ bool TargetLowering::SimplifyDemandedBits(
     break;
   }
   case ISD::BITCAST: {
+    if (VT.isScalableVector())
+      return false;
     SDValue Src = Op.getOperand(0);
     EVT SrcVT = Src.getValueType();
     unsigned NumSrcEltBits = SrcVT.getScalarSizeInBits();
@@ -2680,6 +2698,10 @@ bool TargetLowering::SimplifyDemandedBits(
     // We also ask the target about intrinsics (which could be specific to it).
     if (Op.getOpcode() >= ISD::BUILTIN_OP_END ||
         Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN) {
+      // TODO: Probably okay to remove after audit; here to reduce change size
+      // in initial enablement patch for scalable vectors
+      if (Op.getValueType().isScalableVector())
+        break;
       if (SimplifyDemandedBitsForTargetNode(Op, DemandedBits, DemandedElts,
                                             Known, TLO, Depth))
         return true;
@@ -2749,7 +2771,7 @@ static APInt getKnownUndefForVectorBinop(SDValue BO, SelectionDAG &DAG,
          "Vector binop only");
 
   EVT EltVT = VT.getVectorElementType();
-  unsigned NumElts = VT.getVectorNumElements();
+  unsigned NumElts = VT.isFixedLengthVector() ? VT.getVectorNumElements() : 1;
   assert(UndefOp0.getBitWidth() == NumElts &&
          UndefOp1.getBitWidth() == NumElts && "Bad type for undef analysis");
 

diff  --git a/llvm/test/CodeGen/AArch64/active_lane_mask.ll b/llvm/test/CodeGen/AArch64/active_lane_mask.ll
index 211361da18010..cb2b498f274b0 100644
--- a/llvm/test/CodeGen/AArch64/active_lane_mask.ll
+++ b/llvm/test/CodeGen/AArch64/active_lane_mask.ll
@@ -113,14 +113,13 @@ define <vscale x 4 x i1> @lane_mask_nxv4i1_i8(i8 %index, i8 %TC) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    and w8, w0, #0xff
 ; CHECK-NEXT:    index z0.s, #0, #1
+; CHECK-NEXT:    and w9, w1, #0xff
 ; CHECK-NEXT:    and z0.s, z0.s, #0xff
 ; CHECK-NEXT:    ptrue p0.s
 ; CHECK-NEXT:    mov z1.s, w8
-; CHECK-NEXT:    and w8, w1, #0xff
 ; CHECK-NEXT:    add z0.s, z0.s, z1.s
+; CHECK-NEXT:    mov z1.s, w9
 ; CHECK-NEXT:    umin z0.s, z0.s, #255
-; CHECK-NEXT:    and z0.s, z0.s, #0xff
-; CHECK-NEXT:    mov z1.s, w8
 ; CHECK-NEXT:    cmphi p0.s, p0/z, z1.s, z0.s
 ; CHECK-NEXT:    ret
   %active.lane.mask = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i8(i8 %index, i8 %TC)
@@ -132,17 +131,16 @@ define <vscale x 2 x i1> @lane_mask_nxv2i1_i8(i8 %index, i8 %TC) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    // kill: def $w0 killed $w0 def $x0
 ; CHECK-NEXT:    and x8, x0, #0xff
-; CHECK-NEXT:    index z0.d, #0, #1
 ; CHECK-NEXT:    // kill: def $w1 killed $w1 def $x1
 ; CHECK-NEXT:    and x9, x1, #0xff
-; CHECK-NEXT:    and z0.d, z0.d, #0xff
+; CHECK-NEXT:    index z0.d, #0, #1
 ; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    and z0.d, z0.d, #0xff
 ; CHECK-NEXT:    mov z1.d, x8
+; CHECK-NEXT:    mov z2.d, x9
 ; CHECK-NEXT:    add z0.d, z0.d, z1.d
-; CHECK-NEXT:    mov z1.d, x9
 ; CHECK-NEXT:    umin z0.d, z0.d, #255
-; CHECK-NEXT:    and z0.d, z0.d, #0xff
-; CHECK-NEXT:    cmphi p0.d, p0/z, z1.d, z0.d
+; CHECK-NEXT:    cmphi p0.d, p0/z, z2.d, z0.d
 ; CHECK-NEXT:    ret
   %active.lane.mask = call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i8(i8 %index, i8 %TC)
   ret <vscale x 2 x i1> %active.lane.mask

diff  --git a/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp b/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp
index 67ea38b06ef57..518e714c0cd3c 100644
--- a/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp
+++ b/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp
@@ -224,11 +224,15 @@ TEST_F(AArch64SelectionDAGTest, SimplifyDemandedBitsSVE) {
 
   SDValue Op = DAG->getNode(ISD::AND, Loc, InVecVT, N0, Mask2V);
 
+  // N0 = ?000?0?0
+  // Mask2V = 01010101
+  //  =>
+  // Known.Zero = 00100000 (0xAA)
   KnownBits Known;
   APInt DemandedBits = APInt(8, 0xFF);
   TargetLowering::TargetLoweringOpt TLO(*DAG, false, false);
-  EXPECT_FALSE(TL.SimplifyDemandedBits(Op, DemandedBits, Known, TLO));
-  EXPECT_EQ(Known.Zero, APInt(8, 0));
+  EXPECT_TRUE(TL.SimplifyDemandedBits(Op, DemandedBits, Known, TLO));
+  EXPECT_EQ(Known.Zero, APInt(8, 0xAA));
 }
 
 // Piggy-backing on the AArch64 tests to verify SelectionDAG::computeKnownBits.


        


More information about the llvm-commits mailing list