[llvm] ae68d53 - [RISCV][VLOPT] Allow propagation even when VL isn't VLMAX (#112228)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Oct 16 11:58:04 PDT 2024
Author: Michael Maitland
Date: 2024-10-16T14:58:00-04:00
New Revision: ae68d532f810e217c747b10b26aeea3bb84c3844
URL: https://github.com/llvm/llvm-project/commit/ae68d532f810e217c747b10b26aeea3bb84c3844
DIFF: https://github.com/llvm/llvm-project/commit/ae68d532f810e217c747b10b26aeea3bb84c3844.diff
LOG: [RISCV][VLOPT] Allow propagation even when VL isn't VLMAX (#112228)
The original goal of this pass was to focus on vector operations with
VLMAX. However, users often utilize only part of the result, and such
usage may come from the vectorizer.
We found that relaxing this constraint can capture more optimization
opportunities, such as non-power-of-2 code generation and vector
operation sequences with different VLs.
---------
Co-authored-by: Kito Cheng <kito.cheng at sifive.com>
Added:
Modified:
llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
llvm/lib/Target/RISCV/RISCVInstrInfo.h
llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp
llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index b8539a5d1add14..3989a966edfd33 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -4102,3 +4102,17 @@ unsigned RISCV::getDestLog2EEW(const MCInstrDesc &Desc, unsigned Log2SEW) {
assert(Scaled >= 3 && Scaled <= 6);
return Scaled;
}
+
+/// Given two VL operands, do we know that LHS <= RHS?
+bool RISCV::isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS) {
+ if (LHS.isReg() && RHS.isReg() && LHS.getReg().isVirtual() &&
+ LHS.getReg() == RHS.getReg())
+ return true;
+ if (RHS.isImm() && RHS.getImm() == RISCV::VLMaxSentinel)
+ return true;
+ if (LHS.isImm() && LHS.getImm() == RISCV::VLMaxSentinel)
+ return false;
+ if (!LHS.isImm() || !RHS.isImm())
+ return false;
+ return LHS.getImm() <= RHS.getImm();
+}
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
index 457db9b9860d00..c3aa367486627a 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
@@ -346,6 +346,9 @@ unsigned getDestLog2EEW(const MCInstrDesc &Desc, unsigned Log2SEW);
// Special immediate for AVL operand of V pseudo instructions to indicate VLMax.
static constexpr int64_t VLMaxSentinel = -1LL;
+/// Given two VL operands, do we know that LHS <= RHS?
+bool isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS);
+
// Mask assignments for floating-point
static constexpr unsigned FPMASK_Negative_Infinity = 0x001;
static constexpr unsigned FPMASK_Negative_Normal = 0x002;
diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
index 6053899987db9b..ee494c46815112 100644
--- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
+++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
@@ -51,7 +51,7 @@ class RISCVVLOptimizer : public MachineFunctionPass {
StringRef getPassName() const override { return PASS_NAME; }
private:
- bool checkUsers(std::optional<Register> &CommonVL, MachineInstr &MI);
+ bool checkUsers(const MachineOperand *&CommonVL, MachineInstr &MI);
bool tryReduceVL(MachineInstr &MI);
bool isCandidate(const MachineInstr &MI) const;
};
@@ -658,10 +658,34 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
if (MI.getNumDefs() != 1)
return false;
+ // If we're not using VLMAX, then we need to be careful whether we are using
+ // TA/TU when there is a non-undef Passthru. But when we are using VLMAX, it
+ // does not matter whether we are using TA/TU with a non-undef Passthru, since
+ // there are no tail elements to be perserved.
unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
const MachineOperand &VLOp = MI.getOperand(VLOpNum);
- if (!VLOp.isImm() || VLOp.getImm() != RISCV::VLMaxSentinel)
+ if (VLOp.isReg() || VLOp.getImm() != RISCV::VLMaxSentinel) {
+ // If MI has a non-undef passthru, we will not try to optimize it since
+ // that requires us to preserve tail elements according to TA/TU.
+ // Otherwise, The MI has an undef Passthru, so it doesn't matter whether we
+ // are using TA/TU.
+ bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(Desc);
+ unsigned PassthruOpIdx = MI.getNumExplicitDefs();
+ if (HasPassthru &&
+ MI.getOperand(PassthruOpIdx).getReg() != RISCV::NoRegister) {
+ LLVM_DEBUG(
+ dbgs() << " Not a candidate because it uses non-undef passthru"
+ " with non-VLMAX VL\n");
+ return false;
+ }
+ }
+
+ // If the VL is 1, then there is no need to reduce it. This is an
+ // optimization, not needed to preserve correctness.
+ if (VLOp.isImm() && VLOp.getImm() == 1) {
+ LLVM_DEBUG(dbgs() << " Not a candidate because VL is already 1\n");
return false;
+ }
// Some instructions that produce vectors have semantics that make it more
//
diff icult to determine whether the VL can be reduced. For example, some
@@ -684,7 +708,7 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
return true;
}
-bool RISCVVLOptimizer::checkUsers(std::optional<Register> &CommonVL,
+bool RISCVVLOptimizer::checkUsers(const MachineOperand *&CommonVL,
MachineInstr &MI) {
// FIXME: Avoid visiting each user for each time we visit something on the
// worklist, combined with an extra visit from the outer loop. Restructure
@@ -730,16 +754,17 @@ bool RISCVVLOptimizer::checkUsers(std::optional<Register> &CommonVL,
unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
const MachineOperand &VLOp = UserMI.getOperand(VLOpNum);
- // Looking for a register VL that isn't X0.
- if (!VLOp.isReg() || VLOp.getReg() == RISCV::X0) {
- LLVM_DEBUG(dbgs() << " Abort due to user uses X0 as VL.\n");
- CanReduceVL = false;
- break;
- }
+
+ // Looking for an immediate or a register VL that isn't X0.
+ assert(!VLOp.isReg() ||
+ VLOp.getReg() != RISCV::X0 && "Did not expect X0 VL");
if (!CommonVL) {
- CommonVL = VLOp.getReg();
- } else if (*CommonVL != VLOp.getReg()) {
+ CommonVL = &VLOp;
+ LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n");
+ } else if (!CommonVL->isIdenticalTo(VLOp)) {
+ // FIXME: This check requires all users to have the same VL. We can relax
+ // this and get the largest VL amongst all users.
LLVM_DEBUG(dbgs() << " Abort because users have
diff erent VL\n");
CanReduceVL = false;
break;
@@ -776,7 +801,7 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
MachineInstr &MI = *Worklist.pop_back_val();
LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n");
- std::optional<Register> CommonVL;
+ const MachineOperand *CommonVL = nullptr;
bool CanReduceVL = true;
if (isVectorRegClass(MI.getOperand(0).getReg(), MRI))
CanReduceVL = checkUsers(CommonVL, MI);
@@ -784,21 +809,34 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
if (!CanReduceVL || !CommonVL)
continue;
- if (!CommonVL->isVirtual()) {
- LLVM_DEBUG(
- dbgs() << " Abort due to new VL is not virtual register.\n");
+ assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) &&
+ "Expected VL to be an Imm or virtual Reg");
+
+ unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc());
+ MachineOperand &VLOp = MI.getOperand(VLOpNum);
+
+ if (!RISCV::isVLKnownLE(*CommonVL, VLOp)) {
+ LLVM_DEBUG(dbgs() << " Abort due to CommonVL not <= VLOp.\n");
continue;
}
- const MachineInstr *VLMI = MRI->getVRegDef(*CommonVL);
- if (!MDT->dominates(VLMI, &MI))
- continue;
+ if (CommonVL->isImm()) {
+ LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to "
+ << CommonVL->getImm() << " for " << MI << "\n");
+ VLOp.ChangeToImmediate(CommonVL->getImm());
+ } else {
+ const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg());
+ if (!MDT->dominates(VLMI, &MI))
+ continue;
+ LLVM_DEBUG(
+ dbgs() << " Reduce VL from " << VLOp << " to "
+ << printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo())
+ << " for " << MI << "\n");
+
+ // All our checks passed. We can reduce VL.
+ VLOp.ChangeToRegister(CommonVL->getReg(), false);
+ }
- // All our checks passed. We can reduce VL.
- LLVM_DEBUG(dbgs() << " Reducing VL for: " << MI << "\n");
- unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc());
- MachineOperand &VLOp = MI.getOperand(VLOpNum);
- VLOp.ChangeToRegister(*CommonVL, false);
MadeChange = true;
// Now add all inputs to this instruction to the worklist.
diff --git a/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp b/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp
index b883c50beadc09..a57bc5a3007d03 100644
--- a/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp
+++ b/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp
@@ -86,20 +86,6 @@ char RISCVVectorPeephole::ID = 0;
INITIALIZE_PASS(RISCVVectorPeephole, DEBUG_TYPE, "RISC-V Fold Masks", false,
false)
-/// Given two VL operands, do we know that LHS <= RHS?
-static bool isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS) {
- if (LHS.isReg() && RHS.isReg() && LHS.getReg().isVirtual() &&
- LHS.getReg() == RHS.getReg())
- return true;
- if (RHS.isImm() && RHS.getImm() == RISCV::VLMaxSentinel)
- return true;
- if (LHS.isImm() && LHS.getImm() == RISCV::VLMaxSentinel)
- return false;
- if (!LHS.isImm() || !RHS.isImm())
- return false;
- return LHS.getImm() <= RHS.getImm();
-}
-
/// Given \p User that has an input operand with EEW=SEW, which uses the dest
/// operand of \p Src with an unknown EEW, return true if their EEWs match.
bool RISCVVectorPeephole::hasSameEEW(const MachineInstr &User,
@@ -191,7 +177,7 @@ bool RISCVVectorPeephole::tryToReduceVL(MachineInstr &MI) const {
return false;
MachineOperand &SrcVL = Src->getOperand(RISCVII::getVLOpNum(Src->getDesc()));
- if (VL.isIdenticalTo(SrcVL) || !isVLKnownLE(VL, SrcVL))
+ if (VL.isIdenticalTo(SrcVL) || !RISCV::isVLKnownLE(VL, SrcVL))
return false;
if (!ensureDominates(VL, *Src))
@@ -580,7 +566,7 @@ bool RISCVVectorPeephole::foldUndefPassthruVMV_V_V(MachineInstr &MI) {
MachineOperand &SrcPolicy =
Src->getOperand(RISCVII::getVecPolicyOpNum(Src->getDesc()));
- if (isVLKnownLE(MIVL, SrcVL))
+ if (RISCV::isVLKnownLE(MIVL, SrcVL))
SrcPolicy.setImm(SrcPolicy.getImm() | RISCVII::TAIL_AGNOSTIC);
}
@@ -631,7 +617,7 @@ bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) {
// so we don't need to handle a smaller source VL here. However, the
// user's VL may be larger
MachineOperand &SrcVL = Src->getOperand(RISCVII::getVLOpNum(Src->getDesc()));
- if (!isVLKnownLE(SrcVL, MI.getOperand(3)))
+ if (!RISCV::isVLKnownLE(SrcVL, MI.getOperand(3)))
return false;
// If the new passthru doesn't dominate Src, try to move Src so it does.
@@ -650,7 +636,7 @@ bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) {
// If MI was tail agnostic and the VL didn't increase, preserve it.
int64_t Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED;
if ((MI.getOperand(5).getImm() & RISCVII::TAIL_AGNOSTIC) &&
- isVLKnownLE(MI.getOperand(3), SrcVL))
+ RISCV::isVLKnownLE(MI.getOperand(3), SrcVL))
Policy |= RISCVII::TAIL_AGNOSTIC;
Src->getOperand(RISCVII::getVecPolicyOpNum(Src->getDesc())).setImm(Policy);
diff --git a/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll b/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
index 0b3e67ec895566..1a1472fcfc66f5 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
@@ -11,19 +11,46 @@
declare <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32>, iXLen)
define <vscale x 4 x i32> @
diff erent_imm_vl_with_ta(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
-; CHECK-LABEL:
diff erent_imm_vl_with_ta:
-; CHECK: # %bb.0:
-; CHECK-NEXT: vsetivli zero, 5, e32, m2, ta, ma
-; CHECK-NEXT: vadd.vv v8, v10, v12
-; CHECK-NEXT: vsetivli zero, 4, e32, m2, ta, ma
-; CHECK-NEXT: vadd.vv v8, v8, v10
-; CHECK-NEXT: ret
+; NOVLOPT-LABEL:
diff erent_imm_vl_with_ta:
+; NOVLOPT: # %bb.0:
+; NOVLOPT-NEXT: vsetivli zero, 5, e32, m2, ta, ma
+; NOVLOPT-NEXT: vadd.vv v8, v10, v12
+; NOVLOPT-NEXT: vsetivli zero, 4, e32, m2, ta, ma
+; NOVLOPT-NEXT: vadd.vv v8, v8, v10
+; NOVLOPT-NEXT: ret
+;
+; VLOPT-LABEL:
diff erent_imm_vl_with_ta:
+; VLOPT: # %bb.0:
+; VLOPT-NEXT: vsetivli zero, 4, e32, m2, ta, ma
+; VLOPT-NEXT: vadd.vv v8, v10, v12
+; VLOPT-NEXT: vadd.vv v8, v8, v10
+; VLOPT-NEXT: ret
%v = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen 5)
%w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a, iXLen 4)
ret <vscale x 4 x i32> %w
}
-; No benificial to propagate VL since VL is larger in the use side.
+define <vscale x 4 x i32> @vlmax_and_imm_vl_with_ta(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
+; NOVLOPT-LABEL: vlmax_and_imm_vl_with_ta:
+; NOVLOPT: # %bb.0:
+; NOVLOPT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; NOVLOPT-NEXT: vadd.vv v8, v10, v12
+; NOVLOPT-NEXT: vsetivli zero, 4, e32, m2, ta, ma
+; NOVLOPT-NEXT: vadd.vv v8, v8, v10
+; NOVLOPT-NEXT: ret
+;
+; VLOPT-LABEL: vlmax_and_imm_vl_with_ta:
+; VLOPT: # %bb.0:
+; VLOPT-NEXT: vsetivli zero, 4, e32, m2, ta, ma
+; VLOPT-NEXT: vadd.vv v8, v10, v12
+; VLOPT-NEXT: vadd.vv v8, v8, v10
+; VLOPT-NEXT: ret
+ %v = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen -1)
+ %w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a, iXLen 4)
+ ret <vscale x 4 x i32> %w
+}
+
+; Not beneficial to propagate VL since VL is larger in the use side.
define <vscale x 4 x i32> @
diff erent_imm_vl_with_ta_larger_vl(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
; CHECK-LABEL:
diff erent_imm_vl_with_ta_larger_vl:
; CHECK: # %bb.0:
@@ -50,8 +77,7 @@ define <vscale x 4 x i32> @
diff erent_imm_reg_vl_with_ta(<vscale x 4 x i32> %pass
ret <vscale x 4 x i32> %w
}
-
-; No benificial to propagate VL since VL is already one.
+; Not beneficial to propagate VL since VL is already one.
define <vscale x 4 x i32> @
diff erent_imm_vl_with_ta_1(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
; CHECK-LABEL:
diff erent_imm_vl_with_ta_1:
; CHECK: # %bb.0:
@@ -110,7 +136,3 @@ define <vscale x 4 x i32> @
diff erent_imm_vl_with_tu(<vscale x 4 x i32> %passthru
%w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a,iXLen 4)
ret <vscale x 4 x i32> %w
}
-
-;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
-; NOVLOPT: {{.*}}
-; VLOPT: {{.*}}
More information about the llvm-commits
mailing list