[llvm] [AArch64][SME] Enable dynamic shuffle for fixed length types. (PR #72490)

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 9 01:11:43 PST 2024


================
@@ -26171,13 +26196,39 @@ static SDValue GenerateFixedLengthSVETBL(SDValue Op, SDValue Op1, SDValue Op2,
         DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ContainerVT,
                     DAG.getConstant(Intrinsic::aarch64_sve_tbl, DL, MVT::i32),
                     Op1, SVEMask);
-  else if (Subtarget.hasSVE2())
-    Shuffle =
-        DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ContainerVT,
-                    DAG.getConstant(Intrinsic::aarch64_sve_tbl2, DL, MVT::i32),
-                    Op1, Op2, SVEMask);
-  else
-    llvm_unreachable("Cannot lower shuffle without SVE2 TBL");
+  else if (Subtarget.hasSVE2()) {
+    if (!MinMaxEqual) {
+      SDValue VScale = (BitsPerElt == 64)
+                           ? DAG.getVScale(DL, MVT::i64, APInt(64, 1))
+                           : DAG.getVScale(DL, MVT::i32, APInt(32, 1));
+      SDValue Mul =
+          DAG.getNode(ISD::MUL, DL, (BitsPerElt == 64) ? MVT::i64 : MVT::i32,
+                      DAG.getConstant(128 / BitsPerElt, DL,
+                                      (BitsPerElt == 64) ? MVT::i64 : MVT::i32),
+                      VScale);
+      SDValue VecMask =
+          DAG.getBuildVector(MaskType, DL, ArrayRef(TBLMask.data(), IndexLen));
+      SDValue MulMask = DAG.getBuildVector(
+          MaskType, DL, ArrayRef(MaskNormalized.data(), IndexLen));
+      SDValue SplatPred = DAG.getNode(ISD::SPLAT_VECTOR, DL, MaskType, Mul);
+      SDValue MulMaskNormalized =
+          DAG.getNode(ISD::MUL, DL, MaskType, SplatPred, MulMask);
----------------
sdesmalen-arm wrote:

nit: There are a lot of variables here with only a single use, which makes this a bit tricky to read. Maybe you can combine some expressions together, e.g.:
```
SDValue MulMaskNormalized = DAG.getNode(
    ISD::MUL, DL, MaskType,
    DAG.getBuildVector(MaskType, DL,
                       ArrayRef(MaskNormalized.data(), IndexLen)),
    DAG.getNode(ISD::SPLAT_VECTOR, DL, MaskType, Mul));
```

https://github.com/llvm/llvm-project/pull/72490


More information about the llvm-commits mailing list