[llvm] [AArch64] Add streaming-mode stack hazard optimization remarks (PR #101695)
Hari Limaye via llvm-commits
llvm-commits at lists.llvm.org
Mon Aug 5 05:54:40 PDT 2024
================
@@ -5029,3 +5050,171 @@ void AArch64FrameLowering::inlineStackProbe(MachineFunction &MF,
MI->eraseFromParent();
}
}
+
+struct StackAccess {
+ enum AccessType {
+ NotAccessed = 0, // Stack object not accessed by load/store instructions.
+ GPR = 1 << 0, // A general purpose register.
+ PPR = 1 << 1, // A predicate register.
+ FPR = 1 << 2, // A floating point/Neon/SVE register.
+ };
+
+ int Idx;
+ StackOffset Offset;
+ int64_t Size;
+ unsigned AccessTypes;
+
+ StackAccess() : Idx(0), Offset(), Size(0), AccessTypes(NotAccessed) {}
+
+ bool operator<(const StackAccess &Rhs) const {
+ return std::make_tuple(start(), Idx) <
+ std::make_tuple(Rhs.start(), Rhs.Idx);
+ }
+
+ bool isCPU() const {
+ // Predicate register load and store instructions execute on the CPU.
+ return AccessTypes & (AccessType::GPR | AccessType::PPR);
+ }
+ bool isSME() const { return AccessTypes & AccessType::FPR; }
+ bool isMixed() const { return ((AccessTypes & (AccessTypes - 1)) != 0); }
+
+ int64_t start() const { return Offset.getFixed() + Offset.getScalable(); }
+ int64_t end() const { return start() + Size; }
+
+ std::string getTypeString() const {
+ switch (AccessTypes) {
+ case AccessType::FPR:
+ return "FPR";
+ case AccessType::PPR:
+ return "PPR";
+ case AccessType::GPR:
+ return "GPR";
+ case AccessType::NotAccessed:
+ return "NA";
+ default:
+ return "Mixed";
+ }
+ }
+
+ void print(raw_ostream &OS) const {
+ OS << getTypeString() << " stack object at [SP"
+ << (Offset.getFixed() < 0 ? "" : "+") << Offset.getFixed();
+ if (Offset.getScalable())
+ OS << (Offset.getScalable() < 0 ? "" : "+") << Offset.getScalable()
+ << " * vscale";
+ OS << "]";
+ }
+};
+
+static inline raw_ostream &operator<<(raw_ostream &OS, const StackAccess &SA) {
+ SA.print(OS);
+ return OS;
+}
+
+void AArch64FrameLowering::emitRemarks(
+ const MachineFunction &MF, MachineOptimizationRemarkEmitter &ORE) const {
+
+ SMEAttrs Attrs(MF.getFunction());
+ if (Attrs.hasNonStreamingInterfaceAndBody())
+ return;
+
+ const uint64_t HazardSize =
+ (StackHazardSize) ? StackHazardSize : StackHazardRemarkSize;
+
+ if (HazardSize == 0)
+ return;
+
+ const MachineFrameInfo &MFI = MF.getFrameInfo();
+
+ std::vector<StackAccess> StackAccesses(MFI.getNumObjects());
+
+ size_t NumFPLdSt = 0;
+ size_t NumNonFPLdSt = 0;
+
+ // Collect stack accesses via Load/Store instructions.
+ for (const MachineBasicBlock &MBB : MF) {
+ for (const MachineInstr &MI : MBB) {
+ if (!MI.mayLoadOrStore() || MI.getNumMemOperands() < 1)
+ continue;
+ for (MachineMemOperand *MMO : MI.memoperands()) {
+ std::optional<int> FI = getMMOFrameID(MMO, MFI);
+ if (FI && !MFI.isDeadObjectIndex(*FI)) {
+ int FrameIdx = *FI;
+
+ size_t ArrIdx = FrameIdx + MFI.getNumFixedObjects();
+ if (StackAccesses[ArrIdx].AccessTypes == StackAccess::NotAccessed) {
+ StackAccesses[ArrIdx].Idx = FrameIdx;
+ StackAccesses[ArrIdx].Offset =
+ getFrameIndexReferenceFromSP(MF, FrameIdx);
+ StackAccesses[ArrIdx].Size = MFI.getObjectSize(FrameIdx);
+ }
+
+ unsigned RegTy = StackAccess::AccessType::GPR;
+ if (MFI.getStackID(FrameIdx) == TargetStackID::ScalableVector) {
+ if (AArch64::PPRRegClass.contains(MI.getOperand(0).getReg()))
+ RegTy = StackAccess::PPR;
+ else
+ RegTy = StackAccess::FPR;
+ } else if (AArch64InstrInfo::isFpOrNEON(MI)) {
+ RegTy = StackAccess::FPR;
+ }
+
+ StackAccesses[ArrIdx].AccessTypes |= RegTy;
+
+ if (RegTy == StackAccess::FPR)
+ ++NumFPLdSt;
+ else
+ ++NumNonFPLdSt;
+ }
+ }
+ }
+ }
+
+ if (NumFPLdSt == 0 || NumNonFPLdSt == 0)
+ return;
+
+ llvm::sort(StackAccesses);
+ StackAccesses.erase(llvm::remove_if(StackAccesses,
+ [](const StackAccess &S) {
+ return S.AccessTypes ==
+ StackAccess::NotAccessed;
+ }),
+ StackAccesses.end());
+
+ SmallVector<const StackAccess *> MixedObjects;
+ SmallVector<std::pair<const StackAccess *, const StackAccess *>> HazardPairs;
+
+ if (StackAccesses.front().isMixed())
+ MixedObjects.push_back(&StackAccesses.front());
+
+ for (auto It = StackAccesses.begin(), End = StackAccesses.end();
+ It != (End - 1); ++It) {
----------------
hazzlim wrote:
Good point - Done.
https://github.com/llvm/llvm-project/pull/101695
More information about the llvm-commits
mailing list