[llvm] [RISCV][VLOPT] Allow propagation even when VL isn't VLMAX (PR #112228)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Oct 14 08:57:15 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-risc-v
Author: Michael Maitland (michaelmaitland)
<details>
<summary>Changes</summary>
The original goal of this pass was to focus on vector operations with VLMAX.
However, users often utilize only part of the result, and such usage may come
from the vectorizer.
We found that relaxing this constraint can capture more optimization
opportunities, such as non-power-of-2 code generation and vector operation
sequences with different VLs.
---------
Co-authored-by: Kito Cheng <kito.cheng@<!-- -->sifive.com>
---
Full diff: https://github.com/llvm/llvm-project/pull/112228.diff
2 Files Affected:
- (modified) llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp (+95-16)
- (modified) llvm/test/CodeGen/RISCV/rvv/vl-opt.ll (+39-20)
``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
index eb1f4df4ff7264..2c87826ab7883c 100644
--- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
+++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
@@ -31,6 +31,44 @@ using namespace llvm;
namespace {
+struct VLInfo {
+ VLInfo(const MachineOperand &VLOp) {
+ IsImm = VLOp.isImm();
+ if (IsImm)
+ Imm = VLOp.getImm();
+ else
+ Reg = VLOp.getReg();
+ }
+
+ Register Reg;
+ int64_t Imm;
+ bool IsImm;
+
+ bool isCompatible(const MachineOperand &VLOp) const {
+ if (IsImm != VLOp.isImm())
+ return false;
+ if (IsImm)
+ return Imm == VLOp.getImm();
+ return Reg == VLOp.getReg();
+ }
+
+ bool isValid() const { return IsImm || Reg.isVirtual(); }
+
+ bool hasBenefit(const MachineOperand &VLOp) const {
+ if (IsImm && Imm == RISCV::VLMaxSentinel)
+ return false;
+
+ if (!IsImm || !VLOp.isImm())
+ return true;
+
+ if (VLOp.getImm() == RISCV::VLMaxSentinel)
+ return true;
+
+ // No benefit if the current VL is already smaller than the new one.
+ return Imm < VLOp.getImm();
+ }
+};
+
class RISCVVLOptimizer : public MachineFunctionPass {
const MachineRegisterInfo *MRI;
const MachineDominatorTree *MDT;
@@ -51,7 +89,7 @@ class RISCVVLOptimizer : public MachineFunctionPass {
StringRef getPassName() const override { return PASS_NAME; }
private:
- bool checkUsers(std::optional<Register> &CommonVL, MachineInstr &MI);
+ bool checkUsers(std::optional<VLInfo> &CommonVL, MachineInstr &MI);
bool tryReduceVL(MachineInstr &MI);
bool isCandidate(const MachineInstr &MI) const;
};
@@ -643,8 +681,34 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
const MachineOperand &VLOp = MI.getOperand(VLOpNum);
- if (!VLOp.isImm() || VLOp.getImm() != RISCV::VLMaxSentinel)
+ if (((VLOp.isImm() && VLOp.getImm() != RISCV::VLMaxSentinel) ||
+ VLOp.isReg())) {
+ bool UseTAPolicy = false;
+ bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(Desc);
+ if (RISCVII::hasVecPolicyOp(Desc.TSFlags)) {
+ unsigned PolicyOpNum = RISCVII::getVecPolicyOpNum(Desc);
+ const MachineOperand &PolicyOp = MI.getOperand(PolicyOpNum);
+ uint64_t Policy = PolicyOp.getImm();
+ UseTAPolicy = (Policy & RISCVII::TAIL_AGNOSTIC) == RISCVII::TAIL_AGNOSTIC;
+ if (HasPassthru) {
+ unsigned PassthruOpIdx = MI.getNumExplicitDefs();
+ UseTAPolicy = UseTAPolicy || (MI.getOperand(PassthruOpIdx).getReg() ==
+ RISCV::NoRegister);
+ }
+ }
+ if (!UseTAPolicy) {
+ LLVM_DEBUG(
+ dbgs() << " Not a candidate because it uses tail-undisturbed policy"
+ " with non-VLMAX VL\n");
+ return false;
+ }
+ }
+
+ // If the VL is 1, then there is no need to reduce it.
+ if (VLOp.isImm() && VLOp.getImm() == 1) {
+ LLVM_DEBUG(dbgs() << " Not a candidate because VL is already 1\n");
return false;
+ }
// Some instructions that produce vectors have semantics that make it more
// difficult to determine whether the VL can be reduced. For example, some
@@ -667,7 +731,7 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
return true;
}
-bool RISCVVLOptimizer::checkUsers(std::optional<Register> &CommonVL,
+bool RISCVVLOptimizer::checkUsers(std::optional<VLInfo> &CommonVL,
MachineInstr &MI) {
// FIXME: Avoid visiting each user for each time we visit something on the
// worklist, combined with an extra visit from the outer loop. Restructure
@@ -721,8 +785,9 @@ bool RISCVVLOptimizer::checkUsers(std::optional<Register> &CommonVL,
}
if (!CommonVL) {
- CommonVL = VLOp.getReg();
- } else if (*CommonVL != VLOp.getReg()) {
+ CommonVL = VLInfo(VLOp);
+ LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n");
+ } else if (!CommonVL->isCompatible(VLOp)) {
LLVM_DEBUG(dbgs() << " Abort because users have different VL\n");
CanReduceVL = false;
break;
@@ -759,7 +824,7 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
MachineInstr &MI = *Worklist.pop_back_val();
LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n");
- std::optional<Register> CommonVL;
+ std::optional<VLInfo> CommonVL;
bool CanReduceVL = true;
if (isVectorRegClass(MI.getOperand(0).getReg(), MRI))
CanReduceVL = checkUsers(CommonVL, MI);
@@ -767,21 +832,35 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
if (!CanReduceVL || !CommonVL)
continue;
- if (!CommonVL->isVirtual()) {
- LLVM_DEBUG(
- dbgs() << " Abort due to new VL is not virtual register.\n");
+ if (!CommonVL->isValid()) {
+ LLVM_DEBUG(dbgs() << " Abort due to common VL is not valid.\n");
continue;
}
- const MachineInstr *VLMI = MRI->getVRegDef(*CommonVL);
- if (!MDT->dominates(VLMI, &MI))
- continue;
-
- // All our checks passed. We can reduce VL.
- LLVM_DEBUG(dbgs() << " Reducing VL for: " << MI << "\n");
unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc());
MachineOperand &VLOp = MI.getOperand(VLOpNum);
- VLOp.ChangeToRegister(*CommonVL, false);
+
+ if (!CommonVL->hasBenefit(VLOp)) {
+ LLVM_DEBUG(dbgs() << " Abort due to no benefit.\n");
+ continue;
+ }
+
+ if (CommonVL->IsImm) {
+ LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to "
+ << CommonVL->Imm << " for " << MI << "\n");
+ VLOp.ChangeToImmediate(CommonVL->Imm);
+ } else {
+ const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->Reg);
+ if (!MDT->dominates(VLMI, &MI))
+ continue;
+ LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to "
+ << printReg(CommonVL->Reg, MRI->getTargetRegisterInfo())
+ << " for " << MI << "\n");
+
+ // All our checks passed. We can reduce VL.
+ VLOp.ChangeToRegister(CommonVL->Reg, false);
+ }
+
MadeChange = true;
// Now add all inputs to this instruction to the worklist.
diff --git a/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll b/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
index b03ba076059503..e8ac4efc770484 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
@@ -1,6 +1,12 @@
-; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 2
-; RUN: sed 's/iXLen/i32/g' %s | llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs | FileCheck %s
-; RUN: sed 's/iXLen/i64/g' %s | llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs | FileCheck %s
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: sed 's/iXLen/i32/g' %s | llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs | \
+; RUN: FileCheck %s -check-prefixes=CHECK,NOVLOPT
+; RUN: sed 's/iXLen/i64/g' %s | llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs | \
+; RUN: FileCheck %s -check-prefixes=CHECK,NOVLOPT
+; RUN: sed 's/iXLen/i32/g' %s | llc -mtriple=riscv32 -mattr=+v -riscv-enable-vl-optimizer \
+; RUN: -verify-machineinstrs | FileCheck %s -check-prefixes=CHECK,VLOPT
+; RUN: sed 's/iXLen/i64/g' %s | llc -mtriple=riscv64 -mattr=+v -riscv-enable-vl-optimizer \
+; RUN: -verify-machineinstrs | FileCheck %s -check-prefixes=CHECK,VLOPT
declare <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32>, iXLen)
@@ -17,7 +23,7 @@ define <vscale x 4 x i32> @different_imm_vl_with_ta(<vscale x 4 x i32> %passthru
ret <vscale x 4 x i32> %w
}
-; No benificial to propagate VL since VL is larger in the use side.
+; Not beneficial to propagate VL since VL is larger in the use side.
define <vscale x 4 x i32> @different_imm_vl_with_ta_larger_vl(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
; CHECK-LABEL: different_imm_vl_with_ta_larger_vl:
; CHECK: # %bb.0:
@@ -32,20 +38,26 @@ define <vscale x 4 x i32> @different_imm_vl_with_ta_larger_vl(<vscale x 4 x i32>
}
define <vscale x 4 x i32> @different_imm_reg_vl_with_ta(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
-; CHECK-LABEL: different_imm_reg_vl_with_ta:
-; CHECK: # %bb.0:
-; CHECK-NEXT: vsetivli zero, 4, e32, m2, ta, ma
-; CHECK-NEXT: vadd.vv v8, v10, v12
-; CHECK-NEXT: vsetvli zero, a0, e32, m2, ta, ma
-; CHECK-NEXT: vadd.vv v8, v8, v10
-; CHECK-NEXT: ret
+; NOVLOPT-LABEL: different_imm_reg_vl_with_ta:
+; NOVLOPT: # %bb.0:
+; NOVLOPT-NEXT: vsetivli zero, 4, e32, m2, ta, ma
+; NOVLOPT-NEXT: vadd.vv v8, v10, v12
+; NOVLOPT-NEXT: vsetvli zero, a0, e32, m2, ta, ma
+; NOVLOPT-NEXT: vadd.vv v8, v8, v10
+; NOVLOPT-NEXT: ret
+;
+; VLOPT-LABEL: different_imm_reg_vl_with_ta:
+; VLOPT: # %bb.0:
+; VLOPT-NEXT: vsetvli zero, a0, e32, m2, ta, ma
+; VLOPT-NEXT: vadd.vv v8, v10, v12
+; VLOPT-NEXT: vadd.vv v8, v8, v10
+; VLOPT-NEXT: ret
%v = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen 4)
%w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a, iXLen %vl1)
ret <vscale x 4 x i32> %w
}
-
-; No benificial to propagate VL since VL is already one.
+; Not beneficial to propagate VL since VL is already one.
define <vscale x 4 x i32> @different_imm_vl_with_ta_1(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
; CHECK-LABEL: different_imm_vl_with_ta_1:
; CHECK: # %bb.0:
@@ -63,13 +75,20 @@ define <vscale x 4 x i32> @different_imm_vl_with_ta_1(<vscale x 4 x i32> %passth
; it's still safe even %vl2 is larger than %vl1, becuase rest of the vector are
; undefined value.
define <vscale x 4 x i32> @different_vl_with_ta(<vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
-; CHECK-LABEL: different_vl_with_ta:
-; CHECK: # %bb.0:
-; CHECK-NEXT: vsetvli zero, a0, e32, m2, ta, ma
-; CHECK-NEXT: vadd.vv v10, v8, v10
-; CHECK-NEXT: vsetvli zero, a1, e32, m2, ta, ma
-; CHECK-NEXT: vadd.vv v8, v10, v8
-; CHECK-NEXT: ret
+; NOVLOPT-LABEL: different_vl_with_ta:
+; NOVLOPT: # %bb.0:
+; NOVLOPT-NEXT: vsetvli zero, a0, e32, m2, ta, ma
+; NOVLOPT-NEXT: vadd.vv v10, v8, v10
+; NOVLOPT-NEXT: vsetvli zero, a1, e32, m2, ta, ma
+; NOVLOPT-NEXT: vadd.vv v8, v10, v8
+; NOVLOPT-NEXT: ret
+;
+; VLOPT-LABEL: different_vl_with_ta:
+; VLOPT: # %bb.0:
+; VLOPT-NEXT: vsetvli zero, a1, e32, m2, ta, ma
+; VLOPT-NEXT: vadd.vv v10, v8, v10
+; VLOPT-NEXT: vadd.vv v8, v10, v8
+; VLOPT-NEXT: ret
%v = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1)
%w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a,iXLen %vl2)
ret <vscale x 4 x i32> %w
``````````
</details>
https://github.com/llvm/llvm-project/pull/112228
More information about the llvm-commits
mailing list