[llvm] [AArch64] Add @llvm.experimental.vector.match (PR #101974)

via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 5 06:21:33 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-backend-aarch64

Author: Ricardo Jesus (rj-jesus)

<details>
<summary>Changes</summary>

This patch introduces an experimental intrinsic for matching the elements of one vector against the elements of another.

For AArch64 targets that support SVE2, the intrinsic lowers to a MATCH instruction for supported fixed and scalar vector types. The lowering can be improved a bit for fixed-length vectors but this can be addressed later.

---
Full diff: https://github.com/llvm/llvm-project/pull/101974.diff


10 Files Affected:

- (modified) llvm/docs/LangRef.rst (+45) 
- (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+9) 
- (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+2) 
- (modified) llvm/include/llvm/IR/Intrinsics.td (+10) 
- (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+5) 
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+9) 
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+46) 
- (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+12) 
- (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+2) 
- (added) llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll (+57) 


``````````diff
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index b17e3c828ed3d..dd9851d1af078 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -19637,6 +19637,51 @@ are undefined.
     }
 
 
+'``llvm.experimental.vector.match.*``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+This is an overloaded intrinsic. Support for specific vector types is target
+dependent.
+
+::
+
+    declare <<n> x i1> @llvm.experimental.vector.match(<<n> x <ty>> %op1, <<n> x <ty>> %op2, <<n> x i1> %mask, i32 <segsize>)
+    declare <vscale x <n> x i1> @llvm.experimental.vector.match(<vscale x <n> x <ty>> %op1, <vscale x <n> x <ty>> %op2, <vscale x <n> x i1> %mask, i32 <segsize>)
+
+Overview:
+"""""""""
+
+Find elements of the first argument matching any elements of the second.
+
+Arguments:
+""""""""""
+
+The first argument is the search vector, the second argument is the vector of
+elements we are searching for (i.e. for which we consider a match successful),
+and the third argument is a mask that controls which elements of the first
+argument are active. The fourth argument is an immediate that sets the segment
+size for the search window.
+
+Semantics:
+""""""""""
+
+The '``llvm.experimental.vector.match``' intrinsic compares each element in the
+first argument against potentially several elements of the second, placing
+``1`` in the corresponding element of the output vector if any comparison is
+successful, and ``0`` otherwise. Inactive elements in the mask are set to ``0``
+in the output. The segment size controls the number of elements of the second
+argument that are compared against.
+
+For example, for vectors with 16 elements, if ``segsize = 16`` then each
+element of the first argument is compared against all 16 elements of the second
+argument; but if ``segsize = 4``, then each of the first four elements of the
+first argument is compared against the first four elements of the second
+argument, each of the second four elements of the first argument is compared
+against the second four elements of the second argument, and so forth.
+
 Matrix Intrinsics
 -----------------
 
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 38e8b9da21397..786c13a177ccf 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1746,6 +1746,10 @@ class TargetTransformInfo {
   bool hasActiveVectorLength(unsigned Opcode, Type *DataType,
                              Align Alignment) const;
 
+  /// \returns Returns true if the target supports vector match operations for
+  /// the vector type `VT` using a segment size of `SegSize`.
+  bool hasVectorMatch(VectorType *VT, unsigned SegSize) const;
+
   struct VPLegalization {
     enum VPTransform {
       // keep the predicating parameter
@@ -2184,6 +2188,7 @@ class TargetTransformInfo::Concept {
   virtual bool supportsScalableVectors() const = 0;
   virtual bool hasActiveVectorLength(unsigned Opcode, Type *DataType,
                                      Align Alignment) const = 0;
+  virtual bool hasVectorMatch(VectorType *VT, unsigned SegSize) const = 0;
   virtual VPLegalization
   getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
   virtual bool hasArmWideBranch(bool Thumb) const = 0;
@@ -2952,6 +2957,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.hasActiveVectorLength(Opcode, DataType, Alignment);
   }
 
+  bool hasVectorMatch(VectorType *VT, unsigned SegSize) const override {
+    return Impl.hasVectorMatch(VT, SegSize);
+  }
+
   VPLegalization
   getVPLegalizationStrategy(const VPIntrinsic &PI) const override {
     return Impl.getVPLegalizationStrategy(PI);
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index d208a710bb27f..36621861ab8c8 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -958,6 +958,8 @@ class TargetTransformInfoImplBase {
     return false;
   }
 
+  bool hasVectorMatch(VectorType *VT, unsigned SegSize) const { return false; }
+
   TargetTransformInfo::VPLegalization
   getVPLegalizationStrategy(const VPIntrinsic &PI) const {
     return TargetTransformInfo::VPLegalization(
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index b4e758136b39f..f6d77aa596f60 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1892,6 +1892,16 @@ def int_experimental_vector_histogram_add : DefaultAttrsIntrinsic<[],
                                LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>], // Mask
                              [ IntrArgMemOnly ]>;
 
+// Experimental match
+def int_experimental_vector_match : DefaultAttrsIntrinsic<
+                             [ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty> ],
+                             [ llvm_anyvector_ty,
+                               LLVMMatchType<0>,
+                               LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,  // Mask
+                               llvm_i32_ty ],  // Segment size
+                             [ IntrNoMem, IntrNoSync, IntrWillReturn,
+                               ImmArg<ArgIndex<3>> ]>;
+
 // Operators
 let IntrProperties = [IntrNoMem, IntrNoSync, IntrWillReturn] in {
   // Integer arithmetic
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index dcde78925bfa9..d8314af0537fe 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1352,6 +1352,11 @@ bool TargetTransformInfo::hasActiveVectorLength(unsigned Opcode, Type *DataType,
   return TTIImpl->hasActiveVectorLength(Opcode, DataType, Alignment);
 }
 
+bool TargetTransformInfo::hasVectorMatch(VectorType *VT,
+                                         unsigned SegSize) const {
+  return TTIImpl->hasVectorMatch(VT, SegSize);
+}
+
 TargetTransformInfo::Concept::~Concept() = default;
 
 TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {}
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 9d617c7acd13c..9cb7d65975b9f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8096,6 +8096,15 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
              DAG.getNode(ISD::EXTRACT_SUBVECTOR, sdl, ResultVT, Vec, Index));
     return;
   }
+  case Intrinsic::experimental_vector_match: {
+    auto *VT = dyn_cast<VectorType>(I.getOperand(0)->getType());
+    auto SegmentSize = cast<ConstantInt>(I.getOperand(3))->getLimitedValue();
+    const auto &TTI =
+        TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
+    assert(VT && TTI.hasVectorMatch(VT, SegmentSize) && "Unsupported type!");
+    visitTargetIntrinsic(I, Intrinsic);
+    return;
+  }
   case Intrinsic::vector_reverse:
     visitVectorReverse(I);
     return;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7704321a0fc3a..050807142fc0a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -6106,6 +6106,51 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
         DAG.getNode(AArch64ISD::CTTZ_ELTS, dl, MVT::i64, CttzOp);
     return DAG.getZExtOrTrunc(NewCttzElts, dl, Op.getValueType());
   }
+  case Intrinsic::experimental_vector_match: {
+    SDValue ID =
+        DAG.getTargetConstant(Intrinsic::aarch64_sve_match, dl, MVT::i64);
+
+    auto Op1 = Op.getOperand(1);
+    auto Op2 = Op.getOperand(2);
+    auto Mask = Op.getOperand(3);
+    auto SegmentSize =
+        cast<ConstantSDNode>(Op.getOperand(4))->getLimitedValue();
+
+    EVT VT = Op.getValueType();
+    auto MinNumElts = VT.getVectorMinNumElements();
+
+    assert(Op1.getValueType() == Op2.getValueType() && "Type mismatch.");
+    assert(Op1.getValueSizeInBits().getKnownMinValue() == 128 &&
+           "Custom lower only works on 128-bit segments.");
+    assert((Op1.getValueType().getVectorElementType() == MVT::i8  ||
+            Op1.getValueType().getVectorElementType() == MVT::i16) &&
+           "Custom lower only supports 8-bit or 16-bit characters.");
+    assert(SegmentSize == MinNumElts && "Custom lower needs segment size to "
+                                        "match minimum number of elements.");
+
+    if (VT.isScalableVector())
+      return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, ID, Mask, Op1, Op2);
+
+    // We can use the SVE2 match instruction to lower this intrinsic by
+    // converting the operands to scalable vectors, doing a match, and then
+    // extracting a fixed-width subvector from the scalable vector.
+
+    EVT OpVT = Op1.getValueType();
+    EVT OpContainerVT = getContainerForFixedLengthVector(DAG, OpVT);
+    EVT MatchVT = OpContainerVT.changeElementType(MVT::i1);
+
+    auto ScalableOp1 = convertToScalableVector(DAG, OpContainerVT, Op1);
+    auto ScalableOp2 = convertToScalableVector(DAG, OpContainerVT, Op2);
+    auto ScalableMask = DAG.getNode(ISD::SIGN_EXTEND, dl, OpVT, Mask);
+    ScalableMask = convertFixedMaskToScalableVector(ScalableMask, DAG);
+
+    SDValue Match = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, MatchVT, ID,
+                                ScalableMask, ScalableOp1, ScalableOp2);
+
+    return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OpVT,
+                       DAG.getNode(ISD::SIGN_EXTEND, dl, OpContainerVT, Match),
+                       DAG.getVectorIdxConstant(0, dl));
+  }
   }
 }
 
@@ -26544,6 +26589,7 @@ void AArch64TargetLowering::ReplaceNodeResults(
       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
       return;
     }
+    case Intrinsic::experimental_vector_match:
     case Intrinsic::get_active_lane_mask: {
       if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i1)
         return;
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index b8f19fa87e2ab..806dc856c5862 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3835,6 +3835,18 @@ bool AArch64TTIImpl::isLegalToVectorizeReduction(
   }
 }
 
+bool AArch64TTIImpl::hasVectorMatch(VectorType *VT, unsigned SegSize) const {
+  // Check that the target has SVE2 (and SVE is available), that `VT' is a
+  // legal type for MATCH, and that the segment size is 128-bit.
+  if (ST->hasSVE2() && ST->isSVEAvailable() &&
+      VT->getPrimitiveSizeInBits().getKnownMinValue() == 128 &&
+      VT->getElementCount().getKnownMinValue() == SegSize &&
+      (VT->getElementCount().getKnownMinValue() ==  8 ||
+       VT->getElementCount().getKnownMinValue() == 16))
+    return true;
+  return false;
+}
+
 InstructionCost
 AArch64TTIImpl::getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty,
                                        FastMathFlags FMF,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index a9189fd53f40b..6ad21a9e0a77a 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -391,6 +391,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
     return ST->hasSVE();
   }
 
+  bool hasVectorMatch(VectorType *VT, unsigned SegSize) const;
+
   InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
                                              std::optional<FastMathFlags> FMF,
                                              TTI::TargetCostKind CostKind);
diff --git a/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll b/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll
new file mode 100644
index 0000000000000..0df92dfa80000
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll
@@ -0,0 +1,57 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 3
+; RUN: llc -mtriple=aarch64 < %s -o - | FileCheck %s
+
+define <vscale x 16 x i1> @match_nxv16i8(<vscale x 16 x i8> %op1, <vscale x 16 x i8> %op2, <vscale x 16 x i1> %mask) #0 {
+; CHECK-LABEL: match_nxv16i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    match p0.b, p0/z, z0.b, z1.b
+; CHECK-NEXT:    ret
+  %r = tail call <vscale x 16 x i1> @llvm.experimental.vector.match(<vscale x 16 x i8> %op1, <vscale x 16 x i8> %op2, <vscale x 16 x i1> %mask, i32 16)
+  ret <vscale x 16 x i1> %r
+}
+
+define <vscale x 8 x i1> @match_nxv8i16(<vscale x 8 x i16> %op1, <vscale x 8 x i16> %op2, <vscale x 8 x i1> %mask) #0 {
+; CHECK-LABEL: match_nxv8i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    match p0.h, p0/z, z0.h, z1.h
+; CHECK-NEXT:    ret
+  %r = tail call <vscale x 8 x i1> @llvm.experimental.vector.match(<vscale x 8 x i16> %op1, <vscale x 8 x i16> %op2, <vscale x 8 x i1> %mask, i32 8)
+  ret <vscale x 8 x i1> %r
+}
+
+define <16 x i1> @match_v16i8(<16 x i8> %op1, <16 x i8> %op2, <16 x i1> %mask) #0 {
+; CHECK-LABEL: match_v16i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    shl v2.16b, v2.16b, #7
+; CHECK-NEXT:    ptrue p0.b, vl16
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    cmlt v2.16b, v2.16b, #0
+; CHECK-NEXT:    cmpne p0.b, p0/z, z2.b, #0
+; CHECK-NEXT:    match p0.b, p0/z, z0.b, z1.b
+; CHECK-NEXT:    mov z0.b, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+  %r = tail call <16 x i1> @llvm.experimental.vector.match(<16 x i8> %op1, <16 x i8> %op2, <16 x i1> %mask, i32 16)
+  ret <16 x i1> %r
+}
+
+define <8 x i1> @match_v8i16(<8 x i16> %op1, <8 x i16> %op2, <8 x i1> %mask) #0 {
+; CHECK-LABEL: match_v8i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ushll v2.8h, v2.8b, #0
+; CHECK-NEXT:    ptrue p0.h, vl8
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    shl v2.8h, v2.8h, #15
+; CHECK-NEXT:    cmlt v2.8h, v2.8h, #0
+; CHECK-NEXT:    cmpne p0.h, p0/z, z2.h, #0
+; CHECK-NEXT:    match p0.h, p0/z, z0.h, z1.h
+; CHECK-NEXT:    mov z0.h, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    xtn v0.8b, v0.8h
+; CHECK-NEXT:    ret
+  %r = tail call <8 x i1> @llvm.experimental.vector.match(<8 x i16> %op1, <8 x i16> %op2, <8 x i1> %mask, i32 8)
+  ret <8 x i1> %r
+}
+
+attributes #0 = { "target-features"="+sve2" }

``````````

</details>


https://github.com/llvm/llvm-project/pull/101974


More information about the llvm-commits mailing list