[llvm] [AArch64][GlobalISel] Perfect Shuffles (PR #106446)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 23 05:01:17 PDT 2025


================
@@ -524,6 +524,140 @@ void applyINS(MachineInstr &MI, MachineRegisterInfo &MRI,
   MI.eraseFromParent();
 }
 
+/// Match 4 elemental G_SHUFFLE_VECTOR
+bool matchPerfectShuffle(MachineInstr &MI, MachineRegisterInfo &MRI) {
+  assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
+  return MRI.getType(MI.getOperand(0).getReg()).getNumElements() == 4;
+}
+
+void applyPerfectShuffle(MachineInstr &MI, MachineRegisterInfo &MRI,
+                         MachineIRBuilder &MIB) {
+  Register Dst = MI.getOperand(0).getReg();
+  Register LHS = MI.getOperand(1).getReg();
+  Register RHS = MI.getOperand(2).getReg();
+  ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
+  assert(ShuffleMask.size() == 4 && "Expected 4 element mask");
+
+  unsigned PFIndexes[4];
+  for (unsigned i = 0; i != 4; ++i) {
+    if (ShuffleMask[i] < 0)
+      PFIndexes[i] = 8;
+    else
+      PFIndexes[i] = ShuffleMask[i];
+  }
+
+  // Compute the index in the perfect shuffle table.
+  unsigned PFTableIndex = PFIndexes[0] * 9 * 9 * 9 + PFIndexes[1] * 9 * 9 +
+                          PFIndexes[2] * 9 + PFIndexes[3];
+  unsigned PFEntry = PerfectShuffleTable[PFTableIndex];
+
+  auto BuildRev = [&MIB, &MRI](Register OpLHS) {
+    LLT Ty = MRI.getType(OpLHS);
+    unsigned Opcode = Ty.getScalarSizeInBits() == 32   ? AArch64::G_REV64
+                      : Ty.getScalarSizeInBits() == 16 ? AArch64::G_REV32
+                                                       : AArch64::G_REV16;
+    return MIB.buildInstr(Opcode, {Ty}, {OpLHS}).getReg(0);
+  };
+  auto BuildDup = [&MIB, &MRI](Register OpLHS, unsigned Lane) {
+    LLT Ty = MRI.getType(OpLHS);
+    unsigned Opcode;
+    if (Ty.getScalarSizeInBits() == 8)
+      Opcode = AArch64::G_DUPLANE8;
+    else if (Ty.getScalarSizeInBits() == 16)
+      Opcode = AArch64::G_DUPLANE16;
+    else if (Ty.getScalarSizeInBits() == 32)
+      Opcode = AArch64::G_DUPLANE32;
+    else if (Ty.getScalarSizeInBits() == 64)
+      Opcode = AArch64::G_DUPLANE64;
+    else
+      llvm_unreachable("Invalid vector element type?");
+
+    if (Ty.getSizeInBits() == 64)
+      OpLHS = MIB.buildConcatVectors(
+                     Ty.changeElementCount(Ty.getElementCount() * 2),
+                     {OpLHS, MIB.buildUndef(Ty).getReg(0)})
+                  .getReg(0);
+    Register LaneR = MIB.buildConstant(LLT::scalar(64), Lane).getReg(0);
+    return MIB.buildInstr(Opcode, {Ty}, {OpLHS, LaneR}).getReg(0);
+  };
+  auto BuildExt = [&MIB, &MRI](Register OpLHS, Register OpRHS, unsigned Imm) {
+    LLT Ty = MRI.getType(OpLHS);
+    Imm = Imm * Ty.getScalarSizeInBits() / 8;
+    return MIB
+        .buildInstr(AArch64::G_EXT, {Ty},
+                    {OpLHS, OpRHS, MIB.buildConstant(LLT::scalar(64), Imm)})
+        .getReg(0);
+  };
+  auto BuildZipLike = [&MIB, &MRI](unsigned OpNum, Register OpLHS,
+                                   Register OpRHS) {
+    LLT Ty = MRI.getType(OpLHS);
+    unsigned Opc = 0;
+    switch (OpNum) {
+    default:
+      llvm_unreachable("Unexpected perfect shuffle opcode");
+    case OP_VUZPL:
+      Opc = AArch64::G_UZP1;
+      break;
+    case OP_VUZPR:
+      Opc = AArch64::G_UZP2;
+      break;
+    case OP_VZIPL:
+      Opc = AArch64::G_ZIP1;
+      break;
+    case OP_VZIPR:
+      Opc = AArch64::G_ZIP2;
+      break;
+    case OP_VTRNL:
+      Opc = AArch64::G_TRN1;
+      break;
+    case OP_VTRNR:
+      Opc = AArch64::G_TRN2;
+    }
+    return MIB.buildInstr(Opc, {Ty}, {OpLHS, OpRHS}).getReg(0);
+  };
+  auto BuildExtractInsert64 = [&MIB, &MRI](Register ExtSrc, unsigned ExtLane,
+                                           Register InsSrc, unsigned InsLane) {
+    LLT Ty = MRI.getType(InsSrc);
+    if (Ty.getScalarSizeInBits() == 16 && Ty != LLT::fixed_vector(2, 32)) {
+      ExtSrc = MIB.buildBitcast(LLT::fixed_vector(2, 32), ExtSrc).getReg(0);
+      InsSrc = MIB.buildBitcast(LLT::fixed_vector(2, 32), InsSrc).getReg(0);
+    } else if (Ty.getScalarSizeInBits() == 32 &&
+               Ty != LLT::fixed_vector(2, 64)) {
+      ExtSrc = MIB.buildBitcast(LLT::fixed_vector(2, 64), ExtSrc).getReg(0);
+      InsSrc = MIB.buildBitcast(LLT::fixed_vector(2, 64), InsSrc).getReg(0);
+    }
+    auto Ext = MIB.buildExtractVectorElement(
+        MRI.getType(ExtSrc).getElementType(), ExtSrc,
+        MIB.buildConstant(LLT::scalar(64), ExtLane));
+    auto Ins = MIB.buildInsertVectorElement(
+        MRI.getType(ExtSrc), InsSrc, Ext,
+        MIB.buildConstant(LLT::scalar(64), InsLane));
+    return MIB.buildBitcast(Ty, Ins).getReg(0);
+  };
+  auto BuildExtractInsert32 = [&MIB, &MRI](Register ExtSrc, unsigned ExtLane,
+                                           Register InsSrc, unsigned InsLane) {
+    LLT Ty = MRI.getType(InsSrc);
+    if (Ty.getScalarSizeInBits() == 16 && Ty != LLT::fixed_vector(4, 16)) {
+      ExtSrc = MIB.buildBitcast(LLT::fixed_vector(2, 32), ExtSrc).getReg(0);
+      InsSrc = MIB.buildBitcast(LLT::fixed_vector(2, 32), InsSrc).getReg(0);
+    }
+    auto Ext = MIB.buildExtractVectorElement(
+        MRI.getType(ExtSrc).getElementType(), ExtSrc,
----------------
arsenm wrote:

Avoid repeated getTypes 

https://github.com/llvm/llvm-project/pull/106446


More information about the llvm-commits mailing list