[llvm] [AArch64] Add @llvm.experimental.vector.match (PR #101974)

Ricardo Jesus via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 25 02:42:43 PDT 2024


================
@@ -6379,42 +6379,86 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
     assert((Op1VT.getVectorElementType() == MVT::i8 ||
             Op1VT.getVectorElementType() == MVT::i16) &&
            "Expected 8-bit or 16-bit characters.");
-    assert(!Op2VT.isScalableVector() && "Search vector cannot be scalable.");
     assert(Op1VT.getVectorElementType() == Op2VT.getVectorElementType() &&
            "Operand type mismatch.");
-    assert(Op1VT.getVectorMinNumElements() == Op2VT.getVectorNumElements() &&
-           "Invalid operands.");
-
-    // Wrap the search vector in a scalable vector.
-    EVT OpContainerVT = getContainerForFixedLengthVector(DAG, Op2VT);
-    Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
-
-    // If the result is scalable, we need to broadbast the search vector across
-    // the SVE register and then carry out the MATCH.
-    if (ResVT.isScalableVector()) {
-      Op2 = DAG.getNode(AArch64ISD::DUPLANE128, dl, OpContainerVT, Op2,
-                        DAG.getTargetConstant(0, dl, MVT::i64));
+    assert(!Op2VT.isScalableVector() && "Search vector cannot be scalable.");
+
+    // Note: Currently Op1 needs to be v16i8, v8i16, or the scalable versions.
+    // In the future we could support other types (e.g. v8i8).
+    assert(Op1VT.getSizeInBits().getKnownMinValue() == 128 &&
+           "Unsupported first operand type.");
+
+    // Scalable vector type used to wrap operands.
+    // A single container is enough for both operands because ultimately the
+    // operands will have to be wrapped to the same type (nxv16i8 or nxv8i16).
+    EVT OpContainerVT = Op1VT.isScalableVector()
+                            ? Op1VT
+                            : getContainerForFixedLengthVector(DAG, Op1VT);
+
+    // Wrap Op2 in a scalable register, and splat it if necessary.
+    if (Op1VT.getVectorMinNumElements() == Op2VT.getVectorNumElements()) {
+      // If Op1 and Op2 have the same number of elements we can trivially
+      // wrapping Op2 in an SVE register.
+      Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
+      // If the result is scalable, we need to broadcast Op2 to a full SVE
+      // register.
+      if (ResVT.isScalableVector())
+        Op2 = DAG.getNode(AArch64ISD::DUPLANE128, dl, OpContainerVT, Op2,
+                          DAG.getTargetConstant(0, dl, MVT::i64));
+    } else {
+      // If Op1 and Op2 have different number of elements, we need to broadcast
+      // Op2. Ideally we would use a AArch64ISD::DUPLANE* node for this
+      // similarly to the above, but unfortunately it seems we are missing some
+      // patterns for this. So, in alternative, we splat Op2 through a splat of
+      // a scalable vector extract. This idiom, though a bit more verbose, is
+      // supported and get us the MOV instruction we want.
+
+      // Some types we need. We'll use an integer type with `Op2BitWidth' bits
+      // to wrap Op2 and simulate the DUPLANE.
+      unsigned Op2BitWidth = Op2VT.getFixedSizeInBits();
+      MVT Op2IntVT = MVT::getIntegerVT(Op2BitWidth);
+      MVT Op2FixedVT = MVT::getVectorVT(Op2IntVT, 128 / Op2BitWidth);
+      EVT Op2ScalableVT = getContainerForFixedLengthVector(DAG, Op2FixedVT);
+      // Widen Op2 to a full 128-bit register. We need this to wrap Op2 in an
+      // SVE register before doing the extract and splat.
+      // It is unlikely we'll be widening from types other than v8i8 or v4i16,
+      // so in practice this loop will run for a single iteration.
+      while (Op2VT.getFixedSizeInBits() != 128) {
+        Op2VT = Op2VT.getDoubleNumVectorElementsVT(*DAG.getContext());
+        Op2 = DAG.getNode(ISD::CONCAT_VECTORS, dl, Op2VT, Op2,
+                          DAG.getUNDEF(Op2.getValueType()));
+      }
+      // Wrap Op2 in a scalable vector and do the splat of its 0-index lane.
+      Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
+      Op2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, Op2IntVT,
+                        DAG.getBitcast(Op2ScalableVT, Op2),
+                        DAG.getConstant(0, dl, MVT::i64));
+      Op2 = DAG.getSplatVector(Op2ScalableVT, dl, Op2);
+      Op2 = DAG.getBitcast(OpContainerVT, Op2);
+    }
+
+    // If the result is scalable, we just need to carry out the MATCH.
+    if (ResVT.isScalableVector())
       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, ResVT, ID, Mask, Op1,
                          Op2);
-    }
 
     // If the result is fixed, we can still use MATCH but we need to wrap the
     // first operand and the mask in scalable vectors before doing so.
-    EVT MatchVT = OpContainerVT.changeElementType(MVT::i1);
 
     // Wrap the operands.
     Op1 = convertToScalableVector(DAG, OpContainerVT, Op1);
     Mask = DAG.getNode(ISD::ANY_EXTEND, dl, Op1VT, Mask);
     Mask = convertFixedMaskToScalableVector(Mask, DAG);
 
-    // Carry out the match.
-    SDValue Match =
-        DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, MatchVT, ID, Mask, Op1, Op2);
+    // Carry out the match and extract it.
+    SDValue Match = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl,
+                                Mask.getValueType(), ID, Mask, Op1, Op2);
+    Match = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, Op1VT,
----------------
rj-jesus wrote:

I think the problem here is that you can't extract from `Match` directly because, at this point, `ResVT` has been legalised to something like v16i8 or v8i8. I've tried to tidy this up a bit with `convertFromScalableVector`—please let me know if you think that's better or if you have any better suggestions!

https://github.com/llvm/llvm-project/pull/101974


More information about the llvm-commits mailing list