[llvm] [AArch64][GlobalISel] Perfect Shuffles (PR #106446)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Sun Apr 13 02:01:01 PDT 2025


================
@@ -14090,8 +13924,95 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
     unsigned PFTableIndex = PFIndexes[0] * 9 * 9 * 9 + PFIndexes[1] * 9 * 9 +
                             PFIndexes[2] * 9 + PFIndexes[3];
     unsigned PFEntry = PerfectShuffleTable[PFTableIndex];
-    return GeneratePerfectShuffle(PFTableIndex, V1, V2, PFEntry, V1, V2, DAG,
-                                  dl);
+
+    auto BuildRev = [&DAG, &dl](SDValue OpLHS) {
+      EVT VT = OpLHS.getValueType();
+      unsigned Opcode = VT.getScalarSizeInBits() == 32   ? AArch64ISD::REV64
+                        : VT.getScalarSizeInBits() == 16 ? AArch64ISD::REV32
+                                                         : AArch64ISD::REV16;
+      return DAG.getNode(Opcode, dl, VT, OpLHS);
+    };
+    auto BuildDup = [&DAG, &dl](SDValue OpLHS, unsigned Lane) {
+      EVT VT = OpLHS.getValueType();
+      unsigned Opcode;
+      if (VT.getScalarSizeInBits() == 8)
+        Opcode = AArch64ISD::DUPLANE8;
+      else if (VT.getScalarSizeInBits() == 16)
+        Opcode = AArch64ISD::DUPLANE16;
+      else if (VT.getScalarSizeInBits() == 32)
+        Opcode = AArch64ISD::DUPLANE32;
+      else if (VT.getScalarSizeInBits() == 64)
+        Opcode = AArch64ISD::DUPLANE64;
+      else
+        llvm_unreachable("Invalid vector element type?");
+
+      if (VT.getSizeInBits() == 64)
+        OpLHS = WidenVector(OpLHS, DAG);
+      return DAG.getNode(Opcode, dl, VT, OpLHS,
+                         DAG.getConstant(Lane, dl, MVT::i64));
+    };
+    auto BuildExt = [&DAG, &dl](SDValue OpLHS, SDValue OpRHS, unsigned Imm) {
+      EVT VT = OpLHS.getValueType();
+      Imm = Imm * getExtFactor(OpLHS);
+      return DAG.getNode(AArch64ISD::EXT, dl, VT, OpLHS, OpRHS,
+                         DAG.getConstant(Imm, dl, MVT::i32));
+    };
+    auto BuildZipLike = [&DAG, &dl](unsigned OpNum, SDValue OpLHS,
+                                    SDValue OpRHS) {
+      EVT VT = OpLHS.getValueType();
+      switch (OpNum) {
+      default:
+        llvm_unreachable("Unexpected perfect shuffle opcode\n");
+      case OP_VUZPL:
+        return DAG.getNode(AArch64ISD::UZP1, dl, VT, OpLHS, OpRHS);
+      case OP_VUZPR:
+        return DAG.getNode(AArch64ISD::UZP2, dl, VT, OpLHS, OpRHS);
+      case OP_VZIPL:
+        return DAG.getNode(AArch64ISD::ZIP1, dl, VT, OpLHS, OpRHS);
+      case OP_VZIPR:
+        return DAG.getNode(AArch64ISD::ZIP2, dl, VT, OpLHS, OpRHS);
+      case OP_VTRNL:
+        return DAG.getNode(AArch64ISD::TRN1, dl, VT, OpLHS, OpRHS);
+      case OP_VTRNR:
+        return DAG.getNode(AArch64ISD::TRN2, dl, VT, OpLHS, OpRHS);
+      }
+    };
+    auto BuildExtractInsert64 = [&DAG, &dl](SDValue ExtSrc, unsigned ExtLane,
+                                            SDValue InsSrc, unsigned InsLane) {
+      EVT VT = InsSrc.getValueType();
+      if (VT.getScalarSizeInBits() == 16) {
+        ExtSrc = DAG.getBitcast(MVT::v2f32, ExtSrc);
+        InsSrc = DAG.getBitcast(MVT::v2f32, InsSrc);
+      } else if (VT.getScalarSizeInBits() == 32) {
+        ExtSrc = DAG.getBitcast(MVT::v2f64, ExtSrc);
+        InsSrc = DAG.getBitcast(MVT::v2f64, InsSrc);
+      }
+      SDValue Ext = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
+                                ExtSrc.getValueType().getVectorElementType(),
+                                ExtSrc, DAG.getVectorIdxConstant(ExtLane, dl));
+      SDValue Ins =
+          DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, ExtSrc.getValueType(), InsSrc,
+                      Ext, DAG.getVectorIdxConstant(InsLane, dl));
+      return DAG.getBitcast(VT, Ins);
+    };
+    auto BuildExtractInsert32 = [&DAG, &dl](SDValue ExtSrc, unsigned ExtLane,
+                                            SDValue InsSrc, unsigned InsLane) {
+      EVT VT = InsSrc.getValueType();
+      if (VT.getScalarSizeInBits() == 16) {
+        ExtSrc = DAG.getBitcast(MVT::v4f16, ExtSrc);
+        InsSrc = DAG.getBitcast(MVT::v4f16, InsSrc);
+      }
+      SDValue Ext = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
+                                ExtSrc.getValueType().getVectorElementType(),
+                                ExtSrc, DAG.getVectorIdxConstant(ExtLane, dl));
+      SDValue Ins =
+          DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, ExtSrc.getValueType(), InsSrc,
+                      Ext, DAG.getVectorIdxConstant(InsLane, dl));
----------------
arsenm wrote:

can query the index type once instead of repeating it for each of these cases 

https://github.com/llvm/llvm-project/pull/106446


More information about the llvm-commits mailing list