[llvm] [RISCV] Use DenseMap to track V0 definition. NFC (PR #84465)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 11 01:49:06 PDT 2024


https://github.com/lukel97 updated https://github.com/llvm/llvm-project/pull/84465

>From 0ed2d13e5b937a50818ecb2a30431aee2cac5fc8 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Thu, 7 Mar 2024 23:28:05 +0800
Subject: [PATCH 1/3] [RISCV] Use DenseMap to track V0 definition. NFC

Reviving some of the progress on #71764. To recap, we explored removing the
V0 register copies to simplify the pass, but hit a limitation with the
register allocator due to our use of the vmv0 singleton reg class and
early-clobber constraints.

So since we will have to continue to track the definition of V0 ourselves,
this patch simplifies it by storing it in a map. It will allow us to move
about copes to V0 in #71764 without having to do extra bookkeeping.
---
 llvm/lib/Target/RISCV/RISCVFoldMasks.cpp | 51 ++++++++++++++----------
 1 file changed, 30 insertions(+), 21 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp b/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp
index fddbaa97d0638c..7baf6477ddf1d5 100644
--- a/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp
+++ b/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp
@@ -47,10 +47,13 @@ class RISCVFoldMasks : public MachineFunctionPass {
   StringRef getPassName() const override { return "RISC-V Fold Masks"; }
 
 private:
-  bool convertToUnmasked(MachineInstr &MI, MachineInstr *MaskDef) const;
-  bool convertVMergeToVMv(MachineInstr &MI, MachineInstr *MaskDef) const;
+  bool convertToUnmasked(MachineInstr &MI) const;
+  bool convertVMergeToVMv(MachineInstr &MI) const;
 
-  bool isAllOnesMask(MachineInstr *MaskDef) const;
+  bool isAllOnesMask(const MachineInstr *MaskDef) const;
+
+  /// Maps uses of V0 to the corresponding def of V0.
+  DenseMap<const MachineInstr *, const MachineInstr *> V0Defs;
 };
 
 } // namespace
@@ -59,10 +62,9 @@ char RISCVFoldMasks::ID = 0;
 
 INITIALIZE_PASS(RISCVFoldMasks, DEBUG_TYPE, "RISC-V Fold Masks", false, false)
 
-bool RISCVFoldMasks::isAllOnesMask(MachineInstr *MaskDef) const {
-  if (!MaskDef)
-    return false;
-  assert(MaskDef->isCopy() && MaskDef->getOperand(0).getReg() == RISCV::V0);
+bool RISCVFoldMasks::isAllOnesMask(const MachineInstr *MaskDef) const {
+  assert(MaskDef && MaskDef->isCopy() &&
+         MaskDef->getOperand(0).getReg() == RISCV::V0);
   Register SrcReg = TRI->lookThruCopyLike(MaskDef->getOperand(1).getReg(), MRI);
   if (!SrcReg.isVirtual())
     return false;
@@ -89,8 +91,7 @@ bool RISCVFoldMasks::isAllOnesMask(MachineInstr *MaskDef) const {
 
 // Transform (VMERGE_VVM_<LMUL> false, false, true, allones, vl, sew) to
 // (VMV_V_V_<LMUL> false, true, vl, sew). It may decrease uses of VMSET.
-bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI,
-                                        MachineInstr *V0Def) const {
+bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI) const {
 #define CASE_VMERGE_TO_VMV(lmul)                                               \
   case RISCV::PseudoVMERGE_VVM_##lmul:                                         \
     NewOpc = RISCV::PseudoVMV_V_V_##lmul;                                      \
@@ -116,7 +117,7 @@ bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI,
     return false;
 
   assert(MI.getOperand(4).isReg() && MI.getOperand(4).getReg() == RISCV::V0);
-  if (!isAllOnesMask(V0Def))
+  if (!isAllOnesMask(V0Defs.lookup(&MI)))
     return false;
 
   MI.setDesc(TII->get(NewOpc));
@@ -133,14 +134,13 @@ bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI,
   return true;
 }
 
-bool RISCVFoldMasks::convertToUnmasked(MachineInstr &MI,
-                                       MachineInstr *MaskDef) const {
+bool RISCVFoldMasks::convertToUnmasked(MachineInstr &MI) const {
   const RISCV::RISCVMaskedPseudoInfo *I =
       RISCV::getMaskedPseudoInfo(MI.getOpcode());
   if (!I)
     return false;
 
-  if (!isAllOnesMask(MaskDef))
+  if (!isAllOnesMask(V0Defs.lookup(&MI)))
     return false;
 
   // There are two classes of pseudos in the table - compares and
@@ -198,20 +198,29 @@ bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) {
   // $v0:vr = COPY %mask:vr
   // %x:vr = Pseudo_MASK %a:vr, %b:br, $v0:vr
   //
-  // Because $v0 isn't in SSA, keep track of it so we can check the mask operand
-  // on each pseudo.
-  MachineInstr *CurrentV0Def;
-  for (MachineBasicBlock &MBB : MF) {
-    CurrentV0Def = nullptr;
-    for (MachineInstr &MI : MBB) {
-      Changed |= convertToUnmasked(MI, CurrentV0Def);
-      Changed |= convertVMergeToVMv(MI, CurrentV0Def);
+  // Because $v0 isn't in SSA, keep track of its definition at each use so we
+  // can check mask operands.
+  for (const MachineBasicBlock &MBB : MF) {
+    const MachineInstr *CurrentV0Def = nullptr;
+    for (const MachineInstr &MI : MBB) {
+      auto IsV0 = [](const auto &MO) {
+        return MO.isReg() && MO.getReg() == RISCV::V0;
+      };
+      if (any_of(MI.uses(), IsV0))
+        V0Defs[&MI] = CurrentV0Def;
 
       if (MI.definesRegister(RISCV::V0, TRI))
         CurrentV0Def = &MI;
     }
   }
 
+  for (MachineBasicBlock &MBB : MF) {
+    for (MachineInstr &MI : MBB) {
+      Changed |= convertToUnmasked(MI);
+      Changed |= convertVMergeToVMv(MI);
+    }
+  }
+
   return Changed;
 }
 

>From cf7bc7aa311c1dcbd820a69b4c2ada768bb708b5 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Mon, 11 Mar 2024 16:42:20 +0800
Subject: [PATCH 2/3] Use readsRegister

---
 llvm/lib/Target/RISCV/RISCVFoldMasks.cpp | 7 ++-----
 1 file changed, 2 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp b/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp
index 7baf6477ddf1d5..e0b0f282911fd1 100644
--- a/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp
+++ b/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp
@@ -203,11 +203,8 @@ bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) {
   for (const MachineBasicBlock &MBB : MF) {
     const MachineInstr *CurrentV0Def = nullptr;
     for (const MachineInstr &MI : MBB) {
-      auto IsV0 = [](const auto &MO) {
-        return MO.isReg() && MO.getReg() == RISCV::V0;
-      };
-      if (any_of(MI.uses(), IsV0))
-        V0Defs[&MI] = CurrentV0Def;
+      if (MI.readsRegister(RISCV::V0, TRI))
+	V0Defs[&MI] = CurrentV0Def;
 
       if (MI.definesRegister(RISCV::V0, TRI))
         CurrentV0Def = &MI;

>From 1eb1f13b4084bc1860ae22da5a918e0a05fbfa8c Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Mon, 11 Mar 2024 16:48:47 +0800
Subject: [PATCH 3/3] clang-format

---
 llvm/lib/Target/RISCV/RISCVFoldMasks.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp b/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp
index e0b0f282911fd1..2089f5dda6fe52 100644
--- a/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp
+++ b/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp
@@ -204,7 +204,7 @@ bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) {
     const MachineInstr *CurrentV0Def = nullptr;
     for (const MachineInstr &MI : MBB) {
       if (MI.readsRegister(RISCV::V0, TRI))
-	V0Defs[&MI] = CurrentV0Def;
+        V0Defs[&MI] = CurrentV0Def;
 
       if (MI.definesRegister(RISCV::V0, TRI))
         CurrentV0Def = &MI;



More information about the llvm-commits mailing list