[llvm] [AArch64][GlobalISel] Combine Shuffles of G_CONCAT_VECTORS (PR #87489)

David Green via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 12 04:28:18 PDT 2024


================
@@ -303,6 +303,82 @@ void CombinerHelper::applyCombineConcatVectors(MachineInstr &MI,
   replaceRegWith(MRI, DstReg, NewDstReg);
 }
 
+bool CombinerHelper::matchCombineShuffleConcat(MachineInstr &MI,
+                                               SmallVector<Register> &Ops) {
+  ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
+  auto ConcatMI1 =
+      dyn_cast<GConcatVectors>(MRI.getVRegDef(MI.getOperand(1).getReg()));
+  auto ConcatMI2 =
+      dyn_cast<GConcatVectors>(MRI.getVRegDef(MI.getOperand(2).getReg()));
+  if (!ConcatMI1 || !ConcatMI2)
+    return false;
+
+  // Check that the sources of the Concat instructions have the same type
+  if (MRI.getType(ConcatMI1->getSourceReg(0)) !=
+      MRI.getType(ConcatMI2->getSourceReg(0)))
+    return false;
+
+  LLT ConcatSrcTy = MRI.getType(ConcatMI1->getReg(1));
+  LLT ShuffleSrcTy1 = MRI.getType(MI.getOperand(1).getReg());
+  unsigned ConcatSrcNumElt = ConcatSrcTy.getNumElements();
+  for (unsigned i = 0; i < Mask.size(); i += ConcatSrcNumElt) {
+    // Check if the index takes a whole source register from G_CONCAT_VECTORS
+    // Assumes that all Sources of G_CONCAT_VECTORS are the same type
+    if (Mask[i] == -1) {
+      for (unsigned j = 1; j < ConcatSrcNumElt; j++) {
+        if (i + j >= Mask.size())
+          return false;
+        if (Mask[i + j] != -1)
+          return false;
+      }
+      Ops.push_back(0);
+    } else if (Mask[i] % ConcatSrcNumElt == 0) {
+      for (unsigned j = 1; j < ConcatSrcNumElt; j++) {
+        if (i + j >= Mask.size())
+          return false;
+        if (Mask[i + j] != Mask[i] + static_cast<int>(j))
+          return false;
+      }
+      // Retrieve the source register from its respective G_CONCAT_VECTORS
+      // instruction
+      if (Mask[i] < ShuffleSrcTy1.getNumElements()) {
+        Ops.push_back(ConcatMI1->getSourceReg(Mask[i] / ConcatSrcNumElt));
+      } else {
+        Ops.push_back(ConcatMI2->getSourceReg(Mask[i] / ConcatSrcNumElt -
+                                              ConcatMI1->getNumSources()));
+      }
+    } else {
+      return false;
+    }
+  }
+
+  if (!isLegalOrBeforeLegalizer({TargetOpcode::G_IMPLICIT_DEF, {ConcatSrcTy}}))
----------------
davemgreen wrote:

The implicit def check can move up to into the `if (Mask[i] == -1) {` case, so we only check if we have any implicit defs.

https://github.com/llvm/llvm-project/pull/87489


More information about the llvm-commits mailing list