[llvm] [AArch64] Add @llvm.experimental.vector.match (PR #101974)

Ricardo Jesus via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 31 08:40:00 PDT 2024


https://github.com/rj-jesus updated https://github.com/llvm/llvm-project/pull/101974

>From e9bd6d43f01c815839b2cf82a4afddb92cb7b711 Mon Sep 17 00:00:00 2001
From: Ricardo Jesus <rjj at nvidia.com>
Date: Fri, 19 Jul 2024 16:10:51 +0100
Subject: [PATCH 1/4] [AArch64] Add @llvm.experimental.vector.match

This patch introduces an experimental intrinsic for matching the
elements of one vector against the elements of another.

For AArch64 targets that support SVE2, it lowers to a MATCH instruction
for supported fixed and scalar types. Otherwise, the intrinsic has
generic lowering in SelectionDAGBuilder.
---
 llvm/docs/LangRef.rst                         |  39 +++
 .../llvm/Analysis/TargetTransformInfo.h       |  10 +
 .../llvm/Analysis/TargetTransformInfoImpl.h   |   4 +
 llvm/include/llvm/IR/Intrinsics.td            |   8 +
 llvm/lib/Analysis/TargetTransformInfo.cpp     |   5 +
 .../SelectionDAG/SelectionDAGBuilder.cpp      |  36 +++
 llvm/lib/IR/Verifier.cpp                      |  21 ++
 .../Target/AArch64/AArch64ISelLowering.cpp    |  53 ++++
 .../AArch64/AArch64TargetTransformInfo.cpp    |  24 ++
 .../AArch64/AArch64TargetTransformInfo.h      |   2 +
 .../AArch64/intrinsic-vector-match-sve2.ll    | 253 ++++++++++++++++++
 11 files changed, 455 insertions(+)
 create mode 100644 llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll

diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 566d0d4e4e81a3..aedb101c4af68c 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -20043,6 +20043,45 @@ are undefined.
     }
 
 
+'``llvm.experimental.vector.match.*``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+This is an overloaded intrinsic. Support for specific vector types is target
+dependent.
+
+::
+
+    declare <<n> x i1> @llvm.experimental.vector.match(<<n> x <ty>> %op1, <<m> x <ty>> %op2, <<n> x i1> %mask)
+    declare <vscale x <n> x i1> @llvm.experimental.vector.match(<vscale x <n> x <ty>> %op1, <<m> x <ty>> %op2, <vscale x <n> x i1> %mask)
+
+Overview:
+"""""""""
+
+Find active elements of the first argument matching any elements of the second.
+
+Arguments:
+""""""""""
+
+The first argument is the search vector, the second argument the vector of
+elements we are searching for (i.e. for which we consider a match successful),
+and the third argument is a mask that controls which elements of the first
+argument are active.
+
+Semantics:
+""""""""""
+
+The '``llvm.experimental.vector.match``' intrinsic compares each active element
+in the first argument against the elements of the second argument, placing
+``1`` in the corresponding element of the output vector if any comparison is
+successful, and ``0`` otherwise. Inactive elements in the mask are set to ``0``
+in the output.
+
+The second argument needs to be a fixed-length vector with the same element
+type as the first argument.
+
 Matrix Intrinsics
 -----------------
 
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 0459941fe05cdc..f47874bf0407d5 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1771,6 +1771,11 @@ class TargetTransformInfo {
   /// This should also apply to lowering for vector funnel shifts (rotates).
   bool isVectorShiftByScalarCheap(Type *Ty) const;
 
+  /// \returns True if the target has hardware support for vector match
+  /// operations between vectors of type `VT` and search vectors of `SearchSize`
+  /// elements, and false otherwise.
+  bool hasVectorMatch(VectorType *VT, unsigned SearchSize) const;
+
   struct VPLegalization {
     enum VPTransform {
       // keep the predicating parameter
@@ -2221,6 +2226,7 @@ class TargetTransformInfo::Concept {
                              SmallVectorImpl<Use *> &OpsToSink) const = 0;
 
   virtual bool isVectorShiftByScalarCheap(Type *Ty) const = 0;
+  virtual bool hasVectorMatch(VectorType *VT, unsigned SearchSize) const = 0;
   virtual VPLegalization
   getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
   virtual bool hasArmWideBranch(bool Thumb) const = 0;
@@ -3014,6 +3020,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.isVectorShiftByScalarCheap(Ty);
   }
 
+  bool hasVectorMatch(VectorType *VT, unsigned SearchSize) const override {
+    return Impl.hasVectorMatch(VT, SearchSize);
+  }
+
   VPLegalization
   getVPLegalizationStrategy(const VPIntrinsic &PI) const override {
     return Impl.getVPLegalizationStrategy(PI);
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index dbdfb4d8cdfa32..886acb5120330f 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -995,6 +995,10 @@ class TargetTransformInfoImplBase {
 
   bool isVectorShiftByScalarCheap(Type *Ty) const { return false; }
 
+  bool hasVectorMatch(VectorType *VT, unsigned SearchSize) const {
+    return false;
+  }
+
   TargetTransformInfo::VPLegalization
   getVPLegalizationStrategy(const VPIntrinsic &PI) const {
     return TargetTransformInfo::VPLegalization(
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 94e53f372127da..ccdaf79f63f1ff 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1918,6 +1918,14 @@ def int_experimental_vector_histogram_add : DefaultAttrsIntrinsic<[],
                                LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>], // Mask
                              [ IntrArgMemOnly ]>;
 
+// Experimental match
+def int_experimental_vector_match : DefaultAttrsIntrinsic<
+                             [ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty> ],
+                             [ llvm_anyvector_ty,
+                               llvm_anyvector_ty,
+                               LLVMScalarOrSameVectorWidth<0, llvm_i1_ty> ],  // Mask
+                             [ IntrNoMem, IntrNoSync, IntrWillReturn ]>;
+
 // Operators
 let IntrProperties = [IntrNoMem, IntrNoSync, IntrWillReturn] in {
   // Integer arithmetic
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index a47462b61e03b2..ca7b258dc08d79 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1383,6 +1383,11 @@ bool TargetTransformInfo::isVectorShiftByScalarCheap(Type *Ty) const {
   return TTIImpl->isVectorShiftByScalarCheap(Ty);
 }
 
+bool TargetTransformInfo::hasVectorMatch(VectorType *VT,
+                                         unsigned SearchSize) const {
+  return TTIImpl->hasVectorMatch(VT, SearchSize);
+}
+
 TargetTransformInfo::Concept::~Concept() = default;
 
 TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {}
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 8450553743074c..ca67f623f46258 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8156,6 +8156,42 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
              DAG.getNode(ISD::EXTRACT_SUBVECTOR, sdl, ResultVT, Vec, Index));
     return;
   }
+  case Intrinsic::experimental_vector_match: {
+    SDValue Op1 = getValue(I.getOperand(0));
+    SDValue Op2 = getValue(I.getOperand(1));
+    SDValue Mask = getValue(I.getOperand(2));
+    EVT Op1VT = Op1.getValueType();
+    EVT Op2VT = Op2.getValueType();
+    EVT ResVT = Mask.getValueType();
+    unsigned SearchSize = Op2VT.getVectorNumElements();
+
+    LLVMContext &Ctx = *DAG.getContext();
+    const auto &TTI =
+        TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
+
+    // If the target has native support for this vector match operation, lower
+    // the intrinsic directly; otherwise, lower it below.
+    if (TTI.hasVectorMatch(cast<VectorType>(Op1VT.getTypeForEVT(Ctx)),
+                           SearchSize)) {
+      visitTargetIntrinsic(I, Intrinsic);
+      return;
+    }
+
+    SDValue Ret = DAG.getNode(ISD::SPLAT_VECTOR, sdl, ResVT,
+                              DAG.getConstant(0, sdl, MVT::i1));
+
+    for (unsigned i = 0; i < SearchSize; ++i) {
+      SDValue Op2Elem = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, sdl,
+                                    Op2VT.getVectorElementType(), Op2,
+                                    DAG.getVectorIdxConstant(i, sdl));
+      SDValue Splat = DAG.getNode(ISD::SPLAT_VECTOR, sdl, Op1VT, Op2Elem);
+      SDValue Cmp = DAG.getSetCC(sdl, ResVT, Op1, Splat, ISD::SETEQ);
+      Ret = DAG.getNode(ISD::OR, sdl, ResVT, Ret, Cmp);
+    }
+
+    setValue(&I, DAG.getNode(ISD::AND, sdl, ResVT, Ret, Mask));
+    return;
+  }
   case Intrinsic::vector_reverse:
     visitVectorReverse(I);
     return;
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index f34fe7594c8602..8a031956d70dd9 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -6154,6 +6154,27 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
           &Call);
     break;
   }
+  case Intrinsic::experimental_vector_match: {
+    Value *Op1 = Call.getArgOperand(0);
+    Value *Op2 = Call.getArgOperand(1);
+    Value *Mask = Call.getArgOperand(2);
+
+    VectorType *Op1Ty = dyn_cast<VectorType>(Op1->getType());
+    VectorType *Op2Ty = dyn_cast<VectorType>(Op2->getType());
+    VectorType *MaskTy = dyn_cast<VectorType>(Mask->getType());
+
+    Check(Op1Ty && Op2Ty && MaskTy, "Operands must be vectors.", &Call);
+    Check(!isa<ScalableVectorType>(Op2Ty), "Second operand cannot be scalable.",
+          &Call);
+    Check(Op1Ty->getElementType() == Op2Ty->getElementType(),
+          "First two operands must have the same element type.", &Call);
+    Check(Op1Ty->getElementCount() == MaskTy->getElementCount(),
+          "First operand and mask must have the same number of elements.",
+          &Call);
+    Check(MaskTy->getElementType()->isIntegerTy(1),
+          "Mask must be a vector of i1's.", &Call);
+    break;
+  }
   case Intrinsic::vector_insert: {
     Value *Vec = Call.getArgOperand(0);
     Value *SubVec = Call.getArgOperand(1);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 4aa123b42d1966..979a0cd904f4ad 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -6364,6 +6364,58 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
         DAG.getNode(AArch64ISD::CTTZ_ELTS, dl, MVT::i64, CttzOp);
     return DAG.getZExtOrTrunc(NewCttzElts, dl, Op.getValueType());
   }
+  case Intrinsic::experimental_vector_match: {
+    SDValue ID =
+        DAG.getTargetConstant(Intrinsic::aarch64_sve_match, dl, MVT::i64);
+
+    auto Op1 = Op.getOperand(1);
+    auto Op2 = Op.getOperand(2);
+    auto Mask = Op.getOperand(3);
+
+    EVT Op1VT = Op1.getValueType();
+    EVT Op2VT = Op2.getValueType();
+    EVT ResVT = Op.getValueType();
+
+    assert((Op1VT.getVectorElementType() == MVT::i8 ||
+            Op1VT.getVectorElementType() == MVT::i16) &&
+           "Expected 8-bit or 16-bit characters.");
+    assert(!Op2VT.isScalableVector() && "Search vector cannot be scalable.");
+    assert(Op1VT.getVectorElementType() == Op2VT.getVectorElementType() &&
+           "Operand type mismatch.");
+    assert(Op1VT.getVectorMinNumElements() == Op2VT.getVectorNumElements() &&
+           "Invalid operands.");
+
+    // Wrap the search vector in a scalable vector.
+    EVT OpContainerVT = getContainerForFixedLengthVector(DAG, Op2VT);
+    Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
+
+    // If the result is scalable, we need to broadbast the search vector across
+    // the SVE register and then carry out the MATCH.
+    if (ResVT.isScalableVector()) {
+      Op2 = DAG.getNode(AArch64ISD::DUPLANE128, dl, OpContainerVT, Op2,
+                        DAG.getTargetConstant(0, dl, MVT::i64));
+      return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, ResVT, ID, Mask, Op1,
+                         Op2);
+    }
+
+    // If the result is fixed, we can still use MATCH but we need to wrap the
+    // first operand and the mask in scalable vectors before doing so.
+    EVT MatchVT = OpContainerVT.changeElementType(MVT::i1);
+
+    // Wrap the operands.
+    Op1 = convertToScalableVector(DAG, OpContainerVT, Op1);
+    Mask = DAG.getNode(ISD::ANY_EXTEND, dl, Op1VT, Mask);
+    Mask = convertFixedMaskToScalableVector(Mask, DAG);
+
+    // Carry out the match.
+    SDValue Match =
+        DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, MatchVT, ID, Mask, Op1, Op2);
+
+    // Extract and return the result.
+    return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, Op1VT,
+                       DAG.getNode(ISD::SIGN_EXTEND, dl, OpContainerVT, Match),
+                       DAG.getVectorIdxConstant(0, dl));
+  }
   }
 }
 
@@ -27046,6 +27098,7 @@ void AArch64TargetLowering::ReplaceNodeResults(
       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
       return;
     }
+    case Intrinsic::experimental_vector_match:
     case Intrinsic::get_active_lane_mask: {
       if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i1)
         return;
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index ff3c69f7e10c66..9d33a368a9e86d 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -4072,6 +4072,30 @@ bool AArch64TTIImpl::isLegalToVectorizeReduction(
   }
 }
 
+bool AArch64TTIImpl::hasVectorMatch(VectorType *VT, unsigned SearchSize) const {
+  // Check that (i) the target has SVE2 and SVE is available, (ii) `VT' is a
+  // legal type for MATCH, and (iii) the search vector can be broadcast
+  // efficently to a legal type.
+  //
+  // Currently, we require the length of the search vector to match the minimum
+  // number of elements of `VT'. In practice this means we only support the
+  // cases (nxv16i8, 16), (v16i8, 16), (nxv8i16, 8), and (v8i16, 8), where the
+  // first element of the tuples corresponds to the type of the first argument
+  // and the second the length of the search vector.
+  //
+  // In the future we can support more cases. For example, (nxv16i8, 4) could
+  // be efficiently supported by using a DUP.S to broadcast the search
+  // elements, and more exotic cases like (nxv16i8, 5) could be supported by a
+  // sequence of SEL(DUP).
+  if (ST->hasSVE2() && ST->isSVEAvailable() &&
+      VT->getPrimitiveSizeInBits().getKnownMinValue() == 128 &&
+      (VT->getElementCount().getKnownMinValue() == 8 ||
+       VT->getElementCount().getKnownMinValue() == 16) &&
+      VT->getElementCount().getKnownMinValue() == SearchSize)
+    return true;
+  return false;
+}
+
 InstructionCost
 AArch64TTIImpl::getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty,
                                        FastMathFlags FMF,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 1d09d67f6ec9e3..580bf5e79c3da1 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -392,6 +392,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
     return ST->hasSVE();
   }
 
+  bool hasVectorMatch(VectorType *VT, unsigned SearchSize) const;
+
   InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
                                              std::optional<FastMathFlags> FMF,
                                              TTI::TargetCostKind CostKind);
diff --git a/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll b/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll
new file mode 100644
index 00000000000000..d84a54f327a9bc
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll
@@ -0,0 +1,253 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 3
+; RUN: llc -mtriple=aarch64 < %s -o - | FileCheck %s
+
+define <vscale x 16 x i1> @match_nxv16i8_v1i8(<vscale x 16 x i8> %op1, <1 x i8> %op2, <vscale x 16 x i1> %mask) #0 {
+; CHECK-LABEL: match_nxv16i8_v1i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
+; CHECK-NEXT:    umov w8, v1.b[0]
+; CHECK-NEXT:    mov z1.b, w8
+; CHECK-NEXT:    cmpeq p0.b, p0/z, z0.b, z1.b
+; CHECK-NEXT:    ret
+  %r = tail call <vscale x 16 x i1> @llvm.experimental.vector.match(<vscale x 16 x i8> %op1, <1 x i8> %op2, <vscale x 16 x i1> %mask)
+  ret <vscale x 16 x i1> %r
+}
+
+define <vscale x 16 x i1> @match_nxv16i8_v2i8(<vscale x 16 x i8> %op1, <2 x i8> %op2, <vscale x 16 x i1> %mask) #0 {
+; CHECK-LABEL: match_nxv16i8_v2i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
+; CHECK-NEXT:    mov w8, v1.s[1]
+; CHECK-NEXT:    fmov w9, s1
+; CHECK-NEXT:    ptrue p1.b
+; CHECK-NEXT:    mov z2.b, w9
+; CHECK-NEXT:    mov z1.b, w8
+; CHECK-NEXT:    cmpeq p2.b, p1/z, z0.b, z1.b
+; CHECK-NEXT:    cmpeq p1.b, p1/z, z0.b, z2.b
+; CHECK-NEXT:    sel p1.b, p1, p1.b, p2.b
+; CHECK-NEXT:    and p0.b, p1/z, p1.b, p0.b
+; CHECK-NEXT:    ret
+  %r = tail call <vscale x 16 x i1> @llvm.experimental.vector.match(<vscale x 16 x i8> %op1, <2 x i8> %op2, <vscale x 16 x i1> %mask)
+  ret <vscale x 16 x i1> %r
+}
+
+define <vscale x 16 x i1> @match_nxv16i8_v4i8(<vscale x 16 x i8> %op1, <4 x i8> %op2, <vscale x 16 x i1> %mask) #0 {
+; CHECK-LABEL: match_nxv16i8_v4i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    addvl sp, sp, #-1
+; CHECK-NEXT:    str p4, [sp, #7, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT:    .cfi_offset w29, -16
+; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
+; CHECK-NEXT:    umov w8, v1.h[1]
+; CHECK-NEXT:    umov w9, v1.h[0]
+; CHECK-NEXT:    umov w10, v1.h[2]
+; CHECK-NEXT:    ptrue p1.b
+; CHECK-NEXT:    mov z2.b, w8
+; CHECK-NEXT:    mov z3.b, w9
+; CHECK-NEXT:    umov w8, v1.h[3]
+; CHECK-NEXT:    mov z1.b, w10
+; CHECK-NEXT:    cmpeq p2.b, p1/z, z0.b, z2.b
+; CHECK-NEXT:    cmpeq p3.b, p1/z, z0.b, z3.b
+; CHECK-NEXT:    mov z2.b, w8
+; CHECK-NEXT:    cmpeq p4.b, p1/z, z0.b, z1.b
+; CHECK-NEXT:    cmpeq p1.b, p1/z, z0.b, z2.b
+; CHECK-NEXT:    mov p2.b, p3/m, p3.b
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT:    ldr p4, [sp, #7, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    mov p1.b, p2/m, p2.b
+; CHECK-NEXT:    and p0.b, p1/z, p1.b, p0.b
+; CHECK-NEXT:    addvl sp, sp, #1
+; CHECK-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+  %r = tail call <vscale x 16 x i1> @llvm.experimental.vector.match(<vscale x 16 x i8> %op1, <4 x i8> %op2, <vscale x 16 x i1> %mask)
+  ret <vscale x 16 x i1> %r
+}
+
+define <vscale x 16 x i1> @match_nxv16i8_v8i8(<vscale x 16 x i8> %op1, <8 x i8> %op2, <vscale x 16 x i1> %mask) #0 {
+; CHECK-LABEL: match_nxv16i8_v8i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    addvl sp, sp, #-1
+; CHECK-NEXT:    str p4, [sp, #7, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT:    .cfi_offset w29, -16
+; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
+; CHECK-NEXT:    umov w8, v1.b[1]
+; CHECK-NEXT:    umov w9, v1.b[0]
+; CHECK-NEXT:    umov w10, v1.b[2]
+; CHECK-NEXT:    ptrue p1.b
+; CHECK-NEXT:    mov z2.b, w8
+; CHECK-NEXT:    mov z3.b, w9
+; CHECK-NEXT:    umov w8, v1.b[3]
+; CHECK-NEXT:    mov z4.b, w10
+; CHECK-NEXT:    umov w9, v1.b[4]
+; CHECK-NEXT:    umov w10, v1.b[7]
+; CHECK-NEXT:    cmpeq p2.b, p1/z, z0.b, z2.b
+; CHECK-NEXT:    cmpeq p3.b, p1/z, z0.b, z3.b
+; CHECK-NEXT:    mov z2.b, w8
+; CHECK-NEXT:    umov w8, v1.b[5]
+; CHECK-NEXT:    cmpeq p4.b, p1/z, z0.b, z4.b
+; CHECK-NEXT:    mov z3.b, w9
+; CHECK-NEXT:    umov w9, v1.b[6]
+; CHECK-NEXT:    mov p2.b, p3/m, p3.b
+; CHECK-NEXT:    cmpeq p3.b, p1/z, z0.b, z2.b
+; CHECK-NEXT:    mov z1.b, w8
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT:    cmpeq p4.b, p1/z, z0.b, z3.b
+; CHECK-NEXT:    mov z2.b, w9
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT:    cmpeq p3.b, p1/z, z0.b, z1.b
+; CHECK-NEXT:    mov z1.b, w10
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT:    cmpeq p4.b, p1/z, z0.b, z2.b
+; CHECK-NEXT:    cmpeq p1.b, p1/z, z0.b, z1.b
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT:    ldr p4, [sp, #7, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    mov p1.b, p2/m, p2.b
+; CHECK-NEXT:    and p0.b, p1/z, p1.b, p0.b
+; CHECK-NEXT:    addvl sp, sp, #1
+; CHECK-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+  %r = tail call <vscale x 16 x i1> @llvm.experimental.vector.match(<vscale x 16 x i8> %op1, <8 x i8> %op2, <vscale x 16 x i1> %mask)
+  ret <vscale x 16 x i1> %r
+}
+
+define <vscale x 16 x i1> @match_nxv16i8_v16i8(<vscale x 16 x i8> %op1, <16 x i8> %op2, <vscale x 16 x i1> %mask) #0 {
+; CHECK-LABEL: match_nxv16i8_v16i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    mov z1.q, q1
+; CHECK-NEXT:    match p0.b, p0/z, z0.b, z1.b
+; CHECK-NEXT:    ret
+  %r = tail call <vscale x 16 x i1> @llvm.experimental.vector.match(<vscale x 16 x i8> %op1, <16 x i8> %op2, <vscale x 16 x i1> %mask)
+  ret <vscale x 16 x i1> %r
+}
+
+define <16 x i1> @match_v16i8_v1i8(<16 x i8> %op1, <1 x i8> %op2, <16 x i1> %mask) #0 {
+; CHECK-LABEL: match_v16i8_v1i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
+; CHECK-NEXT:    dup v1.16b, v1.b[0]
+; CHECK-NEXT:    cmeq v0.16b, v0.16b, v1.16b
+; CHECK-NEXT:    and v0.16b, v0.16b, v2.16b
+; CHECK-NEXT:    ret
+  %r = tail call <16 x i1> @llvm.experimental.vector.match(<16 x i8> %op1, <1 x i8> %op2, <16 x i1> %mask)
+  ret <16 x i1> %r
+}
+
+define <16 x i1> @match_v16i8_v2i8(<16 x i8> %op1, <2 x i8> %op2, <16 x i1> %mask) #0 {
+; CHECK-LABEL: match_v16i8_v2i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
+; CHECK-NEXT:    dup v3.16b, v1.b[4]
+; CHECK-NEXT:    dup v1.16b, v1.b[0]
+; CHECK-NEXT:    cmeq v3.16b, v0.16b, v3.16b
+; CHECK-NEXT:    cmeq v0.16b, v0.16b, v1.16b
+; CHECK-NEXT:    orr v0.16b, v0.16b, v3.16b
+; CHECK-NEXT:    and v0.16b, v0.16b, v2.16b
+; CHECK-NEXT:    ret
+  %r = tail call <16 x i1> @llvm.experimental.vector.match(<16 x i8> %op1, <2 x i8> %op2, <16 x i1> %mask)
+  ret <16 x i1> %r
+}
+
+define <16 x i1> @match_v16i8_v4i8(<16 x i8> %op1, <4 x i8> %op2, <16 x i1> %mask) #0 {
+; CHECK-LABEL: match_v16i8_v4i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
+; CHECK-NEXT:    dup v3.16b, v1.b[2]
+; CHECK-NEXT:    dup v4.16b, v1.b[0]
+; CHECK-NEXT:    dup v5.16b, v1.b[4]
+; CHECK-NEXT:    dup v1.16b, v1.b[6]
+; CHECK-NEXT:    cmeq v3.16b, v0.16b, v3.16b
+; CHECK-NEXT:    cmeq v4.16b, v0.16b, v4.16b
+; CHECK-NEXT:    cmeq v5.16b, v0.16b, v5.16b
+; CHECK-NEXT:    cmeq v0.16b, v0.16b, v1.16b
+; CHECK-NEXT:    orr v1.16b, v4.16b, v3.16b
+; CHECK-NEXT:    orr v0.16b, v5.16b, v0.16b
+; CHECK-NEXT:    orr v0.16b, v1.16b, v0.16b
+; CHECK-NEXT:    and v0.16b, v0.16b, v2.16b
+; CHECK-NEXT:    ret
+  %r = tail call <16 x i1> @llvm.experimental.vector.match(<16 x i8> %op1, <4 x i8> %op2, <16 x i1> %mask)
+  ret <16 x i1> %r
+}
+
+define <16 x i1> @match_v16i8_v8i8(<16 x i8> %op1, <8 x i8> %op2, <16 x i1> %mask) #0 {
+; CHECK-LABEL: match_v16i8_v8i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
+; CHECK-NEXT:    dup v3.16b, v1.b[1]
+; CHECK-NEXT:    dup v4.16b, v1.b[0]
+; CHECK-NEXT:    dup v5.16b, v1.b[2]
+; CHECK-NEXT:    dup v6.16b, v1.b[3]
+; CHECK-NEXT:    dup v7.16b, v1.b[4]
+; CHECK-NEXT:    dup v16.16b, v1.b[5]
+; CHECK-NEXT:    dup v17.16b, v1.b[6]
+; CHECK-NEXT:    dup v1.16b, v1.b[7]
+; CHECK-NEXT:    cmeq v3.16b, v0.16b, v3.16b
+; CHECK-NEXT:    cmeq v4.16b, v0.16b, v4.16b
+; CHECK-NEXT:    cmeq v5.16b, v0.16b, v5.16b
+; CHECK-NEXT:    cmeq v6.16b, v0.16b, v6.16b
+; CHECK-NEXT:    cmeq v7.16b, v0.16b, v7.16b
+; CHECK-NEXT:    cmeq v16.16b, v0.16b, v16.16b
+; CHECK-NEXT:    orr v3.16b, v4.16b, v3.16b
+; CHECK-NEXT:    orr v4.16b, v5.16b, v6.16b
+; CHECK-NEXT:    orr v5.16b, v7.16b, v16.16b
+; CHECK-NEXT:    cmeq v6.16b, v0.16b, v17.16b
+; CHECK-NEXT:    cmeq v0.16b, v0.16b, v1.16b
+; CHECK-NEXT:    orr v3.16b, v3.16b, v4.16b
+; CHECK-NEXT:    orr v4.16b, v5.16b, v6.16b
+; CHECK-NEXT:    orr v3.16b, v3.16b, v4.16b
+; CHECK-NEXT:    orr v0.16b, v3.16b, v0.16b
+; CHECK-NEXT:    and v0.16b, v0.16b, v2.16b
+; CHECK-NEXT:    ret
+  %r = tail call <16 x i1> @llvm.experimental.vector.match(<16 x i8> %op1, <8 x i8> %op2, <16 x i1> %mask)
+  ret <16 x i1> %r
+}
+
+define <16 x i1> @match_v16i8_v16i8(<16 x i8> %op1, <16 x i8> %op2, <16 x i1> %mask) #0 {
+; CHECK-LABEL: match_v16i8_v16i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.b, vl16
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    cmpne p0.b, p0/z, z2.b, #0
+; CHECK-NEXT:    match p0.b, p0/z, z0.b, z1.b
+; CHECK-NEXT:    mov z0.b, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+  %r = tail call <16 x i1> @llvm.experimental.vector.match(<16 x i8> %op1, <16 x i8> %op2, <16 x i1> %mask)
+  ret <16 x i1> %r
+}
+
+define <vscale x 8 x i1> @match_nxv8i16_v8i16(<vscale x 8 x i16> %op1, <8 x i16> %op2, <vscale x 8 x i1> %mask) #0 {
+; CHECK-LABEL: match_nxv8i16_v8i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    mov z1.q, q1
+; CHECK-NEXT:    match p0.h, p0/z, z0.h, z1.h
+; CHECK-NEXT:    ret
+  %r = tail call <vscale x 8 x i1> @llvm.experimental.vector.match(<vscale x 8 x i16> %op1, <8 x i16> %op2, <vscale x 8 x i1> %mask)
+  ret <vscale x 8 x i1> %r
+}
+
+define <8 x i1> @match_v8i16(<8 x i16> %op1, <8 x i16> %op2, <8 x i1> %mask) #0 {
+; CHECK-LABEL: match_v8i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ushll v2.8h, v2.8b, #0
+; CHECK-NEXT:    ptrue p0.h, vl8
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    cmpne p0.h, p0/z, z2.h, #0
+; CHECK-NEXT:    match p0.h, p0/z, z0.h, z1.h
+; CHECK-NEXT:    mov z0.h, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    xtn v0.8b, v0.8h
+; CHECK-NEXT:    ret
+  %r = tail call <8 x i1> @llvm.experimental.vector.match(<8 x i16> %op1, <8 x i16> %op2, <8 x i1> %mask)
+  ret <8 x i1> %r
+}
+
+attributes #0 = { "target-features"="+sve2" }

>From 3f9398dd3603962f280e8f0fa760079e446ec1db Mon Sep 17 00:00:00 2001
From: Ricardo Jesus <rjj at nvidia.com>
Date: Wed, 16 Oct 2024 10:11:00 -0700
Subject: [PATCH 2/4] Add support to lower partial search vectors

Add address other review comments.
---
 llvm/lib/IR/Verifier.cpp                      |   2 +
 .../Target/AArch64/AArch64ISelLowering.cpp    |  88 +++-
 .../AArch64/AArch64TargetTransformInfo.cpp    |  15 +-
 .../AArch64/intrinsic-vector-match-sve2.ll    | 419 +++++++++++++++---
 4 files changed, 424 insertions(+), 100 deletions(-)

diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 8a031956d70dd9..09ec2ff4628d36 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -6173,6 +6173,8 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
           &Call);
     Check(MaskTy->getElementType()->isIntegerTy(1),
           "Mask must be a vector of i1's.", &Call);
+    Check(Call.getType() == MaskTy, "Return type must match the mask type.",
+          &Call);
     break;
   }
   case Intrinsic::vector_insert: {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 979a0cd904f4ad..d9a21d7f60df2c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -6379,42 +6379,86 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
     assert((Op1VT.getVectorElementType() == MVT::i8 ||
             Op1VT.getVectorElementType() == MVT::i16) &&
            "Expected 8-bit or 16-bit characters.");
-    assert(!Op2VT.isScalableVector() && "Search vector cannot be scalable.");
     assert(Op1VT.getVectorElementType() == Op2VT.getVectorElementType() &&
            "Operand type mismatch.");
-    assert(Op1VT.getVectorMinNumElements() == Op2VT.getVectorNumElements() &&
-           "Invalid operands.");
-
-    // Wrap the search vector in a scalable vector.
-    EVT OpContainerVT = getContainerForFixedLengthVector(DAG, Op2VT);
-    Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
-
-    // If the result is scalable, we need to broadbast the search vector across
-    // the SVE register and then carry out the MATCH.
-    if (ResVT.isScalableVector()) {
-      Op2 = DAG.getNode(AArch64ISD::DUPLANE128, dl, OpContainerVT, Op2,
-                        DAG.getTargetConstant(0, dl, MVT::i64));
+    assert(!Op2VT.isScalableVector() && "Search vector cannot be scalable.");
+
+    // Note: Currently Op1 needs to be v16i8, v8i16, or the scalable versions.
+    // In the future we could support other types (e.g. v8i8).
+    assert(Op1VT.getSizeInBits().getKnownMinValue() == 128 &&
+           "Unsupported first operand type.");
+
+    // Scalable vector type used to wrap operands.
+    // A single container is enough for both operands because ultimately the
+    // operands will have to be wrapped to the same type (nxv16i8 or nxv8i16).
+    EVT OpContainerVT = Op1VT.isScalableVector()
+                            ? Op1VT
+                            : getContainerForFixedLengthVector(DAG, Op1VT);
+
+    // Wrap Op2 in a scalable register, and splat it if necessary.
+    if (Op1VT.getVectorMinNumElements() == Op2VT.getVectorNumElements()) {
+      // If Op1 and Op2 have the same number of elements we can trivially
+      // wrapping Op2 in an SVE register.
+      Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
+      // If the result is scalable, we need to broadcast Op2 to a full SVE
+      // register.
+      if (ResVT.isScalableVector())
+        Op2 = DAG.getNode(AArch64ISD::DUPLANE128, dl, OpContainerVT, Op2,
+                          DAG.getTargetConstant(0, dl, MVT::i64));
+    } else {
+      // If Op1 and Op2 have different number of elements, we need to broadcast
+      // Op2. Ideally we would use a AArch64ISD::DUPLANE* node for this
+      // similarly to the above, but unfortunately it seems we are missing some
+      // patterns for this. So, in alternative, we splat Op2 through a splat of
+      // a scalable vector extract. This idiom, though a bit more verbose, is
+      // supported and get us the MOV instruction we want.
+
+      // Some types we need. We'll use an integer type with `Op2BitWidth' bits
+      // to wrap Op2 and simulate the DUPLANE.
+      unsigned Op2BitWidth = Op2VT.getFixedSizeInBits();
+      MVT Op2IntVT = MVT::getIntegerVT(Op2BitWidth);
+      MVT Op2FixedVT = MVT::getVectorVT(Op2IntVT, 128 / Op2BitWidth);
+      EVT Op2ScalableVT = getContainerForFixedLengthVector(DAG, Op2FixedVT);
+      // Widen Op2 to a full 128-bit register. We need this to wrap Op2 in an
+      // SVE register before doing the extract and splat.
+      // It is unlikely we'll be widening from types other than v8i8 or v4i16,
+      // so in practice this loop will run for a single iteration.
+      while (Op2VT.getFixedSizeInBits() != 128) {
+        Op2VT = Op2VT.getDoubleNumVectorElementsVT(*DAG.getContext());
+        Op2 = DAG.getNode(ISD::CONCAT_VECTORS, dl, Op2VT, Op2,
+                          DAG.getUNDEF(Op2.getValueType()));
+      }
+      // Wrap Op2 in a scalable vector and do the splat of its 0-index lane.
+      Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
+      Op2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, Op2IntVT,
+                        DAG.getBitcast(Op2ScalableVT, Op2),
+                        DAG.getConstant(0, dl, MVT::i64));
+      Op2 = DAG.getSplatVector(Op2ScalableVT, dl, Op2);
+      Op2 = DAG.getBitcast(OpContainerVT, Op2);
+    }
+
+    // If the result is scalable, we just need to carry out the MATCH.
+    if (ResVT.isScalableVector())
       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, ResVT, ID, Mask, Op1,
                          Op2);
-    }
 
     // If the result is fixed, we can still use MATCH but we need to wrap the
     // first operand and the mask in scalable vectors before doing so.
-    EVT MatchVT = OpContainerVT.changeElementType(MVT::i1);
 
     // Wrap the operands.
     Op1 = convertToScalableVector(DAG, OpContainerVT, Op1);
     Mask = DAG.getNode(ISD::ANY_EXTEND, dl, Op1VT, Mask);
     Mask = convertFixedMaskToScalableVector(Mask, DAG);
 
-    // Carry out the match.
-    SDValue Match =
-        DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, MatchVT, ID, Mask, Op1, Op2);
+    // Carry out the match and extract it.
+    SDValue Match = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl,
+                                Mask.getValueType(), ID, Mask, Op1, Op2);
+    Match = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, Op1VT,
+                        DAG.getNode(ISD::SIGN_EXTEND, dl, OpContainerVT, Match),
+                        DAG.getVectorIdxConstant(0, dl));
 
-    // Extract and return the result.
-    return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, Op1VT,
-                       DAG.getNode(ISD::SIGN_EXTEND, dl, OpContainerVT, Match),
-                       DAG.getVectorIdxConstant(0, dl));
+    // Truncate and return the result.
+    return DAG.getNode(ISD::TRUNCATE, dl, ResVT, Match);
   }
   }
 }
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 9d33a368a9e86d..4a68ecce3b654c 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -4077,21 +4077,14 @@ bool AArch64TTIImpl::hasVectorMatch(VectorType *VT, unsigned SearchSize) const {
   // legal type for MATCH, and (iii) the search vector can be broadcast
   // efficently to a legal type.
   //
-  // Currently, we require the length of the search vector to match the minimum
-  // number of elements of `VT'. In practice this means we only support the
-  // cases (nxv16i8, 16), (v16i8, 16), (nxv8i16, 8), and (v8i16, 8), where the
-  // first element of the tuples corresponds to the type of the first argument
-  // and the second the length of the search vector.
-  //
-  // In the future we can support more cases. For example, (nxv16i8, 4) could
-  // be efficiently supported by using a DUP.S to broadcast the search
-  // elements, and more exotic cases like (nxv16i8, 5) could be supported by a
-  // sequence of SEL(DUP).
+  // Currently, we require the search vector to be 64-bit or 128-bit. In the
+  // future we can support more cases.
   if (ST->hasSVE2() && ST->isSVEAvailable() &&
       VT->getPrimitiveSizeInBits().getKnownMinValue() == 128 &&
       (VT->getElementCount().getKnownMinValue() == 8 ||
        VT->getElementCount().getKnownMinValue() == 16) &&
-      VT->getElementCount().getKnownMinValue() == SearchSize)
+      (VT->getElementCount().getKnownMinValue() == SearchSize ||
+       VT->getElementCount().getKnownMinValue() / 2 == SearchSize))
     return true;
   return false;
 }
diff --git a/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll b/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll
index d84a54f327a9bc..e991bbc80962b5 100644
--- a/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll
+++ b/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll
@@ -68,48 +68,9 @@ define <vscale x 16 x i1> @match_nxv16i8_v4i8(<vscale x 16 x i8> %op1, <4 x i8>
 define <vscale x 16 x i1> @match_nxv16i8_v8i8(<vscale x 16 x i8> %op1, <8 x i8> %op2, <vscale x 16 x i1> %mask) #0 {
 ; CHECK-LABEL: match_nxv16i8_v8i8:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
-; CHECK-NEXT:    addvl sp, sp, #-1
-; CHECK-NEXT:    str p4, [sp, #7, mul vl] // 2-byte Folded Spill
-; CHECK-NEXT:    .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
-; CHECK-NEXT:    .cfi_offset w29, -16
-; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
-; CHECK-NEXT:    umov w8, v1.b[1]
-; CHECK-NEXT:    umov w9, v1.b[0]
-; CHECK-NEXT:    umov w10, v1.b[2]
-; CHECK-NEXT:    ptrue p1.b
-; CHECK-NEXT:    mov z2.b, w8
-; CHECK-NEXT:    mov z3.b, w9
-; CHECK-NEXT:    umov w8, v1.b[3]
-; CHECK-NEXT:    mov z4.b, w10
-; CHECK-NEXT:    umov w9, v1.b[4]
-; CHECK-NEXT:    umov w10, v1.b[7]
-; CHECK-NEXT:    cmpeq p2.b, p1/z, z0.b, z2.b
-; CHECK-NEXT:    cmpeq p3.b, p1/z, z0.b, z3.b
-; CHECK-NEXT:    mov z2.b, w8
-; CHECK-NEXT:    umov w8, v1.b[5]
-; CHECK-NEXT:    cmpeq p4.b, p1/z, z0.b, z4.b
-; CHECK-NEXT:    mov z3.b, w9
-; CHECK-NEXT:    umov w9, v1.b[6]
-; CHECK-NEXT:    mov p2.b, p3/m, p3.b
-; CHECK-NEXT:    cmpeq p3.b, p1/z, z0.b, z2.b
-; CHECK-NEXT:    mov z1.b, w8
-; CHECK-NEXT:    sel p2.b, p2, p2.b, p4.b
-; CHECK-NEXT:    cmpeq p4.b, p1/z, z0.b, z3.b
-; CHECK-NEXT:    mov z2.b, w9
-; CHECK-NEXT:    sel p2.b, p2, p2.b, p3.b
-; CHECK-NEXT:    cmpeq p3.b, p1/z, z0.b, z1.b
-; CHECK-NEXT:    mov z1.b, w10
-; CHECK-NEXT:    sel p2.b, p2, p2.b, p4.b
-; CHECK-NEXT:    cmpeq p4.b, p1/z, z0.b, z2.b
-; CHECK-NEXT:    cmpeq p1.b, p1/z, z0.b, z1.b
-; CHECK-NEXT:    sel p2.b, p2, p2.b, p3.b
-; CHECK-NEXT:    sel p2.b, p2, p2.b, p4.b
-; CHECK-NEXT:    ldr p4, [sp, #7, mul vl] // 2-byte Folded Reload
-; CHECK-NEXT:    mov p1.b, p2/m, p2.b
-; CHECK-NEXT:    and p0.b, p1/z, p1.b, p0.b
-; CHECK-NEXT:    addvl sp, sp, #1
-; CHECK-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    // kill: def $d1 killed $d1 def $z1
+; CHECK-NEXT:    mov z1.d, d1
+; CHECK-NEXT:    match p0.b, p0/z, z0.b, z1.b
 ; CHECK-NEXT:    ret
   %r = tail call <vscale x 16 x i1> @llvm.experimental.vector.match(<vscale x 16 x i8> %op1, <8 x i8> %op2, <vscale x 16 x i1> %mask)
   ret <vscale x 16 x i1> %r
@@ -177,31 +138,15 @@ define <16 x i1> @match_v16i8_v4i8(<16 x i8> %op1, <4 x i8> %op2, <16 x i1> %mas
 define <16 x i1> @match_v16i8_v8i8(<16 x i8> %op1, <8 x i8> %op2, <16 x i1> %mask) #0 {
 ; CHECK-LABEL: match_v16i8_v8i8:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
-; CHECK-NEXT:    dup v3.16b, v1.b[1]
-; CHECK-NEXT:    dup v4.16b, v1.b[0]
-; CHECK-NEXT:    dup v5.16b, v1.b[2]
-; CHECK-NEXT:    dup v6.16b, v1.b[3]
-; CHECK-NEXT:    dup v7.16b, v1.b[4]
-; CHECK-NEXT:    dup v16.16b, v1.b[5]
-; CHECK-NEXT:    dup v17.16b, v1.b[6]
-; CHECK-NEXT:    dup v1.16b, v1.b[7]
-; CHECK-NEXT:    cmeq v3.16b, v0.16b, v3.16b
-; CHECK-NEXT:    cmeq v4.16b, v0.16b, v4.16b
-; CHECK-NEXT:    cmeq v5.16b, v0.16b, v5.16b
-; CHECK-NEXT:    cmeq v6.16b, v0.16b, v6.16b
-; CHECK-NEXT:    cmeq v7.16b, v0.16b, v7.16b
-; CHECK-NEXT:    cmeq v16.16b, v0.16b, v16.16b
-; CHECK-NEXT:    orr v3.16b, v4.16b, v3.16b
-; CHECK-NEXT:    orr v4.16b, v5.16b, v6.16b
-; CHECK-NEXT:    orr v5.16b, v7.16b, v16.16b
-; CHECK-NEXT:    cmeq v6.16b, v0.16b, v17.16b
-; CHECK-NEXT:    cmeq v0.16b, v0.16b, v1.16b
-; CHECK-NEXT:    orr v3.16b, v3.16b, v4.16b
-; CHECK-NEXT:    orr v4.16b, v5.16b, v6.16b
-; CHECK-NEXT:    orr v3.16b, v3.16b, v4.16b
-; CHECK-NEXT:    orr v0.16b, v3.16b, v0.16b
-; CHECK-NEXT:    and v0.16b, v0.16b, v2.16b
+; CHECK-NEXT:    ptrue p0.b, vl16
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    // kill: def $d1 killed $d1 def $z1
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    mov z1.d, d1
+; CHECK-NEXT:    cmpne p0.b, p0/z, z2.b, #0
+; CHECK-NEXT:    match p0.b, p0/z, z0.b, z1.b
+; CHECK-NEXT:    mov z0.b, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
 ; CHECK-NEXT:    ret
   %r = tail call <16 x i1> @llvm.experimental.vector.match(<16 x i8> %op1, <8 x i8> %op2, <16 x i1> %mask)
   ret <16 x i1> %r
@@ -250,4 +195,344 @@ define <8 x i1> @match_v8i16(<8 x i16> %op1, <8 x i16> %op2, <8 x i1> %mask) #0
   ret <8 x i1> %r
 }
 
+; Cases where op2 has more elements than op1.
+
+define <vscale x 16 x i1> @match_nxv16i8_v32i8(<vscale x 16 x i8> %op1, <32 x i8> %op2, <vscale x 16 x i1> %mask) #0 {
+; CHECK-LABEL: match_nxv16i8_v32i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    addvl sp, sp, #-1
+; CHECK-NEXT:    str p4, [sp, #7, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT:    .cfi_offset w29, -16
+; CHECK-NEXT:    umov w8, v1.b[1]
+; CHECK-NEXT:    umov w9, v1.b[0]
+; CHECK-NEXT:    umov w10, v1.b[2]
+; CHECK-NEXT:    ptrue p1.b
+; CHECK-NEXT:    mov z3.b, w8
+; CHECK-NEXT:    mov z4.b, w9
+; CHECK-NEXT:    umov w8, v1.b[3]
+; CHECK-NEXT:    mov z5.b, w10
+; CHECK-NEXT:    umov w9, v1.b[4]
+; CHECK-NEXT:    umov w10, v1.b[15]
+; CHECK-NEXT:    cmpeq p2.b, p1/z, z0.b, z3.b
+; CHECK-NEXT:    cmpeq p3.b, p1/z, z0.b, z4.b
+; CHECK-NEXT:    mov z3.b, w8
+; CHECK-NEXT:    umov w8, v1.b[5]
+; CHECK-NEXT:    cmpeq p4.b, p1/z, z0.b, z5.b
+; CHECK-NEXT:    mov z4.b, w9
+; CHECK-NEXT:    umov w9, v1.b[6]
+; CHECK-NEXT:    mov p2.b, p3/m, p3.b
+; CHECK-NEXT:    cmpeq p3.b, p1/z, z0.b, z3.b
+; CHECK-NEXT:    mov z3.b, w8
+; CHECK-NEXT:    umov w8, v1.b[7]
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT:    cmpeq p4.b, p1/z, z0.b, z4.b
+; CHECK-NEXT:    mov z4.b, w9
+; CHECK-NEXT:    umov w9, v1.b[8]
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT:    cmpeq p3.b, p1/z, z0.b, z3.b
+; CHECK-NEXT:    mov z3.b, w8
+; CHECK-NEXT:    umov w8, v1.b[9]
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT:    cmpeq p4.b, p1/z, z0.b, z4.b
+; CHECK-NEXT:    mov z4.b, w9
+; CHECK-NEXT:    umov w9, v1.b[10]
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT:    cmpeq p3.b, p1/z, z0.b, z3.b
+; CHECK-NEXT:    mov z3.b, w8
+; CHECK-NEXT:    umov w8, v1.b[11]
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT:    cmpeq p4.b, p1/z, z0.b, z4.b
+; CHECK-NEXT:    mov z4.b, w9
+; CHECK-NEXT:    umov w9, v1.b[12]
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT:    cmpeq p3.b, p1/z, z0.b, z3.b
+; CHECK-NEXT:    mov z3.b, w8
+; CHECK-NEXT:    umov w8, v1.b[13]
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT:    cmpeq p4.b, p1/z, z0.b, z4.b
+; CHECK-NEXT:    mov z4.b, w9
+; CHECK-NEXT:    umov w9, v1.b[14]
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT:    cmpeq p3.b, p1/z, z0.b, z3.b
+; CHECK-NEXT:    mov z1.b, w8
+; CHECK-NEXT:    umov w8, v2.b[0]
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT:    cmpeq p4.b, p1/z, z0.b, z4.b
+; CHECK-NEXT:    mov z3.b, w9
+; CHECK-NEXT:    umov w9, v2.b[1]
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT:    cmpeq p3.b, p1/z, z0.b, z1.b
+; CHECK-NEXT:    mov z1.b, w10
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT:    cmpeq p4.b, p1/z, z0.b, z3.b
+; CHECK-NEXT:    mov z3.b, w8
+; CHECK-NEXT:    umov w8, v2.b[2]
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT:    cmpeq p3.b, p1/z, z0.b, z1.b
+; CHECK-NEXT:    mov z1.b, w9
+; CHECK-NEXT:    umov w9, v2.b[3]
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT:    cmpeq p4.b, p1/z, z0.b, z3.b
+; CHECK-NEXT:    mov z3.b, w8
+; CHECK-NEXT:    umov w8, v2.b[4]
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT:    cmpeq p3.b, p1/z, z0.b, z1.b
+; CHECK-NEXT:    mov z1.b, w9
+; CHECK-NEXT:    umov w9, v2.b[5]
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT:    cmpeq p4.b, p1/z, z0.b, z3.b
+; CHECK-NEXT:    mov z3.b, w8
+; CHECK-NEXT:    umov w8, v2.b[6]
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT:    cmpeq p3.b, p1/z, z0.b, z1.b
+; CHECK-NEXT:    mov z1.b, w9
+; CHECK-NEXT:    umov w9, v2.b[7]
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT:    cmpeq p4.b, p1/z, z0.b, z3.b
+; CHECK-NEXT:    mov z3.b, w8
+; CHECK-NEXT:    umov w8, v2.b[8]
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT:    cmpeq p3.b, p1/z, z0.b, z1.b
+; CHECK-NEXT:    mov z1.b, w9
+; CHECK-NEXT:    umov w9, v2.b[9]
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT:    cmpeq p4.b, p1/z, z0.b, z3.b
+; CHECK-NEXT:    mov z3.b, w8
+; CHECK-NEXT:    umov w8, v2.b[10]
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT:    cmpeq p3.b, p1/z, z0.b, z1.b
+; CHECK-NEXT:    mov z1.b, w9
+; CHECK-NEXT:    umov w9, v2.b[11]
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT:    cmpeq p4.b, p1/z, z0.b, z3.b
+; CHECK-NEXT:    mov z3.b, w8
+; CHECK-NEXT:    umov w8, v2.b[12]
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT:    cmpeq p3.b, p1/z, z0.b, z1.b
+; CHECK-NEXT:    mov z1.b, w9
+; CHECK-NEXT:    umov w9, v2.b[13]
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT:    cmpeq p4.b, p1/z, z0.b, z3.b
+; CHECK-NEXT:    mov z3.b, w8
+; CHECK-NEXT:    umov w8, v2.b[14]
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT:    cmpeq p3.b, p1/z, z0.b, z1.b
+; CHECK-NEXT:    mov z1.b, w9
+; CHECK-NEXT:    umov w9, v2.b[15]
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT:    cmpeq p4.b, p1/z, z0.b, z3.b
+; CHECK-NEXT:    mov z2.b, w8
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT:    cmpeq p3.b, p1/z, z0.b, z1.b
+; CHECK-NEXT:    mov z1.b, w9
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT:    cmpeq p4.b, p1/z, z0.b, z2.b
+; CHECK-NEXT:    cmpeq p1.b, p1/z, z0.b, z1.b
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p3.b
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT:    ldr p4, [sp, #7, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    mov p1.b, p2/m, p2.b
+; CHECK-NEXT:    and p0.b, p1/z, p1.b, p0.b
+; CHECK-NEXT:    addvl sp, sp, #1
+; CHECK-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+  %r = tail call <vscale x 16 x i1> @llvm.experimental.vector.match(<vscale x 16 x i8> %op1, <32 x i8> %op2, <vscale x 16 x i1> %mask)
+  ret <vscale x 16 x i1> %r
+}
+
+define <16 x i1> @match_v16i8_v32i8(<16 x i8> %op1, <32 x i8> %op2, <16 x i1> %mask) #0 {
+; CHECK-LABEL: match_v16i8_v32i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    dup v4.16b, v1.b[1]
+; CHECK-NEXT:    dup v5.16b, v1.b[0]
+; CHECK-NEXT:    dup v6.16b, v1.b[2]
+; CHECK-NEXT:    dup v7.16b, v1.b[3]
+; CHECK-NEXT:    dup v16.16b, v1.b[4]
+; CHECK-NEXT:    dup v17.16b, v1.b[5]
+; CHECK-NEXT:    dup v18.16b, v1.b[6]
+; CHECK-NEXT:    dup v19.16b, v1.b[7]
+; CHECK-NEXT:    dup v20.16b, v1.b[8]
+; CHECK-NEXT:    cmeq v4.16b, v0.16b, v4.16b
+; CHECK-NEXT:    cmeq v5.16b, v0.16b, v5.16b
+; CHECK-NEXT:    cmeq v6.16b, v0.16b, v6.16b
+; CHECK-NEXT:    cmeq v7.16b, v0.16b, v7.16b
+; CHECK-NEXT:    cmeq v16.16b, v0.16b, v16.16b
+; CHECK-NEXT:    cmeq v17.16b, v0.16b, v17.16b
+; CHECK-NEXT:    dup v21.16b, v2.b[7]
+; CHECK-NEXT:    dup v22.16b, v1.b[10]
+; CHECK-NEXT:    orr v4.16b, v5.16b, v4.16b
+; CHECK-NEXT:    orr v5.16b, v6.16b, v7.16b
+; CHECK-NEXT:    orr v6.16b, v16.16b, v17.16b
+; CHECK-NEXT:    cmeq v7.16b, v0.16b, v18.16b
+; CHECK-NEXT:    cmeq v16.16b, v0.16b, v19.16b
+; CHECK-NEXT:    cmeq v17.16b, v0.16b, v20.16b
+; CHECK-NEXT:    dup v18.16b, v1.b[9]
+; CHECK-NEXT:    dup v19.16b, v1.b[11]
+; CHECK-NEXT:    dup v20.16b, v1.b[12]
+; CHECK-NEXT:    cmeq v22.16b, v0.16b, v22.16b
+; CHECK-NEXT:    orr v4.16b, v4.16b, v5.16b
+; CHECK-NEXT:    orr v5.16b, v6.16b, v7.16b
+; CHECK-NEXT:    orr v6.16b, v16.16b, v17.16b
+; CHECK-NEXT:    cmeq v7.16b, v0.16b, v18.16b
+; CHECK-NEXT:    dup v18.16b, v1.b[13]
+; CHECK-NEXT:    cmeq v16.16b, v0.16b, v19.16b
+; CHECK-NEXT:    cmeq v17.16b, v0.16b, v20.16b
+; CHECK-NEXT:    dup v19.16b, v2.b[0]
+; CHECK-NEXT:    dup v20.16b, v2.b[1]
+; CHECK-NEXT:    orr v4.16b, v4.16b, v5.16b
+; CHECK-NEXT:    dup v5.16b, v2.b[6]
+; CHECK-NEXT:    orr v6.16b, v6.16b, v7.16b
+; CHECK-NEXT:    orr v7.16b, v16.16b, v17.16b
+; CHECK-NEXT:    cmeq v16.16b, v0.16b, v18.16b
+; CHECK-NEXT:    cmeq v17.16b, v0.16b, v19.16b
+; CHECK-NEXT:    cmeq v18.16b, v0.16b, v20.16b
+; CHECK-NEXT:    dup v19.16b, v2.b[2]
+; CHECK-NEXT:    cmeq v5.16b, v0.16b, v5.16b
+; CHECK-NEXT:    cmeq v20.16b, v0.16b, v21.16b
+; CHECK-NEXT:    dup v21.16b, v2.b[8]
+; CHECK-NEXT:    orr v6.16b, v6.16b, v22.16b
+; CHECK-NEXT:    orr v7.16b, v7.16b, v16.16b
+; CHECK-NEXT:    dup v16.16b, v1.b[14]
+; CHECK-NEXT:    dup v1.16b, v1.b[15]
+; CHECK-NEXT:    orr v17.16b, v17.16b, v18.16b
+; CHECK-NEXT:    cmeq v18.16b, v0.16b, v19.16b
+; CHECK-NEXT:    dup v19.16b, v2.b[3]
+; CHECK-NEXT:    orr v5.16b, v5.16b, v20.16b
+; CHECK-NEXT:    cmeq v20.16b, v0.16b, v21.16b
+; CHECK-NEXT:    dup v21.16b, v2.b[9]
+; CHECK-NEXT:    cmeq v16.16b, v0.16b, v16.16b
+; CHECK-NEXT:    cmeq v1.16b, v0.16b, v1.16b
+; CHECK-NEXT:    orr v4.16b, v4.16b, v6.16b
+; CHECK-NEXT:    orr v17.16b, v17.16b, v18.16b
+; CHECK-NEXT:    cmeq v18.16b, v0.16b, v19.16b
+; CHECK-NEXT:    dup v19.16b, v2.b[4]
+; CHECK-NEXT:    orr v5.16b, v5.16b, v20.16b
+; CHECK-NEXT:    cmeq v20.16b, v0.16b, v21.16b
+; CHECK-NEXT:    dup v21.16b, v2.b[10]
+; CHECK-NEXT:    orr v7.16b, v7.16b, v16.16b
+; CHECK-NEXT:    orr v16.16b, v17.16b, v18.16b
+; CHECK-NEXT:    cmeq v17.16b, v0.16b, v19.16b
+; CHECK-NEXT:    dup v18.16b, v2.b[5]
+; CHECK-NEXT:    orr v5.16b, v5.16b, v20.16b
+; CHECK-NEXT:    cmeq v19.16b, v0.16b, v21.16b
+; CHECK-NEXT:    dup v20.16b, v2.b[11]
+; CHECK-NEXT:    orr v1.16b, v7.16b, v1.16b
+; CHECK-NEXT:    orr v6.16b, v16.16b, v17.16b
+; CHECK-NEXT:    cmeq v7.16b, v0.16b, v18.16b
+; CHECK-NEXT:    dup v17.16b, v2.b[12]
+; CHECK-NEXT:    orr v5.16b, v5.16b, v19.16b
+; CHECK-NEXT:    cmeq v16.16b, v0.16b, v20.16b
+; CHECK-NEXT:    dup v18.16b, v2.b[13]
+; CHECK-NEXT:    dup v19.16b, v2.b[14]
+; CHECK-NEXT:    orr v1.16b, v4.16b, v1.16b
+; CHECK-NEXT:    dup v2.16b, v2.b[15]
+; CHECK-NEXT:    orr v4.16b, v6.16b, v7.16b
+; CHECK-NEXT:    cmeq v6.16b, v0.16b, v17.16b
+; CHECK-NEXT:    orr v5.16b, v5.16b, v16.16b
+; CHECK-NEXT:    cmeq v7.16b, v0.16b, v18.16b
+; CHECK-NEXT:    cmeq v16.16b, v0.16b, v19.16b
+; CHECK-NEXT:    cmeq v0.16b, v0.16b, v2.16b
+; CHECK-NEXT:    orr v1.16b, v1.16b, v4.16b
+; CHECK-NEXT:    orr v4.16b, v5.16b, v6.16b
+; CHECK-NEXT:    orr v5.16b, v7.16b, v16.16b
+; CHECK-NEXT:    orr v1.16b, v1.16b, v4.16b
+; CHECK-NEXT:    orr v0.16b, v5.16b, v0.16b
+; CHECK-NEXT:    orr v0.16b, v1.16b, v0.16b
+; CHECK-NEXT:    and v0.16b, v0.16b, v3.16b
+; CHECK-NEXT:    ret
+  %r = tail call <16 x i1> @llvm.experimental.vector.match(<16 x i8> %op1, <32 x i8> %op2, <16 x i1> %mask)
+  ret <16 x i1> %r
+}
+
+; Data types not supported by MATCH.
+; Note: The cases for SVE could be made tighter.
+
+define <vscale x 4 x i1> @match_nxv4xi32_v4i32(<vscale x 4 x i32> %op1, <4 x i32> %op2, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: match_nxv4xi32_v4i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    addvl sp, sp, #-1
+; CHECK-NEXT:    str p4, [sp, #7, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT:    .cfi_offset w29, -16
+; CHECK-NEXT:    mov w8, v1.s[1]
+; CHECK-NEXT:    fmov w10, s1
+; CHECK-NEXT:    mov w9, v1.s[2]
+; CHECK-NEXT:    mov w11, v1.s[3]
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    mov z2.s, w10
+; CHECK-NEXT:    mov z1.s, w8
+; CHECK-NEXT:    mov z3.s, w9
+; CHECK-NEXT:    cmpeq p3.s, p1/z, z0.s, z2.s
+; CHECK-NEXT:    cmpeq p2.s, p1/z, z0.s, z1.s
+; CHECK-NEXT:    mov z1.s, w11
+; CHECK-NEXT:    cmpeq p4.s, p1/z, z0.s, z3.s
+; CHECK-NEXT:    cmpeq p1.s, p1/z, z0.s, z1.s
+; CHECK-NEXT:    mov p2.b, p3/m, p3.b
+; CHECK-NEXT:    sel p2.b, p2, p2.b, p4.b
+; CHECK-NEXT:    ldr p4, [sp, #7, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    mov p1.b, p2/m, p2.b
+; CHECK-NEXT:    and p0.b, p1/z, p1.b, p0.b
+; CHECK-NEXT:    addvl sp, sp, #1
+; CHECK-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+  %r = tail call <vscale x 4 x i1> @llvm.experimental.vector.match(<vscale x 4 x i32> %op1, <4 x i32> %op2, <vscale x 4 x i1> %mask)
+  ret <vscale x 4 x i1> %r
+}
+
+define <vscale x 2 x i1> @match_nxv2xi64_v2i64(<vscale x 2 x i64> %op1, <2 x i64> %op2, <vscale x 2 x i1> %mask) #0 {
+; CHECK-LABEL: match_nxv2xi64_v2i64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov x8, v1.d[1]
+; CHECK-NEXT:    fmov x9, d1
+; CHECK-NEXT:    ptrue p1.d
+; CHECK-NEXT:    mov z2.d, x9
+; CHECK-NEXT:    mov z1.d, x8
+; CHECK-NEXT:    cmpeq p2.d, p1/z, z0.d, z1.d
+; CHECK-NEXT:    cmpeq p1.d, p1/z, z0.d, z2.d
+; CHECK-NEXT:    sel p1.b, p1, p1.b, p2.b
+; CHECK-NEXT:    and p0.b, p1/z, p1.b, p0.b
+; CHECK-NEXT:    ret
+  %r = tail call <vscale x 2 x i1> @llvm.experimental.vector.match(<vscale x 2 x i64> %op1, <2 x i64> %op2, <vscale x 2 x i1> %mask)
+  ret <vscale x 2 x i1> %r
+}
+
+define <4 x i1> @match_v4xi32_v4i32(<4 x i32> %op1, <4 x i32> %op2, <4 x i1> %mask) #0 {
+; CHECK-LABEL: match_v4xi32_v4i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    dup v3.4s, v1.s[1]
+; CHECK-NEXT:    dup v4.4s, v1.s[0]
+; CHECK-NEXT:    dup v5.4s, v1.s[2]
+; CHECK-NEXT:    dup v1.4s, v1.s[3]
+; CHECK-NEXT:    cmeq v3.4s, v0.4s, v3.4s
+; CHECK-NEXT:    cmeq v4.4s, v0.4s, v4.4s
+; CHECK-NEXT:    cmeq v5.4s, v0.4s, v5.4s
+; CHECK-NEXT:    cmeq v0.4s, v0.4s, v1.4s
+; CHECK-NEXT:    orr v1.16b, v4.16b, v3.16b
+; CHECK-NEXT:    orr v0.16b, v5.16b, v0.16b
+; CHECK-NEXT:    orr v0.16b, v1.16b, v0.16b
+; CHECK-NEXT:    xtn v0.4h, v0.4s
+; CHECK-NEXT:    and v0.8b, v0.8b, v2.8b
+; CHECK-NEXT:    ret
+  %r = tail call <4 x i1> @llvm.experimental.vector.match(<4 x i32> %op1, <4 x i32> %op2, <4 x i1> %mask)
+  ret <4 x i1> %r
+}
+
+define <2 x i1> @match_v2xi64_v2i64(<2 x i64> %op1, <2 x i64> %op2, <2 x i1> %mask) #0 {
+; CHECK-LABEL: match_v2xi64_v2i64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    dup v3.2d, v1.d[1]
+; CHECK-NEXT:    dup v1.2d, v1.d[0]
+; CHECK-NEXT:    cmeq v3.2d, v0.2d, v3.2d
+; CHECK-NEXT:    cmeq v0.2d, v0.2d, v1.2d
+; CHECK-NEXT:    orr v0.16b, v0.16b, v3.16b
+; CHECK-NEXT:    xtn v0.2s, v0.2d
+; CHECK-NEXT:    and v0.8b, v0.8b, v2.8b
+; CHECK-NEXT:    ret
+  %r = tail call <2 x i1> @llvm.experimental.vector.match(<2 x i64> %op1, <2 x i64> %op2, <2 x i1> %mask)
+  ret <2 x i1> %r
+}
+
 attributes #0 = { "target-features"="+sve2" }

>From 56dedaff6107a8625e6ab0d50ae1d8c84b1cb350 Mon Sep 17 00:00:00 2001
From: Ricardo Jesus <rjj at nvidia.com>
Date: Thu, 24 Oct 2024 09:36:10 -0700
Subject: [PATCH 3/4] Address review comments

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 41 +++++++------------
 .../AArch64/AArch64TargetTransformInfo.cpp    | 20 ++++-----
 2 files changed, 25 insertions(+), 36 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d9a21d7f60df2c..c51c079b9b61a5 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -6397,8 +6397,8 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
 
     // Wrap Op2 in a scalable register, and splat it if necessary.
     if (Op1VT.getVectorMinNumElements() == Op2VT.getVectorNumElements()) {
-      // If Op1 and Op2 have the same number of elements we can trivially
-      // wrapping Op2 in an SVE register.
+      // If Op1 and Op2 have the same number of elements we can trivially wrap
+      // Op2 in an SVE register.
       Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
       // If the result is scalable, we need to broadcast Op2 to a full SVE
       // register.
@@ -6408,32 +6408,21 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
     } else {
       // If Op1 and Op2 have different number of elements, we need to broadcast
       // Op2. Ideally we would use a AArch64ISD::DUPLANE* node for this
-      // similarly to the above, but unfortunately it seems we are missing some
+      // similarly to the above, but unfortunately we seem to be missing some
       // patterns for this. So, in alternative, we splat Op2 through a splat of
       // a scalable vector extract. This idiom, though a bit more verbose, is
       // supported and get us the MOV instruction we want.
-
-      // Some types we need. We'll use an integer type with `Op2BitWidth' bits
-      // to wrap Op2 and simulate the DUPLANE.
       unsigned Op2BitWidth = Op2VT.getFixedSizeInBits();
       MVT Op2IntVT = MVT::getIntegerVT(Op2BitWidth);
-      MVT Op2FixedVT = MVT::getVectorVT(Op2IntVT, 128 / Op2BitWidth);
-      EVT Op2ScalableVT = getContainerForFixedLengthVector(DAG, Op2FixedVT);
-      // Widen Op2 to a full 128-bit register. We need this to wrap Op2 in an
-      // SVE register before doing the extract and splat.
-      // It is unlikely we'll be widening from types other than v8i8 or v4i16,
-      // so in practice this loop will run for a single iteration.
-      while (Op2VT.getFixedSizeInBits() != 128) {
-        Op2VT = Op2VT.getDoubleNumVectorElementsVT(*DAG.getContext());
-        Op2 = DAG.getNode(ISD::CONCAT_VECTORS, dl, Op2VT, Op2,
-                          DAG.getUNDEF(Op2.getValueType()));
-      }
-      // Wrap Op2 in a scalable vector and do the splat of its 0-index lane.
-      Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
+      MVT Op2PromotedVT = MVT::getVectorVT(Op2IntVT, 128 / Op2BitWidth,
+                                           /*IsScalable=*/true);
+      SDValue Op2Widened = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, OpContainerVT,
+                                       DAG.getUNDEF(OpContainerVT), Op2,
+                                       DAG.getConstant(0, dl, MVT::i64));
       Op2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, Op2IntVT,
-                        DAG.getBitcast(Op2ScalableVT, Op2),
+                        DAG.getBitcast(Op2PromotedVT, Op2Widened),
                         DAG.getConstant(0, dl, MVT::i64));
-      Op2 = DAG.getSplatVector(Op2ScalableVT, dl, Op2);
+      Op2 = DAG.getSplatVector(Op2PromotedVT, dl, Op2);
       Op2 = DAG.getBitcast(OpContainerVT, Op2);
     }
 
@@ -6450,14 +6439,14 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
     Mask = DAG.getNode(ISD::ANY_EXTEND, dl, Op1VT, Mask);
     Mask = convertFixedMaskToScalableVector(Mask, DAG);
 
-    // Carry out the match and extract it.
+    // Carry out the match.
     SDValue Match = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl,
                                 Mask.getValueType(), ID, Mask, Op1, Op2);
-    Match = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, Op1VT,
-                        DAG.getNode(ISD::SIGN_EXTEND, dl, OpContainerVT, Match),
-                        DAG.getVectorIdxConstant(0, dl));
 
-    // Truncate and return the result.
+    // Extract and promote the match result (nxv16i1/nxv8i1) to ResVT
+    // (v16i8/v8i8).
+    Match = DAG.getNode(ISD::SIGN_EXTEND, dl, OpContainerVT, Match);
+    Match = convertFromScalableVector(DAG, Op1VT, Match);
     return DAG.getNode(ISD::TRUNCATE, dl, ResVT, Match);
   }
   }
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 4a68ecce3b654c..c572337fbb79aa 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -4073,18 +4073,18 @@ bool AArch64TTIImpl::isLegalToVectorizeReduction(
 }
 
 bool AArch64TTIImpl::hasVectorMatch(VectorType *VT, unsigned SearchSize) const {
-  // Check that (i) the target has SVE2 and SVE is available, (ii) `VT' is a
-  // legal type for MATCH, and (iii) the search vector can be broadcast
-  // efficently to a legal type.
-  //
+  // Check that the target has SVE2 and SVE is available.
+  if (!ST->hasSVE2() || !ST->isSVEAvailable())
+    return false;
+
+  // Check that `VT' is a legal type for MATCH, and that the search vector can
+  // be broadcast efficently if necessary.
   // Currently, we require the search vector to be 64-bit or 128-bit. In the
-  // future we can support more cases.
-  if (ST->hasSVE2() && ST->isSVEAvailable() &&
+  // future we can generalise this to other lengths.
+  unsigned MinEC = VT->getElementCount().getKnownMinValue();
+  if ((MinEC == 8 || MinEC == 16) &&
       VT->getPrimitiveSizeInBits().getKnownMinValue() == 128 &&
-      (VT->getElementCount().getKnownMinValue() == 8 ||
-       VT->getElementCount().getKnownMinValue() == 16) &&
-      (VT->getElementCount().getKnownMinValue() == SearchSize ||
-       VT->getElementCount().getKnownMinValue() / 2 == SearchSize))
+      (MinEC == SearchSize || MinEC / 2 == SearchSize))
     return true;
   return false;
 }

>From d09a0cc07ed13b58d022e63fed2cf7c61da8a42e Mon Sep 17 00:00:00 2001
From: Ricardo Jesus <rjj at nvidia.com>
Date: Thu, 31 Oct 2024 06:59:56 -0700
Subject: [PATCH 4/4] Move lowering to LowerVectorMatch and replace TTI hook
 with TLI

(And address other minor feedback.)
---
 llvm/docs/LangRef.rst                         |  14 +-
 .../llvm/Analysis/TargetTransformInfo.h       |  10 -
 .../llvm/Analysis/TargetTransformInfoImpl.h   |   4 -
 llvm/include/llvm/CodeGen/TargetLowering.h    |   7 +
 llvm/lib/Analysis/TargetTransformInfo.cpp     |   5 -
 .../SelectionDAG/SelectionDAGBuilder.cpp      |   7 +-
 llvm/lib/IR/Verifier.cpp                      |   2 +
 .../Target/AArch64/AArch64ISelLowering.cpp    | 175 +++++++++---------
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |   2 +
 .../AArch64/AArch64TargetTransformInfo.cpp    |  17 --
 .../AArch64/AArch64TargetTransformInfo.h      |   2 -
 .../AArch64/intrinsic-vector-match-sve2.ll    |  25 ++-
 12 files changed, 135 insertions(+), 135 deletions(-)

diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index aedb101c4af68c..53b4f746dae4f0 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -20049,8 +20049,7 @@ are undefined.
 Syntax:
 """""""
 
-This is an overloaded intrinsic. Support for specific vector types is target
-dependent.
+This is an overloaded intrinsic.
 
 ::
 
@@ -20068,16 +20067,19 @@ Arguments:
 The first argument is the search vector, the second argument the vector of
 elements we are searching for (i.e. for which we consider a match successful),
 and the third argument is a mask that controls which elements of the first
-argument are active.
+argument are active. The first two arguments must be vectors of matching
+integer element types. The first and third arguments and the result type must
+have matching element counts (fixed or scalable). The second argument must be a
+fixed vector, but its length may be different from the remaining arguments.
 
 Semantics:
 """"""""""
 
 The '``llvm.experimental.vector.match``' intrinsic compares each active element
 in the first argument against the elements of the second argument, placing
-``1`` in the corresponding element of the output vector if any comparison is
-successful, and ``0`` otherwise. Inactive elements in the mask are set to ``0``
-in the output.
+``1`` in the corresponding element of the output vector if any equality
+comparison is successful, and ``0`` otherwise. Inactive elements in the mask
+are set to ``0`` in the output.
 
 The second argument needs to be a fixed-length vector with the same element
 type as the first argument.
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index f47874bf0407d5..0459941fe05cdc 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1771,11 +1771,6 @@ class TargetTransformInfo {
   /// This should also apply to lowering for vector funnel shifts (rotates).
   bool isVectorShiftByScalarCheap(Type *Ty) const;
 
-  /// \returns True if the target has hardware support for vector match
-  /// operations between vectors of type `VT` and search vectors of `SearchSize`
-  /// elements, and false otherwise.
-  bool hasVectorMatch(VectorType *VT, unsigned SearchSize) const;
-
   struct VPLegalization {
     enum VPTransform {
       // keep the predicating parameter
@@ -2226,7 +2221,6 @@ class TargetTransformInfo::Concept {
                              SmallVectorImpl<Use *> &OpsToSink) const = 0;
 
   virtual bool isVectorShiftByScalarCheap(Type *Ty) const = 0;
-  virtual bool hasVectorMatch(VectorType *VT, unsigned SearchSize) const = 0;
   virtual VPLegalization
   getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
   virtual bool hasArmWideBranch(bool Thumb) const = 0;
@@ -3020,10 +3014,6 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.isVectorShiftByScalarCheap(Ty);
   }
 
-  bool hasVectorMatch(VectorType *VT, unsigned SearchSize) const override {
-    return Impl.hasVectorMatch(VT, SearchSize);
-  }
-
   VPLegalization
   getVPLegalizationStrategy(const VPIntrinsic &PI) const override {
     return Impl.getVPLegalizationStrategy(PI);
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 886acb5120330f..dbdfb4d8cdfa32 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -995,10 +995,6 @@ class TargetTransformInfoImplBase {
 
   bool isVectorShiftByScalarCheap(Type *Ty) const { return false; }
 
-  bool hasVectorMatch(VectorType *VT, unsigned SearchSize) const {
-    return false;
-  }
-
   TargetTransformInfo::VPLegalization
   getVPLegalizationStrategy(const VPIntrinsic &PI) const {
     return TargetTransformInfo::VPLegalization(
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 8e0cdc6f1a5e77..231cfd30b426c8 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -483,6 +483,13 @@ class TargetLoweringBase {
                                       bool ZeroIsPoison,
                                       const ConstantRange *VScaleRange) const;
 
+  /// Return true if the @llvm.experimental.vector.match intrinsic should be
+  /// expanded for vector type `VT' and search size `SearchSize' using generic
+  /// code in SelectionDAGBuilder.
+  virtual bool shouldExpandVectorMatch(EVT VT, unsigned SearchSize) const {
+    return true;
+  }
+
   // Return true if op(vecreduce(x), vecreduce(y)) should be reassociated to
   // vecreduce(op(x, y)) for the reduction opcode RedOpc.
   virtual bool shouldReassociateReduction(unsigned RedOpc, EVT VT) const {
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index ca7b258dc08d79..a47462b61e03b2 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1383,11 +1383,6 @@ bool TargetTransformInfo::isVectorShiftByScalarCheap(Type *Ty) const {
   return TTIImpl->isVectorShiftByScalarCheap(Ty);
 }
 
-bool TargetTransformInfo::hasVectorMatch(VectorType *VT,
-                                         unsigned SearchSize) const {
-  return TTIImpl->hasVectorMatch(VT, SearchSize);
-}
-
 TargetTransformInfo::Concept::~Concept() = default;
 
 TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {}
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index ca67f623f46258..b0b2a3bdd11988 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8165,14 +8165,9 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
     EVT ResVT = Mask.getValueType();
     unsigned SearchSize = Op2VT.getVectorNumElements();
 
-    LLVMContext &Ctx = *DAG.getContext();
-    const auto &TTI =
-        TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
-
     // If the target has native support for this vector match operation, lower
     // the intrinsic directly; otherwise, lower it below.
-    if (TTI.hasVectorMatch(cast<VectorType>(Op1VT.getTypeForEVT(Ctx)),
-                           SearchSize)) {
+    if (!TLI.shouldExpandVectorMatch(Op1VT, SearchSize)) {
       visitTargetIntrinsic(I, Intrinsic);
       return;
     }
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 09ec2ff4628d36..4f222269b5b878 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -6166,6 +6166,8 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
     Check(Op1Ty && Op2Ty && MaskTy, "Operands must be vectors.", &Call);
     Check(!isa<ScalableVectorType>(Op2Ty), "Second operand cannot be scalable.",
           &Call);
+    Check(Op1Ty->getElementType()->isIntegerTy(),
+          "First operand must be a vector of integers.", &Call);
     Check(Op1Ty->getElementType() == Op2Ty->getElementType(),
           "First two operands must have the same element type.", &Call);
     Check(Op1Ty->getElementCount() == MaskTy->getElementCount(),
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c51c079b9b61a5..d891b02da3d46c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2052,6 +2052,19 @@ bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
          VT != MVT::v4i1 && VT != MVT::v2i1;
 }
 
+bool AArch64TargetLowering::shouldExpandVectorMatch(EVT VT,
+                                                    unsigned SearchSize) const {
+  // MATCH is SVE2 and only available in non-streaming mode.
+  if (!Subtarget->hasSVE2() || !Subtarget->isSVEAvailable())
+    return true;
+  // Furthermore, we can only use it for 8-bit or 16-bit characters.
+  if (VT == MVT::nxv8i16 || VT == MVT::v8i16)
+    return SearchSize != 8;
+  if (VT == MVT::nxv16i8 || VT == MVT::v16i8 || VT == MVT::v8i8)
+    return SearchSize != 8 && SearchSize != 16;
+  return true;
+}
+
 void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
   assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
 
@@ -5761,6 +5774,84 @@ SDValue LowerSMELdrStr(SDValue N, SelectionDAG &DAG, bool IsLoad) {
                       DAG.getTargetConstant(ImmAddend, DL, MVT::i32)});
 }
 
+SDValue LowerVectorMatch(SDValue Op, SelectionDAG &DAG) {
+  SDLoc dl(Op);
+  SDValue ID =
+      DAG.getTargetConstant(Intrinsic::aarch64_sve_match, dl, MVT::i64);
+
+  auto Op1 = Op.getOperand(1);
+  auto Op2 = Op.getOperand(2);
+  auto Mask = Op.getOperand(3);
+
+  EVT Op1VT = Op1.getValueType();
+  EVT Op2VT = Op2.getValueType();
+  EVT ResVT = Op.getValueType();
+
+  assert((Op1VT.getVectorElementType() == MVT::i8 ||
+          Op1VT.getVectorElementType() == MVT::i16) &&
+         "Expected 8-bit or 16-bit characters.");
+
+  // Scalable vector type used to wrap operands.
+  // A single container is enough for both operands because ultimately the
+  // operands will have to be wrapped to the same type (nxv16i8 or nxv8i16).
+  EVT OpContainerVT = Op1VT.isScalableVector()
+                          ? Op1VT
+                          : getContainerForFixedLengthVector(DAG, Op1VT);
+
+  // Wrap Op2 in a scalable register, and splat it if necessary.
+  if (Op1VT.getVectorMinNumElements() == Op2VT.getVectorNumElements()) {
+    // If Op1 and Op2 have the same number of elements we can trivially wrap
+    // Op2 in an SVE register.
+    Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
+    // If the result is scalable, we need to broadcast Op2 to a full SVE
+    // register.
+    if (ResVT.isScalableVector())
+      Op2 = DAG.getNode(AArch64ISD::DUPLANE128, dl, OpContainerVT, Op2,
+                        DAG.getTargetConstant(0, dl, MVT::i64));
+  } else {
+    // If Op1 and Op2 have different number of elements, we need to broadcast
+    // Op2. Ideally we would use a AArch64ISD::DUPLANE* node for this
+    // similarly to the above, but unfortunately we seem to be missing some
+    // patterns for this. So, in alternative, we splat Op2 through a splat of
+    // a scalable vector extract. This idiom, though a bit more verbose, is
+    // supported and get us the MOV instruction we want.
+    unsigned Op2BitWidth = Op2VT.getFixedSizeInBits();
+    MVT Op2IntVT = MVT::getIntegerVT(Op2BitWidth);
+    MVT Op2PromotedVT = MVT::getVectorVT(Op2IntVT, 128 / Op2BitWidth,
+                                         /*IsScalable=*/true);
+    SDValue Op2Widened = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, OpContainerVT,
+                                     DAG.getUNDEF(OpContainerVT), Op2,
+                                     DAG.getConstant(0, dl, MVT::i64));
+    Op2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, Op2IntVT,
+                      DAG.getBitcast(Op2PromotedVT, Op2Widened),
+                      DAG.getConstant(0, dl, MVT::i64));
+    Op2 = DAG.getSplatVector(Op2PromotedVT, dl, Op2);
+    Op2 = DAG.getBitcast(OpContainerVT, Op2);
+  }
+
+  // If the result is scalable, we just need to carry out the MATCH.
+  if (ResVT.isScalableVector())
+    return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, ResVT, ID, Mask, Op1, Op2);
+
+  // If the result is fixed, we can still use MATCH but we need to wrap the
+  // first operand and the mask in scalable vectors before doing so.
+
+  // Wrap the operands.
+  Op1 = convertToScalableVector(DAG, OpContainerVT, Op1);
+  Mask = DAG.getNode(ISD::SIGN_EXTEND, dl, Op1VT, Mask);
+  Mask = convertFixedMaskToScalableVector(Mask, DAG);
+
+  // Carry out the match.
+  SDValue Match = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, Mask.getValueType(),
+                              ID, Mask, Op1, Op2);
+
+  // Extract and promote the match result (nxv16i1/nxv8i1) to ResVT
+  // (v16i8/v8i8).
+  Match = DAG.getNode(ISD::SIGN_EXTEND, dl, OpContainerVT, Match);
+  Match = convertFromScalableVector(DAG, Op1VT, Match);
+  return DAG.getNode(ISD::TRUNCATE, dl, ResVT, Match);
+}
+
 SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op,
                                                    SelectionDAG &DAG) const {
   unsigned IntNo = Op.getConstantOperandVal(1);
@@ -6365,89 +6456,7 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
     return DAG.getZExtOrTrunc(NewCttzElts, dl, Op.getValueType());
   }
   case Intrinsic::experimental_vector_match: {
-    SDValue ID =
-        DAG.getTargetConstant(Intrinsic::aarch64_sve_match, dl, MVT::i64);
-
-    auto Op1 = Op.getOperand(1);
-    auto Op2 = Op.getOperand(2);
-    auto Mask = Op.getOperand(3);
-
-    EVT Op1VT = Op1.getValueType();
-    EVT Op2VT = Op2.getValueType();
-    EVT ResVT = Op.getValueType();
-
-    assert((Op1VT.getVectorElementType() == MVT::i8 ||
-            Op1VT.getVectorElementType() == MVT::i16) &&
-           "Expected 8-bit or 16-bit characters.");
-    assert(Op1VT.getVectorElementType() == Op2VT.getVectorElementType() &&
-           "Operand type mismatch.");
-    assert(!Op2VT.isScalableVector() && "Search vector cannot be scalable.");
-
-    // Note: Currently Op1 needs to be v16i8, v8i16, or the scalable versions.
-    // In the future we could support other types (e.g. v8i8).
-    assert(Op1VT.getSizeInBits().getKnownMinValue() == 128 &&
-           "Unsupported first operand type.");
-
-    // Scalable vector type used to wrap operands.
-    // A single container is enough for both operands because ultimately the
-    // operands will have to be wrapped to the same type (nxv16i8 or nxv8i16).
-    EVT OpContainerVT = Op1VT.isScalableVector()
-                            ? Op1VT
-                            : getContainerForFixedLengthVector(DAG, Op1VT);
-
-    // Wrap Op2 in a scalable register, and splat it if necessary.
-    if (Op1VT.getVectorMinNumElements() == Op2VT.getVectorNumElements()) {
-      // If Op1 and Op2 have the same number of elements we can trivially wrap
-      // Op2 in an SVE register.
-      Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
-      // If the result is scalable, we need to broadcast Op2 to a full SVE
-      // register.
-      if (ResVT.isScalableVector())
-        Op2 = DAG.getNode(AArch64ISD::DUPLANE128, dl, OpContainerVT, Op2,
-                          DAG.getTargetConstant(0, dl, MVT::i64));
-    } else {
-      // If Op1 and Op2 have different number of elements, we need to broadcast
-      // Op2. Ideally we would use a AArch64ISD::DUPLANE* node for this
-      // similarly to the above, but unfortunately we seem to be missing some
-      // patterns for this. So, in alternative, we splat Op2 through a splat of
-      // a scalable vector extract. This idiom, though a bit more verbose, is
-      // supported and get us the MOV instruction we want.
-      unsigned Op2BitWidth = Op2VT.getFixedSizeInBits();
-      MVT Op2IntVT = MVT::getIntegerVT(Op2BitWidth);
-      MVT Op2PromotedVT = MVT::getVectorVT(Op2IntVT, 128 / Op2BitWidth,
-                                           /*IsScalable=*/true);
-      SDValue Op2Widened = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, OpContainerVT,
-                                       DAG.getUNDEF(OpContainerVT), Op2,
-                                       DAG.getConstant(0, dl, MVT::i64));
-      Op2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, Op2IntVT,
-                        DAG.getBitcast(Op2PromotedVT, Op2Widened),
-                        DAG.getConstant(0, dl, MVT::i64));
-      Op2 = DAG.getSplatVector(Op2PromotedVT, dl, Op2);
-      Op2 = DAG.getBitcast(OpContainerVT, Op2);
-    }
-
-    // If the result is scalable, we just need to carry out the MATCH.
-    if (ResVT.isScalableVector())
-      return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, ResVT, ID, Mask, Op1,
-                         Op2);
-
-    // If the result is fixed, we can still use MATCH but we need to wrap the
-    // first operand and the mask in scalable vectors before doing so.
-
-    // Wrap the operands.
-    Op1 = convertToScalableVector(DAG, OpContainerVT, Op1);
-    Mask = DAG.getNode(ISD::ANY_EXTEND, dl, Op1VT, Mask);
-    Mask = convertFixedMaskToScalableVector(Mask, DAG);
-
-    // Carry out the match.
-    SDValue Match = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl,
-                                Mask.getValueType(), ID, Mask, Op1, Op2);
-
-    // Extract and promote the match result (nxv16i1/nxv8i1) to ResVT
-    // (v16i8/v8i8).
-    Match = DAG.getNode(ISD::SIGN_EXTEND, dl, OpContainerVT, Match);
-    Match = convertFromScalableVector(DAG, Op1VT, Match);
-    return DAG.getNode(ISD::TRUNCATE, dl, ResVT, Match);
+    return LowerVectorMatch(Op, DAG);
   }
   }
 }
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 160cd18ca53b32..d7eaf2e9b724f6 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -993,6 +993,8 @@ class AArch64TargetLowering : public TargetLowering {
 
   bool shouldExpandCttzElements(EVT VT) const override;
 
+  bool shouldExpandVectorMatch(EVT VT, unsigned SearchSize) const override;
+
   /// If a change in streaming mode is required on entry to/return from a
   /// function call it emits and returns the corresponding SMSTART or SMSTOP
   /// node. \p Condition should be one of the enum values from
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index c572337fbb79aa..ff3c69f7e10c66 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -4072,23 +4072,6 @@ bool AArch64TTIImpl::isLegalToVectorizeReduction(
   }
 }
 
-bool AArch64TTIImpl::hasVectorMatch(VectorType *VT, unsigned SearchSize) const {
-  // Check that the target has SVE2 and SVE is available.
-  if (!ST->hasSVE2() || !ST->isSVEAvailable())
-    return false;
-
-  // Check that `VT' is a legal type for MATCH, and that the search vector can
-  // be broadcast efficently if necessary.
-  // Currently, we require the search vector to be 64-bit or 128-bit. In the
-  // future we can generalise this to other lengths.
-  unsigned MinEC = VT->getElementCount().getKnownMinValue();
-  if ((MinEC == 8 || MinEC == 16) &&
-      VT->getPrimitiveSizeInBits().getKnownMinValue() == 128 &&
-      (MinEC == SearchSize || MinEC / 2 == SearchSize))
-    return true;
-  return false;
-}
-
 InstructionCost
 AArch64TTIImpl::getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty,
                                        FastMathFlags FMF,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 580bf5e79c3da1..1d09d67f6ec9e3 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -392,8 +392,6 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
     return ST->hasSVE();
   }
 
-  bool hasVectorMatch(VectorType *VT, unsigned SearchSize) const;
-
   InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
                                              std::optional<FastMathFlags> FMF,
                                              TTI::TargetCostKind CostKind);
diff --git a/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll b/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll
index e991bbc80962b5..1a004f4a574ea3 100644
--- a/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll
+++ b/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll
@@ -138,11 +138,12 @@ define <16 x i1> @match_v16i8_v4i8(<16 x i8> %op1, <4 x i8> %op2, <16 x i1> %mas
 define <16 x i1> @match_v16i8_v8i8(<16 x i8> %op1, <8 x i8> %op2, <16 x i1> %mask) #0 {
 ; CHECK-LABEL: match_v16i8_v8i8:
 ; CHECK:       // %bb.0:
+; CHECK-NEXT:    shl v2.16b, v2.16b, #7
 ; CHECK-NEXT:    ptrue p0.b, vl16
-; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
 ; CHECK-NEXT:    // kill: def $d1 killed $d1 def $z1
 ; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
 ; CHECK-NEXT:    mov z1.d, d1
+; CHECK-NEXT:    cmlt v2.16b, v2.16b, #0
 ; CHECK-NEXT:    cmpne p0.b, p0/z, z2.b, #0
 ; CHECK-NEXT:    match p0.b, p0/z, z0.b, z1.b
 ; CHECK-NEXT:    mov z0.b, p0/z, #-1 // =0xffffffffffffffff
@@ -155,10 +156,11 @@ define <16 x i1> @match_v16i8_v8i8(<16 x i8> %op1, <8 x i8> %op2, <16 x i1> %mas
 define <16 x i1> @match_v16i8_v16i8(<16 x i8> %op1, <16 x i8> %op2, <16 x i1> %mask) #0 {
 ; CHECK-LABEL: match_v16i8_v16i8:
 ; CHECK:       // %bb.0:
+; CHECK-NEXT:    shl v2.16b, v2.16b, #7
 ; CHECK-NEXT:    ptrue p0.b, vl16
-; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
 ; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
 ; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    cmlt v2.16b, v2.16b, #0
 ; CHECK-NEXT:    cmpne p0.b, p0/z, z2.b, #0
 ; CHECK-NEXT:    match p0.b, p0/z, z0.b, z1.b
 ; CHECK-NEXT:    mov z0.b, p0/z, #-1 // =0xffffffffffffffff
@@ -168,6 +170,23 @@ define <16 x i1> @match_v16i8_v16i8(<16 x i8> %op1, <16 x i8> %op2, <16 x i1> %m
   ret <16 x i1> %r
 }
 
+define <8 x i1> @match_v8i8_v8i8(<8 x i8> %op1, <8 x i8> %op2, <8 x i1> %mask) #0 {
+; CHECK-LABEL: match_v8i8_v8i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    shl v2.8b, v2.8b, #7
+; CHECK-NEXT:    ptrue p0.b, vl8
+; CHECK-NEXT:    // kill: def $d1 killed $d1 def $z1
+; CHECK-NEXT:    // kill: def $d0 killed $d0 def $z0
+; CHECK-NEXT:    cmlt v2.8b, v2.8b, #0
+; CHECK-NEXT:    cmpne p0.b, p0/z, z2.b, #0
+; CHECK-NEXT:    match p0.b, p0/z, z0.b, z1.b
+; CHECK-NEXT:    mov z0.b, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
+; CHECK-NEXT:    ret
+  %r = tail call <8 x i1> @llvm.experimental.vector.match(<8 x i8> %op1, <8 x i8> %op2, <8 x i1> %mask)
+  ret <8 x i1> %r
+}
+
 define <vscale x 8 x i1> @match_nxv8i16_v8i16(<vscale x 8 x i16> %op1, <8 x i16> %op2, <vscale x 8 x i1> %mask) #0 {
 ; CHECK-LABEL: match_nxv8i16_v8i16:
 ; CHECK:       // %bb.0:
@@ -186,6 +205,8 @@ define <8 x i1> @match_v8i16(<8 x i16> %op1, <8 x i16> %op2, <8 x i1> %mask) #0
 ; CHECK-NEXT:    ptrue p0.h, vl8
 ; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
 ; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    shl v2.8h, v2.8h, #15
+; CHECK-NEXT:    cmlt v2.8h, v2.8h, #0
 ; CHECK-NEXT:    cmpne p0.h, p0/z, z2.h, #0
 ; CHECK-NEXT:    match p0.h, p0/z, z0.h, z1.h
 ; CHECK-NEXT:    mov z0.h, p0/z, #-1 // =0xffffffffffffffff



More information about the llvm-commits mailing list