[llvm] [RISCV][VLOPT] Allow propagation even when VL isn't VLMAX (PR #112228)

Michael Maitland via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 15 07:09:11 PDT 2024


https://github.com/michaelmaitland updated https://github.com/llvm/llvm-project/pull/112228

>From 6b5faa1cd1e92a6450654e120740385311824341 Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Mon, 14 Oct 2024 08:36:42 -0700
Subject: [PATCH 1/2] [RISCV][VLOPT] Allow propogation even when VL isn't VLMAX

The original goal of this pass was to focus on vector operations with VLMAX.
However, users often utilize only part of the result, and such usage may come
from the vectorizer.

We found that relaxing this constraint can capture more optimization
opportunities, such as non-power-of-2 code generation and vector operation
sequences with different VLs.t show

---------

Co-authored-by: Kito Cheng <kito.cheng at sifive.com>
---
 llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp | 111 ++++++++++++++++++---
 llvm/test/CodeGen/RISCV/rvv/vl-opt.ll      |  51 ++++++----
 2 files changed, 125 insertions(+), 37 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
index eb1f4df4ff7264..2c87826ab7883c 100644
--- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
+++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
@@ -31,6 +31,44 @@ using namespace llvm;
 
 namespace {
 
+struct VLInfo {
+  VLInfo(const MachineOperand &VLOp) {
+    IsImm = VLOp.isImm();
+    if (IsImm)
+      Imm = VLOp.getImm();
+    else
+      Reg = VLOp.getReg();
+  }
+
+  Register Reg;
+  int64_t Imm;
+  bool IsImm;
+
+  bool isCompatible(const MachineOperand &VLOp) const {
+    if (IsImm != VLOp.isImm())
+      return false;
+    if (IsImm)
+      return Imm == VLOp.getImm();
+    return Reg == VLOp.getReg();
+  }
+
+  bool isValid() const { return IsImm || Reg.isVirtual(); }
+
+  bool hasBenefit(const MachineOperand &VLOp) const {
+    if (IsImm && Imm == RISCV::VLMaxSentinel)
+      return false;
+
+    if (!IsImm || !VLOp.isImm())
+      return true;
+
+    if (VLOp.getImm() == RISCV::VLMaxSentinel)
+      return true;
+
+    // No benefit if the current VL is already smaller than the new one.
+    return Imm < VLOp.getImm();
+  }
+};
+
 class RISCVVLOptimizer : public MachineFunctionPass {
   const MachineRegisterInfo *MRI;
   const MachineDominatorTree *MDT;
@@ -51,7 +89,7 @@ class RISCVVLOptimizer : public MachineFunctionPass {
   StringRef getPassName() const override { return PASS_NAME; }
 
 private:
-  bool checkUsers(std::optional<Register> &CommonVL, MachineInstr &MI);
+  bool checkUsers(std::optional<VLInfo> &CommonVL, MachineInstr &MI);
   bool tryReduceVL(MachineInstr &MI);
   bool isCandidate(const MachineInstr &MI) const;
 };
@@ -643,8 +681,34 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
 
   unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
   const MachineOperand &VLOp = MI.getOperand(VLOpNum);
-  if (!VLOp.isImm() || VLOp.getImm() != RISCV::VLMaxSentinel)
+  if (((VLOp.isImm() && VLOp.getImm() != RISCV::VLMaxSentinel) ||
+       VLOp.isReg())) {
+    bool UseTAPolicy = false;
+    bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(Desc);
+    if (RISCVII::hasVecPolicyOp(Desc.TSFlags)) {
+      unsigned PolicyOpNum = RISCVII::getVecPolicyOpNum(Desc);
+      const MachineOperand &PolicyOp = MI.getOperand(PolicyOpNum);
+      uint64_t Policy = PolicyOp.getImm();
+      UseTAPolicy = (Policy & RISCVII::TAIL_AGNOSTIC) == RISCVII::TAIL_AGNOSTIC;
+      if (HasPassthru) {
+        unsigned PassthruOpIdx = MI.getNumExplicitDefs();
+        UseTAPolicy = UseTAPolicy || (MI.getOperand(PassthruOpIdx).getReg() ==
+                                      RISCV::NoRegister);
+      }
+    }
+    if (!UseTAPolicy) {
+      LLVM_DEBUG(
+          dbgs() << "  Not a candidate because it uses tail-undisturbed policy"
+                    " with non-VLMAX VL\n");
+      return false;
+    }
+  }
+
+  // If the VL is 1, then there is no need to reduce it.
+  if (VLOp.isImm() && VLOp.getImm() == 1) {
+    LLVM_DEBUG(dbgs() << "  Not a candidate because VL is already 1\n");
     return false;
+  }
 
   // Some instructions that produce vectors have semantics that make it more
   // difficult to determine whether the VL can be reduced. For example, some
@@ -667,7 +731,7 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
   return true;
 }
 
-bool RISCVVLOptimizer::checkUsers(std::optional<Register> &CommonVL,
+bool RISCVVLOptimizer::checkUsers(std::optional<VLInfo> &CommonVL,
                                   MachineInstr &MI) {
   // FIXME: Avoid visiting each user for each time we visit something on the
   // worklist, combined with an extra visit from the outer loop. Restructure
@@ -721,8 +785,9 @@ bool RISCVVLOptimizer::checkUsers(std::optional<Register> &CommonVL,
     }
 
     if (!CommonVL) {
-      CommonVL = VLOp.getReg();
-    } else if (*CommonVL != VLOp.getReg()) {
+      CommonVL = VLInfo(VLOp);
+      LLVM_DEBUG(dbgs() << "    User VL is: " << VLOp << "\n");
+    } else if (!CommonVL->isCompatible(VLOp)) {
       LLVM_DEBUG(dbgs() << "    Abort because users have different VL\n");
       CanReduceVL = false;
       break;
@@ -759,7 +824,7 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
     MachineInstr &MI = *Worklist.pop_back_val();
     LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n");
 
-    std::optional<Register> CommonVL;
+    std::optional<VLInfo> CommonVL;
     bool CanReduceVL = true;
     if (isVectorRegClass(MI.getOperand(0).getReg(), MRI))
       CanReduceVL = checkUsers(CommonVL, MI);
@@ -767,21 +832,35 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
     if (!CanReduceVL || !CommonVL)
       continue;
 
-    if (!CommonVL->isVirtual()) {
-      LLVM_DEBUG(
-          dbgs() << "    Abort due to new VL is not virtual register.\n");
+    if (!CommonVL->isValid()) {
+      LLVM_DEBUG(dbgs() << "    Abort due to common VL is not valid.\n");
       continue;
     }
 
-    const MachineInstr *VLMI = MRI->getVRegDef(*CommonVL);
-    if (!MDT->dominates(VLMI, &MI))
-      continue;
-
-    // All our checks passed. We can reduce VL.
-    LLVM_DEBUG(dbgs() << "    Reducing VL for: " << MI << "\n");
     unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc());
     MachineOperand &VLOp = MI.getOperand(VLOpNum);
-    VLOp.ChangeToRegister(*CommonVL, false);
+
+    if (!CommonVL->hasBenefit(VLOp)) {
+      LLVM_DEBUG(dbgs() << "    Abort due to no benefit.\n");
+      continue;
+    }
+
+    if (CommonVL->IsImm) {
+      LLVM_DEBUG(dbgs() << "  Reduce VL from " << VLOp << " to "
+                        << CommonVL->Imm << " for " << MI << "\n");
+      VLOp.ChangeToImmediate(CommonVL->Imm);
+    } else {
+      const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->Reg);
+      if (!MDT->dominates(VLMI, &MI))
+        continue;
+      LLVM_DEBUG(dbgs() << "  Reduce VL from " << VLOp << " to "
+                        << printReg(CommonVL->Reg, MRI->getTargetRegisterInfo())
+                        << " for " << MI << "\n");
+
+      // All our checks passed. We can reduce VL.
+      VLOp.ChangeToRegister(CommonVL->Reg, false);
+    }
+
     MadeChange = true;
 
     // Now add all inputs to this instruction to the worklist.
diff --git a/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll b/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
index 0b3e67ec895566..e8ac4efc770484 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
@@ -23,7 +23,7 @@ define <vscale x 4 x i32> @different_imm_vl_with_ta(<vscale x 4 x i32> %passthru
   ret <vscale x 4 x i32> %w
 }
 
-; No benificial to propagate VL since VL is larger in the use side.
+; Not beneficial to propagate VL since VL is larger in the use side.
 define <vscale x 4 x i32> @different_imm_vl_with_ta_larger_vl(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
 ; CHECK-LABEL: different_imm_vl_with_ta_larger_vl:
 ; CHECK:       # %bb.0:
@@ -38,20 +38,26 @@ define <vscale x 4 x i32> @different_imm_vl_with_ta_larger_vl(<vscale x 4 x i32>
 }
 
 define <vscale x 4 x i32> @different_imm_reg_vl_with_ta(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
-; CHECK-LABEL: different_imm_reg_vl_with_ta:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 4, e32, m2, ta, ma
-; CHECK-NEXT:    vadd.vv v8, v10, v12
-; CHECK-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
-; CHECK-NEXT:    vadd.vv v8, v8, v10
-; CHECK-NEXT:    ret
+; NOVLOPT-LABEL: different_imm_reg_vl_with_ta:
+; NOVLOPT:       # %bb.0:
+; NOVLOPT-NEXT:    vsetivli zero, 4, e32, m2, ta, ma
+; NOVLOPT-NEXT:    vadd.vv v8, v10, v12
+; NOVLOPT-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
+; NOVLOPT-NEXT:    vadd.vv v8, v8, v10
+; NOVLOPT-NEXT:    ret
+;
+; VLOPT-LABEL: different_imm_reg_vl_with_ta:
+; VLOPT:       # %bb.0:
+; VLOPT-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
+; VLOPT-NEXT:    vadd.vv v8, v10, v12
+; VLOPT-NEXT:    vadd.vv v8, v8, v10
+; VLOPT-NEXT:    ret
   %v = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen 4)
   %w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a, iXLen %vl1)
   ret <vscale x 4 x i32> %w
 }
 
-
-; No benificial to propagate VL since VL is already one.
+; Not beneficial to propagate VL since VL is already one.
 define <vscale x 4 x i32> @different_imm_vl_with_ta_1(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
 ; CHECK-LABEL: different_imm_vl_with_ta_1:
 ; CHECK:       # %bb.0:
@@ -69,13 +75,20 @@ define <vscale x 4 x i32> @different_imm_vl_with_ta_1(<vscale x 4 x i32> %passth
 ; it's still safe even %vl2 is larger than %vl1, becuase rest of the vector are
 ; undefined value.
 define <vscale x 4 x i32> @different_vl_with_ta(<vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
-; CHECK-LABEL: different_vl_with_ta:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
-; CHECK-NEXT:    vadd.vv v10, v8, v10
-; CHECK-NEXT:    vsetvli zero, a1, e32, m2, ta, ma
-; CHECK-NEXT:    vadd.vv v8, v10, v8
-; CHECK-NEXT:    ret
+; NOVLOPT-LABEL: different_vl_with_ta:
+; NOVLOPT:       # %bb.0:
+; NOVLOPT-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
+; NOVLOPT-NEXT:    vadd.vv v10, v8, v10
+; NOVLOPT-NEXT:    vsetvli zero, a1, e32, m2, ta, ma
+; NOVLOPT-NEXT:    vadd.vv v8, v10, v8
+; NOVLOPT-NEXT:    ret
+;
+; VLOPT-LABEL: different_vl_with_ta:
+; VLOPT:       # %bb.0:
+; VLOPT-NEXT:    vsetvli zero, a1, e32, m2, ta, ma
+; VLOPT-NEXT:    vadd.vv v10, v8, v10
+; VLOPT-NEXT:    vadd.vv v8, v10, v8
+; VLOPT-NEXT:    ret
   %v = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1)
   %w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a,iXLen %vl2)
   ret <vscale x 4 x i32> %w
@@ -110,7 +123,3 @@ define <vscale x 4 x i32> @different_imm_vl_with_tu(<vscale x 4 x i32> %passthru
   %w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a,iXLen 4)
   ret <vscale x 4 x i32> %w
 }
-
-;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
-; NOVLOPT: {{.*}}
-; VLOPT: {{.*}}

>From 9ad694e590c8ffd1216b1fbae4f701c1d95a5971 Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Tue, 15 Oct 2024 07:08:51 -0700
Subject: [PATCH 2/2] fixup! use std::variant

---
 llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp | 60 +++++++++++++---------
 1 file changed, 36 insertions(+), 24 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
index 2c87826ab7883c..9b69aa0c60c4a8 100644
--- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
+++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
@@ -32,40 +32,51 @@ using namespace llvm;
 namespace {
 
 struct VLInfo {
+  std::variant<Register, int64_t> VL;
+
   VLInfo(const MachineOperand &VLOp) {
-    IsImm = VLOp.isImm();
-    if (IsImm)
-      Imm = VLOp.getImm();
+    if (VLOp.isImm())
+      VL = VLOp.getImm();
     else
-      Reg = VLOp.getReg();
+      VL = VLOp.getReg();
   }
 
-  Register Reg;
-  int64_t Imm;
-  bool IsImm;
-
   bool isCompatible(const MachineOperand &VLOp) const {
-    if (IsImm != VLOp.isImm())
+    if (isImm() != VLOp.isImm())
       return false;
-    if (IsImm)
-      return Imm == VLOp.getImm();
-    return Reg == VLOp.getReg();
+    if (isImm())
+      return getImm() == VLOp.getImm();
+    return getReg() == VLOp.getReg();
+  }
+
+  bool isImm() const { return std::holds_alternative<int64_t>(VL); }
+
+  bool isReg() const { return std::holds_alternative<Register>(VL); }
+
+  bool isValid() const { return isImm() || getReg().isVirtual(); }
+
+  int64_t getImm() const {
+    assert (isImm() && "Expected VL to be an immediate");
+    return std::get<int64_t>(VL);
   }
 
-  bool isValid() const { return IsImm || Reg.isVirtual(); }
+  Register getReg() const {
+    assert (isReg() && "Expected VL to be a Register");
+    return std::get<Register>(VL);
+  }
 
   bool hasBenefit(const MachineOperand &VLOp) const {
-    if (IsImm && Imm == RISCV::VLMaxSentinel)
+    if (isImm() && getImm() == RISCV::VLMaxSentinel)
       return false;
 
-    if (!IsImm || !VLOp.isImm())
+    if (!isImm() || !VLOp.isImm())
       return true;
 
     if (VLOp.getImm() == RISCV::VLMaxSentinel)
       return true;
 
     // No benefit if the current VL is already smaller than the new one.
-    return Imm < VLOp.getImm();
+    return getImm() < VLOp.getImm();
   }
 };
 
@@ -845,20 +856,21 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
       continue;
     }
 
-    if (CommonVL->IsImm) {
+    if (CommonVL->isImm()) {
       LLVM_DEBUG(dbgs() << "  Reduce VL from " << VLOp << " to "
-                        << CommonVL->Imm << " for " << MI << "\n");
-      VLOp.ChangeToImmediate(CommonVL->Imm);
+                        << CommonVL->getImm() << " for " << MI << "\n");
+      VLOp.ChangeToImmediate(CommonVL->getImm());
     } else {
-      const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->Reg);
+      const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg());
       if (!MDT->dominates(VLMI, &MI))
         continue;
-      LLVM_DEBUG(dbgs() << "  Reduce VL from " << VLOp << " to "
-                        << printReg(CommonVL->Reg, MRI->getTargetRegisterInfo())
-                        << " for " << MI << "\n");
+      LLVM_DEBUG(
+          dbgs() << "  Reduce VL from " << VLOp << " to "
+                 << printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo())
+                 << " for " << MI << "\n");
 
       // All our checks passed. We can reduce VL.
-      VLOp.ChangeToRegister(CommonVL->Reg, false);
+      VLOp.ChangeToRegister(CommonVL->getReg(), false);
     }
 
     MadeChange = true;



More information about the llvm-commits mailing list