[llvm] [AArch64][SME2] Add FORM_STRIDED_TUPLE pseudo nodes (PR #116399)

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 29 13:45:30 PST 2024


================
@@ -8666,6 +8671,77 @@ void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
     }
   }
 
+  if (MI.getOpcode() == AArch64::FORM_STRIDED_TUPLE_X2_PSEUDO ||
+      MI.getOpcode() == AArch64::FORM_STRIDED_TUPLE_X4_PSEUDO) {
+    MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+    bool UseFormStrided = false;
+    unsigned Size =
+        MI.getOpcode() == AArch64::FORM_STRIDED_TUPLE_X2_PSEUDO ? 2 : 4;
+
+    // The FORM_STRIDED_TUPLE pseudo should only be used if the input operands
+    // are copy nodes where the source register is in a StridedOrContiguous
+    // class. For example:
+    //   %3:zpr2stridedorcontiguous = LD1B_2Z_IMM_PSEUDO ..
+    //   %4:zpr = COPY %3.zsub1:zpr2stridedorcontiguous
+    //   %5:zpr = COPY %3.zsub0:zpr2stridedorcontiguous
+    //   %6:zpr2stridedorcontiguous = LD1B_2Z_PSEUDO ..
+    //   %7:zpr = COPY %6.zsub1:zpr2stridedorcontiguous
+    //   %8:zpr = COPY %6.zsub0:zpr2stridedorcontiguous
+    //   %9:zpr2mul2 = FORM_STRIDED_TUPLE_X2_PSEUDO %5:zpr, %8:zpr
+
+    SmallVector<unsigned, 4> OpSubRegs;
+    for (unsigned I = 1; I < MI.getNumOperands(); ++I) {
+      MachineOperand &MO = MI.getOperand(I);
+      if (!MO.isReg())
+        continue;
+
+      MachineOperand *Def = MRI.getOneDef(MO.getReg());
+      if (!Def || !Def->isReg() || !Def->getParent()->isCopy())
+        continue;
+
+      MachineInstr *Cpy = Def->getParent();
+      MachineOperand CpyOp = Cpy->getOperand(1);
+      if (!CpyOp.isReg())
+        continue;
+
+      MachineOperand *Ld = MRI.getOneDef(CpyOp.getReg());
+      OpSubRegs.push_back(CpyOp.getSubReg());
+      if (!Ld || !Ld->isReg())
+        continue;
+
+      const TargetRegisterClass *RegClass =
+          Size == 2 ? &AArch64::ZPR2StridedOrContiguousRegClass
+                    : &AArch64::ZPR4StridedOrContiguousRegClass;
+
+      if (MRI.getRegClass(Ld->getReg()) == RegClass)
+        UseFormStrided = true;
----------------
sdesmalen-arm wrote:

It would help to have this loop in a separate function that returns `false` if there is any reason that the operands don't meet the criteria. Then you can also include the check for the subreg indices in the same loop (for which you currently require a `std::equal`).

https://github.com/llvm/llvm-project/pull/116399


More information about the llvm-commits mailing list