[llvm] [RISCV][VLOPT] Allow propagation even when VL isn't VLMAX (PR #112228)

Michael Maitland via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 16 11:33:32 PDT 2024


https://github.com/michaelmaitland updated https://github.com/llvm/llvm-project/pull/112228

>From 25b54afca9e00cae94d0312a0acc728f90a81c8b Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Mon, 14 Oct 2024 08:36:42 -0700
Subject: [PATCH 01/10] [RISCV][VLOPT] Allow propogation even when VL isn't
 VLMAX

The original goal of this pass was to focus on vector operations with VLMAX.
However, users often utilize only part of the result, and such usage may come
from the vectorizer.

We found that relaxing this constraint can capture more optimization
opportunities, such as non-power-of-2 code generation and vector operation
sequences with different VLs.t show

---------

Co-authored-by: Kito Cheng <kito.cheng at sifive.com>
---
 llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp | 111 ++++++++++++++++++---
 llvm/test/CodeGen/RISCV/rvv/vl-opt.ll      |  51 ++++++----
 2 files changed, 125 insertions(+), 37 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
index 6053899987db9b..82972f07157124 100644
--- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
+++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
@@ -31,6 +31,44 @@ using namespace llvm;
 
 namespace {
 
+struct VLInfo {
+  VLInfo(const MachineOperand &VLOp) {
+    IsImm = VLOp.isImm();
+    if (IsImm)
+      Imm = VLOp.getImm();
+    else
+      Reg = VLOp.getReg();
+  }
+
+  Register Reg;
+  int64_t Imm;
+  bool IsImm;
+
+  bool isCompatible(const MachineOperand &VLOp) const {
+    if (IsImm != VLOp.isImm())
+      return false;
+    if (IsImm)
+      return Imm == VLOp.getImm();
+    return Reg == VLOp.getReg();
+  }
+
+  bool isValid() const { return IsImm || Reg.isVirtual(); }
+
+  bool hasBenefit(const MachineOperand &VLOp) const {
+    if (IsImm && Imm == RISCV::VLMaxSentinel)
+      return false;
+
+    if (!IsImm || !VLOp.isImm())
+      return true;
+
+    if (VLOp.getImm() == RISCV::VLMaxSentinel)
+      return true;
+
+    // No benefit if the current VL is already smaller than the new one.
+    return Imm < VLOp.getImm();
+  }
+};
+
 class RISCVVLOptimizer : public MachineFunctionPass {
   const MachineRegisterInfo *MRI;
   const MachineDominatorTree *MDT;
@@ -51,7 +89,7 @@ class RISCVVLOptimizer : public MachineFunctionPass {
   StringRef getPassName() const override { return PASS_NAME; }
 
 private:
-  bool checkUsers(std::optional<Register> &CommonVL, MachineInstr &MI);
+  bool checkUsers(std::optional<VLInfo> &CommonVL, MachineInstr &MI);
   bool tryReduceVL(MachineInstr &MI);
   bool isCandidate(const MachineInstr &MI) const;
 };
@@ -660,8 +698,34 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
 
   unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
   const MachineOperand &VLOp = MI.getOperand(VLOpNum);
-  if (!VLOp.isImm() || VLOp.getImm() != RISCV::VLMaxSentinel)
+  if (((VLOp.isImm() && VLOp.getImm() != RISCV::VLMaxSentinel) ||
+       VLOp.isReg())) {
+    bool UseTAPolicy = false;
+    bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(Desc);
+    if (RISCVII::hasVecPolicyOp(Desc.TSFlags)) {
+      unsigned PolicyOpNum = RISCVII::getVecPolicyOpNum(Desc);
+      const MachineOperand &PolicyOp = MI.getOperand(PolicyOpNum);
+      uint64_t Policy = PolicyOp.getImm();
+      UseTAPolicy = (Policy & RISCVII::TAIL_AGNOSTIC) == RISCVII::TAIL_AGNOSTIC;
+      if (HasPassthru) {
+        unsigned PassthruOpIdx = MI.getNumExplicitDefs();
+        UseTAPolicy = UseTAPolicy || (MI.getOperand(PassthruOpIdx).getReg() ==
+                                      RISCV::NoRegister);
+      }
+    }
+    if (!UseTAPolicy) {
+      LLVM_DEBUG(
+          dbgs() << "  Not a candidate because it uses tail-undisturbed policy"
+                    " with non-VLMAX VL\n");
+      return false;
+    }
+  }
+
+  // If the VL is 1, then there is no need to reduce it.
+  if (VLOp.isImm() && VLOp.getImm() == 1) {
+    LLVM_DEBUG(dbgs() << "  Not a candidate because VL is already 1\n");
     return false;
+  }
 
   // Some instructions that produce vectors have semantics that make it more
   // difficult to determine whether the VL can be reduced. For example, some
@@ -684,7 +748,7 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
   return true;
 }
 
-bool RISCVVLOptimizer::checkUsers(std::optional<Register> &CommonVL,
+bool RISCVVLOptimizer::checkUsers(std::optional<VLInfo> &CommonVL,
                                   MachineInstr &MI) {
   // FIXME: Avoid visiting each user for each time we visit something on the
   // worklist, combined with an extra visit from the outer loop. Restructure
@@ -738,8 +802,9 @@ bool RISCVVLOptimizer::checkUsers(std::optional<Register> &CommonVL,
     }
 
     if (!CommonVL) {
-      CommonVL = VLOp.getReg();
-    } else if (*CommonVL != VLOp.getReg()) {
+      CommonVL = VLInfo(VLOp);
+      LLVM_DEBUG(dbgs() << "    User VL is: " << VLOp << "\n");
+    } else if (!CommonVL->isCompatible(VLOp)) {
       LLVM_DEBUG(dbgs() << "    Abort because users have different VL\n");
       CanReduceVL = false;
       break;
@@ -776,7 +841,7 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
     MachineInstr &MI = *Worklist.pop_back_val();
     LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n");
 
-    std::optional<Register> CommonVL;
+    std::optional<VLInfo> CommonVL;
     bool CanReduceVL = true;
     if (isVectorRegClass(MI.getOperand(0).getReg(), MRI))
       CanReduceVL = checkUsers(CommonVL, MI);
@@ -784,21 +849,35 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
     if (!CanReduceVL || !CommonVL)
       continue;
 
-    if (!CommonVL->isVirtual()) {
-      LLVM_DEBUG(
-          dbgs() << "    Abort due to new VL is not virtual register.\n");
+    if (!CommonVL->isValid()) {
+      LLVM_DEBUG(dbgs() << "    Abort due to common VL is not valid.\n");
       continue;
     }
 
-    const MachineInstr *VLMI = MRI->getVRegDef(*CommonVL);
-    if (!MDT->dominates(VLMI, &MI))
-      continue;
-
-    // All our checks passed. We can reduce VL.
-    LLVM_DEBUG(dbgs() << "    Reducing VL for: " << MI << "\n");
     unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc());
     MachineOperand &VLOp = MI.getOperand(VLOpNum);
-    VLOp.ChangeToRegister(*CommonVL, false);
+
+    if (!CommonVL->hasBenefit(VLOp)) {
+      LLVM_DEBUG(dbgs() << "    Abort due to no benefit.\n");
+      continue;
+    }
+
+    if (CommonVL->IsImm) {
+      LLVM_DEBUG(dbgs() << "  Reduce VL from " << VLOp << " to "
+                        << CommonVL->Imm << " for " << MI << "\n");
+      VLOp.ChangeToImmediate(CommonVL->Imm);
+    } else {
+      const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->Reg);
+      if (!MDT->dominates(VLMI, &MI))
+        continue;
+      LLVM_DEBUG(dbgs() << "  Reduce VL from " << VLOp << " to "
+                        << printReg(CommonVL->Reg, MRI->getTargetRegisterInfo())
+                        << " for " << MI << "\n");
+
+      // All our checks passed. We can reduce VL.
+      VLOp.ChangeToRegister(CommonVL->Reg, false);
+    }
+
     MadeChange = true;
 
     // Now add all inputs to this instruction to the worklist.
diff --git a/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll b/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
index 0b3e67ec895566..e8ac4efc770484 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
@@ -23,7 +23,7 @@ define <vscale x 4 x i32> @different_imm_vl_with_ta(<vscale x 4 x i32> %passthru
   ret <vscale x 4 x i32> %w
 }
 
-; No benificial to propagate VL since VL is larger in the use side.
+; Not beneficial to propagate VL since VL is larger in the use side.
 define <vscale x 4 x i32> @different_imm_vl_with_ta_larger_vl(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
 ; CHECK-LABEL: different_imm_vl_with_ta_larger_vl:
 ; CHECK:       # %bb.0:
@@ -38,20 +38,26 @@ define <vscale x 4 x i32> @different_imm_vl_with_ta_larger_vl(<vscale x 4 x i32>
 }
 
 define <vscale x 4 x i32> @different_imm_reg_vl_with_ta(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
-; CHECK-LABEL: different_imm_reg_vl_with_ta:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 4, e32, m2, ta, ma
-; CHECK-NEXT:    vadd.vv v8, v10, v12
-; CHECK-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
-; CHECK-NEXT:    vadd.vv v8, v8, v10
-; CHECK-NEXT:    ret
+; NOVLOPT-LABEL: different_imm_reg_vl_with_ta:
+; NOVLOPT:       # %bb.0:
+; NOVLOPT-NEXT:    vsetivli zero, 4, e32, m2, ta, ma
+; NOVLOPT-NEXT:    vadd.vv v8, v10, v12
+; NOVLOPT-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
+; NOVLOPT-NEXT:    vadd.vv v8, v8, v10
+; NOVLOPT-NEXT:    ret
+;
+; VLOPT-LABEL: different_imm_reg_vl_with_ta:
+; VLOPT:       # %bb.0:
+; VLOPT-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
+; VLOPT-NEXT:    vadd.vv v8, v10, v12
+; VLOPT-NEXT:    vadd.vv v8, v8, v10
+; VLOPT-NEXT:    ret
   %v = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen 4)
   %w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a, iXLen %vl1)
   ret <vscale x 4 x i32> %w
 }
 
-
-; No benificial to propagate VL since VL is already one.
+; Not beneficial to propagate VL since VL is already one.
 define <vscale x 4 x i32> @different_imm_vl_with_ta_1(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
 ; CHECK-LABEL: different_imm_vl_with_ta_1:
 ; CHECK:       # %bb.0:
@@ -69,13 +75,20 @@ define <vscale x 4 x i32> @different_imm_vl_with_ta_1(<vscale x 4 x i32> %passth
 ; it's still safe even %vl2 is larger than %vl1, becuase rest of the vector are
 ; undefined value.
 define <vscale x 4 x i32> @different_vl_with_ta(<vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
-; CHECK-LABEL: different_vl_with_ta:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
-; CHECK-NEXT:    vadd.vv v10, v8, v10
-; CHECK-NEXT:    vsetvli zero, a1, e32, m2, ta, ma
-; CHECK-NEXT:    vadd.vv v8, v10, v8
-; CHECK-NEXT:    ret
+; NOVLOPT-LABEL: different_vl_with_ta:
+; NOVLOPT:       # %bb.0:
+; NOVLOPT-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
+; NOVLOPT-NEXT:    vadd.vv v10, v8, v10
+; NOVLOPT-NEXT:    vsetvli zero, a1, e32, m2, ta, ma
+; NOVLOPT-NEXT:    vadd.vv v8, v10, v8
+; NOVLOPT-NEXT:    ret
+;
+; VLOPT-LABEL: different_vl_with_ta:
+; VLOPT:       # %bb.0:
+; VLOPT-NEXT:    vsetvli zero, a1, e32, m2, ta, ma
+; VLOPT-NEXT:    vadd.vv v10, v8, v10
+; VLOPT-NEXT:    vadd.vv v8, v10, v8
+; VLOPT-NEXT:    ret
   %v = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1)
   %w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a,iXLen %vl2)
   ret <vscale x 4 x i32> %w
@@ -110,7 +123,3 @@ define <vscale x 4 x i32> @different_imm_vl_with_tu(<vscale x 4 x i32> %passthru
   %w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a,iXLen 4)
   ret <vscale x 4 x i32> %w
 }
-
-;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
-; NOVLOPT: {{.*}}
-; VLOPT: {{.*}}

>From 76aeaa6367c3a8ab874a35373b26c6c3247f10d3 Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Tue, 15 Oct 2024 07:08:51 -0700
Subject: [PATCH 02/10] fixup! use std::variant

---
 llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp | 60 +++++++++++++---------
 1 file changed, 36 insertions(+), 24 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
index 82972f07157124..9940370b9e27f0 100644
--- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
+++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
@@ -32,40 +32,51 @@ using namespace llvm;
 namespace {
 
 struct VLInfo {
+  std::variant<Register, int64_t> VL;
+
   VLInfo(const MachineOperand &VLOp) {
-    IsImm = VLOp.isImm();
-    if (IsImm)
-      Imm = VLOp.getImm();
+    if (VLOp.isImm())
+      VL = VLOp.getImm();
     else
-      Reg = VLOp.getReg();
+      VL = VLOp.getReg();
   }
 
-  Register Reg;
-  int64_t Imm;
-  bool IsImm;
-
   bool isCompatible(const MachineOperand &VLOp) const {
-    if (IsImm != VLOp.isImm())
+    if (isImm() != VLOp.isImm())
       return false;
-    if (IsImm)
-      return Imm == VLOp.getImm();
-    return Reg == VLOp.getReg();
+    if (isImm())
+      return getImm() == VLOp.getImm();
+    return getReg() == VLOp.getReg();
+  }
+
+  bool isImm() const { return std::holds_alternative<int64_t>(VL); }
+
+  bool isReg() const { return std::holds_alternative<Register>(VL); }
+
+  bool isValid() const { return isImm() || getReg().isVirtual(); }
+
+  int64_t getImm() const {
+    assert (isImm() && "Expected VL to be an immediate");
+    return std::get<int64_t>(VL);
   }
 
-  bool isValid() const { return IsImm || Reg.isVirtual(); }
+  Register getReg() const {
+    assert (isReg() && "Expected VL to be a Register");
+    return std::get<Register>(VL);
+  }
 
   bool hasBenefit(const MachineOperand &VLOp) const {
-    if (IsImm && Imm == RISCV::VLMaxSentinel)
+    if (isImm() && getImm() == RISCV::VLMaxSentinel)
       return false;
 
-    if (!IsImm || !VLOp.isImm())
+    if (!isImm() || !VLOp.isImm())
       return true;
 
     if (VLOp.getImm() == RISCV::VLMaxSentinel)
       return true;
 
     // No benefit if the current VL is already smaller than the new one.
-    return Imm < VLOp.getImm();
+    return getImm() < VLOp.getImm();
   }
 };
 
@@ -862,20 +873,21 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
       continue;
     }
 
-    if (CommonVL->IsImm) {
+    if (CommonVL->isImm()) {
       LLVM_DEBUG(dbgs() << "  Reduce VL from " << VLOp << " to "
-                        << CommonVL->Imm << " for " << MI << "\n");
-      VLOp.ChangeToImmediate(CommonVL->Imm);
+                        << CommonVL->getImm() << " for " << MI << "\n");
+      VLOp.ChangeToImmediate(CommonVL->getImm());
     } else {
-      const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->Reg);
+      const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg());
       if (!MDT->dominates(VLMI, &MI))
         continue;
-      LLVM_DEBUG(dbgs() << "  Reduce VL from " << VLOp << " to "
-                        << printReg(CommonVL->Reg, MRI->getTargetRegisterInfo())
-                        << " for " << MI << "\n");
+      LLVM_DEBUG(
+          dbgs() << "  Reduce VL from " << VLOp << " to "
+                 << printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo())
+                 << " for " << MI << "\n");
 
       // All our checks passed. We can reduce VL.
-      VLOp.ChangeToRegister(CommonVL->Reg, false);
+      VLOp.ChangeToRegister(CommonVL->getReg(), false);
     }
 
     MadeChange = true;

>From 427ba419a2f5fa86de04663ff67f4aaf4b9d44b1 Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Tue, 15 Oct 2024 07:25:41 -0700
Subject: [PATCH 03/10] fixup! clang-format

---
 llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
index 9940370b9e27f0..b896ad6a03552c 100644
--- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
+++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
@@ -56,12 +56,12 @@ struct VLInfo {
   bool isValid() const { return isImm() || getReg().isVirtual(); }
 
   int64_t getImm() const {
-    assert (isImm() && "Expected VL to be an immediate");
+    assert(isImm() && "Expected VL to be an immediate");
     return std::get<int64_t>(VL);
   }
 
   Register getReg() const {
-    assert (isReg() && "Expected VL to be a Register");
+    assert(isReg() && "Expected VL to be a Register");
     return std::get<Register>(VL);
   }
 

>From 49f8f5ce8fdea65b1ebe272a0d01535bf380d0c6 Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Tue, 15 Oct 2024 08:15:28 -0700
Subject: [PATCH 04/10] fixup! remove VLInfo

---
 llvm/lib/Target/RISCV/RISCVInstrInfo.cpp      | 14 +++
 llvm/lib/Target/RISCV/RISCVInstrInfo.h        |  3 +
 llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp    | 85 +++++--------------
 llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp | 22 +----
 llvm/test/CodeGen/RISCV/rvv/vl-opt-op-info.ll | 37 +++++---
 llvm/test/CodeGen/RISCV/rvv/vl-opt.ll         | 63 ++++++--------
 6 files changed, 93 insertions(+), 131 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index b8539a5d1add14..3989a966edfd33 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -4102,3 +4102,17 @@ unsigned RISCV::getDestLog2EEW(const MCInstrDesc &Desc, unsigned Log2SEW) {
   assert(Scaled >= 3 && Scaled <= 6);
   return Scaled;
 }
+
+/// Given two VL operands, do we know that LHS <= RHS?
+bool RISCV::isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS) {
+  if (LHS.isReg() && RHS.isReg() && LHS.getReg().isVirtual() &&
+      LHS.getReg() == RHS.getReg())
+    return true;
+  if (RHS.isImm() && RHS.getImm() == RISCV::VLMaxSentinel)
+    return true;
+  if (LHS.isImm() && LHS.getImm() == RISCV::VLMaxSentinel)
+    return false;
+  if (!LHS.isImm() || !RHS.isImm())
+    return false;
+  return LHS.getImm() <= RHS.getImm();
+}
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
index 457db9b9860d00..c3aa367486627a 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
@@ -346,6 +346,9 @@ unsigned getDestLog2EEW(const MCInstrDesc &Desc, unsigned Log2SEW);
 // Special immediate for AVL operand of V pseudo instructions to indicate VLMax.
 static constexpr int64_t VLMaxSentinel = -1LL;
 
+/// Given two VL operands, do we know that LHS <= RHS?
+bool isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS);
+
 // Mask assignments for floating-point
 static constexpr unsigned FPMASK_Negative_Infinity = 0x001;
 static constexpr unsigned FPMASK_Negative_Normal = 0x002;
diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
index b896ad6a03552c..72ebe4d6cc42ab 100644
--- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
+++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
@@ -31,55 +31,6 @@ using namespace llvm;
 
 namespace {
 
-struct VLInfo {
-  std::variant<Register, int64_t> VL;
-
-  VLInfo(const MachineOperand &VLOp) {
-    if (VLOp.isImm())
-      VL = VLOp.getImm();
-    else
-      VL = VLOp.getReg();
-  }
-
-  bool isCompatible(const MachineOperand &VLOp) const {
-    if (isImm() != VLOp.isImm())
-      return false;
-    if (isImm())
-      return getImm() == VLOp.getImm();
-    return getReg() == VLOp.getReg();
-  }
-
-  bool isImm() const { return std::holds_alternative<int64_t>(VL); }
-
-  bool isReg() const { return std::holds_alternative<Register>(VL); }
-
-  bool isValid() const { return isImm() || getReg().isVirtual(); }
-
-  int64_t getImm() const {
-    assert(isImm() && "Expected VL to be an immediate");
-    return std::get<int64_t>(VL);
-  }
-
-  Register getReg() const {
-    assert(isReg() && "Expected VL to be a Register");
-    return std::get<Register>(VL);
-  }
-
-  bool hasBenefit(const MachineOperand &VLOp) const {
-    if (isImm() && getImm() == RISCV::VLMaxSentinel)
-      return false;
-
-    if (!isImm() || !VLOp.isImm())
-      return true;
-
-    if (VLOp.getImm() == RISCV::VLMaxSentinel)
-      return true;
-
-    // No benefit if the current VL is already smaller than the new one.
-    return getImm() < VLOp.getImm();
-  }
-};
-
 class RISCVVLOptimizer : public MachineFunctionPass {
   const MachineRegisterInfo *MRI;
   const MachineDominatorTree *MDT;
@@ -100,7 +51,8 @@ class RISCVVLOptimizer : public MachineFunctionPass {
   StringRef getPassName() const override { return PASS_NAME; }
 
 private:
-  bool checkUsers(std::optional<VLInfo> &CommonVL, MachineInstr &MI);
+  bool checkUsers(std::optional<const MachineOperand *> &CommonVL,
+                  MachineInstr &MI);
   bool tryReduceVL(MachineInstr &MI);
   bool isCandidate(const MachineInstr &MI) const;
 };
@@ -759,8 +711,8 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
   return true;
 }
 
-bool RISCVVLOptimizer::checkUsers(std::optional<VLInfo> &CommonVL,
-                                  MachineInstr &MI) {
+bool RISCVVLOptimizer::checkUsers(
+    std::optional<const MachineOperand *> &CommonVL, MachineInstr &MI) {
   // FIXME: Avoid visiting each user for each time we visit something on the
   // worklist, combined with an extra visit from the outer loop. Restructure
   // along lines of an instcombine style worklist which integrates the outer
@@ -805,17 +757,17 @@ bool RISCVVLOptimizer::checkUsers(std::optional<VLInfo> &CommonVL,
 
     unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
     const MachineOperand &VLOp = UserMI.getOperand(VLOpNum);
-    // Looking for a register VL that isn't X0.
-    if (!VLOp.isReg() || VLOp.getReg() == RISCV::X0) {
+    // Looking for an immediate or a register VL that isn't X0.
+    if (VLOp.isReg() && VLOp.getReg() == RISCV::X0) {
       LLVM_DEBUG(dbgs() << "    Abort due to user uses X0 as VL.\n");
       CanReduceVL = false;
       break;
     }
 
     if (!CommonVL) {
-      CommonVL = VLInfo(VLOp);
+      CommonVL = &VLOp;
       LLVM_DEBUG(dbgs() << "    User VL is: " << VLOp << "\n");
-    } else if (!CommonVL->isCompatible(VLOp)) {
+    } else if (!(*CommonVL)->isIdenticalTo(VLOp)) {
       LLVM_DEBUG(dbgs() << "    Abort because users have different VL\n");
       CanReduceVL = false;
       break;
@@ -852,7 +804,7 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
     MachineInstr &MI = *Worklist.pop_back_val();
     LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n");
 
-    std::optional<VLInfo> CommonVL;
+    std::optional<const MachineOperand *> CommonVL;
     bool CanReduceVL = true;
     if (isVectorRegClass(MI.getOperand(0).getReg(), MRI))
       CanReduceVL = checkUsers(CommonVL, MI);
@@ -860,34 +812,35 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
     if (!CanReduceVL || !CommonVL)
       continue;
 
-    if (!CommonVL->isValid()) {
-      LLVM_DEBUG(dbgs() << "    Abort due to common VL is not valid.\n");
+    const MachineOperand *CommonVLMO = *CommonVL;
+    if (!CommonVLMO->isImm() && !CommonVLMO->getReg().isVirtual()) {
+      LLVM_DEBUG(dbgs() << "    Abort because common VL is not valid.\n");
       continue;
     }
 
     unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc());
     MachineOperand &VLOp = MI.getOperand(VLOpNum);
 
-    if (!CommonVL->hasBenefit(VLOp)) {
+    if (!RISCV::isVLKnownLE(*CommonVLMO, VLOp)) {
       LLVM_DEBUG(dbgs() << "    Abort due to no benefit.\n");
       continue;
     }
 
-    if (CommonVL->isImm()) {
+    if (CommonVLMO->isImm()) {
       LLVM_DEBUG(dbgs() << "  Reduce VL from " << VLOp << " to "
-                        << CommonVL->getImm() << " for " << MI << "\n");
-      VLOp.ChangeToImmediate(CommonVL->getImm());
+                        << CommonVLMO->getImm() << " for " << MI << "\n");
+      VLOp.ChangeToImmediate(CommonVLMO->getImm());
     } else {
-      const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg());
+      const MachineInstr *VLMI = MRI->getVRegDef(CommonVLMO->getReg());
       if (!MDT->dominates(VLMI, &MI))
         continue;
       LLVM_DEBUG(
           dbgs() << "  Reduce VL from " << VLOp << " to "
-                 << printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo())
+                 << printReg(CommonVLMO->getReg(), MRI->getTargetRegisterInfo())
                  << " for " << MI << "\n");
 
       // All our checks passed. We can reduce VL.
-      VLOp.ChangeToRegister(CommonVL->getReg(), false);
+      VLOp.ChangeToRegister(CommonVLMO->getReg(), false);
     }
 
     MadeChange = true;
diff --git a/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp b/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp
index b883c50beadc09..a57bc5a3007d03 100644
--- a/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp
+++ b/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp
@@ -86,20 +86,6 @@ char RISCVVectorPeephole::ID = 0;
 INITIALIZE_PASS(RISCVVectorPeephole, DEBUG_TYPE, "RISC-V Fold Masks", false,
                 false)
 
-/// Given two VL operands, do we know that LHS <= RHS?
-static bool isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS) {
-  if (LHS.isReg() && RHS.isReg() && LHS.getReg().isVirtual() &&
-      LHS.getReg() == RHS.getReg())
-    return true;
-  if (RHS.isImm() && RHS.getImm() == RISCV::VLMaxSentinel)
-    return true;
-  if (LHS.isImm() && LHS.getImm() == RISCV::VLMaxSentinel)
-    return false;
-  if (!LHS.isImm() || !RHS.isImm())
-    return false;
-  return LHS.getImm() <= RHS.getImm();
-}
-
 /// Given \p User that has an input operand with EEW=SEW, which uses the dest
 /// operand of \p Src with an unknown EEW, return true if their EEWs match.
 bool RISCVVectorPeephole::hasSameEEW(const MachineInstr &User,
@@ -191,7 +177,7 @@ bool RISCVVectorPeephole::tryToReduceVL(MachineInstr &MI) const {
     return false;
 
   MachineOperand &SrcVL = Src->getOperand(RISCVII::getVLOpNum(Src->getDesc()));
-  if (VL.isIdenticalTo(SrcVL) || !isVLKnownLE(VL, SrcVL))
+  if (VL.isIdenticalTo(SrcVL) || !RISCV::isVLKnownLE(VL, SrcVL))
     return false;
 
   if (!ensureDominates(VL, *Src))
@@ -580,7 +566,7 @@ bool RISCVVectorPeephole::foldUndefPassthruVMV_V_V(MachineInstr &MI) {
     MachineOperand &SrcPolicy =
         Src->getOperand(RISCVII::getVecPolicyOpNum(Src->getDesc()));
 
-    if (isVLKnownLE(MIVL, SrcVL))
+    if (RISCV::isVLKnownLE(MIVL, SrcVL))
       SrcPolicy.setImm(SrcPolicy.getImm() | RISCVII::TAIL_AGNOSTIC);
   }
 
@@ -631,7 +617,7 @@ bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) {
   // so we don't need to handle a smaller source VL here.  However, the
   // user's VL may be larger
   MachineOperand &SrcVL = Src->getOperand(RISCVII::getVLOpNum(Src->getDesc()));
-  if (!isVLKnownLE(SrcVL, MI.getOperand(3)))
+  if (!RISCV::isVLKnownLE(SrcVL, MI.getOperand(3)))
     return false;
 
   // If the new passthru doesn't dominate Src, try to move Src so it does.
@@ -650,7 +636,7 @@ bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) {
   // If MI was tail agnostic and the VL didn't increase, preserve it.
   int64_t Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED;
   if ((MI.getOperand(5).getImm() & RISCVII::TAIL_AGNOSTIC) &&
-      isVLKnownLE(MI.getOperand(3), SrcVL))
+      RISCV::isVLKnownLE(MI.getOperand(3), SrcVL))
     Policy |= RISCVII::TAIL_AGNOSTIC;
   Src->getOperand(RISCVII::getVecPolicyOpNum(Src->getDesc())).setImm(Policy);
 
diff --git a/llvm/test/CodeGen/RISCV/rvv/vl-opt-op-info.ll b/llvm/test/CodeGen/RISCV/rvv/vl-opt-op-info.ll
index 1a01a9bf77cff5..57a05d60679d3c 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vl-opt-op-info.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vl-opt-op-info.ll
@@ -9,18 +9,31 @@
 ; RUN:   -verify-machineinstrs | FileCheck %s -check-prefixes=CHECK,VLOPT
 
 define <2 x i32> @vdot_lane_s32(<2 x i32> noundef %var_1, <8 x i8> noundef %var_3, <8 x i8> noundef %var_5, <8 x i16> %x) {
-; CHECK-LABEL: vdot_lane_s32:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetivli zero, 4, e16, mf4, ta, ma
-; CHECK-NEXT:    vnsrl.wi v8, v11, 0
-; CHECK-NEXT:    vnsrl.wi v9, v11, 16
-; CHECK-NEXT:    vwadd.vv v10, v8, v9
-; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
-; CHECK-NEXT:    vnsrl.wi v8, v10, 0
-; CHECK-NEXT:    li a0, 32
-; CHECK-NEXT:    vnsrl.wx v9, v10, a0
-; CHECK-NEXT:    vadd.vv v8, v8, v9
-; CHECK-NEXT:    ret
+; NOVLOPT-LABEL: vdot_lane_s32:
+; NOVLOPT:       # %bb.0: # %entry
+; NOVLOPT-NEXT:    vsetivli zero, 4, e16, mf4, ta, ma
+; NOVLOPT-NEXT:    vnsrl.wi v8, v11, 0
+; NOVLOPT-NEXT:    vnsrl.wi v9, v11, 16
+; NOVLOPT-NEXT:    vwadd.vv v10, v8, v9
+; NOVLOPT-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; NOVLOPT-NEXT:    vnsrl.wi v8, v10, 0
+; NOVLOPT-NEXT:    li a0, 32
+; NOVLOPT-NEXT:    vnsrl.wx v9, v10, a0
+; NOVLOPT-NEXT:    vadd.vv v8, v8, v9
+; NOVLOPT-NEXT:    ret
+;
+; VLOPT-LABEL: vdot_lane_s32:
+; VLOPT:       # %bb.0: # %entry
+; VLOPT-NEXT:    vsetivli zero, 2, e16, mf4, ta, ma
+; VLOPT-NEXT:    vnsrl.wi v8, v11, 0
+; VLOPT-NEXT:    vnsrl.wi v9, v11, 16
+; VLOPT-NEXT:    vwadd.vv v10, v8, v9
+; VLOPT-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
+; VLOPT-NEXT:    vnsrl.wi v8, v10, 0
+; VLOPT-NEXT:    li a0, 32
+; VLOPT-NEXT:    vnsrl.wx v9, v10, a0
+; VLOPT-NEXT:    vadd.vv v8, v8, v9
+; VLOPT-NEXT:    ret
 entry:
   %a = shufflevector <8 x i16> %x, <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
   %b = shufflevector <8 x i16> %x, <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
diff --git a/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll b/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
index e8ac4efc770484..7fb245bff5d5be 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
@@ -11,13 +11,20 @@
 declare <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32>, iXLen)
 
 define <vscale x 4 x i32> @different_imm_vl_with_ta(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
-; CHECK-LABEL: different_imm_vl_with_ta:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 5, e32, m2, ta, ma
-; CHECK-NEXT:    vadd.vv v8, v10, v12
-; CHECK-NEXT:    vsetivli zero, 4, e32, m2, ta, ma
-; CHECK-NEXT:    vadd.vv v8, v8, v10
-; CHECK-NEXT:    ret
+; NOVLOPT-LABEL: different_imm_vl_with_ta:
+; NOVLOPT:       # %bb.0:
+; NOVLOPT-NEXT:    vsetivli zero, 5, e32, m2, ta, ma
+; NOVLOPT-NEXT:    vadd.vv v8, v10, v12
+; NOVLOPT-NEXT:    vsetivli zero, 4, e32, m2, ta, ma
+; NOVLOPT-NEXT:    vadd.vv v8, v8, v10
+; NOVLOPT-NEXT:    ret
+;
+; VLOPT-LABEL: different_imm_vl_with_ta:
+; VLOPT:       # %bb.0:
+; VLOPT-NEXT:    vsetivli zero, 4, e32, m2, ta, ma
+; VLOPT-NEXT:    vadd.vv v8, v10, v12
+; VLOPT-NEXT:    vadd.vv v8, v8, v10
+; VLOPT-NEXT:    ret
   %v = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen 5)
   %w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a, iXLen 4)
   ret <vscale x 4 x i32> %w
@@ -38,20 +45,13 @@ define <vscale x 4 x i32> @different_imm_vl_with_ta_larger_vl(<vscale x 4 x i32>
 }
 
 define <vscale x 4 x i32> @different_imm_reg_vl_with_ta(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
-; NOVLOPT-LABEL: different_imm_reg_vl_with_ta:
-; NOVLOPT:       # %bb.0:
-; NOVLOPT-NEXT:    vsetivli zero, 4, e32, m2, ta, ma
-; NOVLOPT-NEXT:    vadd.vv v8, v10, v12
-; NOVLOPT-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
-; NOVLOPT-NEXT:    vadd.vv v8, v8, v10
-; NOVLOPT-NEXT:    ret
-;
-; VLOPT-LABEL: different_imm_reg_vl_with_ta:
-; VLOPT:       # %bb.0:
-; VLOPT-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
-; VLOPT-NEXT:    vadd.vv v8, v10, v12
-; VLOPT-NEXT:    vadd.vv v8, v8, v10
-; VLOPT-NEXT:    ret
+; CHECK-LABEL: different_imm_reg_vl_with_ta:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 4, e32, m2, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v10, v12
+; CHECK-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v8, v10
+; CHECK-NEXT:    ret
   %v = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen 4)
   %w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a, iXLen %vl1)
   ret <vscale x 4 x i32> %w
@@ -75,20 +75,13 @@ define <vscale x 4 x i32> @different_imm_vl_with_ta_1(<vscale x 4 x i32> %passth
 ; it's still safe even %vl2 is larger than %vl1, becuase rest of the vector are
 ; undefined value.
 define <vscale x 4 x i32> @different_vl_with_ta(<vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
-; NOVLOPT-LABEL: different_vl_with_ta:
-; NOVLOPT:       # %bb.0:
-; NOVLOPT-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
-; NOVLOPT-NEXT:    vadd.vv v10, v8, v10
-; NOVLOPT-NEXT:    vsetvli zero, a1, e32, m2, ta, ma
-; NOVLOPT-NEXT:    vadd.vv v8, v10, v8
-; NOVLOPT-NEXT:    ret
-;
-; VLOPT-LABEL: different_vl_with_ta:
-; VLOPT:       # %bb.0:
-; VLOPT-NEXT:    vsetvli zero, a1, e32, m2, ta, ma
-; VLOPT-NEXT:    vadd.vv v10, v8, v10
-; VLOPT-NEXT:    vadd.vv v8, v10, v8
-; VLOPT-NEXT:    ret
+; CHECK-LABEL: different_vl_with_ta:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
+; CHECK-NEXT:    vadd.vv v10, v8, v10
+; CHECK-NEXT:    vsetvli zero, a1, e32, m2, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v10, v8
+; CHECK-NEXT:    ret
   %v = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1)
   %w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a,iXLen %vl2)
   ret <vscale x 4 x i32> %w

>From 0bb4f31e8335eb41f7742d60b67f7870a8967c01 Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Tue, 15 Oct 2024 09:55:46 -0700
Subject: [PATCH 05/10] fixup! respond to luke's review

---
 llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp | 42 ++++++++++------------
 1 file changed, 18 insertions(+), 24 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
index 72ebe4d6cc42ab..98720acb57d4e4 100644
--- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
+++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
@@ -51,8 +51,7 @@ class RISCVVLOptimizer : public MachineFunctionPass {
   StringRef getPassName() const override { return PASS_NAME; }
 
 private:
-  bool checkUsers(std::optional<const MachineOperand *> &CommonVL,
-                  MachineInstr &MI);
+  bool checkUsers(const MachineOperand *&CommonVL, MachineInstr &MI);
   bool tryReduceVL(MachineInstr &MI);
   bool isCandidate(const MachineInstr &MI) const;
 };
@@ -669,7 +668,7 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
       unsigned PolicyOpNum = RISCVII::getVecPolicyOpNum(Desc);
       const MachineOperand &PolicyOp = MI.getOperand(PolicyOpNum);
       uint64_t Policy = PolicyOp.getImm();
-      UseTAPolicy = (Policy & RISCVII::TAIL_AGNOSTIC) == RISCVII::TAIL_AGNOSTIC;
+      UseTAPolicy = Policy & RISCVII::TAIL_AGNOSTIC;
       if (HasPassthru) {
         unsigned PassthruOpIdx = MI.getNumExplicitDefs();
         UseTAPolicy = UseTAPolicy || (MI.getOperand(PassthruOpIdx).getReg() ==
@@ -711,8 +710,8 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
   return true;
 }
 
-bool RISCVVLOptimizer::checkUsers(
-    std::optional<const MachineOperand *> &CommonVL, MachineInstr &MI) {
+bool RISCVVLOptimizer::checkUsers(const MachineOperand *&CommonVL,
+                                  MachineInstr &MI) {
   // FIXME: Avoid visiting each user for each time we visit something on the
   // worklist, combined with an extra visit from the outer loop. Restructure
   // along lines of an instcombine style worklist which integrates the outer
@@ -757,17 +756,15 @@ bool RISCVVLOptimizer::checkUsers(
 
     unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
     const MachineOperand &VLOp = UserMI.getOperand(VLOpNum);
+
     // Looking for an immediate or a register VL that isn't X0.
-    if (VLOp.isReg() && VLOp.getReg() == RISCV::X0) {
-      LLVM_DEBUG(dbgs() << "    Abort due to user uses X0 as VL.\n");
-      CanReduceVL = false;
-      break;
-    }
+    assert(!VLOp.isReg() ||
+           VLOp.getReg() != RISCV::X0 && "Did not expect X0 VL");
 
     if (!CommonVL) {
       CommonVL = &VLOp;
       LLVM_DEBUG(dbgs() << "    User VL is: " << VLOp << "\n");
-    } else if (!(*CommonVL)->isIdenticalTo(VLOp)) {
+    } else if (!CommonVL->isIdenticalTo(VLOp)) {
       LLVM_DEBUG(dbgs() << "    Abort because users have different VL\n");
       CanReduceVL = false;
       break;
@@ -804,7 +801,7 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
     MachineInstr &MI = *Worklist.pop_back_val();
     LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n");
 
-    std::optional<const MachineOperand *> CommonVL;
+    const MachineOperand *CommonVL = nullptr;
     bool CanReduceVL = true;
     if (isVectorRegClass(MI.getOperand(0).getReg(), MRI))
       CanReduceVL = checkUsers(CommonVL, MI);
@@ -812,35 +809,32 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
     if (!CanReduceVL || !CommonVL)
       continue;
 
-    const MachineOperand *CommonVLMO = *CommonVL;
-    if (!CommonVLMO->isImm() && !CommonVLMO->getReg().isVirtual()) {
-      LLVM_DEBUG(dbgs() << "    Abort because common VL is not valid.\n");
-      continue;
-    }
+    assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) &&
+           "Expected VL to be an Imm or virtual Reg");
 
     unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc());
     MachineOperand &VLOp = MI.getOperand(VLOpNum);
 
-    if (!RISCV::isVLKnownLE(*CommonVLMO, VLOp)) {
+    if (!RISCV::isVLKnownLE(*CommonVL, VLOp)) {
       LLVM_DEBUG(dbgs() << "    Abort due to no benefit.\n");
       continue;
     }
 
-    if (CommonVLMO->isImm()) {
+    if (CommonVL->isImm()) {
       LLVM_DEBUG(dbgs() << "  Reduce VL from " << VLOp << " to "
-                        << CommonVLMO->getImm() << " for " << MI << "\n");
-      VLOp.ChangeToImmediate(CommonVLMO->getImm());
+                        << CommonVL->getImm() << " for " << MI << "\n");
+      VLOp.ChangeToImmediate(CommonVL->getImm());
     } else {
-      const MachineInstr *VLMI = MRI->getVRegDef(CommonVLMO->getReg());
+      const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg());
       if (!MDT->dominates(VLMI, &MI))
         continue;
       LLVM_DEBUG(
           dbgs() << "  Reduce VL from " << VLOp << " to "
-                 << printReg(CommonVLMO->getReg(), MRI->getTargetRegisterInfo())
+                 << printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo())
                  << " for " << MI << "\n");
 
       // All our checks passed. We can reduce VL.
-      VLOp.ChangeToRegister(CommonVLMO->getReg(), false);
+      VLOp.ChangeToRegister(CommonVL->getReg(), false);
     }
 
     MadeChange = true;

>From 1f522cf76ccd1b89ecdf218124b111a287372113 Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Tue, 15 Oct 2024 09:58:55 -0700
Subject: [PATCH 06/10] fixup! add vlmax and imm test case

---
 llvm/test/CodeGen/RISCV/rvv/vl-opt.ll | 20 ++++++++++++++++++++
 1 file changed, 20 insertions(+)

diff --git a/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll b/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
index 7fb245bff5d5be..1a1472fcfc66f5 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
@@ -30,6 +30,26 @@ define <vscale x 4 x i32> @different_imm_vl_with_ta(<vscale x 4 x i32> %passthru
   ret <vscale x 4 x i32> %w
 }
 
+define <vscale x 4 x i32> @vlmax_and_imm_vl_with_ta(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
+; NOVLOPT-LABEL: vlmax_and_imm_vl_with_ta:
+; NOVLOPT:       # %bb.0:
+; NOVLOPT-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; NOVLOPT-NEXT:    vadd.vv v8, v10, v12
+; NOVLOPT-NEXT:    vsetivli zero, 4, e32, m2, ta, ma
+; NOVLOPT-NEXT:    vadd.vv v8, v8, v10
+; NOVLOPT-NEXT:    ret
+;
+; VLOPT-LABEL: vlmax_and_imm_vl_with_ta:
+; VLOPT:       # %bb.0:
+; VLOPT-NEXT:    vsetivli zero, 4, e32, m2, ta, ma
+; VLOPT-NEXT:    vadd.vv v8, v10, v12
+; VLOPT-NEXT:    vadd.vv v8, v8, v10
+; VLOPT-NEXT:    ret
+  %v = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen -1)
+  %w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a, iXLen 4)
+  ret <vscale x 4 x i32> %w
+}
+
 ; Not beneficial to propagate VL since VL is larger in the use side.
 define <vscale x 4 x i32> @different_imm_vl_with_ta_larger_vl(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
 ; CHECK-LABEL: different_imm_vl_with_ta_larger_vl:

>From 55a46c24be7708fe4240c7896f1e0d2a12fb821d Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Wed, 16 Oct 2024 07:09:20 -0700
Subject: [PATCH 07/10] fixup! respond to review

---
 llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp | 29 ++++++++++------------
 1 file changed, 13 insertions(+), 16 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
index 98720acb57d4e4..979f39c8721694 100644
--- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
+++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
@@ -658,26 +658,23 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
   if (MI.getNumDefs() != 1)
     return false;
 
+  // If we're not using VLMAX, then we need to be careful whether we are using
+  // TA/TU when there is a non-undef Passthru. But when we are using VLMAX, it
+  // does not matter whether we are using TA/TU with a non-undef Passthru, since
+  // there are no tail elements to be perserved.
   unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
   const MachineOperand &VLOp = MI.getOperand(VLOpNum);
-  if (((VLOp.isImm() && VLOp.getImm() != RISCV::VLMaxSentinel) ||
-       VLOp.isReg())) {
-    bool UseTAPolicy = false;
+  if (VLOp.isReg() || VLOp.getImm() != RISCV::VLMaxSentinel) {
+    // If MI has a non-undef passthru, we will not try to optimize it since
+    // that requires us to preserve tail elements according to TA/TU.
+    // Otherwise, The MI has an undef Passthru, so it doesn't matter whether we
+    // are using TA/TU.
     bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(Desc);
-    if (RISCVII::hasVecPolicyOp(Desc.TSFlags)) {
-      unsigned PolicyOpNum = RISCVII::getVecPolicyOpNum(Desc);
-      const MachineOperand &PolicyOp = MI.getOperand(PolicyOpNum);
-      uint64_t Policy = PolicyOp.getImm();
-      UseTAPolicy = Policy & RISCVII::TAIL_AGNOSTIC;
-      if (HasPassthru) {
-        unsigned PassthruOpIdx = MI.getNumExplicitDefs();
-        UseTAPolicy = UseTAPolicy || (MI.getOperand(PassthruOpIdx).getReg() ==
-                                      RISCV::NoRegister);
-      }
-    }
-    if (!UseTAPolicy) {
+    unsigned PassthruOpIdx = MI.getNumExplicitDefs();
+    if (HasPassthru &&
+        MI.getOperand(PassthruOpIdx).getReg() != RISCV::NoRegister) {
       LLVM_DEBUG(
-          dbgs() << "  Not a candidate because it uses tail-undisturbed policy"
+          dbgs() << "  Not a candidate because it uses non-undef passthru"
                     " with non-VLMAX VL\n");
       return false;
     }

>From 715e116256a0f937918a5af0b1b1ca7d44f855c4 Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Wed, 16 Oct 2024 08:39:56 -0700
Subject: [PATCH 08/10] fixup! update comment and debug for clarity

---
 llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
index 979f39c8721694..76245d30b18039 100644
--- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
+++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
@@ -680,7 +680,8 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
     }
   }
 
-  // If the VL is 1, then there is no need to reduce it.
+  // If the VL is 1, then there is no need to reduce it. This is an
+  // optimization, not needed to preserve correctness.
   if (VLOp.isImm() && VLOp.getImm() == 1) {
     LLVM_DEBUG(dbgs() << "  Not a candidate because VL is already 1\n");
     return false;
@@ -762,6 +763,8 @@ bool RISCVVLOptimizer::checkUsers(const MachineOperand *&CommonVL,
       CommonVL = &VLOp;
       LLVM_DEBUG(dbgs() << "    User VL is: " << VLOp << "\n");
     } else if (!CommonVL->isIdenticalTo(VLOp)) {
+      // FIXME: This check requires all users to have the same VL. We can relax
+      // this and get the largest VL amongst all users.
       LLVM_DEBUG(dbgs() << "    Abort because users have different VL\n");
       CanReduceVL = false;
       break;
@@ -813,7 +816,7 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
     MachineOperand &VLOp = MI.getOperand(VLOpNum);
 
     if (!RISCV::isVLKnownLE(*CommonVL, VLOp)) {
-      LLVM_DEBUG(dbgs() << "    Abort due to no benefit.\n");
+      LLVM_DEBUG(dbgs() << "    Abort due to no CommonVL not <= VLOp.\n");
       continue;
     }
 

>From 5a6976baf48dabffee0d18f0609979c4f65fde34 Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Wed, 16 Oct 2024 10:56:47 -0700
Subject: [PATCH 09/10] fixup! remove accidental double negative

---
 llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
index 76245d30b18039..ee494c46815112 100644
--- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
+++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
@@ -816,7 +816,7 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
     MachineOperand &VLOp = MI.getOperand(VLOpNum);
 
     if (!RISCV::isVLKnownLE(*CommonVL, VLOp)) {
-      LLVM_DEBUG(dbgs() << "    Abort due to no CommonVL not <= VLOp.\n");
+      LLVM_DEBUG(dbgs() << "    Abort due to CommonVL not <= VLOp.\n");
       continue;
     }
 

>From b3a33c688eff117eee40d262a842d9695a215c37 Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Wed, 16 Oct 2024 11:33:01 -0700
Subject: [PATCH 10/10] fixup! fix test case after rebase

---
 llvm/test/CodeGen/RISCV/rvv/vl-opt-op-info.ll | 37 ++++++-------------
 1 file changed, 12 insertions(+), 25 deletions(-)

diff --git a/llvm/test/CodeGen/RISCV/rvv/vl-opt-op-info.ll b/llvm/test/CodeGen/RISCV/rvv/vl-opt-op-info.ll
index 57a05d60679d3c..1a01a9bf77cff5 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vl-opt-op-info.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vl-opt-op-info.ll
@@ -9,31 +9,18 @@
 ; RUN:   -verify-machineinstrs | FileCheck %s -check-prefixes=CHECK,VLOPT
 
 define <2 x i32> @vdot_lane_s32(<2 x i32> noundef %var_1, <8 x i8> noundef %var_3, <8 x i8> noundef %var_5, <8 x i16> %x) {
-; NOVLOPT-LABEL: vdot_lane_s32:
-; NOVLOPT:       # %bb.0: # %entry
-; NOVLOPT-NEXT:    vsetivli zero, 4, e16, mf4, ta, ma
-; NOVLOPT-NEXT:    vnsrl.wi v8, v11, 0
-; NOVLOPT-NEXT:    vnsrl.wi v9, v11, 16
-; NOVLOPT-NEXT:    vwadd.vv v10, v8, v9
-; NOVLOPT-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
-; NOVLOPT-NEXT:    vnsrl.wi v8, v10, 0
-; NOVLOPT-NEXT:    li a0, 32
-; NOVLOPT-NEXT:    vnsrl.wx v9, v10, a0
-; NOVLOPT-NEXT:    vadd.vv v8, v8, v9
-; NOVLOPT-NEXT:    ret
-;
-; VLOPT-LABEL: vdot_lane_s32:
-; VLOPT:       # %bb.0: # %entry
-; VLOPT-NEXT:    vsetivli zero, 2, e16, mf4, ta, ma
-; VLOPT-NEXT:    vnsrl.wi v8, v11, 0
-; VLOPT-NEXT:    vnsrl.wi v9, v11, 16
-; VLOPT-NEXT:    vwadd.vv v10, v8, v9
-; VLOPT-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
-; VLOPT-NEXT:    vnsrl.wi v8, v10, 0
-; VLOPT-NEXT:    li a0, 32
-; VLOPT-NEXT:    vnsrl.wx v9, v10, a0
-; VLOPT-NEXT:    vadd.vv v8, v8, v9
-; VLOPT-NEXT:    ret
+; CHECK-LABEL: vdot_lane_s32:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetivli zero, 4, e16, mf4, ta, ma
+; CHECK-NEXT:    vnsrl.wi v8, v11, 0
+; CHECK-NEXT:    vnsrl.wi v9, v11, 16
+; CHECK-NEXT:    vwadd.vv v10, v8, v9
+; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vnsrl.wi v8, v10, 0
+; CHECK-NEXT:    li a0, 32
+; CHECK-NEXT:    vnsrl.wx v9, v10, a0
+; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    ret
 entry:
   %a = shufflevector <8 x i16> %x, <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
   %b = shufflevector <8 x i16> %x, <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>



More information about the llvm-commits mailing list