[llvm] [AArch64] Add @llvm.experimental.vector.match (PR #101974)
Ricardo Jesus via llvm-commits
llvm-commits at lists.llvm.org
Wed Nov 13 07:48:14 PST 2024
https://github.com/rj-jesus updated https://github.com/llvm/llvm-project/pull/101974
>From e9bd6d43f01c815839b2cf82a4afddb92cb7b711 Mon Sep 17 00:00:00 2001
From: Ricardo Jesus <rjj at nvidia.com>
Date: Fri, 19 Jul 2024 16:10:51 +0100
Subject: [PATCH 1/7] [AArch64] Add @llvm.experimental.vector.match
This patch introduces an experimental intrinsic for matching the
elements of one vector against the elements of another.
For AArch64 targets that support SVE2, it lowers to a MATCH instruction
for supported fixed and scalar types. Otherwise, the intrinsic has
generic lowering in SelectionDAGBuilder.
---
llvm/docs/LangRef.rst | 39 +++
.../llvm/Analysis/TargetTransformInfo.h | 10 +
.../llvm/Analysis/TargetTransformInfoImpl.h | 4 +
llvm/include/llvm/IR/Intrinsics.td | 8 +
llvm/lib/Analysis/TargetTransformInfo.cpp | 5 +
.../SelectionDAG/SelectionDAGBuilder.cpp | 36 +++
llvm/lib/IR/Verifier.cpp | 21 ++
.../Target/AArch64/AArch64ISelLowering.cpp | 53 ++++
.../AArch64/AArch64TargetTransformInfo.cpp | 24 ++
.../AArch64/AArch64TargetTransformInfo.h | 2 +
.../AArch64/intrinsic-vector-match-sve2.ll | 253 ++++++++++++++++++
11 files changed, 455 insertions(+)
create mode 100644 llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 566d0d4e4e81a3..aedb101c4af68c 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -20043,6 +20043,45 @@ are undefined.
}
+'``llvm.experimental.vector.match.*``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+This is an overloaded intrinsic. Support for specific vector types is target
+dependent.
+
+::
+
+ declare <<n> x i1> @llvm.experimental.vector.match(<<n> x <ty>> %op1, <<m> x <ty>> %op2, <<n> x i1> %mask)
+ declare <vscale x <n> x i1> @llvm.experimental.vector.match(<vscale x <n> x <ty>> %op1, <<m> x <ty>> %op2, <vscale x <n> x i1> %mask)
+
+Overview:
+"""""""""
+
+Find active elements of the first argument matching any elements of the second.
+
+Arguments:
+""""""""""
+
+The first argument is the search vector, the second argument the vector of
+elements we are searching for (i.e. for which we consider a match successful),
+and the third argument is a mask that controls which elements of the first
+argument are active.
+
+Semantics:
+""""""""""
+
+The '``llvm.experimental.vector.match``' intrinsic compares each active element
+in the first argument against the elements of the second argument, placing
+``1`` in the corresponding element of the output vector if any comparison is
+successful, and ``0`` otherwise. Inactive elements in the mask are set to ``0``
+in the output.
+
+The second argument needs to be a fixed-length vector with the same element
+type as the first argument.
+
Matrix Intrinsics
-----------------
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 0459941fe05cdc..f47874bf0407d5 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1771,6 +1771,11 @@ class TargetTransformInfo {
/// This should also apply to lowering for vector funnel shifts (rotates).
bool isVectorShiftByScalarCheap(Type *Ty) const;
+ /// \returns True if the target has hardware support for vector match
+ /// operations between vectors of type `VT` and search vectors of `SearchSize`
+ /// elements, and false otherwise.
+ bool hasVectorMatch(VectorType *VT, unsigned SearchSize) const;
+
struct VPLegalization {
enum VPTransform {
// keep the predicating parameter
@@ -2221,6 +2226,7 @@ class TargetTransformInfo::Concept {
SmallVectorImpl<Use *> &OpsToSink) const = 0;
virtual bool isVectorShiftByScalarCheap(Type *Ty) const = 0;
+ virtual bool hasVectorMatch(VectorType *VT, unsigned SearchSize) const = 0;
virtual VPLegalization
getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
virtual bool hasArmWideBranch(bool Thumb) const = 0;
@@ -3014,6 +3020,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
return Impl.isVectorShiftByScalarCheap(Ty);
}
+ bool hasVectorMatch(VectorType *VT, unsigned SearchSize) const override {
+ return Impl.hasVectorMatch(VT, SearchSize);
+ }
+
VPLegalization
getVPLegalizationStrategy(const VPIntrinsic &PI) const override {
return Impl.getVPLegalizationStrategy(PI);
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index dbdfb4d8cdfa32..886acb5120330f 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -995,6 +995,10 @@ class TargetTransformInfoImplBase {
bool isVectorShiftByScalarCheap(Type *Ty) const { return false; }
+ bool hasVectorMatch(VectorType *VT, unsigned SearchSize) const {
+ return false;
+ }
+
TargetTransformInfo::VPLegalization
getVPLegalizationStrategy(const VPIntrinsic &PI) const {
return TargetTransformInfo::VPLegalization(
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 94e53f372127da..ccdaf79f63f1ff 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1918,6 +1918,14 @@ def int_experimental_vector_histogram_add : DefaultAttrsIntrinsic<[],
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>], // Mask
[ IntrArgMemOnly ]>;
+// Experimental match
+def int_experimental_vector_match : DefaultAttrsIntrinsic<
+ [ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty> ],
+ [ llvm_anyvector_ty,
+ llvm_anyvector_ty,
+ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty> ], // Mask
+ [ IntrNoMem, IntrNoSync, IntrWillReturn ]>;
+
// Operators
let IntrProperties = [IntrNoMem, IntrNoSync, IntrWillReturn] in {
// Integer arithmetic
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index a47462b61e03b2..ca7b258dc08d79 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1383,6 +1383,11 @@ bool TargetTransformInfo::isVectorShiftByScalarCheap(Type *Ty) const {
return TTIImpl->isVectorShiftByScalarCheap(Ty);
}
+bool TargetTransformInfo::hasVectorMatch(VectorType *VT,
+ unsigned SearchSize) const {
+ return TTIImpl->hasVectorMatch(VT, SearchSize);
+}
+
TargetTransformInfo::Concept::~Concept() = default;
TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {}
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 8450553743074c..ca67f623f46258 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8156,6 +8156,42 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
DAG.getNode(ISD::EXTRACT_SUBVECTOR, sdl, ResultVT, Vec, Index));
return;
}
+ case Intrinsic::experimental_vector_match: {
+ SDValue Op1 = getValue(I.getOperand(0));
+ SDValue Op2 = getValue(I.getOperand(1));
+ SDValue Mask = getValue(I.getOperand(2));
+ EVT Op1VT = Op1.getValueType();
+ EVT Op2VT = Op2.getValueType();
+ EVT ResVT = Mask.getValueType();
+ unsigned SearchSize = Op2VT.getVectorNumElements();
+
+ LLVMContext &Ctx = *DAG.getContext();
+ const auto &TTI =
+ TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
+
+ // If the target has native support for this vector match operation, lower
+ // the intrinsic directly; otherwise, lower it below.
+ if (TTI.hasVectorMatch(cast<VectorType>(Op1VT.getTypeForEVT(Ctx)),
+ SearchSize)) {
+ visitTargetIntrinsic(I, Intrinsic);
+ return;
+ }
+
+ SDValue Ret = DAG.getNode(ISD::SPLAT_VECTOR, sdl, ResVT,
+ DAG.getConstant(0, sdl, MVT::i1));
+
+ for (unsigned i = 0; i < SearchSize; ++i) {
+ SDValue Op2Elem = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, sdl,
+ Op2VT.getVectorElementType(), Op2,
+ DAG.getVectorIdxConstant(i, sdl));
+ SDValue Splat = DAG.getNode(ISD::SPLAT_VECTOR, sdl, Op1VT, Op2Elem);
+ SDValue Cmp = DAG.getSetCC(sdl, ResVT, Op1, Splat, ISD::SETEQ);
+ Ret = DAG.getNode(ISD::OR, sdl, ResVT, Ret, Cmp);
+ }
+
+ setValue(&I, DAG.getNode(ISD::AND, sdl, ResVT, Ret, Mask));
+ return;
+ }
case Intrinsic::vector_reverse:
visitVectorReverse(I);
return;
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index f34fe7594c8602..8a031956d70dd9 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -6154,6 +6154,27 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
&Call);
break;
}
+ case Intrinsic::experimental_vector_match: {
+ Value *Op1 = Call.getArgOperand(0);
+ Value *Op2 = Call.getArgOperand(1);
+ Value *Mask = Call.getArgOperand(2);
+
+ VectorType *Op1Ty = dyn_cast<VectorType>(Op1->getType());
+ VectorType *Op2Ty = dyn_cast<VectorType>(Op2->getType());
+ VectorType *MaskTy = dyn_cast<VectorType>(Mask->getType());
+
+ Check(Op1Ty && Op2Ty && MaskTy, "Operands must be vectors.", &Call);
+ Check(!isa<ScalableVectorType>(Op2Ty), "Second operand cannot be scalable.",
+ &Call);
+ Check(Op1Ty->getElementType() == Op2Ty->getElementType(),
+ "First two operands must have the same element type.", &Call);
+ Check(Op1Ty->getElementCount() == MaskTy->getElementCount(),
+ "First operand and mask must have the same number of elements.",
+ &Call);
+ Check(MaskTy->getElementType()->isIntegerTy(1),
+ "Mask must be a vector of i1's.", &Call);
+ break;
+ }
case Intrinsic::vector_insert: {
Value *Vec = Call.getArgOperand(0);
Value *SubVec = Call.getArgOperand(1);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 4aa123b42d1966..979a0cd904f4ad 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -6364,6 +6364,58 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
DAG.getNode(AArch64ISD::CTTZ_ELTS, dl, MVT::i64, CttzOp);
return DAG.getZExtOrTrunc(NewCttzElts, dl, Op.getValueType());
}
+ case Intrinsic::experimental_vector_match: {
+ SDValue ID =
+ DAG.getTargetConstant(Intrinsic::aarch64_sve_match, dl, MVT::i64);
+
+ auto Op1 = Op.getOperand(1);
+ auto Op2 = Op.getOperand(2);
+ auto Mask = Op.getOperand(3);
+
+ EVT Op1VT = Op1.getValueType();
+ EVT Op2VT = Op2.getValueType();
+ EVT ResVT = Op.getValueType();
+
+ assert((Op1VT.getVectorElementType() == MVT::i8 ||
+ Op1VT.getVectorElementType() == MVT::i16) &&
+ "Expected 8-bit or 16-bit characters.");
+ assert(!Op2VT.isScalableVector() && "Search vector cannot be scalable.");
+ assert(Op1VT.getVectorElementType() == Op2VT.getVectorElementType() &&
+ "Operand type mismatch.");
+ assert(Op1VT.getVectorMinNumElements() == Op2VT.getVectorNumElements() &&
+ "Invalid operands.");
+
+ // Wrap the search vector in a scalable vector.
+ EVT OpContainerVT = getContainerForFixedLengthVector(DAG, Op2VT);
+ Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
+
+ // If the result is scalable, we need to broadbast the search vector across
+ // the SVE register and then carry out the MATCH.
+ if (ResVT.isScalableVector()) {
+ Op2 = DAG.getNode(AArch64ISD::DUPLANE128, dl, OpContainerVT, Op2,
+ DAG.getTargetConstant(0, dl, MVT::i64));
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, ResVT, ID, Mask, Op1,
+ Op2);
+ }
+
+ // If the result is fixed, we can still use MATCH but we need to wrap the
+ // first operand and the mask in scalable vectors before doing so.
+ EVT MatchVT = OpContainerVT.changeElementType(MVT::i1);
+
+ // Wrap the operands.
+ Op1 = convertToScalableVector(DAG, OpContainerVT, Op1);
+ Mask = DAG.getNode(ISD::ANY_EXTEND, dl, Op1VT, Mask);
+ Mask = convertFixedMaskToScalableVector(Mask, DAG);
+
+ // Carry out the match.
+ SDValue Match =
+ DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, MatchVT, ID, Mask, Op1, Op2);
+
+ // Extract and return the result.
+ return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, Op1VT,
+ DAG.getNode(ISD::SIGN_EXTEND, dl, OpContainerVT, Match),
+ DAG.getVectorIdxConstant(0, dl));
+ }
}
}
@@ -27046,6 +27098,7 @@ void AArch64TargetLowering::ReplaceNodeResults(
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
return;
}
+ case Intrinsic::experimental_vector_match:
case Intrinsic::get_active_lane_mask: {
if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i1)
return;
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index ff3c69f7e10c66..9d33a368a9e86d 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -4072,6 +4072,30 @@ bool AArch64TTIImpl::isLegalToVectorizeReduction(
}
}
+bool AArch64TTIImpl::hasVectorMatch(VectorType *VT, unsigned SearchSize) const {
+ // Check that (i) the target has SVE2 and SVE is available, (ii) `VT' is a
+ // legal type for MATCH, and (iii) the search vector can be broadcast
+ // efficently to a legal type.
+ //
+ // Currently, we require the length of the search vector to match the minimum
+ // number of elements of `VT'. In practice this means we only support the
+ // cases (nxv16i8, 16), (v16i8, 16), (nxv8i16, 8), and (v8i16, 8), where the
+ // first element of the tuples corresponds to the type of the first argument
+ // and the second the length of the search vector.
+ //
+ // In the future we can support more cases. For example, (nxv16i8, 4) could
+ // be efficiently supported by using a DUP.S to broadcast the search
+ // elements, and more exotic cases like (nxv16i8, 5) could be supported by a
+ // sequence of SEL(DUP).
+ if (ST->hasSVE2() && ST->isSVEAvailable() &&
+ VT->getPrimitiveSizeInBits().getKnownMinValue() == 128 &&
+ (VT->getElementCount().getKnownMinValue() == 8 ||
+ VT->getElementCount().getKnownMinValue() == 16) &&
+ VT->getElementCount().getKnownMinValue() == SearchSize)
+ return true;
+ return false;
+}
+
InstructionCost
AArch64TTIImpl::getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty,
FastMathFlags FMF,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 1d09d67f6ec9e3..580bf5e79c3da1 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -392,6 +392,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
return ST->hasSVE();
}
+ bool hasVectorMatch(VectorType *VT, unsigned SearchSize) const;
+
InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
std::optional<FastMathFlags> FMF,
TTI::TargetCostKind CostKind);
diff --git a/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll b/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll
new file mode 100644
index 00000000000000..d84a54f327a9bc
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll
@@ -0,0 +1,253 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 3
+; RUN: llc -mtriple=aarch64 < %s -o - | FileCheck %s
+
+define <vscale x 16 x i1> @match_nxv16i8_v1i8(<vscale x 16 x i8> %op1, <1 x i8> %op2, <vscale x 16 x i1> %mask) #0 {
+; CHECK-LABEL: match_nxv16i8_v1i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: // kill: def $d1 killed $d1 def $q1
+; CHECK-NEXT: umov w8, v1.b[0]
+; CHECK-NEXT: mov z1.b, w8
+; CHECK-NEXT: cmpeq p0.b, p0/z, z0.b, z1.b
+; CHECK-NEXT: ret
+ %r = tail call <vscale x 16 x i1> @llvm.experimental.vector.match(<vscale x 16 x i8> %op1, <1 x i8> %op2, <vscale x 16 x i1> %mask)
+ ret <vscale x 16 x i1> %r
+}
+
+define <vscale x 16 x i1> @match_nxv16i8_v2i8(<vscale x 16 x i8> %op1, <2 x i8> %op2, <vscale x 16 x i1> %mask) #0 {
+; CHECK-LABEL: match_nxv16i8_v2i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: // kill: def $d1 killed $d1 def $q1
+; CHECK-NEXT: mov w8, v1.s[1]
+; CHECK-NEXT: fmov w9, s1
+; CHECK-NEXT: ptrue p1.b
+; CHECK-NEXT: mov z2.b, w9
+; CHECK-NEXT: mov z1.b, w8
+; CHECK-NEXT: cmpeq p2.b, p1/z, z0.b, z1.b
+; CHECK-NEXT: cmpeq p1.b, p1/z, z0.b, z2.b
+; CHECK-NEXT: sel p1.b, p1, p1.b, p2.b
+; CHECK-NEXT: and p0.b, p1/z, p1.b, p0.b
+; CHECK-NEXT: ret
+ %r = tail call <vscale x 16 x i1> @llvm.experimental.vector.match(<vscale x 16 x i8> %op1, <2 x i8> %op2, <vscale x 16 x i1> %mask)
+ ret <vscale x 16 x i1> %r
+}
+
+define <vscale x 16 x i1> @match_nxv16i8_v4i8(<vscale x 16 x i8> %op1, <4 x i8> %op2, <vscale x 16 x i1> %mask) #0 {
+; CHECK-LABEL: match_nxv16i8_v4i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT: addvl sp, sp, #-1
+; CHECK-NEXT: str p4, [sp, #7, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT: .cfi_offset w29, -16
+; CHECK-NEXT: // kill: def $d1 killed $d1 def $q1
+; CHECK-NEXT: umov w8, v1.h[1]
+; CHECK-NEXT: umov w9, v1.h[0]
+; CHECK-NEXT: umov w10, v1.h[2]
+; CHECK-NEXT: ptrue p1.b
+; CHECK-NEXT: mov z2.b, w8
+; CHECK-NEXT: mov z3.b, w9
+; CHECK-NEXT: umov w8, v1.h[3]
+; CHECK-NEXT: mov z1.b, w10
+; CHECK-NEXT: cmpeq p2.b, p1/z, z0.b, z2.b
+; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z3.b
+; CHECK-NEXT: mov z2.b, w8
+; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z1.b
+; CHECK-NEXT: cmpeq p1.b, p1/z, z0.b, z2.b
+; CHECK-NEXT: mov p2.b, p3/m, p3.b
+; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT: ldr p4, [sp, #7, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT: mov p1.b, p2/m, p2.b
+; CHECK-NEXT: and p0.b, p1/z, p1.b, p0.b
+; CHECK-NEXT: addvl sp, sp, #1
+; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT: ret
+ %r = tail call <vscale x 16 x i1> @llvm.experimental.vector.match(<vscale x 16 x i8> %op1, <4 x i8> %op2, <vscale x 16 x i1> %mask)
+ ret <vscale x 16 x i1> %r
+}
+
+define <vscale x 16 x i1> @match_nxv16i8_v8i8(<vscale x 16 x i8> %op1, <8 x i8> %op2, <vscale x 16 x i1> %mask) #0 {
+; CHECK-LABEL: match_nxv16i8_v8i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT: addvl sp, sp, #-1
+; CHECK-NEXT: str p4, [sp, #7, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT: .cfi_offset w29, -16
+; CHECK-NEXT: // kill: def $d1 killed $d1 def $q1
+; CHECK-NEXT: umov w8, v1.b[1]
+; CHECK-NEXT: umov w9, v1.b[0]
+; CHECK-NEXT: umov w10, v1.b[2]
+; CHECK-NEXT: ptrue p1.b
+; CHECK-NEXT: mov z2.b, w8
+; CHECK-NEXT: mov z3.b, w9
+; CHECK-NEXT: umov w8, v1.b[3]
+; CHECK-NEXT: mov z4.b, w10
+; CHECK-NEXT: umov w9, v1.b[4]
+; CHECK-NEXT: umov w10, v1.b[7]
+; CHECK-NEXT: cmpeq p2.b, p1/z, z0.b, z2.b
+; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z3.b
+; CHECK-NEXT: mov z2.b, w8
+; CHECK-NEXT: umov w8, v1.b[5]
+; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z4.b
+; CHECK-NEXT: mov z3.b, w9
+; CHECK-NEXT: umov w9, v1.b[6]
+; CHECK-NEXT: mov p2.b, p3/m, p3.b
+; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z2.b
+; CHECK-NEXT: mov z1.b, w8
+; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z3.b
+; CHECK-NEXT: mov z2.b, w9
+; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z1.b
+; CHECK-NEXT: mov z1.b, w10
+; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z2.b
+; CHECK-NEXT: cmpeq p1.b, p1/z, z0.b, z1.b
+; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT: ldr p4, [sp, #7, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT: mov p1.b, p2/m, p2.b
+; CHECK-NEXT: and p0.b, p1/z, p1.b, p0.b
+; CHECK-NEXT: addvl sp, sp, #1
+; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT: ret
+ %r = tail call <vscale x 16 x i1> @llvm.experimental.vector.match(<vscale x 16 x i8> %op1, <8 x i8> %op2, <vscale x 16 x i1> %mask)
+ ret <vscale x 16 x i1> %r
+}
+
+define <vscale x 16 x i1> @match_nxv16i8_v16i8(<vscale x 16 x i8> %op1, <16 x i8> %op2, <vscale x 16 x i1> %mask) #0 {
+; CHECK-LABEL: match_nxv16i8_v16i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: mov z1.q, q1
+; CHECK-NEXT: match p0.b, p0/z, z0.b, z1.b
+; CHECK-NEXT: ret
+ %r = tail call <vscale x 16 x i1> @llvm.experimental.vector.match(<vscale x 16 x i8> %op1, <16 x i8> %op2, <vscale x 16 x i1> %mask)
+ ret <vscale x 16 x i1> %r
+}
+
+define <16 x i1> @match_v16i8_v1i8(<16 x i8> %op1, <1 x i8> %op2, <16 x i1> %mask) #0 {
+; CHECK-LABEL: match_v16i8_v1i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: // kill: def $d1 killed $d1 def $q1
+; CHECK-NEXT: dup v1.16b, v1.b[0]
+; CHECK-NEXT: cmeq v0.16b, v0.16b, v1.16b
+; CHECK-NEXT: and v0.16b, v0.16b, v2.16b
+; CHECK-NEXT: ret
+ %r = tail call <16 x i1> @llvm.experimental.vector.match(<16 x i8> %op1, <1 x i8> %op2, <16 x i1> %mask)
+ ret <16 x i1> %r
+}
+
+define <16 x i1> @match_v16i8_v2i8(<16 x i8> %op1, <2 x i8> %op2, <16 x i1> %mask) #0 {
+; CHECK-LABEL: match_v16i8_v2i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: // kill: def $d1 killed $d1 def $q1
+; CHECK-NEXT: dup v3.16b, v1.b[4]
+; CHECK-NEXT: dup v1.16b, v1.b[0]
+; CHECK-NEXT: cmeq v3.16b, v0.16b, v3.16b
+; CHECK-NEXT: cmeq v0.16b, v0.16b, v1.16b
+; CHECK-NEXT: orr v0.16b, v0.16b, v3.16b
+; CHECK-NEXT: and v0.16b, v0.16b, v2.16b
+; CHECK-NEXT: ret
+ %r = tail call <16 x i1> @llvm.experimental.vector.match(<16 x i8> %op1, <2 x i8> %op2, <16 x i1> %mask)
+ ret <16 x i1> %r
+}
+
+define <16 x i1> @match_v16i8_v4i8(<16 x i8> %op1, <4 x i8> %op2, <16 x i1> %mask) #0 {
+; CHECK-LABEL: match_v16i8_v4i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: // kill: def $d1 killed $d1 def $q1
+; CHECK-NEXT: dup v3.16b, v1.b[2]
+; CHECK-NEXT: dup v4.16b, v1.b[0]
+; CHECK-NEXT: dup v5.16b, v1.b[4]
+; CHECK-NEXT: dup v1.16b, v1.b[6]
+; CHECK-NEXT: cmeq v3.16b, v0.16b, v3.16b
+; CHECK-NEXT: cmeq v4.16b, v0.16b, v4.16b
+; CHECK-NEXT: cmeq v5.16b, v0.16b, v5.16b
+; CHECK-NEXT: cmeq v0.16b, v0.16b, v1.16b
+; CHECK-NEXT: orr v1.16b, v4.16b, v3.16b
+; CHECK-NEXT: orr v0.16b, v5.16b, v0.16b
+; CHECK-NEXT: orr v0.16b, v1.16b, v0.16b
+; CHECK-NEXT: and v0.16b, v0.16b, v2.16b
+; CHECK-NEXT: ret
+ %r = tail call <16 x i1> @llvm.experimental.vector.match(<16 x i8> %op1, <4 x i8> %op2, <16 x i1> %mask)
+ ret <16 x i1> %r
+}
+
+define <16 x i1> @match_v16i8_v8i8(<16 x i8> %op1, <8 x i8> %op2, <16 x i1> %mask) #0 {
+; CHECK-LABEL: match_v16i8_v8i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: // kill: def $d1 killed $d1 def $q1
+; CHECK-NEXT: dup v3.16b, v1.b[1]
+; CHECK-NEXT: dup v4.16b, v1.b[0]
+; CHECK-NEXT: dup v5.16b, v1.b[2]
+; CHECK-NEXT: dup v6.16b, v1.b[3]
+; CHECK-NEXT: dup v7.16b, v1.b[4]
+; CHECK-NEXT: dup v16.16b, v1.b[5]
+; CHECK-NEXT: dup v17.16b, v1.b[6]
+; CHECK-NEXT: dup v1.16b, v1.b[7]
+; CHECK-NEXT: cmeq v3.16b, v0.16b, v3.16b
+; CHECK-NEXT: cmeq v4.16b, v0.16b, v4.16b
+; CHECK-NEXT: cmeq v5.16b, v0.16b, v5.16b
+; CHECK-NEXT: cmeq v6.16b, v0.16b, v6.16b
+; CHECK-NEXT: cmeq v7.16b, v0.16b, v7.16b
+; CHECK-NEXT: cmeq v16.16b, v0.16b, v16.16b
+; CHECK-NEXT: orr v3.16b, v4.16b, v3.16b
+; CHECK-NEXT: orr v4.16b, v5.16b, v6.16b
+; CHECK-NEXT: orr v5.16b, v7.16b, v16.16b
+; CHECK-NEXT: cmeq v6.16b, v0.16b, v17.16b
+; CHECK-NEXT: cmeq v0.16b, v0.16b, v1.16b
+; CHECK-NEXT: orr v3.16b, v3.16b, v4.16b
+; CHECK-NEXT: orr v4.16b, v5.16b, v6.16b
+; CHECK-NEXT: orr v3.16b, v3.16b, v4.16b
+; CHECK-NEXT: orr v0.16b, v3.16b, v0.16b
+; CHECK-NEXT: and v0.16b, v0.16b, v2.16b
+; CHECK-NEXT: ret
+ %r = tail call <16 x i1> @llvm.experimental.vector.match(<16 x i8> %op1, <8 x i8> %op2, <16 x i1> %mask)
+ ret <16 x i1> %r
+}
+
+define <16 x i1> @match_v16i8_v16i8(<16 x i8> %op1, <16 x i8> %op2, <16 x i1> %mask) #0 {
+; CHECK-LABEL: match_v16i8_v16i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ptrue p0.b, vl16
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: cmpne p0.b, p0/z, z2.b, #0
+; CHECK-NEXT: match p0.b, p0/z, z0.b, z1.b
+; CHECK-NEXT: mov z0.b, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+ %r = tail call <16 x i1> @llvm.experimental.vector.match(<16 x i8> %op1, <16 x i8> %op2, <16 x i1> %mask)
+ ret <16 x i1> %r
+}
+
+define <vscale x 8 x i1> @match_nxv8i16_v8i16(<vscale x 8 x i16> %op1, <8 x i16> %op2, <vscale x 8 x i1> %mask) #0 {
+; CHECK-LABEL: match_nxv8i16_v8i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: mov z1.q, q1
+; CHECK-NEXT: match p0.h, p0/z, z0.h, z1.h
+; CHECK-NEXT: ret
+ %r = tail call <vscale x 8 x i1> @llvm.experimental.vector.match(<vscale x 8 x i16> %op1, <8 x i16> %op2, <vscale x 8 x i1> %mask)
+ ret <vscale x 8 x i1> %r
+}
+
+define <8 x i1> @match_v8i16(<8 x i16> %op1, <8 x i16> %op2, <8 x i1> %mask) #0 {
+; CHECK-LABEL: match_v8i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ushll v2.8h, v2.8b, #0
+; CHECK-NEXT: ptrue p0.h, vl8
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: cmpne p0.h, p0/z, z2.h, #0
+; CHECK-NEXT: match p0.h, p0/z, z0.h, z1.h
+; CHECK-NEXT: mov z0.h, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT: xtn v0.8b, v0.8h
+; CHECK-NEXT: ret
+ %r = tail call <8 x i1> @llvm.experimental.vector.match(<8 x i16> %op1, <8 x i16> %op2, <8 x i1> %mask)
+ ret <8 x i1> %r
+}
+
+attributes #0 = { "target-features"="+sve2" }
>From 3f9398dd3603962f280e8f0fa760079e446ec1db Mon Sep 17 00:00:00 2001
From: Ricardo Jesus <rjj at nvidia.com>
Date: Wed, 16 Oct 2024 10:11:00 -0700
Subject: [PATCH 2/7] Add support to lower partial search vectors
Add address other review comments.
---
llvm/lib/IR/Verifier.cpp | 2 +
.../Target/AArch64/AArch64ISelLowering.cpp | 88 +++-
.../AArch64/AArch64TargetTransformInfo.cpp | 15 +-
.../AArch64/intrinsic-vector-match-sve2.ll | 419 +++++++++++++++---
4 files changed, 424 insertions(+), 100 deletions(-)
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 8a031956d70dd9..09ec2ff4628d36 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -6173,6 +6173,8 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
&Call);
Check(MaskTy->getElementType()->isIntegerTy(1),
"Mask must be a vector of i1's.", &Call);
+ Check(Call.getType() == MaskTy, "Return type must match the mask type.",
+ &Call);
break;
}
case Intrinsic::vector_insert: {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 979a0cd904f4ad..d9a21d7f60df2c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -6379,42 +6379,86 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
assert((Op1VT.getVectorElementType() == MVT::i8 ||
Op1VT.getVectorElementType() == MVT::i16) &&
"Expected 8-bit or 16-bit characters.");
- assert(!Op2VT.isScalableVector() && "Search vector cannot be scalable.");
assert(Op1VT.getVectorElementType() == Op2VT.getVectorElementType() &&
"Operand type mismatch.");
- assert(Op1VT.getVectorMinNumElements() == Op2VT.getVectorNumElements() &&
- "Invalid operands.");
-
- // Wrap the search vector in a scalable vector.
- EVT OpContainerVT = getContainerForFixedLengthVector(DAG, Op2VT);
- Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
-
- // If the result is scalable, we need to broadbast the search vector across
- // the SVE register and then carry out the MATCH.
- if (ResVT.isScalableVector()) {
- Op2 = DAG.getNode(AArch64ISD::DUPLANE128, dl, OpContainerVT, Op2,
- DAG.getTargetConstant(0, dl, MVT::i64));
+ assert(!Op2VT.isScalableVector() && "Search vector cannot be scalable.");
+
+ // Note: Currently Op1 needs to be v16i8, v8i16, or the scalable versions.
+ // In the future we could support other types (e.g. v8i8).
+ assert(Op1VT.getSizeInBits().getKnownMinValue() == 128 &&
+ "Unsupported first operand type.");
+
+ // Scalable vector type used to wrap operands.
+ // A single container is enough for both operands because ultimately the
+ // operands will have to be wrapped to the same type (nxv16i8 or nxv8i16).
+ EVT OpContainerVT = Op1VT.isScalableVector()
+ ? Op1VT
+ : getContainerForFixedLengthVector(DAG, Op1VT);
+
+ // Wrap Op2 in a scalable register, and splat it if necessary.
+ if (Op1VT.getVectorMinNumElements() == Op2VT.getVectorNumElements()) {
+ // If Op1 and Op2 have the same number of elements we can trivially
+ // wrapping Op2 in an SVE register.
+ Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
+ // If the result is scalable, we need to broadcast Op2 to a full SVE
+ // register.
+ if (ResVT.isScalableVector())
+ Op2 = DAG.getNode(AArch64ISD::DUPLANE128, dl, OpContainerVT, Op2,
+ DAG.getTargetConstant(0, dl, MVT::i64));
+ } else {
+ // If Op1 and Op2 have different number of elements, we need to broadcast
+ // Op2. Ideally we would use a AArch64ISD::DUPLANE* node for this
+ // similarly to the above, but unfortunately it seems we are missing some
+ // patterns for this. So, in alternative, we splat Op2 through a splat of
+ // a scalable vector extract. This idiom, though a bit more verbose, is
+ // supported and get us the MOV instruction we want.
+
+ // Some types we need. We'll use an integer type with `Op2BitWidth' bits
+ // to wrap Op2 and simulate the DUPLANE.
+ unsigned Op2BitWidth = Op2VT.getFixedSizeInBits();
+ MVT Op2IntVT = MVT::getIntegerVT(Op2BitWidth);
+ MVT Op2FixedVT = MVT::getVectorVT(Op2IntVT, 128 / Op2BitWidth);
+ EVT Op2ScalableVT = getContainerForFixedLengthVector(DAG, Op2FixedVT);
+ // Widen Op2 to a full 128-bit register. We need this to wrap Op2 in an
+ // SVE register before doing the extract and splat.
+ // It is unlikely we'll be widening from types other than v8i8 or v4i16,
+ // so in practice this loop will run for a single iteration.
+ while (Op2VT.getFixedSizeInBits() != 128) {
+ Op2VT = Op2VT.getDoubleNumVectorElementsVT(*DAG.getContext());
+ Op2 = DAG.getNode(ISD::CONCAT_VECTORS, dl, Op2VT, Op2,
+ DAG.getUNDEF(Op2.getValueType()));
+ }
+ // Wrap Op2 in a scalable vector and do the splat of its 0-index lane.
+ Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
+ Op2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, Op2IntVT,
+ DAG.getBitcast(Op2ScalableVT, Op2),
+ DAG.getConstant(0, dl, MVT::i64));
+ Op2 = DAG.getSplatVector(Op2ScalableVT, dl, Op2);
+ Op2 = DAG.getBitcast(OpContainerVT, Op2);
+ }
+
+ // If the result is scalable, we just need to carry out the MATCH.
+ if (ResVT.isScalableVector())
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, ResVT, ID, Mask, Op1,
Op2);
- }
// If the result is fixed, we can still use MATCH but we need to wrap the
// first operand and the mask in scalable vectors before doing so.
- EVT MatchVT = OpContainerVT.changeElementType(MVT::i1);
// Wrap the operands.
Op1 = convertToScalableVector(DAG, OpContainerVT, Op1);
Mask = DAG.getNode(ISD::ANY_EXTEND, dl, Op1VT, Mask);
Mask = convertFixedMaskToScalableVector(Mask, DAG);
- // Carry out the match.
- SDValue Match =
- DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, MatchVT, ID, Mask, Op1, Op2);
+ // Carry out the match and extract it.
+ SDValue Match = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl,
+ Mask.getValueType(), ID, Mask, Op1, Op2);
+ Match = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, Op1VT,
+ DAG.getNode(ISD::SIGN_EXTEND, dl, OpContainerVT, Match),
+ DAG.getVectorIdxConstant(0, dl));
- // Extract and return the result.
- return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, Op1VT,
- DAG.getNode(ISD::SIGN_EXTEND, dl, OpContainerVT, Match),
- DAG.getVectorIdxConstant(0, dl));
+ // Truncate and return the result.
+ return DAG.getNode(ISD::TRUNCATE, dl, ResVT, Match);
}
}
}
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 9d33a368a9e86d..4a68ecce3b654c 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -4077,21 +4077,14 @@ bool AArch64TTIImpl::hasVectorMatch(VectorType *VT, unsigned SearchSize) const {
// legal type for MATCH, and (iii) the search vector can be broadcast
// efficently to a legal type.
//
- // Currently, we require the length of the search vector to match the minimum
- // number of elements of `VT'. In practice this means we only support the
- // cases (nxv16i8, 16), (v16i8, 16), (nxv8i16, 8), and (v8i16, 8), where the
- // first element of the tuples corresponds to the type of the first argument
- // and the second the length of the search vector.
- //
- // In the future we can support more cases. For example, (nxv16i8, 4) could
- // be efficiently supported by using a DUP.S to broadcast the search
- // elements, and more exotic cases like (nxv16i8, 5) could be supported by a
- // sequence of SEL(DUP).
+ // Currently, we require the search vector to be 64-bit or 128-bit. In the
+ // future we can support more cases.
if (ST->hasSVE2() && ST->isSVEAvailable() &&
VT->getPrimitiveSizeInBits().getKnownMinValue() == 128 &&
(VT->getElementCount().getKnownMinValue() == 8 ||
VT->getElementCount().getKnownMinValue() == 16) &&
- VT->getElementCount().getKnownMinValue() == SearchSize)
+ (VT->getElementCount().getKnownMinValue() == SearchSize ||
+ VT->getElementCount().getKnownMinValue() / 2 == SearchSize))
return true;
return false;
}
diff --git a/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll b/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll
index d84a54f327a9bc..e991bbc80962b5 100644
--- a/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll
+++ b/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll
@@ -68,48 +68,9 @@ define <vscale x 16 x i1> @match_nxv16i8_v4i8(<vscale x 16 x i8> %op1, <4 x i8>
define <vscale x 16 x i1> @match_nxv16i8_v8i8(<vscale x 16 x i8> %op1, <8 x i8> %op2, <vscale x 16 x i1> %mask) #0 {
; CHECK-LABEL: match_nxv16i8_v8i8:
; CHECK: // %bb.0:
-; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
-; CHECK-NEXT: addvl sp, sp, #-1
-; CHECK-NEXT: str p4, [sp, #7, mul vl] // 2-byte Folded Spill
-; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
-; CHECK-NEXT: .cfi_offset w29, -16
-; CHECK-NEXT: // kill: def $d1 killed $d1 def $q1
-; CHECK-NEXT: umov w8, v1.b[1]
-; CHECK-NEXT: umov w9, v1.b[0]
-; CHECK-NEXT: umov w10, v1.b[2]
-; CHECK-NEXT: ptrue p1.b
-; CHECK-NEXT: mov z2.b, w8
-; CHECK-NEXT: mov z3.b, w9
-; CHECK-NEXT: umov w8, v1.b[3]
-; CHECK-NEXT: mov z4.b, w10
-; CHECK-NEXT: umov w9, v1.b[4]
-; CHECK-NEXT: umov w10, v1.b[7]
-; CHECK-NEXT: cmpeq p2.b, p1/z, z0.b, z2.b
-; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z3.b
-; CHECK-NEXT: mov z2.b, w8
-; CHECK-NEXT: umov w8, v1.b[5]
-; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z4.b
-; CHECK-NEXT: mov z3.b, w9
-; CHECK-NEXT: umov w9, v1.b[6]
-; CHECK-NEXT: mov p2.b, p3/m, p3.b
-; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z2.b
-; CHECK-NEXT: mov z1.b, w8
-; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
-; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z3.b
-; CHECK-NEXT: mov z2.b, w9
-; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
-; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z1.b
-; CHECK-NEXT: mov z1.b, w10
-; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
-; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z2.b
-; CHECK-NEXT: cmpeq p1.b, p1/z, z0.b, z1.b
-; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
-; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
-; CHECK-NEXT: ldr p4, [sp, #7, mul vl] // 2-byte Folded Reload
-; CHECK-NEXT: mov p1.b, p2/m, p2.b
-; CHECK-NEXT: and p0.b, p1/z, p1.b, p0.b
-; CHECK-NEXT: addvl sp, sp, #1
-; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1
+; CHECK-NEXT: mov z1.d, d1
+; CHECK-NEXT: match p0.b, p0/z, z0.b, z1.b
; CHECK-NEXT: ret
%r = tail call <vscale x 16 x i1> @llvm.experimental.vector.match(<vscale x 16 x i8> %op1, <8 x i8> %op2, <vscale x 16 x i1> %mask)
ret <vscale x 16 x i1> %r
@@ -177,31 +138,15 @@ define <16 x i1> @match_v16i8_v4i8(<16 x i8> %op1, <4 x i8> %op2, <16 x i1> %mas
define <16 x i1> @match_v16i8_v8i8(<16 x i8> %op1, <8 x i8> %op2, <16 x i1> %mask) #0 {
; CHECK-LABEL: match_v16i8_v8i8:
; CHECK: // %bb.0:
-; CHECK-NEXT: // kill: def $d1 killed $d1 def $q1
-; CHECK-NEXT: dup v3.16b, v1.b[1]
-; CHECK-NEXT: dup v4.16b, v1.b[0]
-; CHECK-NEXT: dup v5.16b, v1.b[2]
-; CHECK-NEXT: dup v6.16b, v1.b[3]
-; CHECK-NEXT: dup v7.16b, v1.b[4]
-; CHECK-NEXT: dup v16.16b, v1.b[5]
-; CHECK-NEXT: dup v17.16b, v1.b[6]
-; CHECK-NEXT: dup v1.16b, v1.b[7]
-; CHECK-NEXT: cmeq v3.16b, v0.16b, v3.16b
-; CHECK-NEXT: cmeq v4.16b, v0.16b, v4.16b
-; CHECK-NEXT: cmeq v5.16b, v0.16b, v5.16b
-; CHECK-NEXT: cmeq v6.16b, v0.16b, v6.16b
-; CHECK-NEXT: cmeq v7.16b, v0.16b, v7.16b
-; CHECK-NEXT: cmeq v16.16b, v0.16b, v16.16b
-; CHECK-NEXT: orr v3.16b, v4.16b, v3.16b
-; CHECK-NEXT: orr v4.16b, v5.16b, v6.16b
-; CHECK-NEXT: orr v5.16b, v7.16b, v16.16b
-; CHECK-NEXT: cmeq v6.16b, v0.16b, v17.16b
-; CHECK-NEXT: cmeq v0.16b, v0.16b, v1.16b
-; CHECK-NEXT: orr v3.16b, v3.16b, v4.16b
-; CHECK-NEXT: orr v4.16b, v5.16b, v6.16b
-; CHECK-NEXT: orr v3.16b, v3.16b, v4.16b
-; CHECK-NEXT: orr v0.16b, v3.16b, v0.16b
-; CHECK-NEXT: and v0.16b, v0.16b, v2.16b
+; CHECK-NEXT: ptrue p0.b, vl16
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: mov z1.d, d1
+; CHECK-NEXT: cmpne p0.b, p0/z, z2.b, #0
+; CHECK-NEXT: match p0.b, p0/z, z0.b, z1.b
+; CHECK-NEXT: mov z0.b, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
; CHECK-NEXT: ret
%r = tail call <16 x i1> @llvm.experimental.vector.match(<16 x i8> %op1, <8 x i8> %op2, <16 x i1> %mask)
ret <16 x i1> %r
@@ -250,4 +195,344 @@ define <8 x i1> @match_v8i16(<8 x i16> %op1, <8 x i16> %op2, <8 x i1> %mask) #0
ret <8 x i1> %r
}
+; Cases where op2 has more elements than op1.
+
+define <vscale x 16 x i1> @match_nxv16i8_v32i8(<vscale x 16 x i8> %op1, <32 x i8> %op2, <vscale x 16 x i1> %mask) #0 {
+; CHECK-LABEL: match_nxv16i8_v32i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT: addvl sp, sp, #-1
+; CHECK-NEXT: str p4, [sp, #7, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT: .cfi_offset w29, -16
+; CHECK-NEXT: umov w8, v1.b[1]
+; CHECK-NEXT: umov w9, v1.b[0]
+; CHECK-NEXT: umov w10, v1.b[2]
+; CHECK-NEXT: ptrue p1.b
+; CHECK-NEXT: mov z3.b, w8
+; CHECK-NEXT: mov z4.b, w9
+; CHECK-NEXT: umov w8, v1.b[3]
+; CHECK-NEXT: mov z5.b, w10
+; CHECK-NEXT: umov w9, v1.b[4]
+; CHECK-NEXT: umov w10, v1.b[15]
+; CHECK-NEXT: cmpeq p2.b, p1/z, z0.b, z3.b
+; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z4.b
+; CHECK-NEXT: mov z3.b, w8
+; CHECK-NEXT: umov w8, v1.b[5]
+; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z5.b
+; CHECK-NEXT: mov z4.b, w9
+; CHECK-NEXT: umov w9, v1.b[6]
+; CHECK-NEXT: mov p2.b, p3/m, p3.b
+; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z3.b
+; CHECK-NEXT: mov z3.b, w8
+; CHECK-NEXT: umov w8, v1.b[7]
+; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z4.b
+; CHECK-NEXT: mov z4.b, w9
+; CHECK-NEXT: umov w9, v1.b[8]
+; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z3.b
+; CHECK-NEXT: mov z3.b, w8
+; CHECK-NEXT: umov w8, v1.b[9]
+; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z4.b
+; CHECK-NEXT: mov z4.b, w9
+; CHECK-NEXT: umov w9, v1.b[10]
+; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z3.b
+; CHECK-NEXT: mov z3.b, w8
+; CHECK-NEXT: umov w8, v1.b[11]
+; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z4.b
+; CHECK-NEXT: mov z4.b, w9
+; CHECK-NEXT: umov w9, v1.b[12]
+; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z3.b
+; CHECK-NEXT: mov z3.b, w8
+; CHECK-NEXT: umov w8, v1.b[13]
+; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z4.b
+; CHECK-NEXT: mov z4.b, w9
+; CHECK-NEXT: umov w9, v1.b[14]
+; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z3.b
+; CHECK-NEXT: mov z1.b, w8
+; CHECK-NEXT: umov w8, v2.b[0]
+; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z4.b
+; CHECK-NEXT: mov z3.b, w9
+; CHECK-NEXT: umov w9, v2.b[1]
+; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z1.b
+; CHECK-NEXT: mov z1.b, w10
+; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z3.b
+; CHECK-NEXT: mov z3.b, w8
+; CHECK-NEXT: umov w8, v2.b[2]
+; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z1.b
+; CHECK-NEXT: mov z1.b, w9
+; CHECK-NEXT: umov w9, v2.b[3]
+; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z3.b
+; CHECK-NEXT: mov z3.b, w8
+; CHECK-NEXT: umov w8, v2.b[4]
+; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z1.b
+; CHECK-NEXT: mov z1.b, w9
+; CHECK-NEXT: umov w9, v2.b[5]
+; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z3.b
+; CHECK-NEXT: mov z3.b, w8
+; CHECK-NEXT: umov w8, v2.b[6]
+; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z1.b
+; CHECK-NEXT: mov z1.b, w9
+; CHECK-NEXT: umov w9, v2.b[7]
+; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z3.b
+; CHECK-NEXT: mov z3.b, w8
+; CHECK-NEXT: umov w8, v2.b[8]
+; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z1.b
+; CHECK-NEXT: mov z1.b, w9
+; CHECK-NEXT: umov w9, v2.b[9]
+; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z3.b
+; CHECK-NEXT: mov z3.b, w8
+; CHECK-NEXT: umov w8, v2.b[10]
+; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z1.b
+; CHECK-NEXT: mov z1.b, w9
+; CHECK-NEXT: umov w9, v2.b[11]
+; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z3.b
+; CHECK-NEXT: mov z3.b, w8
+; CHECK-NEXT: umov w8, v2.b[12]
+; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z1.b
+; CHECK-NEXT: mov z1.b, w9
+; CHECK-NEXT: umov w9, v2.b[13]
+; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z3.b
+; CHECK-NEXT: mov z3.b, w8
+; CHECK-NEXT: umov w8, v2.b[14]
+; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z1.b
+; CHECK-NEXT: mov z1.b, w9
+; CHECK-NEXT: umov w9, v2.b[15]
+; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z3.b
+; CHECK-NEXT: mov z2.b, w8
+; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z1.b
+; CHECK-NEXT: mov z1.b, w9
+; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z2.b
+; CHECK-NEXT: cmpeq p1.b, p1/z, z0.b, z1.b
+; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT: ldr p4, [sp, #7, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT: mov p1.b, p2/m, p2.b
+; CHECK-NEXT: and p0.b, p1/z, p1.b, p0.b
+; CHECK-NEXT: addvl sp, sp, #1
+; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT: ret
+ %r = tail call <vscale x 16 x i1> @llvm.experimental.vector.match(<vscale x 16 x i8> %op1, <32 x i8> %op2, <vscale x 16 x i1> %mask)
+ ret <vscale x 16 x i1> %r
+}
+
+define <16 x i1> @match_v16i8_v32i8(<16 x i8> %op1, <32 x i8> %op2, <16 x i1> %mask) #0 {
+; CHECK-LABEL: match_v16i8_v32i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: dup v4.16b, v1.b[1]
+; CHECK-NEXT: dup v5.16b, v1.b[0]
+; CHECK-NEXT: dup v6.16b, v1.b[2]
+; CHECK-NEXT: dup v7.16b, v1.b[3]
+; CHECK-NEXT: dup v16.16b, v1.b[4]
+; CHECK-NEXT: dup v17.16b, v1.b[5]
+; CHECK-NEXT: dup v18.16b, v1.b[6]
+; CHECK-NEXT: dup v19.16b, v1.b[7]
+; CHECK-NEXT: dup v20.16b, v1.b[8]
+; CHECK-NEXT: cmeq v4.16b, v0.16b, v4.16b
+; CHECK-NEXT: cmeq v5.16b, v0.16b, v5.16b
+; CHECK-NEXT: cmeq v6.16b, v0.16b, v6.16b
+; CHECK-NEXT: cmeq v7.16b, v0.16b, v7.16b
+; CHECK-NEXT: cmeq v16.16b, v0.16b, v16.16b
+; CHECK-NEXT: cmeq v17.16b, v0.16b, v17.16b
+; CHECK-NEXT: dup v21.16b, v2.b[7]
+; CHECK-NEXT: dup v22.16b, v1.b[10]
+; CHECK-NEXT: orr v4.16b, v5.16b, v4.16b
+; CHECK-NEXT: orr v5.16b, v6.16b, v7.16b
+; CHECK-NEXT: orr v6.16b, v16.16b, v17.16b
+; CHECK-NEXT: cmeq v7.16b, v0.16b, v18.16b
+; CHECK-NEXT: cmeq v16.16b, v0.16b, v19.16b
+; CHECK-NEXT: cmeq v17.16b, v0.16b, v20.16b
+; CHECK-NEXT: dup v18.16b, v1.b[9]
+; CHECK-NEXT: dup v19.16b, v1.b[11]
+; CHECK-NEXT: dup v20.16b, v1.b[12]
+; CHECK-NEXT: cmeq v22.16b, v0.16b, v22.16b
+; CHECK-NEXT: orr v4.16b, v4.16b, v5.16b
+; CHECK-NEXT: orr v5.16b, v6.16b, v7.16b
+; CHECK-NEXT: orr v6.16b, v16.16b, v17.16b
+; CHECK-NEXT: cmeq v7.16b, v0.16b, v18.16b
+; CHECK-NEXT: dup v18.16b, v1.b[13]
+; CHECK-NEXT: cmeq v16.16b, v0.16b, v19.16b
+; CHECK-NEXT: cmeq v17.16b, v0.16b, v20.16b
+; CHECK-NEXT: dup v19.16b, v2.b[0]
+; CHECK-NEXT: dup v20.16b, v2.b[1]
+; CHECK-NEXT: orr v4.16b, v4.16b, v5.16b
+; CHECK-NEXT: dup v5.16b, v2.b[6]
+; CHECK-NEXT: orr v6.16b, v6.16b, v7.16b
+; CHECK-NEXT: orr v7.16b, v16.16b, v17.16b
+; CHECK-NEXT: cmeq v16.16b, v0.16b, v18.16b
+; CHECK-NEXT: cmeq v17.16b, v0.16b, v19.16b
+; CHECK-NEXT: cmeq v18.16b, v0.16b, v20.16b
+; CHECK-NEXT: dup v19.16b, v2.b[2]
+; CHECK-NEXT: cmeq v5.16b, v0.16b, v5.16b
+; CHECK-NEXT: cmeq v20.16b, v0.16b, v21.16b
+; CHECK-NEXT: dup v21.16b, v2.b[8]
+; CHECK-NEXT: orr v6.16b, v6.16b, v22.16b
+; CHECK-NEXT: orr v7.16b, v7.16b, v16.16b
+; CHECK-NEXT: dup v16.16b, v1.b[14]
+; CHECK-NEXT: dup v1.16b, v1.b[15]
+; CHECK-NEXT: orr v17.16b, v17.16b, v18.16b
+; CHECK-NEXT: cmeq v18.16b, v0.16b, v19.16b
+; CHECK-NEXT: dup v19.16b, v2.b[3]
+; CHECK-NEXT: orr v5.16b, v5.16b, v20.16b
+; CHECK-NEXT: cmeq v20.16b, v0.16b, v21.16b
+; CHECK-NEXT: dup v21.16b, v2.b[9]
+; CHECK-NEXT: cmeq v16.16b, v0.16b, v16.16b
+; CHECK-NEXT: cmeq v1.16b, v0.16b, v1.16b
+; CHECK-NEXT: orr v4.16b, v4.16b, v6.16b
+; CHECK-NEXT: orr v17.16b, v17.16b, v18.16b
+; CHECK-NEXT: cmeq v18.16b, v0.16b, v19.16b
+; CHECK-NEXT: dup v19.16b, v2.b[4]
+; CHECK-NEXT: orr v5.16b, v5.16b, v20.16b
+; CHECK-NEXT: cmeq v20.16b, v0.16b, v21.16b
+; CHECK-NEXT: dup v21.16b, v2.b[10]
+; CHECK-NEXT: orr v7.16b, v7.16b, v16.16b
+; CHECK-NEXT: orr v16.16b, v17.16b, v18.16b
+; CHECK-NEXT: cmeq v17.16b, v0.16b, v19.16b
+; CHECK-NEXT: dup v18.16b, v2.b[5]
+; CHECK-NEXT: orr v5.16b, v5.16b, v20.16b
+; CHECK-NEXT: cmeq v19.16b, v0.16b, v21.16b
+; CHECK-NEXT: dup v20.16b, v2.b[11]
+; CHECK-NEXT: orr v1.16b, v7.16b, v1.16b
+; CHECK-NEXT: orr v6.16b, v16.16b, v17.16b
+; CHECK-NEXT: cmeq v7.16b, v0.16b, v18.16b
+; CHECK-NEXT: dup v17.16b, v2.b[12]
+; CHECK-NEXT: orr v5.16b, v5.16b, v19.16b
+; CHECK-NEXT: cmeq v16.16b, v0.16b, v20.16b
+; CHECK-NEXT: dup v18.16b, v2.b[13]
+; CHECK-NEXT: dup v19.16b, v2.b[14]
+; CHECK-NEXT: orr v1.16b, v4.16b, v1.16b
+; CHECK-NEXT: dup v2.16b, v2.b[15]
+; CHECK-NEXT: orr v4.16b, v6.16b, v7.16b
+; CHECK-NEXT: cmeq v6.16b, v0.16b, v17.16b
+; CHECK-NEXT: orr v5.16b, v5.16b, v16.16b
+; CHECK-NEXT: cmeq v7.16b, v0.16b, v18.16b
+; CHECK-NEXT: cmeq v16.16b, v0.16b, v19.16b
+; CHECK-NEXT: cmeq v0.16b, v0.16b, v2.16b
+; CHECK-NEXT: orr v1.16b, v1.16b, v4.16b
+; CHECK-NEXT: orr v4.16b, v5.16b, v6.16b
+; CHECK-NEXT: orr v5.16b, v7.16b, v16.16b
+; CHECK-NEXT: orr v1.16b, v1.16b, v4.16b
+; CHECK-NEXT: orr v0.16b, v5.16b, v0.16b
+; CHECK-NEXT: orr v0.16b, v1.16b, v0.16b
+; CHECK-NEXT: and v0.16b, v0.16b, v3.16b
+; CHECK-NEXT: ret
+ %r = tail call <16 x i1> @llvm.experimental.vector.match(<16 x i8> %op1, <32 x i8> %op2, <16 x i1> %mask)
+ ret <16 x i1> %r
+}
+
+; Data types not supported by MATCH.
+; Note: The cases for SVE could be made tighter.
+
+define <vscale x 4 x i1> @match_nxv4xi32_v4i32(<vscale x 4 x i32> %op1, <4 x i32> %op2, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: match_nxv4xi32_v4i32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT: addvl sp, sp, #-1
+; CHECK-NEXT: str p4, [sp, #7, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT: .cfi_offset w29, -16
+; CHECK-NEXT: mov w8, v1.s[1]
+; CHECK-NEXT: fmov w10, s1
+; CHECK-NEXT: mov w9, v1.s[2]
+; CHECK-NEXT: mov w11, v1.s[3]
+; CHECK-NEXT: ptrue p1.s
+; CHECK-NEXT: mov z2.s, w10
+; CHECK-NEXT: mov z1.s, w8
+; CHECK-NEXT: mov z3.s, w9
+; CHECK-NEXT: cmpeq p3.s, p1/z, z0.s, z2.s
+; CHECK-NEXT: cmpeq p2.s, p1/z, z0.s, z1.s
+; CHECK-NEXT: mov z1.s, w11
+; CHECK-NEXT: cmpeq p4.s, p1/z, z0.s, z3.s
+; CHECK-NEXT: cmpeq p1.s, p1/z, z0.s, z1.s
+; CHECK-NEXT: mov p2.b, p3/m, p3.b
+; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT: ldr p4, [sp, #7, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT: mov p1.b, p2/m, p2.b
+; CHECK-NEXT: and p0.b, p1/z, p1.b, p0.b
+; CHECK-NEXT: addvl sp, sp, #1
+; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT: ret
+ %r = tail call <vscale x 4 x i1> @llvm.experimental.vector.match(<vscale x 4 x i32> %op1, <4 x i32> %op2, <vscale x 4 x i1> %mask)
+ ret <vscale x 4 x i1> %r
+}
+
+define <vscale x 2 x i1> @match_nxv2xi64_v2i64(<vscale x 2 x i64> %op1, <2 x i64> %op2, <vscale x 2 x i1> %mask) #0 {
+; CHECK-LABEL: match_nxv2xi64_v2i64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov x8, v1.d[1]
+; CHECK-NEXT: fmov x9, d1
+; CHECK-NEXT: ptrue p1.d
+; CHECK-NEXT: mov z2.d, x9
+; CHECK-NEXT: mov z1.d, x8
+; CHECK-NEXT: cmpeq p2.d, p1/z, z0.d, z1.d
+; CHECK-NEXT: cmpeq p1.d, p1/z, z0.d, z2.d
+; CHECK-NEXT: sel p1.b, p1, p1.b, p2.b
+; CHECK-NEXT: and p0.b, p1/z, p1.b, p0.b
+; CHECK-NEXT: ret
+ %r = tail call <vscale x 2 x i1> @llvm.experimental.vector.match(<vscale x 2 x i64> %op1, <2 x i64> %op2, <vscale x 2 x i1> %mask)
+ ret <vscale x 2 x i1> %r
+}
+
+define <4 x i1> @match_v4xi32_v4i32(<4 x i32> %op1, <4 x i32> %op2, <4 x i1> %mask) #0 {
+; CHECK-LABEL: match_v4xi32_v4i32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: dup v3.4s, v1.s[1]
+; CHECK-NEXT: dup v4.4s, v1.s[0]
+; CHECK-NEXT: dup v5.4s, v1.s[2]
+; CHECK-NEXT: dup v1.4s, v1.s[3]
+; CHECK-NEXT: cmeq v3.4s, v0.4s, v3.4s
+; CHECK-NEXT: cmeq v4.4s, v0.4s, v4.4s
+; CHECK-NEXT: cmeq v5.4s, v0.4s, v5.4s
+; CHECK-NEXT: cmeq v0.4s, v0.4s, v1.4s
+; CHECK-NEXT: orr v1.16b, v4.16b, v3.16b
+; CHECK-NEXT: orr v0.16b, v5.16b, v0.16b
+; CHECK-NEXT: orr v0.16b, v1.16b, v0.16b
+; CHECK-NEXT: xtn v0.4h, v0.4s
+; CHECK-NEXT: and v0.8b, v0.8b, v2.8b
+; CHECK-NEXT: ret
+ %r = tail call <4 x i1> @llvm.experimental.vector.match(<4 x i32> %op1, <4 x i32> %op2, <4 x i1> %mask)
+ ret <4 x i1> %r
+}
+
+define <2 x i1> @match_v2xi64_v2i64(<2 x i64> %op1, <2 x i64> %op2, <2 x i1> %mask) #0 {
+; CHECK-LABEL: match_v2xi64_v2i64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: dup v3.2d, v1.d[1]
+; CHECK-NEXT: dup v1.2d, v1.d[0]
+; CHECK-NEXT: cmeq v3.2d, v0.2d, v3.2d
+; CHECK-NEXT: cmeq v0.2d, v0.2d, v1.2d
+; CHECK-NEXT: orr v0.16b, v0.16b, v3.16b
+; CHECK-NEXT: xtn v0.2s, v0.2d
+; CHECK-NEXT: and v0.8b, v0.8b, v2.8b
+; CHECK-NEXT: ret
+ %r = tail call <2 x i1> @llvm.experimental.vector.match(<2 x i64> %op1, <2 x i64> %op2, <2 x i1> %mask)
+ ret <2 x i1> %r
+}
+
attributes #0 = { "target-features"="+sve2" }
>From 56dedaff6107a8625e6ab0d50ae1d8c84b1cb350 Mon Sep 17 00:00:00 2001
From: Ricardo Jesus <rjj at nvidia.com>
Date: Thu, 24 Oct 2024 09:36:10 -0700
Subject: [PATCH 3/7] Address review comments
---
.../Target/AArch64/AArch64ISelLowering.cpp | 41 +++++++------------
.../AArch64/AArch64TargetTransformInfo.cpp | 20 ++++-----
2 files changed, 25 insertions(+), 36 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d9a21d7f60df2c..c51c079b9b61a5 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -6397,8 +6397,8 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
// Wrap Op2 in a scalable register, and splat it if necessary.
if (Op1VT.getVectorMinNumElements() == Op2VT.getVectorNumElements()) {
- // If Op1 and Op2 have the same number of elements we can trivially
- // wrapping Op2 in an SVE register.
+ // If Op1 and Op2 have the same number of elements we can trivially wrap
+ // Op2 in an SVE register.
Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
// If the result is scalable, we need to broadcast Op2 to a full SVE
// register.
@@ -6408,32 +6408,21 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
} else {
// If Op1 and Op2 have different number of elements, we need to broadcast
// Op2. Ideally we would use a AArch64ISD::DUPLANE* node for this
- // similarly to the above, but unfortunately it seems we are missing some
+ // similarly to the above, but unfortunately we seem to be missing some
// patterns for this. So, in alternative, we splat Op2 through a splat of
// a scalable vector extract. This idiom, though a bit more verbose, is
// supported and get us the MOV instruction we want.
-
- // Some types we need. We'll use an integer type with `Op2BitWidth' bits
- // to wrap Op2 and simulate the DUPLANE.
unsigned Op2BitWidth = Op2VT.getFixedSizeInBits();
MVT Op2IntVT = MVT::getIntegerVT(Op2BitWidth);
- MVT Op2FixedVT = MVT::getVectorVT(Op2IntVT, 128 / Op2BitWidth);
- EVT Op2ScalableVT = getContainerForFixedLengthVector(DAG, Op2FixedVT);
- // Widen Op2 to a full 128-bit register. We need this to wrap Op2 in an
- // SVE register before doing the extract and splat.
- // It is unlikely we'll be widening from types other than v8i8 or v4i16,
- // so in practice this loop will run for a single iteration.
- while (Op2VT.getFixedSizeInBits() != 128) {
- Op2VT = Op2VT.getDoubleNumVectorElementsVT(*DAG.getContext());
- Op2 = DAG.getNode(ISD::CONCAT_VECTORS, dl, Op2VT, Op2,
- DAG.getUNDEF(Op2.getValueType()));
- }
- // Wrap Op2 in a scalable vector and do the splat of its 0-index lane.
- Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
+ MVT Op2PromotedVT = MVT::getVectorVT(Op2IntVT, 128 / Op2BitWidth,
+ /*IsScalable=*/true);
+ SDValue Op2Widened = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, OpContainerVT,
+ DAG.getUNDEF(OpContainerVT), Op2,
+ DAG.getConstant(0, dl, MVT::i64));
Op2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, Op2IntVT,
- DAG.getBitcast(Op2ScalableVT, Op2),
+ DAG.getBitcast(Op2PromotedVT, Op2Widened),
DAG.getConstant(0, dl, MVT::i64));
- Op2 = DAG.getSplatVector(Op2ScalableVT, dl, Op2);
+ Op2 = DAG.getSplatVector(Op2PromotedVT, dl, Op2);
Op2 = DAG.getBitcast(OpContainerVT, Op2);
}
@@ -6450,14 +6439,14 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
Mask = DAG.getNode(ISD::ANY_EXTEND, dl, Op1VT, Mask);
Mask = convertFixedMaskToScalableVector(Mask, DAG);
- // Carry out the match and extract it.
+ // Carry out the match.
SDValue Match = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl,
Mask.getValueType(), ID, Mask, Op1, Op2);
- Match = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, Op1VT,
- DAG.getNode(ISD::SIGN_EXTEND, dl, OpContainerVT, Match),
- DAG.getVectorIdxConstant(0, dl));
- // Truncate and return the result.
+ // Extract and promote the match result (nxv16i1/nxv8i1) to ResVT
+ // (v16i8/v8i8).
+ Match = DAG.getNode(ISD::SIGN_EXTEND, dl, OpContainerVT, Match);
+ Match = convertFromScalableVector(DAG, Op1VT, Match);
return DAG.getNode(ISD::TRUNCATE, dl, ResVT, Match);
}
}
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 4a68ecce3b654c..c572337fbb79aa 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -4073,18 +4073,18 @@ bool AArch64TTIImpl::isLegalToVectorizeReduction(
}
bool AArch64TTIImpl::hasVectorMatch(VectorType *VT, unsigned SearchSize) const {
- // Check that (i) the target has SVE2 and SVE is available, (ii) `VT' is a
- // legal type for MATCH, and (iii) the search vector can be broadcast
- // efficently to a legal type.
- //
+ // Check that the target has SVE2 and SVE is available.
+ if (!ST->hasSVE2() || !ST->isSVEAvailable())
+ return false;
+
+ // Check that `VT' is a legal type for MATCH, and that the search vector can
+ // be broadcast efficently if necessary.
// Currently, we require the search vector to be 64-bit or 128-bit. In the
- // future we can support more cases.
- if (ST->hasSVE2() && ST->isSVEAvailable() &&
+ // future we can generalise this to other lengths.
+ unsigned MinEC = VT->getElementCount().getKnownMinValue();
+ if ((MinEC == 8 || MinEC == 16) &&
VT->getPrimitiveSizeInBits().getKnownMinValue() == 128 &&
- (VT->getElementCount().getKnownMinValue() == 8 ||
- VT->getElementCount().getKnownMinValue() == 16) &&
- (VT->getElementCount().getKnownMinValue() == SearchSize ||
- VT->getElementCount().getKnownMinValue() / 2 == SearchSize))
+ (MinEC == SearchSize || MinEC / 2 == SearchSize))
return true;
return false;
}
>From d09a0cc07ed13b58d022e63fed2cf7c61da8a42e Mon Sep 17 00:00:00 2001
From: Ricardo Jesus <rjj at nvidia.com>
Date: Thu, 31 Oct 2024 06:59:56 -0700
Subject: [PATCH 4/7] Move lowering to LowerVectorMatch and replace TTI hook
with TLI
(And address other minor feedback.)
---
llvm/docs/LangRef.rst | 14 +-
.../llvm/Analysis/TargetTransformInfo.h | 10 -
.../llvm/Analysis/TargetTransformInfoImpl.h | 4 -
llvm/include/llvm/CodeGen/TargetLowering.h | 7 +
llvm/lib/Analysis/TargetTransformInfo.cpp | 5 -
.../SelectionDAG/SelectionDAGBuilder.cpp | 7 +-
llvm/lib/IR/Verifier.cpp | 2 +
.../Target/AArch64/AArch64ISelLowering.cpp | 175 +++++++++---------
llvm/lib/Target/AArch64/AArch64ISelLowering.h | 2 +
.../AArch64/AArch64TargetTransformInfo.cpp | 17 --
.../AArch64/AArch64TargetTransformInfo.h | 2 -
.../AArch64/intrinsic-vector-match-sve2.ll | 25 ++-
12 files changed, 135 insertions(+), 135 deletions(-)
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index aedb101c4af68c..53b4f746dae4f0 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -20049,8 +20049,7 @@ are undefined.
Syntax:
"""""""
-This is an overloaded intrinsic. Support for specific vector types is target
-dependent.
+This is an overloaded intrinsic.
::
@@ -20068,16 +20067,19 @@ Arguments:
The first argument is the search vector, the second argument the vector of
elements we are searching for (i.e. for which we consider a match successful),
and the third argument is a mask that controls which elements of the first
-argument are active.
+argument are active. The first two arguments must be vectors of matching
+integer element types. The first and third arguments and the result type must
+have matching element counts (fixed or scalable). The second argument must be a
+fixed vector, but its length may be different from the remaining arguments.
Semantics:
""""""""""
The '``llvm.experimental.vector.match``' intrinsic compares each active element
in the first argument against the elements of the second argument, placing
-``1`` in the corresponding element of the output vector if any comparison is
-successful, and ``0`` otherwise. Inactive elements in the mask are set to ``0``
-in the output.
+``1`` in the corresponding element of the output vector if any equality
+comparison is successful, and ``0`` otherwise. Inactive elements in the mask
+are set to ``0`` in the output.
The second argument needs to be a fixed-length vector with the same element
type as the first argument.
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index f47874bf0407d5..0459941fe05cdc 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1771,11 +1771,6 @@ class TargetTransformInfo {
/// This should also apply to lowering for vector funnel shifts (rotates).
bool isVectorShiftByScalarCheap(Type *Ty) const;
- /// \returns True if the target has hardware support for vector match
- /// operations between vectors of type `VT` and search vectors of `SearchSize`
- /// elements, and false otherwise.
- bool hasVectorMatch(VectorType *VT, unsigned SearchSize) const;
-
struct VPLegalization {
enum VPTransform {
// keep the predicating parameter
@@ -2226,7 +2221,6 @@ class TargetTransformInfo::Concept {
SmallVectorImpl<Use *> &OpsToSink) const = 0;
virtual bool isVectorShiftByScalarCheap(Type *Ty) const = 0;
- virtual bool hasVectorMatch(VectorType *VT, unsigned SearchSize) const = 0;
virtual VPLegalization
getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
virtual bool hasArmWideBranch(bool Thumb) const = 0;
@@ -3020,10 +3014,6 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
return Impl.isVectorShiftByScalarCheap(Ty);
}
- bool hasVectorMatch(VectorType *VT, unsigned SearchSize) const override {
- return Impl.hasVectorMatch(VT, SearchSize);
- }
-
VPLegalization
getVPLegalizationStrategy(const VPIntrinsic &PI) const override {
return Impl.getVPLegalizationStrategy(PI);
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 886acb5120330f..dbdfb4d8cdfa32 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -995,10 +995,6 @@ class TargetTransformInfoImplBase {
bool isVectorShiftByScalarCheap(Type *Ty) const { return false; }
- bool hasVectorMatch(VectorType *VT, unsigned SearchSize) const {
- return false;
- }
-
TargetTransformInfo::VPLegalization
getVPLegalizationStrategy(const VPIntrinsic &PI) const {
return TargetTransformInfo::VPLegalization(
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 8e0cdc6f1a5e77..231cfd30b426c8 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -483,6 +483,13 @@ class TargetLoweringBase {
bool ZeroIsPoison,
const ConstantRange *VScaleRange) const;
+ /// Return true if the @llvm.experimental.vector.match intrinsic should be
+ /// expanded for vector type `VT' and search size `SearchSize' using generic
+ /// code in SelectionDAGBuilder.
+ virtual bool shouldExpandVectorMatch(EVT VT, unsigned SearchSize) const {
+ return true;
+ }
+
// Return true if op(vecreduce(x), vecreduce(y)) should be reassociated to
// vecreduce(op(x, y)) for the reduction opcode RedOpc.
virtual bool shouldReassociateReduction(unsigned RedOpc, EVT VT) const {
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index ca7b258dc08d79..a47462b61e03b2 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1383,11 +1383,6 @@ bool TargetTransformInfo::isVectorShiftByScalarCheap(Type *Ty) const {
return TTIImpl->isVectorShiftByScalarCheap(Ty);
}
-bool TargetTransformInfo::hasVectorMatch(VectorType *VT,
- unsigned SearchSize) const {
- return TTIImpl->hasVectorMatch(VT, SearchSize);
-}
-
TargetTransformInfo::Concept::~Concept() = default;
TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {}
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index ca67f623f46258..b0b2a3bdd11988 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8165,14 +8165,9 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
EVT ResVT = Mask.getValueType();
unsigned SearchSize = Op2VT.getVectorNumElements();
- LLVMContext &Ctx = *DAG.getContext();
- const auto &TTI =
- TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
-
// If the target has native support for this vector match operation, lower
// the intrinsic directly; otherwise, lower it below.
- if (TTI.hasVectorMatch(cast<VectorType>(Op1VT.getTypeForEVT(Ctx)),
- SearchSize)) {
+ if (!TLI.shouldExpandVectorMatch(Op1VT, SearchSize)) {
visitTargetIntrinsic(I, Intrinsic);
return;
}
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 09ec2ff4628d36..4f222269b5b878 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -6166,6 +6166,8 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
Check(Op1Ty && Op2Ty && MaskTy, "Operands must be vectors.", &Call);
Check(!isa<ScalableVectorType>(Op2Ty), "Second operand cannot be scalable.",
&Call);
+ Check(Op1Ty->getElementType()->isIntegerTy(),
+ "First operand must be a vector of integers.", &Call);
Check(Op1Ty->getElementType() == Op2Ty->getElementType(),
"First two operands must have the same element type.", &Call);
Check(Op1Ty->getElementCount() == MaskTy->getElementCount(),
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c51c079b9b61a5..d891b02da3d46c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2052,6 +2052,19 @@ bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
VT != MVT::v4i1 && VT != MVT::v2i1;
}
+bool AArch64TargetLowering::shouldExpandVectorMatch(EVT VT,
+ unsigned SearchSize) const {
+ // MATCH is SVE2 and only available in non-streaming mode.
+ if (!Subtarget->hasSVE2() || !Subtarget->isSVEAvailable())
+ return true;
+ // Furthermore, we can only use it for 8-bit or 16-bit characters.
+ if (VT == MVT::nxv8i16 || VT == MVT::v8i16)
+ return SearchSize != 8;
+ if (VT == MVT::nxv16i8 || VT == MVT::v16i8 || VT == MVT::v8i8)
+ return SearchSize != 8 && SearchSize != 16;
+ return true;
+}
+
void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
@@ -5761,6 +5774,84 @@ SDValue LowerSMELdrStr(SDValue N, SelectionDAG &DAG, bool IsLoad) {
DAG.getTargetConstant(ImmAddend, DL, MVT::i32)});
}
+SDValue LowerVectorMatch(SDValue Op, SelectionDAG &DAG) {
+ SDLoc dl(Op);
+ SDValue ID =
+ DAG.getTargetConstant(Intrinsic::aarch64_sve_match, dl, MVT::i64);
+
+ auto Op1 = Op.getOperand(1);
+ auto Op2 = Op.getOperand(2);
+ auto Mask = Op.getOperand(3);
+
+ EVT Op1VT = Op1.getValueType();
+ EVT Op2VT = Op2.getValueType();
+ EVT ResVT = Op.getValueType();
+
+ assert((Op1VT.getVectorElementType() == MVT::i8 ||
+ Op1VT.getVectorElementType() == MVT::i16) &&
+ "Expected 8-bit or 16-bit characters.");
+
+ // Scalable vector type used to wrap operands.
+ // A single container is enough for both operands because ultimately the
+ // operands will have to be wrapped to the same type (nxv16i8 or nxv8i16).
+ EVT OpContainerVT = Op1VT.isScalableVector()
+ ? Op1VT
+ : getContainerForFixedLengthVector(DAG, Op1VT);
+
+ // Wrap Op2 in a scalable register, and splat it if necessary.
+ if (Op1VT.getVectorMinNumElements() == Op2VT.getVectorNumElements()) {
+ // If Op1 and Op2 have the same number of elements we can trivially wrap
+ // Op2 in an SVE register.
+ Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
+ // If the result is scalable, we need to broadcast Op2 to a full SVE
+ // register.
+ if (ResVT.isScalableVector())
+ Op2 = DAG.getNode(AArch64ISD::DUPLANE128, dl, OpContainerVT, Op2,
+ DAG.getTargetConstant(0, dl, MVT::i64));
+ } else {
+ // If Op1 and Op2 have different number of elements, we need to broadcast
+ // Op2. Ideally we would use a AArch64ISD::DUPLANE* node for this
+ // similarly to the above, but unfortunately we seem to be missing some
+ // patterns for this. So, in alternative, we splat Op2 through a splat of
+ // a scalable vector extract. This idiom, though a bit more verbose, is
+ // supported and get us the MOV instruction we want.
+ unsigned Op2BitWidth = Op2VT.getFixedSizeInBits();
+ MVT Op2IntVT = MVT::getIntegerVT(Op2BitWidth);
+ MVT Op2PromotedVT = MVT::getVectorVT(Op2IntVT, 128 / Op2BitWidth,
+ /*IsScalable=*/true);
+ SDValue Op2Widened = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, OpContainerVT,
+ DAG.getUNDEF(OpContainerVT), Op2,
+ DAG.getConstant(0, dl, MVT::i64));
+ Op2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, Op2IntVT,
+ DAG.getBitcast(Op2PromotedVT, Op2Widened),
+ DAG.getConstant(0, dl, MVT::i64));
+ Op2 = DAG.getSplatVector(Op2PromotedVT, dl, Op2);
+ Op2 = DAG.getBitcast(OpContainerVT, Op2);
+ }
+
+ // If the result is scalable, we just need to carry out the MATCH.
+ if (ResVT.isScalableVector())
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, ResVT, ID, Mask, Op1, Op2);
+
+ // If the result is fixed, we can still use MATCH but we need to wrap the
+ // first operand and the mask in scalable vectors before doing so.
+
+ // Wrap the operands.
+ Op1 = convertToScalableVector(DAG, OpContainerVT, Op1);
+ Mask = DAG.getNode(ISD::SIGN_EXTEND, dl, Op1VT, Mask);
+ Mask = convertFixedMaskToScalableVector(Mask, DAG);
+
+ // Carry out the match.
+ SDValue Match = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, Mask.getValueType(),
+ ID, Mask, Op1, Op2);
+
+ // Extract and promote the match result (nxv16i1/nxv8i1) to ResVT
+ // (v16i8/v8i8).
+ Match = DAG.getNode(ISD::SIGN_EXTEND, dl, OpContainerVT, Match);
+ Match = convertFromScalableVector(DAG, Op1VT, Match);
+ return DAG.getNode(ISD::TRUNCATE, dl, ResVT, Match);
+}
+
SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op,
SelectionDAG &DAG) const {
unsigned IntNo = Op.getConstantOperandVal(1);
@@ -6365,89 +6456,7 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
return DAG.getZExtOrTrunc(NewCttzElts, dl, Op.getValueType());
}
case Intrinsic::experimental_vector_match: {
- SDValue ID =
- DAG.getTargetConstant(Intrinsic::aarch64_sve_match, dl, MVT::i64);
-
- auto Op1 = Op.getOperand(1);
- auto Op2 = Op.getOperand(2);
- auto Mask = Op.getOperand(3);
-
- EVT Op1VT = Op1.getValueType();
- EVT Op2VT = Op2.getValueType();
- EVT ResVT = Op.getValueType();
-
- assert((Op1VT.getVectorElementType() == MVT::i8 ||
- Op1VT.getVectorElementType() == MVT::i16) &&
- "Expected 8-bit or 16-bit characters.");
- assert(Op1VT.getVectorElementType() == Op2VT.getVectorElementType() &&
- "Operand type mismatch.");
- assert(!Op2VT.isScalableVector() && "Search vector cannot be scalable.");
-
- // Note: Currently Op1 needs to be v16i8, v8i16, or the scalable versions.
- // In the future we could support other types (e.g. v8i8).
- assert(Op1VT.getSizeInBits().getKnownMinValue() == 128 &&
- "Unsupported first operand type.");
-
- // Scalable vector type used to wrap operands.
- // A single container is enough for both operands because ultimately the
- // operands will have to be wrapped to the same type (nxv16i8 or nxv8i16).
- EVT OpContainerVT = Op1VT.isScalableVector()
- ? Op1VT
- : getContainerForFixedLengthVector(DAG, Op1VT);
-
- // Wrap Op2 in a scalable register, and splat it if necessary.
- if (Op1VT.getVectorMinNumElements() == Op2VT.getVectorNumElements()) {
- // If Op1 and Op2 have the same number of elements we can trivially wrap
- // Op2 in an SVE register.
- Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
- // If the result is scalable, we need to broadcast Op2 to a full SVE
- // register.
- if (ResVT.isScalableVector())
- Op2 = DAG.getNode(AArch64ISD::DUPLANE128, dl, OpContainerVT, Op2,
- DAG.getTargetConstant(0, dl, MVT::i64));
- } else {
- // If Op1 and Op2 have different number of elements, we need to broadcast
- // Op2. Ideally we would use a AArch64ISD::DUPLANE* node for this
- // similarly to the above, but unfortunately we seem to be missing some
- // patterns for this. So, in alternative, we splat Op2 through a splat of
- // a scalable vector extract. This idiom, though a bit more verbose, is
- // supported and get us the MOV instruction we want.
- unsigned Op2BitWidth = Op2VT.getFixedSizeInBits();
- MVT Op2IntVT = MVT::getIntegerVT(Op2BitWidth);
- MVT Op2PromotedVT = MVT::getVectorVT(Op2IntVT, 128 / Op2BitWidth,
- /*IsScalable=*/true);
- SDValue Op2Widened = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, OpContainerVT,
- DAG.getUNDEF(OpContainerVT), Op2,
- DAG.getConstant(0, dl, MVT::i64));
- Op2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, Op2IntVT,
- DAG.getBitcast(Op2PromotedVT, Op2Widened),
- DAG.getConstant(0, dl, MVT::i64));
- Op2 = DAG.getSplatVector(Op2PromotedVT, dl, Op2);
- Op2 = DAG.getBitcast(OpContainerVT, Op2);
- }
-
- // If the result is scalable, we just need to carry out the MATCH.
- if (ResVT.isScalableVector())
- return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, ResVT, ID, Mask, Op1,
- Op2);
-
- // If the result is fixed, we can still use MATCH but we need to wrap the
- // first operand and the mask in scalable vectors before doing so.
-
- // Wrap the operands.
- Op1 = convertToScalableVector(DAG, OpContainerVT, Op1);
- Mask = DAG.getNode(ISD::ANY_EXTEND, dl, Op1VT, Mask);
- Mask = convertFixedMaskToScalableVector(Mask, DAG);
-
- // Carry out the match.
- SDValue Match = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl,
- Mask.getValueType(), ID, Mask, Op1, Op2);
-
- // Extract and promote the match result (nxv16i1/nxv8i1) to ResVT
- // (v16i8/v8i8).
- Match = DAG.getNode(ISD::SIGN_EXTEND, dl, OpContainerVT, Match);
- Match = convertFromScalableVector(DAG, Op1VT, Match);
- return DAG.getNode(ISD::TRUNCATE, dl, ResVT, Match);
+ return LowerVectorMatch(Op, DAG);
}
}
}
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 160cd18ca53b32..d7eaf2e9b724f6 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -993,6 +993,8 @@ class AArch64TargetLowering : public TargetLowering {
bool shouldExpandCttzElements(EVT VT) const override;
+ bool shouldExpandVectorMatch(EVT VT, unsigned SearchSize) const override;
+
/// If a change in streaming mode is required on entry to/return from a
/// function call it emits and returns the corresponding SMSTART or SMSTOP
/// node. \p Condition should be one of the enum values from
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index c572337fbb79aa..ff3c69f7e10c66 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -4072,23 +4072,6 @@ bool AArch64TTIImpl::isLegalToVectorizeReduction(
}
}
-bool AArch64TTIImpl::hasVectorMatch(VectorType *VT, unsigned SearchSize) const {
- // Check that the target has SVE2 and SVE is available.
- if (!ST->hasSVE2() || !ST->isSVEAvailable())
- return false;
-
- // Check that `VT' is a legal type for MATCH, and that the search vector can
- // be broadcast efficently if necessary.
- // Currently, we require the search vector to be 64-bit or 128-bit. In the
- // future we can generalise this to other lengths.
- unsigned MinEC = VT->getElementCount().getKnownMinValue();
- if ((MinEC == 8 || MinEC == 16) &&
- VT->getPrimitiveSizeInBits().getKnownMinValue() == 128 &&
- (MinEC == SearchSize || MinEC / 2 == SearchSize))
- return true;
- return false;
-}
-
InstructionCost
AArch64TTIImpl::getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty,
FastMathFlags FMF,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 580bf5e79c3da1..1d09d67f6ec9e3 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -392,8 +392,6 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
return ST->hasSVE();
}
- bool hasVectorMatch(VectorType *VT, unsigned SearchSize) const;
-
InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
std::optional<FastMathFlags> FMF,
TTI::TargetCostKind CostKind);
diff --git a/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll b/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll
index e991bbc80962b5..1a004f4a574ea3 100644
--- a/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll
+++ b/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll
@@ -138,11 +138,12 @@ define <16 x i1> @match_v16i8_v4i8(<16 x i8> %op1, <4 x i8> %op2, <16 x i1> %mas
define <16 x i1> @match_v16i8_v8i8(<16 x i8> %op1, <8 x i8> %op2, <16 x i1> %mask) #0 {
; CHECK-LABEL: match_v16i8_v8i8:
; CHECK: // %bb.0:
+; CHECK-NEXT: shl v2.16b, v2.16b, #7
; CHECK-NEXT: ptrue p0.b, vl16
-; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
; CHECK-NEXT: mov z1.d, d1
+; CHECK-NEXT: cmlt v2.16b, v2.16b, #0
; CHECK-NEXT: cmpne p0.b, p0/z, z2.b, #0
; CHECK-NEXT: match p0.b, p0/z, z0.b, z1.b
; CHECK-NEXT: mov z0.b, p0/z, #-1 // =0xffffffffffffffff
@@ -155,10 +156,11 @@ define <16 x i1> @match_v16i8_v8i8(<16 x i8> %op1, <8 x i8> %op2, <16 x i1> %mas
define <16 x i1> @match_v16i8_v16i8(<16 x i8> %op1, <16 x i8> %op2, <16 x i1> %mask) #0 {
; CHECK-LABEL: match_v16i8_v16i8:
; CHECK: // %bb.0:
+; CHECK-NEXT: shl v2.16b, v2.16b, #7
; CHECK-NEXT: ptrue p0.b, vl16
-; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: cmlt v2.16b, v2.16b, #0
; CHECK-NEXT: cmpne p0.b, p0/z, z2.b, #0
; CHECK-NEXT: match p0.b, p0/z, z0.b, z1.b
; CHECK-NEXT: mov z0.b, p0/z, #-1 // =0xffffffffffffffff
@@ -168,6 +170,23 @@ define <16 x i1> @match_v16i8_v16i8(<16 x i8> %op1, <16 x i8> %op2, <16 x i1> %m
ret <16 x i1> %r
}
+define <8 x i1> @match_v8i8_v8i8(<8 x i8> %op1, <8 x i8> %op2, <8 x i1> %mask) #0 {
+; CHECK-LABEL: match_v8i8_v8i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: shl v2.8b, v2.8b, #7
+; CHECK-NEXT: ptrue p0.b, vl8
+; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1
+; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0
+; CHECK-NEXT: cmlt v2.8b, v2.8b, #0
+; CHECK-NEXT: cmpne p0.b, p0/z, z2.b, #0
+; CHECK-NEXT: match p0.b, p0/z, z0.b, z1.b
+; CHECK-NEXT: mov z0.b, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
+; CHECK-NEXT: ret
+ %r = tail call <8 x i1> @llvm.experimental.vector.match(<8 x i8> %op1, <8 x i8> %op2, <8 x i1> %mask)
+ ret <8 x i1> %r
+}
+
define <vscale x 8 x i1> @match_nxv8i16_v8i16(<vscale x 8 x i16> %op1, <8 x i16> %op2, <vscale x 8 x i1> %mask) #0 {
; CHECK-LABEL: match_nxv8i16_v8i16:
; CHECK: // %bb.0:
@@ -186,6 +205,8 @@ define <8 x i1> @match_v8i16(<8 x i16> %op1, <8 x i16> %op2, <8 x i1> %mask) #0
; CHECK-NEXT: ptrue p0.h, vl8
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: shl v2.8h, v2.8h, #15
+; CHECK-NEXT: cmlt v2.8h, v2.8h, #0
; CHECK-NEXT: cmpne p0.h, p0/z, z2.h, #0
; CHECK-NEXT: match p0.h, p0/z, z0.h, z1.h
; CHECK-NEXT: mov z0.h, p0/z, #-1 // =0xffffffffffffffff
>From bcf5b91890f54ae5b632e93272b9da81916b8d06 Mon Sep 17 00:00:00 2001
From: Ricardo Jesus <rjj at nvidia.com>
Date: Fri, 8 Nov 2024 08:35:40 -0800
Subject: [PATCH 5/7] Fix Op2 broadcasting and address other comments
Decide to broadcast Op2 based on the result of `Op2VT.is128BitVector()`
rather than by comparing its number of elements with Op1. This fixes a
bug with search and needle vectors of `v8i8`, and enables us to match
search vectors of `v8i8` with needle vectors of `v16i8`.
Also address other review comments.
---
llvm/docs/LangRef.rst | 3 ---
.../SelectionDAG/SelectionDAGBuilder.cpp | 5 ++--
llvm/lib/IR/Verifier.cpp | 4 ++--
.../Target/AArch64/AArch64ISelLowering.cpp | 23 ++++++++-----------
.../AArch64/intrinsic-vector-match-sve2.ll | 18 +++++++++++++++
5 files changed, 32 insertions(+), 21 deletions(-)
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 53b4f746dae4f0..b4af06437dd23f 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -20081,9 +20081,6 @@ in the first argument against the elements of the second argument, placing
comparison is successful, and ``0`` otherwise. Inactive elements in the mask
are set to ``0`` in the output.
-The second argument needs to be a fixed-length vector with the same element
-type as the first argument.
-
Matrix Intrinsics
-----------------
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index b0b2a3bdd11988..c47a1ef664fbc1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8166,14 +8166,13 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
unsigned SearchSize = Op2VT.getVectorNumElements();
// If the target has native support for this vector match operation, lower
- // the intrinsic directly; otherwise, lower it below.
+ // the intrinsic untouched; otherwise, expand it below.
if (!TLI.shouldExpandVectorMatch(Op1VT, SearchSize)) {
visitTargetIntrinsic(I, Intrinsic);
return;
}
- SDValue Ret = DAG.getNode(ISD::SPLAT_VECTOR, sdl, ResVT,
- DAG.getConstant(0, sdl, MVT::i1));
+ SDValue Ret = DAG.getConstant(0, sdl, ResVT);
for (unsigned i = 0; i < SearchSize; ++i) {
SDValue Op2Elem = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, sdl,
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 4f222269b5b878..552b5b40fd82d0 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -6164,8 +6164,8 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
VectorType *MaskTy = dyn_cast<VectorType>(Mask->getType());
Check(Op1Ty && Op2Ty && MaskTy, "Operands must be vectors.", &Call);
- Check(!isa<ScalableVectorType>(Op2Ty), "Second operand cannot be scalable.",
- &Call);
+ Check(isa<FixedVectorType>(Op2Ty),
+ "Second operand must be a fixed length vector.", &Call);
Check(Op1Ty->getElementType()->isIntegerTy(),
"First operand must be a vector of integers.", &Call);
Check(Op1Ty->getElementType() == Op2Ty->getElementType(),
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d891b02da3d46c..aa13e841a21b19 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2057,7 +2057,7 @@ bool AArch64TargetLowering::shouldExpandVectorMatch(EVT VT,
// MATCH is SVE2 and only available in non-streaming mode.
if (!Subtarget->hasSVE2() || !Subtarget->isSVEAvailable())
return true;
- // Furthermore, we can only use it for 8-bit or 16-bit characters.
+ // Furthermore, we can only use it for 8-bit or 16-bit elements.
if (VT == MVT::nxv8i16 || VT == MVT::v8i16)
return SearchSize != 8;
if (VT == MVT::nxv16i8 || VT == MVT::v16i8 || VT == MVT::v8i8)
@@ -5798,23 +5798,20 @@ SDValue LowerVectorMatch(SDValue Op, SelectionDAG &DAG) {
? Op1VT
: getContainerForFixedLengthVector(DAG, Op1VT);
- // Wrap Op2 in a scalable register, and splat it if necessary.
- if (Op1VT.getVectorMinNumElements() == Op2VT.getVectorNumElements()) {
- // If Op1 and Op2 have the same number of elements we can trivially wrap
- // Op2 in an SVE register.
+ if (Op2VT.is128BitVector()) {
+ // If Op2 is a full 128-bit vector, wrap it trivially in a scalable vector.
Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
- // If the result is scalable, we need to broadcast Op2 to a full SVE
- // register.
+ // Further, if the result is scalable, broadcast Op2 to a full SVE register.
if (ResVT.isScalableVector())
Op2 = DAG.getNode(AArch64ISD::DUPLANE128, dl, OpContainerVT, Op2,
DAG.getTargetConstant(0, dl, MVT::i64));
} else {
- // If Op1 and Op2 have different number of elements, we need to broadcast
- // Op2. Ideally we would use a AArch64ISD::DUPLANE* node for this
- // similarly to the above, but unfortunately we seem to be missing some
- // patterns for this. So, in alternative, we splat Op2 through a splat of
- // a scalable vector extract. This idiom, though a bit more verbose, is
- // supported and get us the MOV instruction we want.
+ // If Op2 is not a full 128-bit vector, we need to broadcast it. Ideally we
+ // would use a AArch64ISD::DUPLANE* node for this similarly to the below,
+ // but unfortunately we seem to be missing some patterns for this. So, in
+ // alternative, we splat Op2 through a splat of a scalable vector extract.
+ // This idiom, though a bit more verbose, is supported and get us the MOV
+ // instruction we want.
unsigned Op2BitWidth = Op2VT.getFixedSizeInBits();
MVT Op2IntVT = MVT::getIntegerVT(Op2BitWidth);
MVT Op2PromotedVT = MVT::getVectorVT(Op2IntVT, 128 / Op2BitWidth,
diff --git a/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll b/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll
index 1a004f4a574ea3..54e323118dd2ef 100644
--- a/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll
+++ b/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll
@@ -177,6 +177,7 @@ define <8 x i1> @match_v8i8_v8i8(<8 x i8> %op1, <8 x i8> %op2, <8 x i1> %mask) #
; CHECK-NEXT: ptrue p0.b, vl8
; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1
; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0
+; CHECK-NEXT: mov z1.d, d1
; CHECK-NEXT: cmlt v2.8b, v2.8b, #0
; CHECK-NEXT: cmpne p0.b, p0/z, z2.b, #0
; CHECK-NEXT: match p0.b, p0/z, z0.b, z1.b
@@ -218,6 +219,23 @@ define <8 x i1> @match_v8i16(<8 x i16> %op1, <8 x i16> %op2, <8 x i1> %mask) #0
; Cases where op2 has more elements than op1.
+define <8 x i1> @match_v8i8_v16i8(<8 x i8> %op1, <16 x i8> %op2, <8 x i1> %mask) #0 {
+; CHECK-LABEL: match_v8i8_v16i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: shl v2.8b, v2.8b, #7
+; CHECK-NEXT: ptrue p0.b, vl8
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0
+; CHECK-NEXT: cmlt v2.8b, v2.8b, #0
+; CHECK-NEXT: cmpne p0.b, p0/z, z2.b, #0
+; CHECK-NEXT: match p0.b, p0/z, z0.b, z1.b
+; CHECK-NEXT: mov z0.b, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
+; CHECK-NEXT: ret
+ %r = tail call <8 x i1> @llvm.experimental.vector.match(<8 x i8> %op1, <16 x i8> %op2, <8 x i1> %mask)
+ ret <8 x i1> %r
+}
+
define <vscale x 16 x i1> @match_nxv16i8_v32i8(<vscale x 16 x i8> %op1, <32 x i8> %op2, <vscale x 16 x i1> %mask) #0 {
; CHECK-LABEL: match_nxv16i8_v32i8:
; CHECK: // %bb.0:
>From b2f1a6ad630ee9b174c312b65624cb043834d9a9 Mon Sep 17 00:00:00 2001
From: Ricardo Jesus <rjj at nvidia.com>
Date: Mon, 11 Nov 2024 07:23:07 -0800
Subject: [PATCH 6/7] Improve Op2 broadcast
---
.../Target/AArch64/AArch64ISelLowering.cpp | 14 +-
.../AArch64/intrinsic-vector-match-sve2.ll | 139 +++++++-----------
2 files changed, 55 insertions(+), 98 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 26f7ad36c43075..b30cac3975ff4d 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -5823,21 +5823,13 @@ SDValue LowerVectorMatch(SDValue Op, SelectionDAG &DAG) {
Op2 = DAG.getNode(AArch64ISD::DUPLANE128, dl, OpContainerVT, Op2,
DAG.getTargetConstant(0, dl, MVT::i64));
} else {
- // If Op2 is not a full 128-bit vector, we need to broadcast it. Ideally we
- // would use a AArch64ISD::DUPLANE* node for this similarly to the below,
- // but unfortunately we seem to be missing some patterns for this. So, in
- // alternative, we splat Op2 through a splat of a scalable vector extract.
- // This idiom, though a bit more verbose, is supported and get us the MOV
- // instruction we want.
+ // If Op2 is not a full 128-bit vector, we always need to broadcast it.
unsigned Op2BitWidth = Op2VT.getFixedSizeInBits();
MVT Op2IntVT = MVT::getIntegerVT(Op2BitWidth);
MVT Op2PromotedVT = MVT::getVectorVT(Op2IntVT, 128 / Op2BitWidth,
/*IsScalable=*/true);
- SDValue Op2Widened = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, OpContainerVT,
- DAG.getUNDEF(OpContainerVT), Op2,
- DAG.getConstant(0, dl, MVT::i64));
- Op2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, Op2IntVT,
- DAG.getBitcast(Op2PromotedVT, Op2Widened),
+ Op2 = DAG.getBitcast(MVT::getVectorVT(Op2IntVT, 1), Op2);
+ Op2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, Op2IntVT, Op2,
DAG.getConstant(0, dl, MVT::i64));
Op2 = DAG.getSplatVector(Op2PromotedVT, dl, Op2);
Op2 = DAG.getBitcast(OpContainerVT, Op2);
diff --git a/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll b/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll
index 54e323118dd2ef..2cf8621ca066dd 100644
--- a/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll
+++ b/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll
@@ -4,9 +4,8 @@
define <vscale x 16 x i1> @match_nxv16i8_v1i8(<vscale x 16 x i8> %op1, <1 x i8> %op2, <vscale x 16 x i1> %mask) #0 {
; CHECK-LABEL: match_nxv16i8_v1i8:
; CHECK: // %bb.0:
-; CHECK-NEXT: // kill: def $d1 killed $d1 def $q1
-; CHECK-NEXT: umov w8, v1.b[0]
-; CHECK-NEXT: mov z1.b, w8
+; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1
+; CHECK-NEXT: mov z1.b, b1
; CHECK-NEXT: cmpeq p0.b, p0/z, z0.b, z1.b
; CHECK-NEXT: ret
%r = tail call <vscale x 16 x i1> @llvm.experimental.vector.match(<vscale x 16 x i8> %op1, <1 x i8> %op2, <vscale x 16 x i1> %mask)
@@ -244,130 +243,100 @@ define <vscale x 16 x i1> @match_nxv16i8_v32i8(<vscale x 16 x i8> %op1, <32 x i8
; CHECK-NEXT: str p4, [sp, #7, mul vl] // 2-byte Folded Spill
; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
; CHECK-NEXT: .cfi_offset w29, -16
-; CHECK-NEXT: umov w8, v1.b[1]
-; CHECK-NEXT: umov w9, v1.b[0]
-; CHECK-NEXT: umov w10, v1.b[2]
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: mov z3.b, z1.b[1]
+; CHECK-NEXT: mov z4.b, b1
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
; CHECK-NEXT: ptrue p1.b
-; CHECK-NEXT: mov z3.b, w8
-; CHECK-NEXT: mov z4.b, w9
-; CHECK-NEXT: umov w8, v1.b[3]
-; CHECK-NEXT: mov z5.b, w10
-; CHECK-NEXT: umov w9, v1.b[4]
-; CHECK-NEXT: umov w10, v1.b[15]
+; CHECK-NEXT: mov z5.b, z1.b[2]
; CHECK-NEXT: cmpeq p2.b, p1/z, z0.b, z3.b
; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z4.b
-; CHECK-NEXT: mov z3.b, w8
-; CHECK-NEXT: umov w8, v1.b[5]
+; CHECK-NEXT: mov z3.b, z1.b[3]
; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z5.b
-; CHECK-NEXT: mov z4.b, w9
-; CHECK-NEXT: umov w9, v1.b[6]
+; CHECK-NEXT: mov z4.b, z1.b[4]
; CHECK-NEXT: mov p2.b, p3/m, p3.b
; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z3.b
-; CHECK-NEXT: mov z3.b, w8
-; CHECK-NEXT: umov w8, v1.b[7]
+; CHECK-NEXT: mov z3.b, z1.b[5]
; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z4.b
-; CHECK-NEXT: mov z4.b, w9
-; CHECK-NEXT: umov w9, v1.b[8]
+; CHECK-NEXT: mov z4.b, z1.b[6]
; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z3.b
-; CHECK-NEXT: mov z3.b, w8
-; CHECK-NEXT: umov w8, v1.b[9]
+; CHECK-NEXT: mov z3.b, z1.b[7]
; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z4.b
-; CHECK-NEXT: mov z4.b, w9
-; CHECK-NEXT: umov w9, v1.b[10]
+; CHECK-NEXT: mov z4.b, z1.b[8]
; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z3.b
-; CHECK-NEXT: mov z3.b, w8
-; CHECK-NEXT: umov w8, v1.b[11]
+; CHECK-NEXT: mov z3.b, z1.b[9]
; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z4.b
-; CHECK-NEXT: mov z4.b, w9
-; CHECK-NEXT: umov w9, v1.b[12]
+; CHECK-NEXT: mov z4.b, z1.b[10]
; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z3.b
-; CHECK-NEXT: mov z3.b, w8
-; CHECK-NEXT: umov w8, v1.b[13]
+; CHECK-NEXT: mov z3.b, z1.b[11]
; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z4.b
-; CHECK-NEXT: mov z4.b, w9
-; CHECK-NEXT: umov w9, v1.b[14]
+; CHECK-NEXT: mov z4.b, z1.b[12]
; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z3.b
-; CHECK-NEXT: mov z1.b, w8
-; CHECK-NEXT: umov w8, v2.b[0]
+; CHECK-NEXT: mov z3.b, z1.b[13]
; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z4.b
-; CHECK-NEXT: mov z3.b, w9
-; CHECK-NEXT: umov w9, v2.b[1]
+; CHECK-NEXT: mov z4.b, z1.b[14]
+; CHECK-NEXT: mov z1.b, z1.b[15]
; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
-; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z1.b
-; CHECK-NEXT: mov z1.b, w10
+; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z3.b
+; CHECK-NEXT: mov z3.b, b2
; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
-; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z3.b
-; CHECK-NEXT: mov z3.b, w8
-; CHECK-NEXT: umov w8, v2.b[2]
+; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z4.b
; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z1.b
-; CHECK-NEXT: mov z1.b, w9
-; CHECK-NEXT: umov w9, v2.b[3]
+; CHECK-NEXT: mov z1.b, z2.b[1]
; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z3.b
-; CHECK-NEXT: mov z3.b, w8
-; CHECK-NEXT: umov w8, v2.b[4]
+; CHECK-NEXT: mov z3.b, z2.b[2]
; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z1.b
-; CHECK-NEXT: mov z1.b, w9
-; CHECK-NEXT: umov w9, v2.b[5]
+; CHECK-NEXT: mov z1.b, z2.b[3]
; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z3.b
-; CHECK-NEXT: mov z3.b, w8
-; CHECK-NEXT: umov w8, v2.b[6]
+; CHECK-NEXT: mov z3.b, z2.b[4]
; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z1.b
-; CHECK-NEXT: mov z1.b, w9
-; CHECK-NEXT: umov w9, v2.b[7]
+; CHECK-NEXT: mov z1.b, z2.b[5]
; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z3.b
-; CHECK-NEXT: mov z3.b, w8
-; CHECK-NEXT: umov w8, v2.b[8]
+; CHECK-NEXT: mov z3.b, z2.b[6]
; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z1.b
-; CHECK-NEXT: mov z1.b, w9
-; CHECK-NEXT: umov w9, v2.b[9]
+; CHECK-NEXT: mov z1.b, z2.b[7]
; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z3.b
-; CHECK-NEXT: mov z3.b, w8
-; CHECK-NEXT: umov w8, v2.b[10]
+; CHECK-NEXT: mov z3.b, z2.b[8]
; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z1.b
-; CHECK-NEXT: mov z1.b, w9
-; CHECK-NEXT: umov w9, v2.b[11]
+; CHECK-NEXT: mov z1.b, z2.b[9]
; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z3.b
-; CHECK-NEXT: mov z3.b, w8
-; CHECK-NEXT: umov w8, v2.b[12]
+; CHECK-NEXT: mov z3.b, z2.b[10]
; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z1.b
-; CHECK-NEXT: mov z1.b, w9
-; CHECK-NEXT: umov w9, v2.b[13]
+; CHECK-NEXT: mov z1.b, z2.b[11]
; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z3.b
-; CHECK-NEXT: mov z3.b, w8
-; CHECK-NEXT: umov w8, v2.b[14]
+; CHECK-NEXT: mov z3.b, z2.b[12]
; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z1.b
-; CHECK-NEXT: mov z1.b, w9
-; CHECK-NEXT: umov w9, v2.b[15]
+; CHECK-NEXT: mov z1.b, z2.b[13]
; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z3.b
-; CHECK-NEXT: mov z2.b, w8
+; CHECK-NEXT: mov z3.b, z2.b[14]
; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
; CHECK-NEXT: cmpeq p3.b, p1/z, z0.b, z1.b
-; CHECK-NEXT: mov z1.b, w9
+; CHECK-NEXT: mov z1.b, z2.b[15]
; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
-; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z2.b
+; CHECK-NEXT: cmpeq p4.b, p1/z, z0.b, z3.b
; CHECK-NEXT: cmpeq p1.b, p1/z, z0.b, z1.b
; CHECK-NEXT: sel p2.b, p2, p2.b, p3.b
; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
@@ -496,18 +465,15 @@ define <vscale x 4 x i1> @match_nxv4xi32_v4i32(<vscale x 4 x i32> %op1, <4 x i32
; CHECK-NEXT: str p4, [sp, #7, mul vl] // 2-byte Folded Spill
; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
; CHECK-NEXT: .cfi_offset w29, -16
-; CHECK-NEXT: mov w8, v1.s[1]
-; CHECK-NEXT: fmov w10, s1
-; CHECK-NEXT: mov w9, v1.s[2]
-; CHECK-NEXT: mov w11, v1.s[3]
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: mov z2.s, z1.s[1]
+; CHECK-NEXT: mov z3.s, s1
; CHECK-NEXT: ptrue p1.s
-; CHECK-NEXT: mov z2.s, w10
-; CHECK-NEXT: mov z1.s, w8
-; CHECK-NEXT: mov z3.s, w9
-; CHECK-NEXT: cmpeq p3.s, p1/z, z0.s, z2.s
-; CHECK-NEXT: cmpeq p2.s, p1/z, z0.s, z1.s
-; CHECK-NEXT: mov z1.s, w11
-; CHECK-NEXT: cmpeq p4.s, p1/z, z0.s, z3.s
+; CHECK-NEXT: mov z4.s, z1.s[2]
+; CHECK-NEXT: mov z1.s, z1.s[3]
+; CHECK-NEXT: cmpeq p2.s, p1/z, z0.s, z2.s
+; CHECK-NEXT: cmpeq p3.s, p1/z, z0.s, z3.s
+; CHECK-NEXT: cmpeq p4.s, p1/z, z0.s, z4.s
; CHECK-NEXT: cmpeq p1.s, p1/z, z0.s, z1.s
; CHECK-NEXT: mov p2.b, p3/m, p3.b
; CHECK-NEXT: sel p2.b, p2, p2.b, p4.b
@@ -524,13 +490,12 @@ define <vscale x 4 x i1> @match_nxv4xi32_v4i32(<vscale x 4 x i32> %op1, <4 x i32
define <vscale x 2 x i1> @match_nxv2xi64_v2i64(<vscale x 2 x i64> %op1, <2 x i64> %op2, <vscale x 2 x i1> %mask) #0 {
; CHECK-LABEL: match_nxv2xi64_v2i64:
; CHECK: // %bb.0:
-; CHECK-NEXT: mov x8, v1.d[1]
-; CHECK-NEXT: fmov x9, d1
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: mov z2.d, z1.d[1]
; CHECK-NEXT: ptrue p1.d
-; CHECK-NEXT: mov z2.d, x9
-; CHECK-NEXT: mov z1.d, x8
-; CHECK-NEXT: cmpeq p2.d, p1/z, z0.d, z1.d
-; CHECK-NEXT: cmpeq p1.d, p1/z, z0.d, z2.d
+; CHECK-NEXT: mov z1.d, d1
+; CHECK-NEXT: cmpeq p2.d, p1/z, z0.d, z2.d
+; CHECK-NEXT: cmpeq p1.d, p1/z, z0.d, z1.d
; CHECK-NEXT: sel p1.b, p1, p1.b, p2.b
; CHECK-NEXT: and p0.b, p1/z, p1.b, p0.b
; CHECK-NEXT: ret
>From e6668bc9b5432727c99e15ded2fa6178c9ca740d Mon Sep 17 00:00:00 2001
From: Ricardo Jesus <rjj at nvidia.com>
Date: Wed, 13 Nov 2024 07:45:12 -0800
Subject: [PATCH 7/7] Replace call to MVT::getVectorVT with
getPackedSVEVectorVT
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b30cac3975ff4d..58b279aa5fac77 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -5826,8 +5826,7 @@ SDValue LowerVectorMatch(SDValue Op, SelectionDAG &DAG) {
// If Op2 is not a full 128-bit vector, we always need to broadcast it.
unsigned Op2BitWidth = Op2VT.getFixedSizeInBits();
MVT Op2IntVT = MVT::getIntegerVT(Op2BitWidth);
- MVT Op2PromotedVT = MVT::getVectorVT(Op2IntVT, 128 / Op2BitWidth,
- /*IsScalable=*/true);
+ EVT Op2PromotedVT = getPackedSVEVectorVT(Op2IntVT);
Op2 = DAG.getBitcast(MVT::getVectorVT(Op2IntVT, 1), Op2);
Op2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, Op2IntVT, Op2,
DAG.getConstant(0, dl, MVT::i64));
More information about the llvm-commits
mailing list