[llvm] [RISCV][VLOPT] Allow propagation even when VL isn't VLMAX (PR #112228)

Michael Maitland via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 15 08:36:22 PDT 2024


https://github.com/michaelmaitland updated https://github.com/llvm/llvm-project/pull/112228

>From 6b5faa1cd1e92a6450654e120740385311824341 Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Mon, 14 Oct 2024 08:36:42 -0700
Subject: [PATCH 1/4] [RISCV][VLOPT] Allow propogation even when VL isn't VLMAX

The original goal of this pass was to focus on vector operations with VLMAX.
However, users often utilize only part of the result, and such usage may come
from the vectorizer.

We found that relaxing this constraint can capture more optimization
opportunities, such as non-power-of-2 code generation and vector operation
sequences with different VLs.t show

---------

Co-authored-by: Kito Cheng <kito.cheng at sifive.com>
---
 llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp | 111 ++++++++++++++++++---
 llvm/test/CodeGen/RISCV/rvv/vl-opt.ll      |  51 ++++++----
 2 files changed, 125 insertions(+), 37 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
index eb1f4df4ff7264..2c87826ab7883c 100644
--- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
+++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
@@ -31,6 +31,44 @@ using namespace llvm;
 
 namespace {
 
+struct VLInfo {
+  VLInfo(const MachineOperand &VLOp) {
+    IsImm = VLOp.isImm();
+    if (IsImm)
+      Imm = VLOp.getImm();
+    else
+      Reg = VLOp.getReg();
+  }
+
+  Register Reg;
+  int64_t Imm;
+  bool IsImm;
+
+  bool isCompatible(const MachineOperand &VLOp) const {
+    if (IsImm != VLOp.isImm())
+      return false;
+    if (IsImm)
+      return Imm == VLOp.getImm();
+    return Reg == VLOp.getReg();
+  }
+
+  bool isValid() const { return IsImm || Reg.isVirtual(); }
+
+  bool hasBenefit(const MachineOperand &VLOp) const {
+    if (IsImm && Imm == RISCV::VLMaxSentinel)
+      return false;
+
+    if (!IsImm || !VLOp.isImm())
+      return true;
+
+    if (VLOp.getImm() == RISCV::VLMaxSentinel)
+      return true;
+
+    // No benefit if the current VL is already smaller than the new one.
+    return Imm < VLOp.getImm();
+  }
+};
+
 class RISCVVLOptimizer : public MachineFunctionPass {
   const MachineRegisterInfo *MRI;
   const MachineDominatorTree *MDT;
@@ -51,7 +89,7 @@ class RISCVVLOptimizer : public MachineFunctionPass {
   StringRef getPassName() const override { return PASS_NAME; }
 
 private:
-  bool checkUsers(std::optional<Register> &CommonVL, MachineInstr &MI);
+  bool checkUsers(std::optional<VLInfo> &CommonVL, MachineInstr &MI);
   bool tryReduceVL(MachineInstr &MI);
   bool isCandidate(const MachineInstr &MI) const;
 };
@@ -643,8 +681,34 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
 
   unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
   const MachineOperand &VLOp = MI.getOperand(VLOpNum);
-  if (!VLOp.isImm() || VLOp.getImm() != RISCV::VLMaxSentinel)
+  if (((VLOp.isImm() && VLOp.getImm() != RISCV::VLMaxSentinel) ||
+       VLOp.isReg())) {
+    bool UseTAPolicy = false;
+    bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(Desc);
+    if (RISCVII::hasVecPolicyOp(Desc.TSFlags)) {
+      unsigned PolicyOpNum = RISCVII::getVecPolicyOpNum(Desc);
+      const MachineOperand &PolicyOp = MI.getOperand(PolicyOpNum);
+      uint64_t Policy = PolicyOp.getImm();
+      UseTAPolicy = (Policy & RISCVII::TAIL_AGNOSTIC) == RISCVII::TAIL_AGNOSTIC;
+      if (HasPassthru) {
+        unsigned PassthruOpIdx = MI.getNumExplicitDefs();
+        UseTAPolicy = UseTAPolicy || (MI.getOperand(PassthruOpIdx).getReg() ==
+                                      RISCV::NoRegister);
+      }
+    }
+    if (!UseTAPolicy) {
+      LLVM_DEBUG(
+          dbgs() << "  Not a candidate because it uses tail-undisturbed policy"
+                    " with non-VLMAX VL\n");
+      return false;
+    }
+  }
+
+  // If the VL is 1, then there is no need to reduce it.
+  if (VLOp.isImm() && VLOp.getImm() == 1) {
+    LLVM_DEBUG(dbgs() << "  Not a candidate because VL is already 1\n");
     return false;
+  }
 
   // Some instructions that produce vectors have semantics that make it more
   // difficult to determine whether the VL can be reduced. For example, some
@@ -667,7 +731,7 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
   return true;
 }
 
-bool RISCVVLOptimizer::checkUsers(std::optional<Register> &CommonVL,
+bool RISCVVLOptimizer::checkUsers(std::optional<VLInfo> &CommonVL,
                                   MachineInstr &MI) {
   // FIXME: Avoid visiting each user for each time we visit something on the
   // worklist, combined with an extra visit from the outer loop. Restructure
@@ -721,8 +785,9 @@ bool RISCVVLOptimizer::checkUsers(std::optional<Register> &CommonVL,
     }
 
     if (!CommonVL) {
-      CommonVL = VLOp.getReg();
-    } else if (*CommonVL != VLOp.getReg()) {
+      CommonVL = VLInfo(VLOp);
+      LLVM_DEBUG(dbgs() << "    User VL is: " << VLOp << "\n");
+    } else if (!CommonVL->isCompatible(VLOp)) {
       LLVM_DEBUG(dbgs() << "    Abort because users have different VL\n");
       CanReduceVL = false;
       break;
@@ -759,7 +824,7 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
     MachineInstr &MI = *Worklist.pop_back_val();
     LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n");
 
-    std::optional<Register> CommonVL;
+    std::optional<VLInfo> CommonVL;
     bool CanReduceVL = true;
     if (isVectorRegClass(MI.getOperand(0).getReg(), MRI))
       CanReduceVL = checkUsers(CommonVL, MI);
@@ -767,21 +832,35 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
     if (!CanReduceVL || !CommonVL)
       continue;
 
-    if (!CommonVL->isVirtual()) {
-      LLVM_DEBUG(
-          dbgs() << "    Abort due to new VL is not virtual register.\n");
+    if (!CommonVL->isValid()) {
+      LLVM_DEBUG(dbgs() << "    Abort due to common VL is not valid.\n");
       continue;
     }
 
-    const MachineInstr *VLMI = MRI->getVRegDef(*CommonVL);
-    if (!MDT->dominates(VLMI, &MI))
-      continue;
-
-    // All our checks passed. We can reduce VL.
-    LLVM_DEBUG(dbgs() << "    Reducing VL for: " << MI << "\n");
     unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc());
     MachineOperand &VLOp = MI.getOperand(VLOpNum);
-    VLOp.ChangeToRegister(*CommonVL, false);
+
+    if (!CommonVL->hasBenefit(VLOp)) {
+      LLVM_DEBUG(dbgs() << "    Abort due to no benefit.\n");
+      continue;
+    }
+
+    if (CommonVL->IsImm) {
+      LLVM_DEBUG(dbgs() << "  Reduce VL from " << VLOp << " to "
+                        << CommonVL->Imm << " for " << MI << "\n");
+      VLOp.ChangeToImmediate(CommonVL->Imm);
+    } else {
+      const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->Reg);
+      if (!MDT->dominates(VLMI, &MI))
+        continue;
+      LLVM_DEBUG(dbgs() << "  Reduce VL from " << VLOp << " to "
+                        << printReg(CommonVL->Reg, MRI->getTargetRegisterInfo())
+                        << " for " << MI << "\n");
+
+      // All our checks passed. We can reduce VL.
+      VLOp.ChangeToRegister(CommonVL->Reg, false);
+    }
+
     MadeChange = true;
 
     // Now add all inputs to this instruction to the worklist.
diff --git a/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll b/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
index 0b3e67ec895566..e8ac4efc770484 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
@@ -23,7 +23,7 @@ define <vscale x 4 x i32> @different_imm_vl_with_ta(<vscale x 4 x i32> %passthru
   ret <vscale x 4 x i32> %w
 }
 
-; No benificial to propagate VL since VL is larger in the use side.
+; Not beneficial to propagate VL since VL is larger in the use side.
 define <vscale x 4 x i32> @different_imm_vl_with_ta_larger_vl(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
 ; CHECK-LABEL: different_imm_vl_with_ta_larger_vl:
 ; CHECK:       # %bb.0:
@@ -38,20 +38,26 @@ define <vscale x 4 x i32> @different_imm_vl_with_ta_larger_vl(<vscale x 4 x i32>
 }
 
 define <vscale x 4 x i32> @different_imm_reg_vl_with_ta(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
-; CHECK-LABEL: different_imm_reg_vl_with_ta:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 4, e32, m2, ta, ma
-; CHECK-NEXT:    vadd.vv v8, v10, v12
-; CHECK-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
-; CHECK-NEXT:    vadd.vv v8, v8, v10
-; CHECK-NEXT:    ret
+; NOVLOPT-LABEL: different_imm_reg_vl_with_ta:
+; NOVLOPT:       # %bb.0:
+; NOVLOPT-NEXT:    vsetivli zero, 4, e32, m2, ta, ma
+; NOVLOPT-NEXT:    vadd.vv v8, v10, v12
+; NOVLOPT-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
+; NOVLOPT-NEXT:    vadd.vv v8, v8, v10
+; NOVLOPT-NEXT:    ret
+;
+; VLOPT-LABEL: different_imm_reg_vl_with_ta:
+; VLOPT:       # %bb.0:
+; VLOPT-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
+; VLOPT-NEXT:    vadd.vv v8, v10, v12
+; VLOPT-NEXT:    vadd.vv v8, v8, v10
+; VLOPT-NEXT:    ret
   %v = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen 4)
   %w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a, iXLen %vl1)
   ret <vscale x 4 x i32> %w
 }
 
-
-; No benificial to propagate VL since VL is already one.
+; Not beneficial to propagate VL since VL is already one.
 define <vscale x 4 x i32> @different_imm_vl_with_ta_1(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
 ; CHECK-LABEL: different_imm_vl_with_ta_1:
 ; CHECK:       # %bb.0:
@@ -69,13 +75,20 @@ define <vscale x 4 x i32> @different_imm_vl_with_ta_1(<vscale x 4 x i32> %passth
 ; it's still safe even %vl2 is larger than %vl1, becuase rest of the vector are
 ; undefined value.
 define <vscale x 4 x i32> @different_vl_with_ta(<vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
-; CHECK-LABEL: different_vl_with_ta:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
-; CHECK-NEXT:    vadd.vv v10, v8, v10
-; CHECK-NEXT:    vsetvli zero, a1, e32, m2, ta, ma
-; CHECK-NEXT:    vadd.vv v8, v10, v8
-; CHECK-NEXT:    ret
+; NOVLOPT-LABEL: different_vl_with_ta:
+; NOVLOPT:       # %bb.0:
+; NOVLOPT-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
+; NOVLOPT-NEXT:    vadd.vv v10, v8, v10
+; NOVLOPT-NEXT:    vsetvli zero, a1, e32, m2, ta, ma
+; NOVLOPT-NEXT:    vadd.vv v8, v10, v8
+; NOVLOPT-NEXT:    ret
+;
+; VLOPT-LABEL: different_vl_with_ta:
+; VLOPT:       # %bb.0:
+; VLOPT-NEXT:    vsetvli zero, a1, e32, m2, ta, ma
+; VLOPT-NEXT:    vadd.vv v10, v8, v10
+; VLOPT-NEXT:    vadd.vv v8, v10, v8
+; VLOPT-NEXT:    ret
   %v = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1)
   %w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a,iXLen %vl2)
   ret <vscale x 4 x i32> %w
@@ -110,7 +123,3 @@ define <vscale x 4 x i32> @different_imm_vl_with_tu(<vscale x 4 x i32> %passthru
   %w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a,iXLen 4)
   ret <vscale x 4 x i32> %w
 }
-
-;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
-; NOVLOPT: {{.*}}
-; VLOPT: {{.*}}

>From 9ad694e590c8ffd1216b1fbae4f701c1d95a5971 Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Tue, 15 Oct 2024 07:08:51 -0700
Subject: [PATCH 2/4] fixup! use std::variant

---
 llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp | 60 +++++++++++++---------
 1 file changed, 36 insertions(+), 24 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
index 2c87826ab7883c..9b69aa0c60c4a8 100644
--- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
+++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
@@ -32,40 +32,51 @@ using namespace llvm;
 namespace {
 
 struct VLInfo {
+  std::variant<Register, int64_t> VL;
+
   VLInfo(const MachineOperand &VLOp) {
-    IsImm = VLOp.isImm();
-    if (IsImm)
-      Imm = VLOp.getImm();
+    if (VLOp.isImm())
+      VL = VLOp.getImm();
     else
-      Reg = VLOp.getReg();
+      VL = VLOp.getReg();
   }
 
-  Register Reg;
-  int64_t Imm;
-  bool IsImm;
-
   bool isCompatible(const MachineOperand &VLOp) const {
-    if (IsImm != VLOp.isImm())
+    if (isImm() != VLOp.isImm())
       return false;
-    if (IsImm)
-      return Imm == VLOp.getImm();
-    return Reg == VLOp.getReg();
+    if (isImm())
+      return getImm() == VLOp.getImm();
+    return getReg() == VLOp.getReg();
+  }
+
+  bool isImm() const { return std::holds_alternative<int64_t>(VL); }
+
+  bool isReg() const { return std::holds_alternative<Register>(VL); }
+
+  bool isValid() const { return isImm() || getReg().isVirtual(); }
+
+  int64_t getImm() const {
+    assert (isImm() && "Expected VL to be an immediate");
+    return std::get<int64_t>(VL);
   }
 
-  bool isValid() const { return IsImm || Reg.isVirtual(); }
+  Register getReg() const {
+    assert (isReg() && "Expected VL to be a Register");
+    return std::get<Register>(VL);
+  }
 
   bool hasBenefit(const MachineOperand &VLOp) const {
-    if (IsImm && Imm == RISCV::VLMaxSentinel)
+    if (isImm() && getImm() == RISCV::VLMaxSentinel)
       return false;
 
-    if (!IsImm || !VLOp.isImm())
+    if (!isImm() || !VLOp.isImm())
       return true;
 
     if (VLOp.getImm() == RISCV::VLMaxSentinel)
       return true;
 
     // No benefit if the current VL is already smaller than the new one.
-    return Imm < VLOp.getImm();
+    return getImm() < VLOp.getImm();
   }
 };
 
@@ -845,20 +856,21 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
       continue;
     }
 
-    if (CommonVL->IsImm) {
+    if (CommonVL->isImm()) {
       LLVM_DEBUG(dbgs() << "  Reduce VL from " << VLOp << " to "
-                        << CommonVL->Imm << " for " << MI << "\n");
-      VLOp.ChangeToImmediate(CommonVL->Imm);
+                        << CommonVL->getImm() << " for " << MI << "\n");
+      VLOp.ChangeToImmediate(CommonVL->getImm());
     } else {
-      const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->Reg);
+      const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg());
       if (!MDT->dominates(VLMI, &MI))
         continue;
-      LLVM_DEBUG(dbgs() << "  Reduce VL from " << VLOp << " to "
-                        << printReg(CommonVL->Reg, MRI->getTargetRegisterInfo())
-                        << " for " << MI << "\n");
+      LLVM_DEBUG(
+          dbgs() << "  Reduce VL from " << VLOp << " to "
+                 << printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo())
+                 << " for " << MI << "\n");
 
       // All our checks passed. We can reduce VL.
-      VLOp.ChangeToRegister(CommonVL->Reg, false);
+      VLOp.ChangeToRegister(CommonVL->getReg(), false);
     }
 
     MadeChange = true;

>From bb7db6422c1599541d9bd14e95f2a3aea402c786 Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Tue, 15 Oct 2024 07:25:41 -0700
Subject: [PATCH 3/4] fixup! clang-format

---
 llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
index 9b69aa0c60c4a8..65d28f8bc7779b 100644
--- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
+++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
@@ -56,12 +56,12 @@ struct VLInfo {
   bool isValid() const { return isImm() || getReg().isVirtual(); }
 
   int64_t getImm() const {
-    assert (isImm() && "Expected VL to be an immediate");
+    assert(isImm() && "Expected VL to be an immediate");
     return std::get<int64_t>(VL);
   }
 
   Register getReg() const {
-    assert (isReg() && "Expected VL to be a Register");
+    assert(isReg() && "Expected VL to be a Register");
     return std::get<Register>(VL);
   }
 

>From f679b7ab981c15a362652bdd144eccf76bb24955 Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Tue, 15 Oct 2024 08:15:28 -0700
Subject: [PATCH 4/4] fixup! remove VLInfo

---
 llvm/lib/Target/RISCV/RISCVInstrInfo.cpp      | 14 ++++
 llvm/lib/Target/RISCV/RISCVInstrInfo.h        |  3 +
 llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp    | 84 ++++---------------
 llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp | 22 +----
 llvm/test/CodeGen/RISCV/rvv/vl-opt-op-info.ll | 40 +++++----
 llvm/test/CodeGen/RISCV/rvv/vl-opt.ll         | 63 +++++++-------
 6 files changed, 92 insertions(+), 134 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index b8539a5d1add14..3989a966edfd33 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -4102,3 +4102,17 @@ unsigned RISCV::getDestLog2EEW(const MCInstrDesc &Desc, unsigned Log2SEW) {
   assert(Scaled >= 3 && Scaled <= 6);
   return Scaled;
 }
+
+/// Given two VL operands, do we know that LHS <= RHS?
+bool RISCV::isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS) {
+  if (LHS.isReg() && RHS.isReg() && LHS.getReg().isVirtual() &&
+      LHS.getReg() == RHS.getReg())
+    return true;
+  if (RHS.isImm() && RHS.getImm() == RISCV::VLMaxSentinel)
+    return true;
+  if (LHS.isImm() && LHS.getImm() == RISCV::VLMaxSentinel)
+    return false;
+  if (!LHS.isImm() || !RHS.isImm())
+    return false;
+  return LHS.getImm() <= RHS.getImm();
+}
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
index 457db9b9860d00..c3aa367486627a 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
@@ -346,6 +346,9 @@ unsigned getDestLog2EEW(const MCInstrDesc &Desc, unsigned Log2SEW);
 // Special immediate for AVL operand of V pseudo instructions to indicate VLMax.
 static constexpr int64_t VLMaxSentinel = -1LL;
 
+/// Given two VL operands, do we know that LHS <= RHS?
+bool isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS);
+
 // Mask assignments for floating-point
 static constexpr unsigned FPMASK_Negative_Infinity = 0x001;
 static constexpr unsigned FPMASK_Negative_Normal = 0x002;
diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
index 65d28f8bc7779b..2a7a59c8e03e1f 100644
--- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
+++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
@@ -30,56 +30,6 @@ using namespace llvm;
 #define PASS_NAME "RISC-V VL Optimizer"
 
 namespace {
-
-struct VLInfo {
-  std::variant<Register, int64_t> VL;
-
-  VLInfo(const MachineOperand &VLOp) {
-    if (VLOp.isImm())
-      VL = VLOp.getImm();
-    else
-      VL = VLOp.getReg();
-  }
-
-  bool isCompatible(const MachineOperand &VLOp) const {
-    if (isImm() != VLOp.isImm())
-      return false;
-    if (isImm())
-      return getImm() == VLOp.getImm();
-    return getReg() == VLOp.getReg();
-  }
-
-  bool isImm() const { return std::holds_alternative<int64_t>(VL); }
-
-  bool isReg() const { return std::holds_alternative<Register>(VL); }
-
-  bool isValid() const { return isImm() || getReg().isVirtual(); }
-
-  int64_t getImm() const {
-    assert(isImm() && "Expected VL to be an immediate");
-    return std::get<int64_t>(VL);
-  }
-
-  Register getReg() const {
-    assert(isReg() && "Expected VL to be a Register");
-    return std::get<Register>(VL);
-  }
-
-  bool hasBenefit(const MachineOperand &VLOp) const {
-    if (isImm() && getImm() == RISCV::VLMaxSentinel)
-      return false;
-
-    if (!isImm() || !VLOp.isImm())
-      return true;
-
-    if (VLOp.getImm() == RISCV::VLMaxSentinel)
-      return true;
-
-    // No benefit if the current VL is already smaller than the new one.
-    return getImm() < VLOp.getImm();
-  }
-};
-
 class RISCVVLOptimizer : public MachineFunctionPass {
   const MachineRegisterInfo *MRI;
   const MachineDominatorTree *MDT;
@@ -100,7 +50,8 @@ class RISCVVLOptimizer : public MachineFunctionPass {
   StringRef getPassName() const override { return PASS_NAME; }
 
 private:
-  bool checkUsers(std::optional<VLInfo> &CommonVL, MachineInstr &MI);
+  bool checkUsers(std::optional<const MachineOperand *> &CommonVL,
+                  MachineInstr &MI);
   bool tryReduceVL(MachineInstr &MI);
   bool isCandidate(const MachineInstr &MI) const;
 };
@@ -742,8 +693,8 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
   return true;
 }
 
-bool RISCVVLOptimizer::checkUsers(std::optional<VLInfo> &CommonVL,
-                                  MachineInstr &MI) {
+bool RISCVVLOptimizer::checkUsers(
+    std::optional<const MachineOperand *> &CommonVL, MachineInstr &MI) {
   // FIXME: Avoid visiting each user for each time we visit something on the
   // worklist, combined with an extra visit from the outer loop. Restructure
   // along lines of an instcombine style worklist which integrates the outer
@@ -789,16 +740,16 @@ bool RISCVVLOptimizer::checkUsers(std::optional<VLInfo> &CommonVL,
     unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
     const MachineOperand &VLOp = UserMI.getOperand(VLOpNum);
     // Looking for a register VL that isn't X0.
-    if (!VLOp.isReg() || VLOp.getReg() == RISCV::X0) {
+    if (VLOp.isReg() && VLOp.getReg() == RISCV::X0) {
       LLVM_DEBUG(dbgs() << "    Abort due to user uses X0 as VL.\n");
       CanReduceVL = false;
       break;
     }
 
     if (!CommonVL) {
-      CommonVL = VLInfo(VLOp);
+      CommonVL = &VLOp;
       LLVM_DEBUG(dbgs() << "    User VL is: " << VLOp << "\n");
-    } else if (!CommonVL->isCompatible(VLOp)) {
+    } else if (!(*CommonVL)->isIdenticalTo(VLOp)) {
       LLVM_DEBUG(dbgs() << "    Abort because users have different VL\n");
       CanReduceVL = false;
       break;
@@ -835,7 +786,7 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
     MachineInstr &MI = *Worklist.pop_back_val();
     LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n");
 
-    std::optional<VLInfo> CommonVL;
+    std::optional<const MachineOperand *> CommonVL;
     bool CanReduceVL = true;
     if (isVectorRegClass(MI.getOperand(0).getReg(), MRI))
       CanReduceVL = checkUsers(CommonVL, MI);
@@ -843,34 +794,35 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
     if (!CanReduceVL || !CommonVL)
       continue;
 
-    if (!CommonVL->isValid()) {
-      LLVM_DEBUG(dbgs() << "    Abort due to common VL is not valid.\n");
+    const MachineOperand *CommonVLMO = *CommonVL;
+    if (!CommonVLMO->isImm() && !CommonVLMO->getReg().isVirtual()) {
+      LLVM_DEBUG(dbgs() << "    Abort because common VL is not valid.\n");
       continue;
     }
 
     unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc());
     MachineOperand &VLOp = MI.getOperand(VLOpNum);
 
-    if (!CommonVL->hasBenefit(VLOp)) {
+    if (!RISCV::isVLKnownLE(*CommonVLMO, VLOp)) {
       LLVM_DEBUG(dbgs() << "    Abort due to no benefit.\n");
       continue;
     }
 
-    if (CommonVL->isImm()) {
+    if (CommonVLMO->isImm()) {
       LLVM_DEBUG(dbgs() << "  Reduce VL from " << VLOp << " to "
-                        << CommonVL->getImm() << " for " << MI << "\n");
-      VLOp.ChangeToImmediate(CommonVL->getImm());
+                        << CommonVLMO->getImm() << " for " << MI << "\n");
+      VLOp.ChangeToImmediate(CommonVLMO->getImm());
     } else {
-      const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg());
+      const MachineInstr *VLMI = MRI->getVRegDef(CommonVLMO->getReg());
       if (!MDT->dominates(VLMI, &MI))
         continue;
       LLVM_DEBUG(
           dbgs() << "  Reduce VL from " << VLOp << " to "
-                 << printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo())
+                 << printReg(CommonVLMO->getReg(), MRI->getTargetRegisterInfo())
                  << " for " << MI << "\n");
 
       // All our checks passed. We can reduce VL.
-      VLOp.ChangeToRegister(CommonVL->getReg(), false);
+      VLOp.ChangeToRegister(CommonVLMO->getReg(), false);
     }
 
     MadeChange = true;
diff --git a/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp b/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp
index b883c50beadc09..a57bc5a3007d03 100644
--- a/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp
+++ b/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp
@@ -86,20 +86,6 @@ char RISCVVectorPeephole::ID = 0;
 INITIALIZE_PASS(RISCVVectorPeephole, DEBUG_TYPE, "RISC-V Fold Masks", false,
                 false)
 
-/// Given two VL operands, do we know that LHS <= RHS?
-static bool isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS) {
-  if (LHS.isReg() && RHS.isReg() && LHS.getReg().isVirtual() &&
-      LHS.getReg() == RHS.getReg())
-    return true;
-  if (RHS.isImm() && RHS.getImm() == RISCV::VLMaxSentinel)
-    return true;
-  if (LHS.isImm() && LHS.getImm() == RISCV::VLMaxSentinel)
-    return false;
-  if (!LHS.isImm() || !RHS.isImm())
-    return false;
-  return LHS.getImm() <= RHS.getImm();
-}
-
 /// Given \p User that has an input operand with EEW=SEW, which uses the dest
 /// operand of \p Src with an unknown EEW, return true if their EEWs match.
 bool RISCVVectorPeephole::hasSameEEW(const MachineInstr &User,
@@ -191,7 +177,7 @@ bool RISCVVectorPeephole::tryToReduceVL(MachineInstr &MI) const {
     return false;
 
   MachineOperand &SrcVL = Src->getOperand(RISCVII::getVLOpNum(Src->getDesc()));
-  if (VL.isIdenticalTo(SrcVL) || !isVLKnownLE(VL, SrcVL))
+  if (VL.isIdenticalTo(SrcVL) || !RISCV::isVLKnownLE(VL, SrcVL))
     return false;
 
   if (!ensureDominates(VL, *Src))
@@ -580,7 +566,7 @@ bool RISCVVectorPeephole::foldUndefPassthruVMV_V_V(MachineInstr &MI) {
     MachineOperand &SrcPolicy =
         Src->getOperand(RISCVII::getVecPolicyOpNum(Src->getDesc()));
 
-    if (isVLKnownLE(MIVL, SrcVL))
+    if (RISCV::isVLKnownLE(MIVL, SrcVL))
       SrcPolicy.setImm(SrcPolicy.getImm() | RISCVII::TAIL_AGNOSTIC);
   }
 
@@ -631,7 +617,7 @@ bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) {
   // so we don't need to handle a smaller source VL here.  However, the
   // user's VL may be larger
   MachineOperand &SrcVL = Src->getOperand(RISCVII::getVLOpNum(Src->getDesc()));
-  if (!isVLKnownLE(SrcVL, MI.getOperand(3)))
+  if (!RISCV::isVLKnownLE(SrcVL, MI.getOperand(3)))
     return false;
 
   // If the new passthru doesn't dominate Src, try to move Src so it does.
@@ -650,7 +636,7 @@ bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) {
   // If MI was tail agnostic and the VL didn't increase, preserve it.
   int64_t Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED;
   if ((MI.getOperand(5).getImm() & RISCVII::TAIL_AGNOSTIC) &&
-      isVLKnownLE(MI.getOperand(3), SrcVL))
+      RISCV::isVLKnownLE(MI.getOperand(3), SrcVL))
     Policy |= RISCVII::TAIL_AGNOSTIC;
   Src->getOperand(RISCVII::getVecPolicyOpNum(Src->getDesc())).setImm(Policy);
 
diff --git a/llvm/test/CodeGen/RISCV/rvv/vl-opt-op-info.ll b/llvm/test/CodeGen/RISCV/rvv/vl-opt-op-info.ll
index 6e604d200a6279..b036d43c3dd409 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vl-opt-op-info.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vl-opt-op-info.ll
@@ -9,18 +9,31 @@
 ; RUN:   -verify-machineinstrs | FileCheck %s -check-prefixes=CHECK,VLOPT
 
 define <2 x i32> @vdot_lane_s32(<2 x i32> noundef %var_1, <8 x i8> noundef %var_3, <8 x i8> noundef %var_5, <8 x i16> %x) {
-; CHECK-LABEL: vdot_lane_s32:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetivli zero, 4, e16, mf4, ta, ma
-; CHECK-NEXT:    vnsrl.wi v8, v11, 0
-; CHECK-NEXT:    vnsrl.wi v9, v11, 16
-; CHECK-NEXT:    vwadd.vv v10, v8, v9
-; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
-; CHECK-NEXT:    vnsrl.wi v8, v10, 0
-; CHECK-NEXT:    li a0, 32
-; CHECK-NEXT:    vnsrl.wx v9, v10, a0
-; CHECK-NEXT:    vadd.vv v8, v8, v9
-; CHECK-NEXT:    ret
+; NOVLOPT-LABEL: vdot_lane_s32:
+; NOVLOPT:       # %bb.0: # %entry
+; NOVLOPT-NEXT:    vsetivli zero, 4, e16, mf4, ta, ma
+; NOVLOPT-NEXT:    vnsrl.wi v8, v11, 0
+; NOVLOPT-NEXT:    vnsrl.wi v9, v11, 16
+; NOVLOPT-NEXT:    vwadd.vv v10, v8, v9
+; NOVLOPT-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; NOVLOPT-NEXT:    vnsrl.wi v8, v10, 0
+; NOVLOPT-NEXT:    li a0, 32
+; NOVLOPT-NEXT:    vnsrl.wx v9, v10, a0
+; NOVLOPT-NEXT:    vadd.vv v8, v8, v9
+; NOVLOPT-NEXT:    ret
+;
+; VLOPT-LABEL: vdot_lane_s32:
+; VLOPT:       # %bb.0: # %entry
+; VLOPT-NEXT:    vsetivli zero, 2, e16, mf4, ta, ma
+; VLOPT-NEXT:    vnsrl.wi v8, v11, 0
+; VLOPT-NEXT:    vnsrl.wi v9, v11, 16
+; VLOPT-NEXT:    vwadd.vv v10, v8, v9
+; VLOPT-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
+; VLOPT-NEXT:    vnsrl.wi v8, v10, 0
+; VLOPT-NEXT:    li a0, 32
+; VLOPT-NEXT:    vnsrl.wx v9, v10, a0
+; VLOPT-NEXT:    vadd.vv v8, v8, v9
+; VLOPT-NEXT:    ret
 entry:
   %a = shufflevector <8 x i16> %x, <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
   %b = shufflevector <8 x i16> %x, <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
@@ -88,6 +101,3 @@ entry:
   ret <vscale x 2 x i16> %x
 }
 
-;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
-; NOVLOPT: {{.*}}
-; VLOPT: {{.*}}
diff --git a/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll b/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
index e8ac4efc770484..7fb245bff5d5be 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
@@ -11,13 +11,20 @@
 declare <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32>, iXLen)
 
 define <vscale x 4 x i32> @different_imm_vl_with_ta(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
-; CHECK-LABEL: different_imm_vl_with_ta:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 5, e32, m2, ta, ma
-; CHECK-NEXT:    vadd.vv v8, v10, v12
-; CHECK-NEXT:    vsetivli zero, 4, e32, m2, ta, ma
-; CHECK-NEXT:    vadd.vv v8, v8, v10
-; CHECK-NEXT:    ret
+; NOVLOPT-LABEL: different_imm_vl_with_ta:
+; NOVLOPT:       # %bb.0:
+; NOVLOPT-NEXT:    vsetivli zero, 5, e32, m2, ta, ma
+; NOVLOPT-NEXT:    vadd.vv v8, v10, v12
+; NOVLOPT-NEXT:    vsetivli zero, 4, e32, m2, ta, ma
+; NOVLOPT-NEXT:    vadd.vv v8, v8, v10
+; NOVLOPT-NEXT:    ret
+;
+; VLOPT-LABEL: different_imm_vl_with_ta:
+; VLOPT:       # %bb.0:
+; VLOPT-NEXT:    vsetivli zero, 4, e32, m2, ta, ma
+; VLOPT-NEXT:    vadd.vv v8, v10, v12
+; VLOPT-NEXT:    vadd.vv v8, v8, v10
+; VLOPT-NEXT:    ret
   %v = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen 5)
   %w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a, iXLen 4)
   ret <vscale x 4 x i32> %w
@@ -38,20 +45,13 @@ define <vscale x 4 x i32> @different_imm_vl_with_ta_larger_vl(<vscale x 4 x i32>
 }
 
 define <vscale x 4 x i32> @different_imm_reg_vl_with_ta(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
-; NOVLOPT-LABEL: different_imm_reg_vl_with_ta:
-; NOVLOPT:       # %bb.0:
-; NOVLOPT-NEXT:    vsetivli zero, 4, e32, m2, ta, ma
-; NOVLOPT-NEXT:    vadd.vv v8, v10, v12
-; NOVLOPT-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
-; NOVLOPT-NEXT:    vadd.vv v8, v8, v10
-; NOVLOPT-NEXT:    ret
-;
-; VLOPT-LABEL: different_imm_reg_vl_with_ta:
-; VLOPT:       # %bb.0:
-; VLOPT-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
-; VLOPT-NEXT:    vadd.vv v8, v10, v12
-; VLOPT-NEXT:    vadd.vv v8, v8, v10
-; VLOPT-NEXT:    ret
+; CHECK-LABEL: different_imm_reg_vl_with_ta:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 4, e32, m2, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v10, v12
+; CHECK-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v8, v10
+; CHECK-NEXT:    ret
   %v = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen 4)
   %w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a, iXLen %vl1)
   ret <vscale x 4 x i32> %w
@@ -75,20 +75,13 @@ define <vscale x 4 x i32> @different_imm_vl_with_ta_1(<vscale x 4 x i32> %passth
 ; it's still safe even %vl2 is larger than %vl1, becuase rest of the vector are
 ; undefined value.
 define <vscale x 4 x i32> @different_vl_with_ta(<vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
-; NOVLOPT-LABEL: different_vl_with_ta:
-; NOVLOPT:       # %bb.0:
-; NOVLOPT-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
-; NOVLOPT-NEXT:    vadd.vv v10, v8, v10
-; NOVLOPT-NEXT:    vsetvli zero, a1, e32, m2, ta, ma
-; NOVLOPT-NEXT:    vadd.vv v8, v10, v8
-; NOVLOPT-NEXT:    ret
-;
-; VLOPT-LABEL: different_vl_with_ta:
-; VLOPT:       # %bb.0:
-; VLOPT-NEXT:    vsetvli zero, a1, e32, m2, ta, ma
-; VLOPT-NEXT:    vadd.vv v10, v8, v10
-; VLOPT-NEXT:    vadd.vv v8, v10, v8
-; VLOPT-NEXT:    ret
+; CHECK-LABEL: different_vl_with_ta:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
+; CHECK-NEXT:    vadd.vv v10, v8, v10
+; CHECK-NEXT:    vsetvli zero, a1, e32, m2, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v10, v8
+; CHECK-NEXT:    ret
   %v = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1)
   %w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a,iXLen %vl2)
   ret <vscale x 4 x i32> %w



More information about the llvm-commits mailing list