[llvm] [AArch64][GlobalISel] Perfect Shuffles (PR #106446)

Madhur Amilkanthwar via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 28 22:13:07 PDT 2024


================
@@ -504,6 +504,189 @@ void applyINS(MachineInstr &MI, MachineRegisterInfo &MRI,
   MI.eraseFromParent();
 }
 
+/// Match 4 elemental G_SHUFFLE_VECTOR
+bool matchPerfectShuffle(MachineInstr &MI, MachineRegisterInfo &MRI) {
+  assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
+  return MRI.getType(MI.getOperand(0).getReg()).getNumElements() == 4;
+}
+
+static Register GeneratePerfectShuffle(unsigned ID, Register V1, Register V2,
+                                       unsigned PFEntry, Register LHS,
+                                       Register RHS, MachineIRBuilder &MIB,
+                                       MachineRegisterInfo &MRI) {
+  unsigned OpNum = (PFEntry >> 26) & 0x0F;
+  unsigned LHSID = (PFEntry >> 13) & ((1 << 13) - 1);
+  unsigned RHSID = (PFEntry >> 0) & ((1 << 13) - 1);
+
+  if (OpNum == OP_COPY) {
+    if (LHSID == (1 * 9 + 2) * 9 + 3)
+      return LHS;
+    assert(LHSID == ((4 * 9 + 5) * 9 + 6) * 9 + 7 && "Illegal OP_COPY!");
+    return RHS;
+  }
+
+  if (OpNum == OP_MOVLANE) {
+    // Decompose a PerfectShuffle ID to get the Mask for lane Elt
+    auto getPFIDLane = [](unsigned ID, int Elt) -> int {
+      assert(Elt < 4 && "Expected Perfect Lanes to be less than 4");
+      Elt = 3 - Elt;
+      while (Elt > 0) {
+        ID /= 9;
+        Elt--;
+      }
+      return (ID % 9 == 8) ? -1 : ID % 9;
+    };
+
+    // For OP_MOVLANE shuffles, the RHSID represents the lane to move into. We
+    // get the lane to move from the PFID, which is always from the
+    // original vectors (V1 or V2).
+    Register OpLHS = GeneratePerfectShuffle(
+        LHSID, V1, V2, PerfectShuffleTable[LHSID], LHS, RHS, MIB, MRI);
+    LLT VT = MRI.getType(OpLHS);
+    assert(RHSID < 8 && "Expected a lane index for RHSID!");
+    unsigned ExtLane = 0;
+    Register Input;
+
+    // OP_MOVLANE are either D movs (if bit 0x4 is set) or S movs. D movs
+    // convert into a higher type.
+    if (RHSID & 0x4) {
+      int MaskElt = getPFIDLane(ID, (RHSID & 0x01) << 1) >> 1;
+      if (MaskElt == -1)
+        MaskElt = (getPFIDLane(ID, ((RHSID & 0x01) << 1) + 1) - 1) >> 1;
+      assert(MaskElt >= 0 && "Didn't expect an undef movlane index!");
+      ExtLane = MaskElt < 2 ? MaskElt : (MaskElt - 2);
+      Input = MaskElt < 2 ? V1 : V2;
+      if (VT.getScalarSizeInBits() == 16 && VT != LLT::fixed_vector(2, 32)) {
+        Input = MIB.buildInstr(TargetOpcode::G_BITCAST,
+                               {LLT::fixed_vector(2, 32)}, {Input})
+                    .getReg(0);
+        OpLHS = MIB.buildInstr(TargetOpcode::G_BITCAST,
+                               {LLT::fixed_vector(2, 32)}, {OpLHS})
+                    .getReg(0);
+      }
+      if (VT.getScalarSizeInBits() == 32 && VT != LLT::fixed_vector(2, 64)) {
+        Input = MIB.buildInstr(TargetOpcode::G_BITCAST,
+                               {LLT::fixed_vector(2, 64)}, {Input})
+                    .getReg(0);
+        OpLHS = MIB.buildInstr(TargetOpcode::G_BITCAST,
+                               {LLT::fixed_vector(2, 64)}, {OpLHS})
+                    .getReg(0);
+      }
+    } else {
+      int MaskElt = getPFIDLane(ID, RHSID);
+      assert(MaskElt >= 0 && "Didn't expect an undef movlane index!");
+      ExtLane = MaskElt < 4 ? MaskElt : (MaskElt - 4);
+      Input = MaskElt < 4 ? V1 : V2;
+      if (VT.getScalarSizeInBits() == 16 && VT != LLT::fixed_vector(4, 16)) {
+        Input = MIB.buildInstr(TargetOpcode::G_BITCAST,
+                               {LLT::fixed_vector(4, 16)}, {Input})
+                    .getReg(0);
+        OpLHS = MIB.buildInstr(TargetOpcode::G_BITCAST,
+                               {LLT::fixed_vector(4, 16)}, {OpLHS})
+                    .getReg(0);
+      }
+    }
+    auto Ext = MIB.buildExtractVectorElementConstant(
+        MRI.getType(Input).getElementType(), Input, ExtLane);
+    auto Ins = MIB.buildInsertVectorElementConstant(MRI.getType(Input), OpLHS,
+                                                    Ext, RHSID & 0x3);
+    if (MRI.getType(Ins.getReg(0)) != VT)
+      return MIB.buildInstr(TargetOpcode::G_BITCAST, {VT}, {Ins}).getReg(0);
+    return Ins.getReg(0);
+  }
+
+  Register OpLHS, OpRHS;
+  OpLHS = GeneratePerfectShuffle(LHSID, V1, V2, PerfectShuffleTable[LHSID], LHS,
+                                 RHS, MIB, MRI);
+  OpRHS = GeneratePerfectShuffle(RHSID, V1, V2, PerfectShuffleTable[RHSID], LHS,
+                                 RHS, MIB, MRI);
+  LLT VT = MRI.getType(OpLHS);
+
+  switch (OpNum) {
+  default:
+    llvm_unreachable("Unknown shuffle opcode!");
+  case OP_VREV: {
+    // VREV divides the vector in half and swaps within the half.
+    unsigned Opcode = VT.getScalarSizeInBits() == 32   ? AArch64::G_REV64
+                      : VT.getScalarSizeInBits() == 16 ? AArch64::G_REV32
+                                                       : AArch64::G_REV16;
+    return MIB.buildInstr(Opcode, {VT}, {OpLHS}).getReg(0);
+  }
+  case OP_VDUP0:
+  case OP_VDUP1:
+  case OP_VDUP2:
+  case OP_VDUP3: {
+    unsigned Opcode;
+    if (VT.getScalarSizeInBits() == 8)
+      Opcode = AArch64::G_DUPLANE8;
+    else if (VT.getScalarSizeInBits() == 16)
+      Opcode = AArch64::G_DUPLANE16;
+    else if (VT.getScalarSizeInBits() == 32)
+      Opcode = AArch64::G_DUPLANE32;
+    else if (VT.getScalarSizeInBits() == 64)
+      Opcode = AArch64::G_DUPLANE64;
+    else
+      llvm_unreachable("Invalid vector element type?");
+
+    if (VT.getSizeInBits() == 64)
+      OpLHS = MIB.buildConcatVectors(
+                     VT.changeElementCount(VT.getElementCount() * 2),
+                     {OpLHS, MIB.buildUndef(VT).getReg(0)})
+                  .getReg(0);
+    Register Lane =
+        MIB.buildConstant(LLT::scalar(64), OpNum - OP_VDUP0).getReg(0);
+    return MIB.buildInstr(Opcode, {VT}, {OpLHS, Lane}).getReg(0);
+  }
+  case OP_VEXT1:
+  case OP_VEXT2:
+  case OP_VEXT3: {
+    unsigned Imm = (OpNum - OP_VEXT1 + 1) * VT.getScalarSizeInBits() / 8;
+    return MIB
+        .buildInstr(AArch64::G_EXT, {VT},
+                    {OpLHS, OpRHS, MIB.buildConstant(LLT::scalar(64), Imm)})
+        .getReg(0);
+  }
+  case OP_VUZPL:
+    return MIB.buildInstr(AArch64::G_UZP1, {VT}, {OpLHS, OpRHS}).getReg(0);
+  case OP_VUZPR:
+    return MIB.buildInstr(AArch64::G_UZP2, {VT}, {OpLHS, OpRHS}).getReg(0);
+  case OP_VZIPL:
+    return MIB.buildInstr(AArch64::G_ZIP1, {VT}, {OpLHS, OpRHS}).getReg(0);
+  case OP_VZIPR:
+    return MIB.buildInstr(AArch64::G_ZIP2, {VT}, {OpLHS, OpRHS}).getReg(0);
+  case OP_VTRNL:
+    return MIB.buildInstr(AArch64::G_TRN1, {VT}, {OpLHS, OpRHS}).getReg(0);
+  case OP_VTRNR:
+    return MIB.buildInstr(AArch64::G_TRN2, {VT}, {OpLHS, OpRHS}).getReg(0);
+  }
+}
+
+void applyPerfectShuffle(MachineInstr &MI, MachineRegisterInfo &MRI,
+                         MachineIRBuilder &Builder) {
+  Register Dst = MI.getOperand(0).getReg();
+  Register LHS = MI.getOperand(1).getReg();
+  Register RHS = MI.getOperand(2).getReg();
+  ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
+  assert(ShuffleMask.size() == 4 && "Expected 4 element mask");
+
+  unsigned PFIndexes[4];
+  for (unsigned i = 0; i != 4; ++i) {
+    if (ShuffleMask[i] < 0)
+      PFIndexes[i] = 8;
+    else
+      PFIndexes[i] = ShuffleMask[i];
+  }
+
+  // Compute the index in the perfect shuffle table.
----------------
madhur13490 wrote:

A comment explaining constants used for multiplication would be good.

https://github.com/llvm/llvm-project/pull/106446


More information about the llvm-commits mailing list