[llvm] [AArch64][GlobalISel] Perfect Shuffles (PR #106446)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Sun Apr 13 02:01:00 PDT 2025


================
@@ -524,6 +524,129 @@ void applyINS(MachineInstr &MI, MachineRegisterInfo &MRI,
   MI.eraseFromParent();
 }
 
+/// Match 4 elemental G_SHUFFLE_VECTOR
+bool matchPerfectShuffle(MachineInstr &MI, MachineRegisterInfo &MRI) {
+  assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
+  return MRI.getType(MI.getOperand(0).getReg()).getNumElements() == 4;
+}
+
+void applyPerfectShuffle(MachineInstr &MI, MachineRegisterInfo &MRI,
+                         MachineIRBuilder &MIB) {
+  Register Dst = MI.getOperand(0).getReg();
+  Register LHS = MI.getOperand(1).getReg();
+  Register RHS = MI.getOperand(2).getReg();
+  ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
+  assert(ShuffleMask.size() == 4 && "Expected 4 element mask");
+
+  unsigned PFIndexes[4];
+  for (unsigned i = 0; i != 4; ++i) {
+    if (ShuffleMask[i] < 0)
+      PFIndexes[i] = 8;
+    else
+      PFIndexes[i] = ShuffleMask[i];
+  }
+
+  // Compute the index in the perfect shuffle table.
+  unsigned PFTableIndex = PFIndexes[0] * 9 * 9 * 9 + PFIndexes[1] * 9 * 9 +
+                          PFIndexes[2] * 9 + PFIndexes[3];
+  unsigned PFEntry = PerfectShuffleTable[PFTableIndex];
+
+  auto BuildRev = [&MIB, &MRI](Register OpLHS) {
+    LLT Ty = MRI.getType(OpLHS);
+    unsigned Opcode = Ty.getScalarSizeInBits() == 32   ? AArch64::G_REV64
+                      : Ty.getScalarSizeInBits() == 16 ? AArch64::G_REV32
+                                                       : AArch64::G_REV16;
+    return MIB.buildInstr(Opcode, {Ty}, {OpLHS}).getReg(0);
+  };
+  auto BuildDup = [&MIB, &MRI](Register OpLHS, unsigned Lane) {
+    LLT Ty = MRI.getType(OpLHS);
+    unsigned Opcode;
+    if (Ty.getScalarSizeInBits() == 8)
+      Opcode = AArch64::G_DUPLANE8;
+    else if (Ty.getScalarSizeInBits() == 16)
+      Opcode = AArch64::G_DUPLANE16;
+    else if (Ty.getScalarSizeInBits() == 32)
+      Opcode = AArch64::G_DUPLANE32;
+    else if (Ty.getScalarSizeInBits() == 64)
+      Opcode = AArch64::G_DUPLANE64;
+    else
+      llvm_unreachable("Invalid vector element type?");
+
+    if (Ty.getSizeInBits() == 64)
+      OpLHS = MIB.buildConcatVectors(
+                     Ty.changeElementCount(Ty.getElementCount() * 2),
+                     {OpLHS, MIB.buildUndef(Ty).getReg(0)})
+                  .getReg(0);
+    Register LaneR = MIB.buildConstant(LLT::scalar(64), Lane).getReg(0);
+    return MIB.buildInstr(Opcode, {Ty}, {OpLHS, LaneR}).getReg(0);
+  };
+  auto BuildExt = [&MIB, &MRI](Register OpLHS, Register OpRHS, unsigned Imm) {
+    LLT Ty = MRI.getType(OpLHS);
+    Imm = Imm * Ty.getScalarSizeInBits() / 8;
+    return MIB
+        .buildInstr(AArch64::G_EXT, {Ty},
+                    {OpLHS, OpRHS, MIB.buildConstant(LLT::scalar(64), Imm)})
+        .getReg(0);
+  };
+  auto BuildZipLike = [&MIB, &MRI](unsigned OpNum, Register OpLHS,
+                                   Register OpRHS) {
+    LLT Ty = MRI.getType(OpLHS);
+    switch (OpNum) {
+    default:
+      llvm_unreachable("Unexpected perfect shuffle opcode\n");
+    case OP_VUZPL:
+      return MIB.buildInstr(AArch64::G_UZP1, {Ty}, {OpLHS, OpRHS}).getReg(0);
+    case OP_VUZPR:
+      return MIB.buildInstr(AArch64::G_UZP2, {Ty}, {OpLHS, OpRHS}).getReg(0);
+    case OP_VZIPL:
+      return MIB.buildInstr(AArch64::G_ZIP1, {Ty}, {OpLHS, OpRHS}).getReg(0);
+    case OP_VZIPR:
+      return MIB.buildInstr(AArch64::G_ZIP2, {Ty}, {OpLHS, OpRHS}).getReg(0);
+    case OP_VTRNL:
+      return MIB.buildInstr(AArch64::G_TRN1, {Ty}, {OpLHS, OpRHS}).getReg(0);
+    case OP_VTRNR:
+      return MIB.buildInstr(AArch64::G_TRN2, {Ty}, {OpLHS, OpRHS}).getReg(0);
+    }
+  };
+  auto BuildExtractInsert64 = [&MIB, &MRI](Register ExtSrc, unsigned ExtLane,
+                                           Register InsSrc, unsigned InsLane) {
+    LLT Ty = MRI.getType(InsSrc);
+    if (Ty.getScalarSizeInBits() == 16 && Ty != LLT::fixed_vector(2, 32)) {
+      ExtSrc = MIB.buildBitcast(LLT::fixed_vector(2, 32), ExtSrc).getReg(0);
+      InsSrc = MIB.buildBitcast(LLT::fixed_vector(2, 32), InsSrc).getReg(0);
+    } else if (Ty.getScalarSizeInBits() == 32 &&
+               Ty != LLT::fixed_vector(2, 64)) {
+      ExtSrc = MIB.buildBitcast(LLT::fixed_vector(2, 64), ExtSrc).getReg(0);
+      InsSrc = MIB.buildBitcast(LLT::fixed_vector(2, 64), InsSrc).getReg(0);
+    }
+    auto Ext = MIB.buildExtractVectorElementConstant(
+        MRI.getType(ExtSrc).getElementType(), ExtSrc, ExtLane);
+    auto Ins = MIB.buildInsertVectorElementConstant(MRI.getType(ExtSrc), InsSrc,
+                                                    Ext, InsLane);
----------------
arsenm wrote:

ditto 

https://github.com/llvm/llvm-project/pull/106446


More information about the llvm-commits mailing list