[llvm] [ModuloSchedule] Implement modulo variable expansion for pipelining (PR #65609)

David Green via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 5 01:19:14 PDT 2024


================
@@ -9596,31 +9628,256 @@ class AArch64PipelinerLoopInfo : public TargetInstrInfo::PipelinerLoopInfo {
     return {};
   }
 
+  void createRemainingIterationsGreaterCondition(
+      int TC, MachineBasicBlock &MBB, SmallVectorImpl<MachineOperand> &Cond,
+      DenseMap<MachineInstr *, MachineInstr *> &LastStage0Insts) override;
+
   void setPreheader(MachineBasicBlock *NewPreheader) override {}
 
   void adjustTripCount(int TripCountAdjust) override {}
 
   void disposed() override {}
+  bool isMVEExpanderSupported() override { return true; }
 };
 } // namespace
 
-static bool isCompareAndBranch(unsigned Opcode) {
-  switch (Opcode) {
-  case AArch64::CBZW:
-  case AArch64::CBZX:
-  case AArch64::CBNZW:
-  case AArch64::CBNZX:
-  case AArch64::TBZW:
-  case AArch64::TBZX:
-  case AArch64::TBNZW:
-  case AArch64::TBNZX:
-    return true;
+/// Clone an instruction from MI. The register of ReplaceOprNum-th operand
+/// is replaced by ReplaceReg. The output register is newly created.
+/// The other operands are unchanged from MI.
+static Register cloneInstr(const MachineInstr *MI, unsigned ReplaceOprNum,
+                           Register ReplaceReg, MachineBasicBlock &MBB,
+                           MachineBasicBlock::iterator InsertTo) {
+  MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
+  const TargetInstrInfo *TII = MBB.getParent()->getSubtarget().getInstrInfo();
+  const TargetRegisterInfo *TRI =
+      MBB.getParent()->getSubtarget().getRegisterInfo();
+  MachineInstr *NewMI = MBB.getParent()->CloneMachineInstr(MI);
+  Register Result = 0;
+  for (unsigned I = 0; I < NewMI->getNumOperands(); ++I) {
+    if (I == 0 && NewMI->getOperand(0).getReg().isVirtual()) {
+      Result = MRI.createVirtualRegister(
+          MRI.getRegClass(NewMI->getOperand(0).getReg()));
+      NewMI->getOperand(I).setReg(Result);
+    } else if (I == ReplaceOprNum) {
+      MRI.constrainRegClass(
+          ReplaceReg,
+          TII->getRegClass(NewMI->getDesc(), I, TRI, *MBB.getParent()));
+      NewMI->getOperand(I).setReg(ReplaceReg);
+    }
   }
-  return false;
+  MBB.insert(InsertTo, NewMI);
+  return Result;
+}
+
+void AArch64PipelinerLoopInfo::createRemainingIterationsGreaterCondition(
+    int TC, MachineBasicBlock &MBB, SmallVectorImpl<MachineOperand> &Cond,
+    DenseMap<MachineInstr *, MachineInstr *> &LastStage0Insts) {
+  // Create and accumulate conditions for next TC iterations.
+  // Example:
+  //   SUBSXrr N, counter, implicit-def $nzcv # compare instruction for the last
+  //                                          # iteration of the kernel
+  //
+  //   # insert the following instructions
+  //   cond = CSINCXr 0, 0, C, implicit $nzcv
+  //   counter = ADDXri counter, 1            # clone from this->Update
+  //   SUBSXrr n, counter, implicit-def $nzcv # clone from this->Comp
+  //   cond = CSINCXr cond, cond, C, implicit $nzcv
+  //   ... (repeat TC times)
+  //   SUBSXri cond, 0, implicit-def $nzcv
+
+  assert(CondBranch->getOpcode() == AArch64::Bcc);
+  // CondCode to exit the loop
+  AArch64CC::CondCode CC =
+      (AArch64CC::CondCode)CondBranch->getOperand(0).getImm();
+  if (CondBranch->getOperand(1).getMBB() == LoopBB)
+    CC = AArch64CC::getInvertedCondCode(CC);
+
+  // Accumulate conditions to exit the loop
+  Register AccCond = AArch64::XZR;
+
+  // If CC holds, CurCond+1 is returned; otherwise CurCond is returned.
+  auto AccumulateCond = [&](Register CurCond,
+                            AArch64CC::CondCode CC) -> Register {
+    Register NewCond = MRI.createVirtualRegister(&AArch64::GPR64commonRegClass);
+    BuildMI(MBB, MBB.end(), Comp->getDebugLoc(), TII->get(AArch64::CSINCXr))
+        .addReg(NewCond, RegState::Define)
+        .addReg(CurCond)
+        .addReg(CurCond)
+        .addImm(AArch64CC::getInvertedCondCode(CC));
+    return NewCond;
+  };
+
+  if (!LastStage0Insts.empty() && LastStage0Insts[Comp]->getParent() == &MBB) {
+    // Update and Comp for I==0 are already exists in MBB
+    // (MBB is an unrolled kernel)
+    Register Counter;
+    for (int I = 0; I <= TC; ++I) {
+      Register NextCounter;
+      if (I != 0)
+        NextCounter =
+            cloneInstr(Comp, CompCounterOprNum, Counter, MBB, MBB.end());
+
+      AccCond = AccumulateCond(AccCond, CC);
+
+      if (I != TC) {
+        if (I == 0) {
+          if (Update != Comp && IsUpdatePriorComp) {
+            Counter =
+                LastStage0Insts[Comp]->getOperand(CompCounterOprNum).getReg();
+            NextCounter = cloneInstr(Update, UpdateCounterOprNum, Counter, MBB,
+                                     MBB.end());
+          } else {
+            // can use already calculated value
+            NextCounter = LastStage0Insts[Update]->getOperand(0).getReg();
+          }
+        } else if (Update != Comp) {
+          NextCounter =
+              cloneInstr(Update, UpdateCounterOprNum, Counter, MBB, MBB.end());
+        }
+      }
+      Counter = NextCounter;
+    }
+  } else {
+    Register Counter;
+    if (LastStage0Insts.empty()) {
+      // use initial counter value (testing if the trip count is sufficient to
+      // be executed by pipelined code)
+      Counter = Init;
+      if (IsUpdatePriorComp)
+        Counter =
+            cloneInstr(Update, UpdateCounterOprNum, Counter, MBB, MBB.end());
+    } else {
+      // MBB is an epilogue block. LastStage0Insts[Comp] is in the kernel block.
+      Counter = LastStage0Insts[Comp]->getOperand(CompCounterOprNum).getReg();
+    }
+
+    for (int I = 0; I <= TC; ++I) {
+      Register NextCounter;
+      NextCounter =
+          cloneInstr(Comp, CompCounterOprNum, Counter, MBB, MBB.end());
+      AccCond = AccumulateCond(AccCond, CC);
+      if (I != TC && Update != Comp)
+        NextCounter =
+            cloneInstr(Update, UpdateCounterOprNum, Counter, MBB, MBB.end());
+      Counter = NextCounter;
+    }
+  }
+
+  // If AccCond == 0, the remainder is greater than TC.
+  BuildMI(MBB, MBB.end(), Comp->getDebugLoc(), TII->get(AArch64::SUBSXri))
+      .addReg(AArch64::XZR, RegState::Define | RegState::Dead)
+      .addReg(AccCond)
+      .addImm(0)
+      .addImm(0);
+  Cond.clear();
+  Cond.push_back(MachineOperand::CreateImm(AArch64CC::EQ));
+}
+
+static void extractPhiReg(const MachineInstr &Phi, const MachineBasicBlock *MBB,
+                          Register *RegMBB, Register *RegOther) {
+  assert(Phi.getNumOperands() == 5);
+  if (Phi.getOperand(2).getMBB() == MBB) {
+    *RegMBB = Phi.getOperand(1).getReg();
+    *RegOther = Phi.getOperand(3).getReg();
+  } else {
+    assert(Phi.getOperand(4).getMBB() == MBB);
+    *RegMBB = Phi.getOperand(3).getReg();
+    *RegOther = Phi.getOperand(1).getReg();
+  }
+}
+
+static bool isDefinedOutside(Register Reg, const MachineBasicBlock *BB) {
+  if (!Reg.isVirtual())
+    return false;
+  const MachineRegisterInfo &MRI = BB->getParent()->getRegInfo();
+  return MRI.getVRegDef(Reg)->getParent() != BB;
+}
+
+/// If Reg is an induction variable, return true and set some parameters
+static bool getIndVarInfo(Register Reg, const MachineBasicBlock *LoopBB,
+                          MachineInstr *&UpdateInst,
+                          unsigned *UpdateCounterOprNum, Register *InitReg,
----------------
davemgreen wrote:

Pass UpdateCounterOprNum and the other "outputs" by reference? `unsigned &UpdateCounterOprNum`

https://github.com/llvm/llvm-project/pull/65609


More information about the llvm-commits mailing list