[llvm] [CodeGen][TII] Allow reassociation on custom operand indices (PR #88306)
Min-Yih Hsu via llvm-commits
llvm-commits at lists.llvm.org
Tue Apr 23 10:45:30 PDT 2024
https://github.com/mshockwave updated https://github.com/llvm/llvm-project/pull/88306
>From d81a535c3aed080a9ca96a59f8d7dce2e436434d Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Mon, 8 Apr 2024 11:41:27 -0700
Subject: [PATCH 1/2] [CodeGen][TII] Allow reassociation on custom operand
indices
This opens up a door for reusing reassociation optimizations on target-specific
binary operations with non-standard operand list.
This is effectively a NFC.
---
llvm/include/llvm/CodeGen/TargetInstrInfo.h | 10 ++
llvm/lib/CodeGen/TargetInstrInfo.cpp | 142 ++++++++++++++------
llvm/lib/Target/RISCV/RISCVInstrInfo.cpp | 8 +-
3 files changed, 113 insertions(+), 47 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/TargetInstrInfo.h b/llvm/include/llvm/CodeGen/TargetInstrInfo.h
index d4a83e3753d980..3393d09b77e7ff 100644
--- a/llvm/include/llvm/CodeGen/TargetInstrInfo.h
+++ b/llvm/include/llvm/CodeGen/TargetInstrInfo.h
@@ -31,6 +31,7 @@
#include "llvm/MC/MCInstrInfo.h"
#include "llvm/Support/BranchProbability.h"
#include "llvm/Support/ErrorHandling.h"
+#include <array>
#include <cassert>
#include <cstddef>
#include <cstdint>
@@ -1271,11 +1272,20 @@ class TargetInstrInfo : public MCInstrInfo {
return true;
}
+ /// The returned array encodes the operand index for each parameter because
+ /// the operands may be commuted; the operand indices for associative
+ /// operations might also be target-specific. Each element specifies the index
+ /// of {Prev, A, B, X, Y}.
+ virtual void
+ getReassociateOperandIdx(const MachineInstr &Root, unsigned Pattern,
+ std::array<unsigned, 5> &OperandIndices) const;
+
/// Attempt to reassociate \P Root and \P Prev according to \P Pattern to
/// reduce critical path length.
void reassociateOps(MachineInstr &Root, MachineInstr &Prev, unsigned Pattern,
SmallVectorImpl<MachineInstr *> &InsInstrs,
SmallVectorImpl<MachineInstr *> &DelInstrs,
+ ArrayRef<unsigned> OperandIndices,
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const;
/// Reassociation of some instructions requires inverse operations (e.g.
diff --git a/llvm/lib/CodeGen/TargetInstrInfo.cpp b/llvm/lib/CodeGen/TargetInstrInfo.cpp
index 14b2e4268eb0fb..2a90d2f627b500 100644
--- a/llvm/lib/CodeGen/TargetInstrInfo.cpp
+++ b/llvm/lib/CodeGen/TargetInstrInfo.cpp
@@ -1055,12 +1055,34 @@ static std::pair<bool, bool> mustSwapOperands(unsigned Pattern) {
}
}
+void TargetInstrInfo::getReassociateOperandIdx(
+ const MachineInstr &Root, unsigned Pattern,
+ std::array<unsigned, 5> &OperandIndices) const {
+ switch (Pattern) {
+ case MachineCombinerPattern::REASSOC_AX_BY:
+ OperandIndices = {1, 1, 1, 2, 2};
+ break;
+ case MachineCombinerPattern::REASSOC_AX_YB:
+ OperandIndices = {2, 1, 2, 2, 1};
+ break;
+ case MachineCombinerPattern::REASSOC_XA_BY:
+ OperandIndices = {1, 2, 1, 1, 2};
+ break;
+ case MachineCombinerPattern::REASSOC_XA_YB:
+ OperandIndices = {2, 2, 2, 1, 1};
+ break;
+ default:
+ llvm_unreachable("unexpected MachineCombinerPattern");
+ }
+}
+
/// Attempt the reassociation transformation to reduce critical path length.
/// See the above comments before getMachineCombinerPatterns().
void TargetInstrInfo::reassociateOps(
MachineInstr &Root, MachineInstr &Prev, unsigned Pattern,
SmallVectorImpl<MachineInstr *> &InsInstrs,
SmallVectorImpl<MachineInstr *> &DelInstrs,
+ ArrayRef<unsigned> OperandIndices,
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const {
MachineFunction *MF = Root.getMF();
MachineRegisterInfo &MRI = MF->getRegInfo();
@@ -1068,29 +1090,10 @@ void TargetInstrInfo::reassociateOps(
const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();
const TargetRegisterClass *RC = Root.getRegClassConstraint(0, TII, TRI);
- // This array encodes the operand index for each parameter because the
- // operands may be commuted. Each row corresponds to a pattern value,
- // and each column specifies the index of A, B, X, Y.
- unsigned OpIdx[4][4] = {
- { 1, 1, 2, 2 },
- { 1, 2, 2, 1 },
- { 2, 1, 1, 2 },
- { 2, 2, 1, 1 }
- };
-
- int Row;
- switch (Pattern) {
- case MachineCombinerPattern::REASSOC_AX_BY: Row = 0; break;
- case MachineCombinerPattern::REASSOC_AX_YB: Row = 1; break;
- case MachineCombinerPattern::REASSOC_XA_BY: Row = 2; break;
- case MachineCombinerPattern::REASSOC_XA_YB: Row = 3; break;
- default: llvm_unreachable("unexpected MachineCombinerPattern");
- }
-
- MachineOperand &OpA = Prev.getOperand(OpIdx[Row][0]);
- MachineOperand &OpB = Root.getOperand(OpIdx[Row][1]);
- MachineOperand &OpX = Prev.getOperand(OpIdx[Row][2]);
- MachineOperand &OpY = Root.getOperand(OpIdx[Row][3]);
+ MachineOperand &OpA = Prev.getOperand(OperandIndices[1]);
+ MachineOperand &OpB = Root.getOperand(OperandIndices[2]);
+ MachineOperand &OpX = Prev.getOperand(OperandIndices[3]);
+ MachineOperand &OpY = Root.getOperand(OperandIndices[4]);
MachineOperand &OpC = Root.getOperand(0);
Register RegA = OpA.getReg();
@@ -1129,11 +1132,62 @@ void TargetInstrInfo::reassociateOps(
std::swap(KillX, KillY);
}
+ unsigned PrevFirstOpIdx, PrevSecondOpIdx;
+ unsigned RootFirstOpIdx, RootSecondOpIdx;
+ switch (Pattern) {
+ case MachineCombinerPattern::REASSOC_AX_BY:
+ PrevFirstOpIdx = OperandIndices[1];
+ PrevSecondOpIdx = OperandIndices[3];
+ RootFirstOpIdx = OperandIndices[2];
+ RootSecondOpIdx = OperandIndices[4];
+ break;
+ case MachineCombinerPattern::REASSOC_AX_YB:
+ PrevFirstOpIdx = OperandIndices[1];
+ PrevSecondOpIdx = OperandIndices[3];
+ RootFirstOpIdx = OperandIndices[4];
+ RootSecondOpIdx = OperandIndices[2];
+ break;
+ case MachineCombinerPattern::REASSOC_XA_BY:
+ PrevFirstOpIdx = OperandIndices[3];
+ PrevSecondOpIdx = OperandIndices[1];
+ RootFirstOpIdx = OperandIndices[2];
+ RootSecondOpIdx = OperandIndices[4];
+ break;
+ case MachineCombinerPattern::REASSOC_XA_YB:
+ PrevFirstOpIdx = OperandIndices[3];
+ PrevSecondOpIdx = OperandIndices[1];
+ RootFirstOpIdx = OperandIndices[4];
+ RootSecondOpIdx = OperandIndices[2];
+ break;
+ default:
+ llvm_unreachable("unexpected MachineCombinerPattern");
+ }
+
+ // Basically BuildMI but doesn't add implicit operands by default.
+ auto buildMINoImplicit = [](MachineFunction &MF, const MIMetadata &MIMD,
+ const MCInstrDesc &MCID, Register DestReg) {
+ return MachineInstrBuilder(
+ MF, MF.CreateMachineInstr(MCID, MIMD.getDL(), /*NoImpl=*/true))
+ .setPCSections(MIMD.getPCSections())
+ .addReg(DestReg, RegState::Define);
+ };
+
// Create new instructions for insertion.
MachineInstrBuilder MIB1 =
- BuildMI(*MF, MIMetadata(Prev), TII->get(NewPrevOpc), NewVR)
- .addReg(RegX, getKillRegState(KillX))
- .addReg(RegY, getKillRegState(KillY));
+ buildMINoImplicit(*MF, MIMetadata(Prev), TII->get(NewPrevOpc), NewVR);
+ for (const auto &MO : Prev.explicit_operands()) {
+ unsigned Idx = MO.getOperandNo();
+ // Skip the result operand we'd already added.
+ if (Idx == 0)
+ continue;
+ if (Idx == PrevFirstOpIdx)
+ MIB1.addReg(RegX, getKillRegState(KillX));
+ else if (Idx == PrevSecondOpIdx)
+ MIB1.addReg(RegY, getKillRegState(KillY));
+ else
+ MIB1.add(MO);
+ }
+ MIB1.copyImplicitOps(Prev);
if (SwapRootOperands) {
std::swap(RegA, NewVR);
@@ -1141,9 +1195,20 @@ void TargetInstrInfo::reassociateOps(
}
MachineInstrBuilder MIB2 =
- BuildMI(*MF, MIMetadata(Root), TII->get(NewRootOpc), RegC)
- .addReg(RegA, getKillRegState(KillA))
- .addReg(NewVR, getKillRegState(KillNewVR));
+ buildMINoImplicit(*MF, MIMetadata(Root), TII->get(NewRootOpc), RegC);
+ for (const auto &MO : Root.explicit_operands()) {
+ unsigned Idx = MO.getOperandNo();
+ // Skip the result operand.
+ if (Idx == 0)
+ continue;
+ if (Idx == RootFirstOpIdx)
+ MIB2 = MIB2.addReg(RegA, getKillRegState(KillA));
+ else if (Idx == RootSecondOpIdx)
+ MIB2 = MIB2.addReg(NewVR, getKillRegState(KillNewVR));
+ else
+ MIB2 = MIB2.add(MO);
+ }
+ MIB2.copyImplicitOps(Root);
// Propagate FP flags from the original instructions.
// But clear poison-generating flags because those may not be valid now.
@@ -1187,25 +1252,16 @@ void TargetInstrInfo::genAlternativeCodeSequence(
MachineRegisterInfo &MRI = Root.getMF()->getRegInfo();
// Select the previous instruction in the sequence based on the input pattern.
- MachineInstr *Prev = nullptr;
- switch (Pattern) {
- case MachineCombinerPattern::REASSOC_AX_BY:
- case MachineCombinerPattern::REASSOC_XA_BY:
- Prev = MRI.getUniqueVRegDef(Root.getOperand(1).getReg());
- break;
- case MachineCombinerPattern::REASSOC_AX_YB:
- case MachineCombinerPattern::REASSOC_XA_YB:
- Prev = MRI.getUniqueVRegDef(Root.getOperand(2).getReg());
- break;
- default:
- llvm_unreachable("Unknown pattern for machine combiner");
- }
+ std::array<unsigned, 5> OpIdx;
+ getReassociateOperandIdx(Root, Pattern, OpIdx);
+ MachineInstr *Prev = MRI.getUniqueVRegDef(Root.getOperand(OpIdx[0]).getReg());
// Don't reassociate if Prev and Root are in different blocks.
if (Prev->getParent() != Root.getParent())
return;
- reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, InstIdxForVirtReg);
+ reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, OpIdx,
+ InstIdxForVirtReg);
}
MachineTraceStrategy TargetInstrInfo::getMachineCombinerTraceStrategy() const {
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 8331fc0b8c3024..70ac1f8a592e02 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -1582,10 +1582,10 @@ void RISCVInstrInfo::finalizeInsInstrs(
MachineFunction &MF = *Root.getMF();
for (auto *NewMI : InsInstrs) {
- assert(static_cast<unsigned>(RISCV::getNamedOperandIdx(
- NewMI->getOpcode(), RISCV::OpName::frm)) ==
- NewMI->getNumOperands() &&
- "Instruction has unexpected number of operands");
+ // We'd already added the FRM operand.
+ if (static_cast<unsigned>(RISCV::getNamedOperandIdx(
+ NewMI->getOpcode(), RISCV::OpName::frm)) != NewMI->getNumOperands())
+ continue;
MachineInstrBuilder MIB(MF, NewMI);
MIB.add(FRM);
if (FRM.getImm() == RISCVFPRndMode::DYN)
>From e71afbe4641a677e08c5a4fa1ffbee326d08f5e9 Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Wed, 10 Apr 2024 16:38:38 -0700
Subject: [PATCH 2/2] fixup! [CodeGen][TII] Allow reassociation on custom
operand indices
---
llvm/include/llvm/CodeGen/TargetInstrInfo.h | 4 ++--
llvm/lib/CodeGen/TargetInstrInfo.cpp | 11 ++++++-----
2 files changed, 8 insertions(+), 7 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/TargetInstrInfo.h b/llvm/include/llvm/CodeGen/TargetInstrInfo.h
index 3393d09b77e7ff..d5b1df2114e9e7 100644
--- a/llvm/include/llvm/CodeGen/TargetInstrInfo.h
+++ b/llvm/include/llvm/CodeGen/TargetInstrInfo.h
@@ -1277,8 +1277,8 @@ class TargetInstrInfo : public MCInstrInfo {
/// operations might also be target-specific. Each element specifies the index
/// of {Prev, A, B, X, Y}.
virtual void
- getReassociateOperandIdx(const MachineInstr &Root, unsigned Pattern,
- std::array<unsigned, 5> &OperandIndices) const;
+ getReassociateOperandIndices(const MachineInstr &Root, unsigned Pattern,
+ std::array<unsigned, 5> &OperandIndices) const;
/// Attempt to reassociate \P Root and \P Prev according to \P Pattern to
/// reduce critical path length.
diff --git a/llvm/lib/CodeGen/TargetInstrInfo.cpp b/llvm/lib/CodeGen/TargetInstrInfo.cpp
index 2a90d2f627b500..e01e7b3888915b 100644
--- a/llvm/lib/CodeGen/TargetInstrInfo.cpp
+++ b/llvm/lib/CodeGen/TargetInstrInfo.cpp
@@ -1055,7 +1055,7 @@ static std::pair<bool, bool> mustSwapOperands(unsigned Pattern) {
}
}
-void TargetInstrInfo::getReassociateOperandIdx(
+void TargetInstrInfo::getReassociateOperandIndices(
const MachineInstr &Root, unsigned Pattern,
std::array<unsigned, 5> &OperandIndices) const {
switch (Pattern) {
@@ -1252,15 +1252,16 @@ void TargetInstrInfo::genAlternativeCodeSequence(
MachineRegisterInfo &MRI = Root.getMF()->getRegInfo();
// Select the previous instruction in the sequence based on the input pattern.
- std::array<unsigned, 5> OpIdx;
- getReassociateOperandIdx(Root, Pattern, OpIdx);
- MachineInstr *Prev = MRI.getUniqueVRegDef(Root.getOperand(OpIdx[0]).getReg());
+ std::array<unsigned, 5> OperandIndices;
+ getReassociateOperandIndices(Root, Pattern, OperandIndices);
+ MachineInstr *Prev =
+ MRI.getUniqueVRegDef(Root.getOperand(OperandIndices[0]).getReg());
// Don't reassociate if Prev and Root are in different blocks.
if (Prev->getParent() != Root.getParent())
return;
- reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, OpIdx,
+ reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, OperandIndices,
InstIdxForVirtReg);
}
More information about the llvm-commits
mailing list