[llvm] [CodeGen][TII] Allow reassociation on custom operand indices (PR #88306)

Min-Yih Hsu via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 10 11:31:52 PDT 2024


https://github.com/mshockwave created https://github.com/llvm/llvm-project/pull/88306

This opens up a door for reusing reassociation optimizations on target-specific binary operations with non-standard operand list.

This is effectively a NFC.

>From 8e67ee297d4e050b6fc0ac7cc6d2c14d514c8997 Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Mon, 8 Apr 2024 11:41:27 -0700
Subject: [PATCH] [CodeGen][TII] Allow reassociation on custom operand indices

This opens up a door for reusing reassociation optimizations on target-specific
binary operations with non-standard operand list.

This is effectively a NFC.
---
 llvm/include/llvm/CodeGen/TargetInstrInfo.h |  11 ++
 llvm/lib/CodeGen/TargetInstrInfo.cpp        | 145 ++++++++++++++------
 llvm/lib/Target/RISCV/RISCVInstrInfo.cpp    |   8 +-
 3 files changed, 115 insertions(+), 49 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/TargetInstrInfo.h b/llvm/include/llvm/CodeGen/TargetInstrInfo.h
index 9fd0ebe6956fbe..82c952b227557d 100644
--- a/llvm/include/llvm/CodeGen/TargetInstrInfo.h
+++ b/llvm/include/llvm/CodeGen/TargetInstrInfo.h
@@ -30,6 +30,7 @@
 #include "llvm/MC/MCInstrInfo.h"
 #include "llvm/Support/BranchProbability.h"
 #include "llvm/Support/ErrorHandling.h"
+#include <array>
 #include <cassert>
 #include <cstddef>
 #include <cstdint>
@@ -1268,12 +1269,22 @@ class TargetInstrInfo : public MCInstrInfo {
     return true;
   }
 
+  /// The returned array encodes the operand index for each parameter because
+  /// the operands may be commuted; the operand indices for associative
+  /// operations might also be target-specific. Each element specifies the index
+  /// of {Prev, A, B, X, Y}.
+  virtual void
+  getReassociateOperandIdx(const MachineInstr &Root,
+                           MachineCombinerPattern Pattern,
+                           std::array<unsigned, 5> &OperandIndices) const;
+
   /// Attempt to reassociate \P Root and \P Prev according to \P Pattern to
   /// reduce critical path length.
   void reassociateOps(MachineInstr &Root, MachineInstr &Prev,
                       MachineCombinerPattern Pattern,
                       SmallVectorImpl<MachineInstr *> &InsInstrs,
                       SmallVectorImpl<MachineInstr *> &DelInstrs,
+                      ArrayRef<unsigned> OperandIndices,
                       DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const;
 
   /// Reassociation of some instructions requires inverse operations (e.g.
diff --git a/llvm/lib/CodeGen/TargetInstrInfo.cpp b/llvm/lib/CodeGen/TargetInstrInfo.cpp
index 9fbd516acea8e1..488922e3c1b720 100644
--- a/llvm/lib/CodeGen/TargetInstrInfo.cpp
+++ b/llvm/lib/CodeGen/TargetInstrInfo.cpp
@@ -1051,13 +1051,34 @@ static std::pair<bool, bool> mustSwapOperands(MachineCombinerPattern Pattern) {
   }
 }
 
+void TargetInstrInfo::getReassociateOperandIdx(
+    const MachineInstr &Root, MachineCombinerPattern Pattern,
+    std::array<unsigned, 5> &OperandIndices) const {
+  switch (Pattern) {
+  case MachineCombinerPattern::REASSOC_AX_BY:
+    OperandIndices = {1, 1, 1, 2, 2};
+    break;
+  case MachineCombinerPattern::REASSOC_AX_YB:
+    OperandIndices = {2, 1, 2, 2, 1};
+    break;
+  case MachineCombinerPattern::REASSOC_XA_BY:
+    OperandIndices = {1, 2, 1, 1, 2};
+    break;
+  case MachineCombinerPattern::REASSOC_XA_YB:
+    OperandIndices = {2, 2, 2, 1, 1};
+    break;
+  default:
+    llvm_unreachable("unexpected MachineCombinerPattern");
+  }
+}
+
 /// Attempt the reassociation transformation to reduce critical path length.
 /// See the above comments before getMachineCombinerPatterns().
 void TargetInstrInfo::reassociateOps(
-    MachineInstr &Root, MachineInstr &Prev,
-    MachineCombinerPattern Pattern,
+    MachineInstr &Root, MachineInstr &Prev, MachineCombinerPattern Pattern,
     SmallVectorImpl<MachineInstr *> &InsInstrs,
     SmallVectorImpl<MachineInstr *> &DelInstrs,
+    ArrayRef<unsigned> OperandIndices,
     DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const {
   MachineFunction *MF = Root.getMF();
   MachineRegisterInfo &MRI = MF->getRegInfo();
@@ -1065,29 +1086,10 @@ void TargetInstrInfo::reassociateOps(
   const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();
   const TargetRegisterClass *RC = Root.getRegClassConstraint(0, TII, TRI);
 
-  // This array encodes the operand index for each parameter because the
-  // operands may be commuted. Each row corresponds to a pattern value,
-  // and each column specifies the index of A, B, X, Y.
-  unsigned OpIdx[4][4] = {
-    { 1, 1, 2, 2 },
-    { 1, 2, 2, 1 },
-    { 2, 1, 1, 2 },
-    { 2, 2, 1, 1 }
-  };
-
-  int Row;
-  switch (Pattern) {
-  case MachineCombinerPattern::REASSOC_AX_BY: Row = 0; break;
-  case MachineCombinerPattern::REASSOC_AX_YB: Row = 1; break;
-  case MachineCombinerPattern::REASSOC_XA_BY: Row = 2; break;
-  case MachineCombinerPattern::REASSOC_XA_YB: Row = 3; break;
-  default: llvm_unreachable("unexpected MachineCombinerPattern");
-  }
-
-  MachineOperand &OpA = Prev.getOperand(OpIdx[Row][0]);
-  MachineOperand &OpB = Root.getOperand(OpIdx[Row][1]);
-  MachineOperand &OpX = Prev.getOperand(OpIdx[Row][2]);
-  MachineOperand &OpY = Root.getOperand(OpIdx[Row][3]);
+  MachineOperand &OpA = Prev.getOperand(OperandIndices[1]);
+  MachineOperand &OpB = Root.getOperand(OperandIndices[2]);
+  MachineOperand &OpX = Prev.getOperand(OperandIndices[3]);
+  MachineOperand &OpY = Root.getOperand(OperandIndices[4]);
   MachineOperand &OpC = Root.getOperand(0);
 
   Register RegA = OpA.getReg();
@@ -1126,11 +1128,62 @@ void TargetInstrInfo::reassociateOps(
     std::swap(KillX, KillY);
   }
 
+  unsigned PrevFirstOpIdx, PrevSecondOpIdx;
+  unsigned RootFirstOpIdx, RootSecondOpIdx;
+  switch (Pattern) {
+  case MachineCombinerPattern::REASSOC_AX_BY:
+    PrevFirstOpIdx = OperandIndices[1];
+    PrevSecondOpIdx = OperandIndices[3];
+    RootFirstOpIdx = OperandIndices[2];
+    RootSecondOpIdx = OperandIndices[4];
+    break;
+  case MachineCombinerPattern::REASSOC_AX_YB:
+    PrevFirstOpIdx = OperandIndices[1];
+    PrevSecondOpIdx = OperandIndices[3];
+    RootFirstOpIdx = OperandIndices[4];
+    RootSecondOpIdx = OperandIndices[2];
+    break;
+  case MachineCombinerPattern::REASSOC_XA_BY:
+    PrevFirstOpIdx = OperandIndices[3];
+    PrevSecondOpIdx = OperandIndices[1];
+    RootFirstOpIdx = OperandIndices[2];
+    RootSecondOpIdx = OperandIndices[4];
+    break;
+  case MachineCombinerPattern::REASSOC_XA_YB:
+    PrevFirstOpIdx = OperandIndices[3];
+    PrevSecondOpIdx = OperandIndices[1];
+    RootFirstOpIdx = OperandIndices[4];
+    RootSecondOpIdx = OperandIndices[2];
+    break;
+  default:
+    llvm_unreachable("unexpected MachineCombinerPattern");
+  }
+
+  // Basically BuildMI but doesn't add implicit operands by default.
+  auto buildMINoImplicit = [](MachineFunction &MF, const MIMetadata &MIMD,
+                              const MCInstrDesc &MCID, Register DestReg) {
+    return MachineInstrBuilder(
+               MF, MF.CreateMachineInstr(MCID, MIMD.getDL(), /*NoImpl=*/true))
+        .setPCSections(MIMD.getPCSections())
+        .addReg(DestReg, RegState::Define);
+  };
+
   // Create new instructions for insertion.
   MachineInstrBuilder MIB1 =
-      BuildMI(*MF, MIMetadata(Prev), TII->get(NewPrevOpc), NewVR)
-          .addReg(RegX, getKillRegState(KillX))
-          .addReg(RegY, getKillRegState(KillY));
+      buildMINoImplicit(*MF, MIMetadata(Prev), TII->get(NewPrevOpc), NewVR);
+  for (const auto &MO : Prev.explicit_operands()) {
+    unsigned Idx = MO.getOperandNo();
+    // Skip the result operand we'd already added.
+    if (Idx == 0)
+      continue;
+    if (Idx == PrevFirstOpIdx)
+      MIB1.addReg(RegX, getKillRegState(KillX));
+    else if (Idx == PrevSecondOpIdx)
+      MIB1.addReg(RegY, getKillRegState(KillY));
+    else
+      MIB1.add(MO);
+  }
+  MIB1.copyImplicitOps(Prev);
 
   if (SwapRootOperands) {
     std::swap(RegA, NewVR);
@@ -1138,9 +1191,20 @@ void TargetInstrInfo::reassociateOps(
   }
 
   MachineInstrBuilder MIB2 =
-      BuildMI(*MF, MIMetadata(Root), TII->get(NewRootOpc), RegC)
-          .addReg(RegA, getKillRegState(KillA))
-          .addReg(NewVR, getKillRegState(KillNewVR));
+      buildMINoImplicit(*MF, MIMetadata(Root), TII->get(NewRootOpc), RegC);
+  for (const auto &MO : Root.explicit_operands()) {
+    unsigned Idx = MO.getOperandNo();
+    // Skip the result operand.
+    if (Idx == 0)
+      continue;
+    if (Idx == RootFirstOpIdx)
+      MIB2 = MIB2.addReg(RegA, getKillRegState(KillA));
+    else if (Idx == RootSecondOpIdx)
+      MIB2 = MIB2.addReg(NewVR, getKillRegState(KillNewVR));
+    else
+      MIB2 = MIB2.add(MO);
+  }
+  MIB2.copyImplicitOps(Root);
 
   // Propagate FP flags from the original instructions.
   // But clear poison-generating flags because those may not be valid now.
@@ -1184,25 +1248,16 @@ void TargetInstrInfo::genAlternativeCodeSequence(
   MachineRegisterInfo &MRI = Root.getMF()->getRegInfo();
 
   // Select the previous instruction in the sequence based on the input pattern.
-  MachineInstr *Prev = nullptr;
-  switch (Pattern) {
-  case MachineCombinerPattern::REASSOC_AX_BY:
-  case MachineCombinerPattern::REASSOC_XA_BY:
-    Prev = MRI.getUniqueVRegDef(Root.getOperand(1).getReg());
-    break;
-  case MachineCombinerPattern::REASSOC_AX_YB:
-  case MachineCombinerPattern::REASSOC_XA_YB:
-    Prev = MRI.getUniqueVRegDef(Root.getOperand(2).getReg());
-    break;
-  default:
-    llvm_unreachable("Unknown pattern for machine combiner");
-  }
+  std::array<unsigned, 5> OpIdx;
+  getReassociateOperandIdx(Root, Pattern, OpIdx);
+  MachineInstr *Prev = MRI.getUniqueVRegDef(Root.getOperand(OpIdx[0]).getReg());
 
   // Don't reassociate if Prev and Root are in different blocks.
   if (Prev->getParent() != Root.getParent())
     return;
 
-  reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, InstIdxForVirtReg);
+  reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, OpIdx,
+                 InstIdxForVirtReg);
 }
 
 MachineTraceStrategy TargetInstrInfo::getMachineCombinerTraceStrategy() const {
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 6b75efe684d913..5eeb0d7c27cb98 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -1575,10 +1575,10 @@ void RISCVInstrInfo::finalizeInsInstrs(
   MachineFunction &MF = *Root.getMF();
 
   for (auto *NewMI : InsInstrs) {
-    assert(static_cast<unsigned>(RISCV::getNamedOperandIdx(
-               NewMI->getOpcode(), RISCV::OpName::frm)) ==
-               NewMI->getNumOperands() &&
-           "Instruction has unexpected number of operands");
+    // We'd already added the FRM operand.
+    if (static_cast<unsigned>(RISCV::getNamedOperandIdx(
+            NewMI->getOpcode(), RISCV::OpName::frm)) != NewMI->getNumOperands())
+      continue;
     MachineInstrBuilder MIB(MF, NewMI);
     MIB.add(FRM);
     if (FRM.getImm() == RISCVFPRndMode::DYN)



More information about the llvm-commits mailing list