[llvm] 7969ab8 - [SDAG] Allow scalable vectors in ComputeKnownBits (try 2)

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 5 08:53:26 PST 2022


Author: Philip Reames
Date: 2022-12-05T08:52:37-08:00
New Revision: 7969ab85e0a413eeba2bf4360f0c7ea6c5f00e4c

URL: https://github.com/llvm/llvm-project/commit/7969ab85e0a413eeba2bf4360f0c7ea6c5f00e4c
DIFF: https://github.com/llvm/llvm-project/commit/7969ab85e0a413eeba2bf4360f0c7ea6c5f00e4c.diff

LOG: [SDAG] Allow scalable vectors in ComputeKnownBits (try 2)

This was previously reverted due to a hang on a Hexagon bot.  This turned out to be a bug in the Hexagon backend around how splat_vectors are legalized (which they're using for fixed length vectors!).  I adjusted this patch to remove the implicit truncate support.  This hides the hexagon bug for now, and unblocks the rest of the change.

Original commit message:

This is the SelectionDAG equivalent of D136470, and is thus an alternate patch to D128159.

The basic idea here is that we track a single lane for scalable vectors which corresponds to an unknown number of lanes at runtime. This is enough for us to perform lane wise reasoning on many arithmetic operations.

This patch also includes an implementation for SPLAT_VECTOR as without it, the lane wise reasoning has no base case. The original patch which inspired this (D128159), also included STEP_VECTOR. I plan to do that as a separate patch.

Differential Revision: https://reviews.llvm.org/D137140

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/test/CodeGen/AArch64/sve-intrinsics-index.ll
    llvm/test/CodeGen/AArch64/sve-intrinsics-perm-select.ll
    llvm/test/CodeGen/AArch64/sve-umulo-sdnode.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 76b09961b053..37809f8db76c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2910,14 +2910,10 @@ const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(
 KnownBits SelectionDAG::computeKnownBits(SDValue Op, unsigned Depth) const {
   EVT VT = Op.getValueType();
 
-  // TOOD: Until we have a plan for how to represent demanded elements for
-  // scalable vectors, we can just bail out for now.
-  if (Op.getValueType().isScalableVector()) {
-    unsigned BitWidth = Op.getScalarValueSizeInBits();
-    return KnownBits(BitWidth);
-  }
-
-  APInt DemandedElts = VT.isVector()
+  // Since the number of lanes in a scalable vector is unknown at compile time,
+  // we track one bit which is implicitly broadcast to all lanes.  This means
+  // that all lanes in a scalable vector are considered demanded.
+  APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
   return computeKnownBits(Op, DemandedElts, Depth);
@@ -2932,11 +2928,6 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
 
   KnownBits Known(BitWidth);   // Don't know anything.
 
-  // TOOD: Until we have a plan for how to represent demanded elements for
-  // scalable vectors, we can just bail out for now.
-  if (Op.getValueType().isScalableVector())
-    return Known;
-
   if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
     // We know all of the bits for a constant!
     return KnownBits::makeConstant(C->getAPIntValue());
@@ -2951,7 +2942,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
 
   KnownBits Known2;
   unsigned NumElts = DemandedElts.getBitWidth();
-  assert((!Op.getValueType().isVector() ||
+  assert((!Op.getValueType().isFixedLengthVector() ||
           NumElts == Op.getValueType().getVectorNumElements()) &&
          "Unexpected vector size");
 
@@ -2963,7 +2954,23 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
   case ISD::MERGE_VALUES:
     return computeKnownBits(Op.getOperand(Op.getResNo()), DemandedElts,
                             Depth + 1);
+  case ISD::SPLAT_VECTOR: {
+    SDValue SrcOp = Op.getOperand(0);
+    if (SrcOp.getValueSizeInBits() != BitWidth) {
+      assert(SrcOp.getValueSizeInBits() > BitWidth &&
+             "Expected SPLAT_VECTOR implicit truncation");
+      // FIXME: We should be able to truncate the known bits here to match
+      // the official semantics of SPLAT_VECTOR, but doing so exposes a
+      // Hexagon target bug which results in an infinite loop during
+      // DAGCombine.  (See D137140 for repo).  Once that's fixed, we can
+      // strengthen this.
+      break;
+    }
+    Known = computeKnownBits(SrcOp, Depth + 1);
+    break;
+  }
   case ISD::BUILD_VECTOR:
+    assert(!Op.getValueType().isScalableVector());
     // Collect the known bits that are shared by every demanded vector element.
     Known.Zero.setAllBits(); Known.One.setAllBits();
     for (unsigned i = 0, e = Op.getNumOperands(); i != e; ++i) {
@@ -2989,6 +2996,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
     }
     break;
   case ISD::VECTOR_SHUFFLE: {
+    assert(!Op.getValueType().isScalableVector());
     // Collect the known bits that are shared by every vector element referenced
     // by the shuffle.
     APInt DemandedLHS, DemandedRHS;
@@ -3016,6 +3024,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
     break;
   }
   case ISD::CONCAT_VECTORS: {
+    if (Op.getValueType().isScalableVector())
+      break;
     // Split DemandedElts and test each of the demanded subvectors.
     Known.Zero.setAllBits(); Known.One.setAllBits();
     EVT SubVectorVT = Op.getOperand(0).getValueType();
@@ -3036,6 +3046,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
     break;
   }
   case ISD::INSERT_SUBVECTOR: {
+    if (Op.getValueType().isScalableVector())
+      break;
     // Demand any elements from the subvector and the remainder from the src its
     // inserted into.
     SDValue Src = Op.getOperand(0);
@@ -3063,7 +3075,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
     // Offset the demanded elts by the subvector index.
     SDValue Src = Op.getOperand(0);
     // Bail until we can represent demanded elements for scalable vectors.
-    if (Src.getValueType().isScalableVector())
+    if (Op.getValueType().isScalableVector() || Src.getValueType().isScalableVector())
       break;
     uint64_t Idx = Op.getConstantOperandVal(1);
     unsigned NumSrcElts = Src.getValueType().getVectorNumElements();
@@ -3072,6 +3084,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
     break;
   }
   case ISD::SCALAR_TO_VECTOR: {
+    if (Op.getValueType().isScalableVector())
+      break;
     // We know about scalar_to_vector as much as we know about it source,
     // which becomes the first element of otherwise unknown vector.
     if (DemandedElts != 1)
@@ -3085,6 +3099,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
     break;
   }
   case ISD::BITCAST: {
+    if (Op.getValueType().isScalableVector())
+      break;
+
     SDValue N0 = Op.getOperand(0);
     EVT SubVT = N0.getValueType();
     unsigned SubBitWidth = SubVT.getScalarSizeInBits();
@@ -3406,7 +3423,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
     if (ISD::isNON_EXTLoad(LD) && Cst) {
       // Determine any common known bits from the loaded constant pool value.
       Type *CstTy = Cst->getType();
-      if ((NumElts * BitWidth) == CstTy->getPrimitiveSizeInBits()) {
+      if ((NumElts * BitWidth) == CstTy->getPrimitiveSizeInBits() &&
+          !Op.getValueType().isScalableVector()) {
         // If its a vector splat, then we can (quickly) reuse the scalar path.
         // NOTE: We assume all elements match and none are UNDEF.
         if (CstTy->isVectorTy()) {
@@ -3480,6 +3498,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
     break;
   }
   case ISD::ZERO_EXTEND_VECTOR_INREG: {
+    if (Op.getValueType().isScalableVector())
+      break;
     EVT InVT = Op.getOperand(0).getValueType();
     APInt InDemandedElts = DemandedElts.zext(InVT.getVectorNumElements());
     Known = computeKnownBits(Op.getOperand(0), InDemandedElts, Depth + 1);
@@ -3492,6 +3512,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
     break;
   }
   case ISD::SIGN_EXTEND_VECTOR_INREG: {
+    if (Op.getValueType().isScalableVector())
+      break;
     EVT InVT = Op.getOperand(0).getValueType();
     APInt InDemandedElts = DemandedElts.zext(InVT.getVectorNumElements());
     Known = computeKnownBits(Op.getOperand(0), InDemandedElts, Depth + 1);
@@ -3508,6 +3530,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
     break;
   }
   case ISD::ANY_EXTEND_VECTOR_INREG: {
+    if (Op.getValueType().isScalableVector())
+      break;
     EVT InVT = Op.getOperand(0).getValueType();
     APInt InDemandedElts = DemandedElts.zext(InVT.getVectorNumElements());
     Known = computeKnownBits(Op.getOperand(0), InDemandedElts, Depth + 1);
@@ -3673,6 +3697,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
     break;
   }
   case ISD::INSERT_VECTOR_ELT: {
+    if (Op.getValueType().isScalableVector())
+      break;
+
     // If we know the element index, split the demand between the
     // source vector and the inserted element, otherwise assume we need
     // the original demanded vector elements and the value.
@@ -3839,6 +3866,11 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
   case ISD::INTRINSIC_WO_CHAIN:
   case ISD::INTRINSIC_W_CHAIN:
   case ISD::INTRINSIC_VOID:
+    // TODO: Probably okay to remove after audit; here to reduce change size
+    // in initial enablement patch for scalable vectors
+    if (Op.getValueType().isScalableVector())
+      break;
+
     // Allow the target to implement this method for its nodes.
     TLI->computeKnownBitsForTargetNode(Op, Known, DemandedElts, *this, Depth);
     break;

diff  --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-index.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-index.ll
index cc6000954763..7c46cf9c239e 100644
--- a/llvm/test/CodeGen/AArch64/sve-intrinsics-index.ll
+++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-index.ll
@@ -55,7 +55,8 @@ define <vscale x 2 x i64> @index_ii_range() {
 define <vscale x 8 x i16> @index_ii_range_combine(i16 %a) {
 ; CHECK-LABEL: index_ii_range_combine:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    index z0.h, #2, #8
+; CHECK-NEXT:    index z0.h, #0, #8
+; CHECK-NEXT:    orr z0.h, z0.h, #0x2
 ; CHECK-NEXT:    ret
   %val = insertelement <vscale x 8 x i16> poison, i16 2, i32 0
   %val1 = shufflevector <vscale x 8 x i16> %val, <vscale x 8 x i16> poison, <vscale x 8 x i32> zeroinitializer

diff  --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-perm-select.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-perm-select.ll
index 7b31f0e7f6d4..2f3f342cdb93 100644
--- a/llvm/test/CodeGen/AArch64/sve-intrinsics-perm-select.ll
+++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-perm-select.ll
@@ -574,7 +574,7 @@ define <vscale x 2 x i64> @dupq_i64_range(<vscale x 2 x i64> %a) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    index z1.d, #0, #1
 ; CHECK-NEXT:    and z1.d, z1.d, #0x1
-; CHECK-NEXT:    add z1.d, z1.d, #8 // =0x8
+; CHECK-NEXT:    orr z1.d, z1.d, #0x8
 ; CHECK-NEXT:    tbl z0.d, { z0.d }, z1.d
 ; CHECK-NEXT:    ret
   %out = call <vscale x 2 x i64> @llvm.aarch64.sve.dupq.lane.nxv2i64(<vscale x 2 x i64> %a, i64 4)

diff  --git a/llvm/test/CodeGen/AArch64/sve-umulo-sdnode.ll b/llvm/test/CodeGen/AArch64/sve-umulo-sdnode.ll
index 377031ba6b20..8b2455bff61e 100644
--- a/llvm/test/CodeGen/AArch64/sve-umulo-sdnode.ll
+++ b/llvm/test/CodeGen/AArch64/sve-umulo-sdnode.ll
@@ -9,15 +9,10 @@ define <vscale x 2 x i8> @umulo_nxv2i8(<vscale x 2 x i8> %x, <vscale x 2 x i8> %
 ; CHECK-NEXT:    ptrue p0.d
 ; CHECK-NEXT:    and z1.d, z1.d, #0xff
 ; CHECK-NEXT:    and z0.d, z0.d, #0xff
-; CHECK-NEXT:    movprfx z2, z0
-; CHECK-NEXT:    mul z2.d, p0/m, z2.d, z1.d
-; CHECK-NEXT:    umulh z0.d, p0/m, z0.d, z1.d
-; CHECK-NEXT:    lsr z1.d, z2.d, #8
-; CHECK-NEXT:    cmpne p1.d, p0/z, z0.d, #0
+; CHECK-NEXT:    mul z0.d, p0/m, z0.d, z1.d
+; CHECK-NEXT:    lsr z1.d, z0.d, #8
 ; CHECK-NEXT:    cmpne p0.d, p0/z, z1.d, #0
-; CHECK-NEXT:    sel p0.b, p0, p0.b, p1.b
-; CHECK-NEXT:    mov z2.d, p0/m, #0 // =0x0
-; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    mov z0.d, p0/m, #0 // =0x0
 ; CHECK-NEXT:    ret
   %a = call { <vscale x 2 x i8>, <vscale x 2 x i1> } @llvm.umul.with.overflow.nxv2i8(<vscale x 2 x i8> %x, <vscale x 2 x i8> %y)
   %b = extractvalue { <vscale x 2 x i8>, <vscale x 2 x i1> } %a, 0
@@ -34,15 +29,10 @@ define <vscale x 4 x i8> @umulo_nxv4i8(<vscale x 4 x i8> %x, <vscale x 4 x i8> %
 ; CHECK-NEXT:    ptrue p0.s
 ; CHECK-NEXT:    and z1.s, z1.s, #0xff
 ; CHECK-NEXT:    and z0.s, z0.s, #0xff
-; CHECK-NEXT:    movprfx z2, z0
-; CHECK-NEXT:    mul z2.s, p0/m, z2.s, z1.s
-; CHECK-NEXT:    umulh z0.s, p0/m, z0.s, z1.s
-; CHECK-NEXT:    lsr z1.s, z2.s, #8
-; CHECK-NEXT:    cmpne p1.s, p0/z, z0.s, #0
+; CHECK-NEXT:    mul z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT:    lsr z1.s, z0.s, #8
 ; CHECK-NEXT:    cmpne p0.s, p0/z, z1.s, #0
-; CHECK-NEXT:    sel p0.b, p0, p0.b, p1.b
-; CHECK-NEXT:    mov z2.s, p0/m, #0 // =0x0
-; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    mov z0.s, p0/m, #0 // =0x0
 ; CHECK-NEXT:    ret
   %a = call { <vscale x 4 x i8>, <vscale x 4 x i1> } @llvm.umul.with.overflow.nxv4i8(<vscale x 4 x i8> %x, <vscale x 4 x i8> %y)
   %b = extractvalue { <vscale x 4 x i8>, <vscale x 4 x i1> } %a, 0
@@ -164,15 +154,10 @@ define <vscale x 2 x i16> @umulo_nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1
 ; CHECK-NEXT:    ptrue p0.d
 ; CHECK-NEXT:    and z1.d, z1.d, #0xffff
 ; CHECK-NEXT:    and z0.d, z0.d, #0xffff
-; CHECK-NEXT:    movprfx z2, z0
-; CHECK-NEXT:    mul z2.d, p0/m, z2.d, z1.d
-; CHECK-NEXT:    umulh z0.d, p0/m, z0.d, z1.d
-; CHECK-NEXT:    lsr z1.d, z2.d, #16
-; CHECK-NEXT:    cmpne p1.d, p0/z, z0.d, #0
+; CHECK-NEXT:    mul z0.d, p0/m, z0.d, z1.d
+; CHECK-NEXT:    lsr z1.d, z0.d, #16
 ; CHECK-NEXT:    cmpne p0.d, p0/z, z1.d, #0
-; CHECK-NEXT:    sel p0.b, p0, p0.b, p1.b
-; CHECK-NEXT:    mov z2.d, p0/m, #0 // =0x0
-; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    mov z0.d, p0/m, #0 // =0x0
 ; CHECK-NEXT:    ret
   %a = call { <vscale x 2 x i16>, <vscale x 2 x i1> } @llvm.umul.with.overflow.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y)
   %b = extractvalue { <vscale x 2 x i16>, <vscale x 2 x i1> } %a, 0
@@ -189,15 +174,10 @@ define <vscale x 4 x i16> @umulo_nxv4i16(<vscale x 4 x i16> %x, <vscale x 4 x i1
 ; CHECK-NEXT:    ptrue p0.s
 ; CHECK-NEXT:    and z1.s, z1.s, #0xffff
 ; CHECK-NEXT:    and z0.s, z0.s, #0xffff
-; CHECK-NEXT:    movprfx z2, z0
-; CHECK-NEXT:    mul z2.s, p0/m, z2.s, z1.s
-; CHECK-NEXT:    umulh z0.s, p0/m, z0.s, z1.s
-; CHECK-NEXT:    lsr z1.s, z2.s, #16
-; CHECK-NEXT:    cmpne p1.s, p0/z, z0.s, #0
+; CHECK-NEXT:    mul z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT:    lsr z1.s, z0.s, #16
 ; CHECK-NEXT:    cmpne p0.s, p0/z, z1.s, #0
-; CHECK-NEXT:    sel p0.b, p0, p0.b, p1.b
-; CHECK-NEXT:    mov z2.s, p0/m, #0 // =0x0
-; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    mov z0.s, p0/m, #0 // =0x0
 ; CHECK-NEXT:    ret
   %a = call { <vscale x 4 x i16>, <vscale x 4 x i1> } @llvm.umul.with.overflow.nxv4i16(<vscale x 4 x i16> %x, <vscale x 4 x i16> %y)
   %b = extractvalue { <vscale x 4 x i16>, <vscale x 4 x i1> } %a, 0
@@ -294,15 +274,10 @@ define <vscale x 2 x i32> @umulo_nxv2i32(<vscale x 2 x i32> %x, <vscale x 2 x i3
 ; CHECK-NEXT:    ptrue p0.d
 ; CHECK-NEXT:    and z1.d, z1.d, #0xffffffff
 ; CHECK-NEXT:    and z0.d, z0.d, #0xffffffff
-; CHECK-NEXT:    movprfx z2, z0
-; CHECK-NEXT:    mul z2.d, p0/m, z2.d, z1.d
-; CHECK-NEXT:    umulh z0.d, p0/m, z0.d, z1.d
-; CHECK-NEXT:    lsr z1.d, z2.d, #32
-; CHECK-NEXT:    cmpne p1.d, p0/z, z0.d, #0
+; CHECK-NEXT:    mul z0.d, p0/m, z0.d, z1.d
+; CHECK-NEXT:    lsr z1.d, z0.d, #32
 ; CHECK-NEXT:    cmpne p0.d, p0/z, z1.d, #0
-; CHECK-NEXT:    sel p0.b, p0, p0.b, p1.b
-; CHECK-NEXT:    mov z2.d, p0/m, #0 // =0x0
-; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    mov z0.d, p0/m, #0 // =0x0
 ; CHECK-NEXT:    ret
   %a = call { <vscale x 2 x i32>, <vscale x 2 x i1> } @llvm.umul.with.overflow.nxv2i32(<vscale x 2 x i32> %x, <vscale x 2 x i32> %y)
   %b = extractvalue { <vscale x 2 x i32>, <vscale x 2 x i1> } %a, 0


        


More information about the llvm-commits mailing list