[llvm] bc0fea0 - [SDAG] Allow scalable vectors in ComputeKnownBits
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 18 07:40:52 PST 2022
Author: Philip Reames
Date: 2022-11-18T07:40:32-08:00
New Revision: bc0fea0d551b5182c541c43070830bfdcaa33ef2
URL: https://github.com/llvm/llvm-project/commit/bc0fea0d551b5182c541c43070830bfdcaa33ef2
DIFF: https://github.com/llvm/llvm-project/commit/bc0fea0d551b5182c541c43070830bfdcaa33ef2.diff
LOG: [SDAG] Allow scalable vectors in ComputeKnownBits
his is the SelectionDAG equivalent of D136470, and is thus an alternate patch to D128159.
The basic idea here is that we track a single lane for scalable vectors which corresponds to an unknown number of lanes at runtime. This is enough for us to perform lane wise reasoning on many arithmetic operations.
This patch also includes an implementation for SPLAT_VECTOR as without it, the lane wise reasoning has no base case. The original patch which inspired this (D128159), also included STEP_VECTOR. I plan to do that as a separate patch.
Differential Revision: https://reviews.llvm.org/D137140
Added:
Modified:
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
llvm/test/CodeGen/AArch64/sve-intrinsics-index.ll
llvm/test/CodeGen/AArch64/sve-intrinsics-perm-select.ll
llvm/test/CodeGen/AArch64/sve-umulo-sdnode.ll
Removed:
################################################################################
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 3d38968a7bc3..ce86bf4ea25b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2910,14 +2910,10 @@ const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(
KnownBits SelectionDAG::computeKnownBits(SDValue Op, unsigned Depth) const {
EVT VT = Op.getValueType();
- // TOOD: Until we have a plan for how to represent demanded elements for
- // scalable vectors, we can just bail out for now.
- if (Op.getValueType().isScalableVector()) {
- unsigned BitWidth = Op.getScalarValueSizeInBits();
- return KnownBits(BitWidth);
- }
-
- APInt DemandedElts = VT.isVector()
+ // Since the number of lanes in a scalable vector is unknown at compile time,
+ // we track one bit which is implicitly broadcast to all lanes. This means
+ // that all lanes in a scalable vector are considered demanded.
+ APInt DemandedElts = VT.isFixedLengthVector()
? APInt::getAllOnes(VT.getVectorNumElements())
: APInt(1, 1);
return computeKnownBits(Op, DemandedElts, Depth);
@@ -2932,11 +2928,6 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
KnownBits Known(BitWidth); // Don't know anything.
- // TOOD: Until we have a plan for how to represent demanded elements for
- // scalable vectors, we can just bail out for now.
- if (Op.getValueType().isScalableVector())
- return Known;
-
if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
// We know all of the bits for a constant!
return KnownBits::makeConstant(C->getAPIntValue());
@@ -2951,7 +2942,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
KnownBits Known2;
unsigned NumElts = DemandedElts.getBitWidth();
- assert((!Op.getValueType().isVector() ||
+ assert((!Op.getValueType().isFixedLengthVector() ||
NumElts == Op.getValueType().getVectorNumElements()) &&
"Unexpected vector size");
@@ -2963,7 +2954,18 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
case ISD::MERGE_VALUES:
return computeKnownBits(Op.getOperand(Op.getResNo()), DemandedElts,
Depth + 1);
+ case ISD::SPLAT_VECTOR: {
+ SDValue SrcOp = Op.getOperand(0);
+ Known = computeKnownBits(SrcOp, Depth + 1);
+ if (SrcOp.getValueSizeInBits() != BitWidth) {
+ assert(SrcOp.getValueSizeInBits() > BitWidth &&
+ "Expected SPLAT_VECTOR implicit truncation");
+ Known = Known.trunc(BitWidth);
+ }
+ break;
+ }
case ISD::BUILD_VECTOR:
+ assert(!Op.getValueType().isScalableVector());
// Collect the known bits that are shared by every demanded vector element.
Known.Zero.setAllBits(); Known.One.setAllBits();
for (unsigned i = 0, e = Op.getNumOperands(); i != e; ++i) {
@@ -2989,6 +2991,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
}
break;
case ISD::VECTOR_SHUFFLE: {
+ assert(!Op.getValueType().isScalableVector());
// Collect the known bits that are shared by every vector element referenced
// by the shuffle.
APInt DemandedLHS, DemandedRHS;
@@ -3016,6 +3019,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
break;
}
case ISD::CONCAT_VECTORS: {
+ if (Op.getValueType().isScalableVector())
+ break;
// Split DemandedElts and test each of the demanded subvectors.
Known.Zero.setAllBits(); Known.One.setAllBits();
EVT SubVectorVT = Op.getOperand(0).getValueType();
@@ -3036,6 +3041,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
break;
}
case ISD::INSERT_SUBVECTOR: {
+ if (Op.getValueType().isScalableVector())
+ break;
// Demand any elements from the subvector and the remainder from the src its
// inserted into.
SDValue Src = Op.getOperand(0);
@@ -3063,7 +3070,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
// Offset the demanded elts by the subvector index.
SDValue Src = Op.getOperand(0);
// Bail until we can represent demanded elements for scalable vectors.
- if (Src.getValueType().isScalableVector())
+ if (Op.getValueType().isScalableVector() || Src.getValueType().isScalableVector())
break;
uint64_t Idx = Op.getConstantOperandVal(1);
unsigned NumSrcElts = Src.getValueType().getVectorNumElements();
@@ -3072,6 +3079,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
break;
}
case ISD::SCALAR_TO_VECTOR: {
+ if (Op.getValueType().isScalableVector())
+ break;
// We know about scalar_to_vector as much as we know about it source,
// which becomes the first element of otherwise unknown vector.
if (DemandedElts != 1)
@@ -3085,6 +3094,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
break;
}
case ISD::BITCAST: {
+ if (Op.getValueType().isScalableVector())
+ break;
+
SDValue N0 = Op.getOperand(0);
EVT SubVT = N0.getValueType();
unsigned SubBitWidth = SubVT.getScalarSizeInBits();
@@ -3406,7 +3418,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
if (ISD::isNON_EXTLoad(LD) && Cst) {
// Determine any common known bits from the loaded constant pool value.
Type *CstTy = Cst->getType();
- if ((NumElts * BitWidth) == CstTy->getPrimitiveSizeInBits()) {
+ if ((NumElts * BitWidth) == CstTy->getPrimitiveSizeInBits() &&
+ !Op.getValueType().isScalableVector()) {
// If its a vector splat, then we can (quickly) reuse the scalar path.
// NOTE: We assume all elements match and none are UNDEF.
if (CstTy->isVectorTy()) {
@@ -3480,6 +3493,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
break;
}
case ISD::ZERO_EXTEND_VECTOR_INREG: {
+ if (Op.getValueType().isScalableVector())
+ break;
EVT InVT = Op.getOperand(0).getValueType();
APInt InDemandedElts = DemandedElts.zext(InVT.getVectorNumElements());
Known = computeKnownBits(Op.getOperand(0), InDemandedElts, Depth + 1);
@@ -3492,6 +3507,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
break;
}
case ISD::SIGN_EXTEND_VECTOR_INREG: {
+ if (Op.getValueType().isScalableVector())
+ break;
EVT InVT = Op.getOperand(0).getValueType();
APInt InDemandedElts = DemandedElts.zext(InVT.getVectorNumElements());
Known = computeKnownBits(Op.getOperand(0), InDemandedElts, Depth + 1);
@@ -3508,6 +3525,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
break;
}
case ISD::ANY_EXTEND_VECTOR_INREG: {
+ if (Op.getValueType().isScalableVector())
+ break;
EVT InVT = Op.getOperand(0).getValueType();
APInt InDemandedElts = DemandedElts.zext(InVT.getVectorNumElements());
Known = computeKnownBits(Op.getOperand(0), InDemandedElts, Depth + 1);
@@ -3673,6 +3692,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
break;
}
case ISD::INSERT_VECTOR_ELT: {
+ if (Op.getValueType().isScalableVector())
+ break;
+
// If we know the element index, split the demand between the
// source vector and the inserted element, otherwise assume we need
// the original demanded vector elements and the value.
@@ -3839,6 +3861,11 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
case ISD::INTRINSIC_WO_CHAIN:
case ISD::INTRINSIC_W_CHAIN:
case ISD::INTRINSIC_VOID:
+ // TODO: Probably okay to remove after audit; here to reduce change size
+ // in initial enablement patch for scalable vectors
+ if (Op.getValueType().isScalableVector())
+ break;
+
// Allow the target to implement this method for its nodes.
TLI->computeKnownBitsForTargetNode(Op, Known, DemandedElts, *this, Depth);
break;
diff --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-index.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-index.ll
index cc6000954763..7c46cf9c239e 100644
--- a/llvm/test/CodeGen/AArch64/sve-intrinsics-index.ll
+++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-index.ll
@@ -55,7 +55,8 @@ define <vscale x 2 x i64> @index_ii_range() {
define <vscale x 8 x i16> @index_ii_range_combine(i16 %a) {
; CHECK-LABEL: index_ii_range_combine:
; CHECK: // %bb.0:
-; CHECK-NEXT: index z0.h, #2, #8
+; CHECK-NEXT: index z0.h, #0, #8
+; CHECK-NEXT: orr z0.h, z0.h, #0x2
; CHECK-NEXT: ret
%val = insertelement <vscale x 8 x i16> poison, i16 2, i32 0
%val1 = shufflevector <vscale x 8 x i16> %val, <vscale x 8 x i16> poison, <vscale x 8 x i32> zeroinitializer
diff --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-perm-select.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-perm-select.ll
index 7b31f0e7f6d4..2f3f342cdb93 100644
--- a/llvm/test/CodeGen/AArch64/sve-intrinsics-perm-select.ll
+++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-perm-select.ll
@@ -574,7 +574,7 @@ define <vscale x 2 x i64> @dupq_i64_range(<vscale x 2 x i64> %a) {
; CHECK: // %bb.0:
; CHECK-NEXT: index z1.d, #0, #1
; CHECK-NEXT: and z1.d, z1.d, #0x1
-; CHECK-NEXT: add z1.d, z1.d, #8 // =0x8
+; CHECK-NEXT: orr z1.d, z1.d, #0x8
; CHECK-NEXT: tbl z0.d, { z0.d }, z1.d
; CHECK-NEXT: ret
%out = call <vscale x 2 x i64> @llvm.aarch64.sve.dupq.lane.nxv2i64(<vscale x 2 x i64> %a, i64 4)
diff --git a/llvm/test/CodeGen/AArch64/sve-umulo-sdnode.ll b/llvm/test/CodeGen/AArch64/sve-umulo-sdnode.ll
index 377031ba6b20..8ef7b8032cc0 100644
--- a/llvm/test/CodeGen/AArch64/sve-umulo-sdnode.ll
+++ b/llvm/test/CodeGen/AArch64/sve-umulo-sdnode.ll
@@ -9,15 +9,10 @@ define <vscale x 2 x i8> @umulo_nxv2i8(<vscale x 2 x i8> %x, <vscale x 2 x i8> %
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: and z1.d, z1.d, #0xff
; CHECK-NEXT: and z0.d, z0.d, #0xff
-; CHECK-NEXT: movprfx z2, z0
-; CHECK-NEXT: mul z2.d, p0/m, z2.d, z1.d
-; CHECK-NEXT: umulh z0.d, p0/m, z0.d, z1.d
-; CHECK-NEXT: lsr z1.d, z2.d, #8
-; CHECK-NEXT: cmpne p1.d, p0/z, z0.d, #0
+; CHECK-NEXT: mul z0.d, p0/m, z0.d, z1.d
+; CHECK-NEXT: lsr z1.d, z0.d, #8
; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, #0
-; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b
-; CHECK-NEXT: mov z2.d, p0/m, #0 // =0x0
-; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: mov z0.d, p0/m, #0 // =0x0
; CHECK-NEXT: ret
%a = call { <vscale x 2 x i8>, <vscale x 2 x i1> } @llvm.umul.with.overflow.nxv2i8(<vscale x 2 x i8> %x, <vscale x 2 x i8> %y)
%b = extractvalue { <vscale x 2 x i8>, <vscale x 2 x i1> } %a, 0
@@ -34,15 +29,10 @@ define <vscale x 4 x i8> @umulo_nxv4i8(<vscale x 4 x i8> %x, <vscale x 4 x i8> %
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: and z1.s, z1.s, #0xff
; CHECK-NEXT: and z0.s, z0.s, #0xff
-; CHECK-NEXT: movprfx z2, z0
-; CHECK-NEXT: mul z2.s, p0/m, z2.s, z1.s
-; CHECK-NEXT: umulh z0.s, p0/m, z0.s, z1.s
-; CHECK-NEXT: lsr z1.s, z2.s, #8
-; CHECK-NEXT: cmpne p1.s, p0/z, z0.s, #0
+; CHECK-NEXT: mul z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT: lsr z1.s, z0.s, #8
; CHECK-NEXT: cmpne p0.s, p0/z, z1.s, #0
-; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b
-; CHECK-NEXT: mov z2.s, p0/m, #0 // =0x0
-; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: mov z0.s, p0/m, #0 // =0x0
; CHECK-NEXT: ret
%a = call { <vscale x 4 x i8>, <vscale x 4 x i1> } @llvm.umul.with.overflow.nxv4i8(<vscale x 4 x i8> %x, <vscale x 4 x i8> %y)
%b = extractvalue { <vscale x 4 x i8>, <vscale x 4 x i1> } %a, 0
@@ -59,15 +49,10 @@ define <vscale x 8 x i8> @umulo_nxv8i8(<vscale x 8 x i8> %x, <vscale x 8 x i8> %
; CHECK-NEXT: ptrue p0.h
; CHECK-NEXT: and z1.h, z1.h, #0xff
; CHECK-NEXT: and z0.h, z0.h, #0xff
-; CHECK-NEXT: movprfx z2, z0
-; CHECK-NEXT: mul z2.h, p0/m, z2.h, z1.h
-; CHECK-NEXT: umulh z0.h, p0/m, z0.h, z1.h
-; CHECK-NEXT: lsr z1.h, z2.h, #8
-; CHECK-NEXT: cmpne p1.h, p0/z, z0.h, #0
+; CHECK-NEXT: mul z0.h, p0/m, z0.h, z1.h
+; CHECK-NEXT: lsr z1.h, z0.h, #8
; CHECK-NEXT: cmpne p0.h, p0/z, z1.h, #0
-; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b
-; CHECK-NEXT: mov z2.h, p0/m, #0 // =0x0
-; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: mov z0.h, p0/m, #0 // =0x0
; CHECK-NEXT: ret
%a = call { <vscale x 8 x i8>, <vscale x 8 x i1> } @llvm.umul.with.overflow.nxv8i8(<vscale x 8 x i8> %x, <vscale x 8 x i8> %y)
%b = extractvalue { <vscale x 8 x i8>, <vscale x 8 x i1> } %a, 0
@@ -164,15 +149,10 @@ define <vscale x 2 x i16> @umulo_nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: and z1.d, z1.d, #0xffff
; CHECK-NEXT: and z0.d, z0.d, #0xffff
-; CHECK-NEXT: movprfx z2, z0
-; CHECK-NEXT: mul z2.d, p0/m, z2.d, z1.d
-; CHECK-NEXT: umulh z0.d, p0/m, z0.d, z1.d
-; CHECK-NEXT: lsr z1.d, z2.d, #16
-; CHECK-NEXT: cmpne p1.d, p0/z, z0.d, #0
+; CHECK-NEXT: mul z0.d, p0/m, z0.d, z1.d
+; CHECK-NEXT: lsr z1.d, z0.d, #16
; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, #0
-; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b
-; CHECK-NEXT: mov z2.d, p0/m, #0 // =0x0
-; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: mov z0.d, p0/m, #0 // =0x0
; CHECK-NEXT: ret
%a = call { <vscale x 2 x i16>, <vscale x 2 x i1> } @llvm.umul.with.overflow.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y)
%b = extractvalue { <vscale x 2 x i16>, <vscale x 2 x i1> } %a, 0
@@ -189,15 +169,10 @@ define <vscale x 4 x i16> @umulo_nxv4i16(<vscale x 4 x i16> %x, <vscale x 4 x i1
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: and z1.s, z1.s, #0xffff
; CHECK-NEXT: and z0.s, z0.s, #0xffff
-; CHECK-NEXT: movprfx z2, z0
-; CHECK-NEXT: mul z2.s, p0/m, z2.s, z1.s
-; CHECK-NEXT: umulh z0.s, p0/m, z0.s, z1.s
-; CHECK-NEXT: lsr z1.s, z2.s, #16
-; CHECK-NEXT: cmpne p1.s, p0/z, z0.s, #0
+; CHECK-NEXT: mul z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT: lsr z1.s, z0.s, #16
; CHECK-NEXT: cmpne p0.s, p0/z, z1.s, #0
-; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b
-; CHECK-NEXT: mov z2.s, p0/m, #0 // =0x0
-; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: mov z0.s, p0/m, #0 // =0x0
; CHECK-NEXT: ret
%a = call { <vscale x 4 x i16>, <vscale x 4 x i1> } @llvm.umul.with.overflow.nxv4i16(<vscale x 4 x i16> %x, <vscale x 4 x i16> %y)
%b = extractvalue { <vscale x 4 x i16>, <vscale x 4 x i1> } %a, 0
@@ -294,15 +269,10 @@ define <vscale x 2 x i32> @umulo_nxv2i32(<vscale x 2 x i32> %x, <vscale x 2 x i3
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: and z1.d, z1.d, #0xffffffff
; CHECK-NEXT: and z0.d, z0.d, #0xffffffff
-; CHECK-NEXT: movprfx z2, z0
-; CHECK-NEXT: mul z2.d, p0/m, z2.d, z1.d
-; CHECK-NEXT: umulh z0.d, p0/m, z0.d, z1.d
-; CHECK-NEXT: lsr z1.d, z2.d, #32
-; CHECK-NEXT: cmpne p1.d, p0/z, z0.d, #0
+; CHECK-NEXT: mul z0.d, p0/m, z0.d, z1.d
+; CHECK-NEXT: lsr z1.d, z0.d, #32
; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, #0
-; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b
-; CHECK-NEXT: mov z2.d, p0/m, #0 // =0x0
-; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: mov z0.d, p0/m, #0 // =0x0
; CHECK-NEXT: ret
%a = call { <vscale x 2 x i32>, <vscale x 2 x i1> } @llvm.umul.with.overflow.nxv2i32(<vscale x 2 x i32> %x, <vscale x 2 x i32> %y)
%b = extractvalue { <vscale x 2 x i32>, <vscale x 2 x i1> } %a, 0
More information about the llvm-commits
mailing list