[llvm] [RISCV] Unify vsetvli compatibility logic in forward and backwards passes (PR #71657)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Nov 8 03:23:01 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-risc-v
Author: Luke Lau (lukel97)
<details>
<summary>Changes</summary>
The backwards local postpass has its own logic for figuring out if two vsetvlis
are compatible, separate from isCompatible used by the forward pass. However
these largely work out the same thing, i.e. is it possible to mutate vsetvli A
to be vsetvli B given these demanded properties.
The main difference between the two is that the backwards postpass needs to be able to reason about vsetvli x0, x0, whereas the forward pass doesn't (because regular vector pseudos can't express this).
So if we teach VSETVLIInfo used by the forward pass to handle the vsetvli x0,
x0 case, then it becomes possible to unify the two passes. To do this we
introduce a new state to represent that the VL is preserved from the previous
vsetvli. Then in VSETVLIInfo::isCompatible, we can use this information to avoid
check if the AVLs are the same if we know that the second vsetvli is x0, x0.
For the backwards pass, we keep a running track of VSETVLIInfo as we iterate up
through the basic block, and swap out canMutatePriorConfig with
VSETVLIInfo::isCompatible.
It's possible now to move areCompatibleVTYPEs into VSETVLIInfo, but I've
deferred that code motion for now to keep the diff small, and can move it
afterwards as an NFC.
We now need to represent the notion that a VSETVLIInfo's state
---
Full diff: https://github.com/llvm/llvm-project/pull/71657.diff
4 Files Affected:
- (modified) llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp (+49-60)
- (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-insert.ll (+4-7)
- (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll (+8-21)
- (modified) llvm/test/CodeGen/RISCV/rvv/vsetvli-insert.ll (+2-2)
``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp
index f6d8b1f0a70e13d..f328cef92a95b9a 100644
--- a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp
@@ -419,6 +419,17 @@ DemandedFields getDemanded(const MachineInstr &MI,
return Res;
}
+static MachineInstr *isADDIX0(Register Reg, const MachineRegisterInfo &MRI) {
+ if (Reg == RISCV::X0)
+ return nullptr;
+ if (MachineInstr *MI = MRI.getVRegDef(Reg);
+ MI && MI->getOpcode() == RISCV::ADDI && MI->getOperand(1).isReg() &&
+ MI->getOperand(2).isImm() && MI->getOperand(1).getReg() == RISCV::X0 &&
+ MI->getOperand(2).getImm() != 0)
+ return MI;
+ return nullptr;
+}
+
/// Defines the abstract state with which the forward dataflow models the
/// values of the VL and VTYPE registers after insertion.
class VSETVLIInfo {
@@ -431,6 +442,7 @@ class VSETVLIInfo {
Uninitialized,
AVLIsReg,
AVLIsImm,
+ PreserveVL, // vsetvli x0, x0
Unknown,
} State = Uninitialized;
@@ -466,6 +478,8 @@ class VSETVLIInfo {
State = AVLIsImm;
}
+ void setPreserveVL() { State = PreserveVL; }
+
bool hasAVLImm() const { return State == AVLIsImm; }
bool hasAVLReg() const { return State == AVLIsReg; }
Register getAVLReg() const {
@@ -486,11 +500,7 @@ class VSETVLIInfo {
if (hasAVLReg()) {
if (getAVLReg() == RISCV::X0)
return true;
- if (MachineInstr *MI = MRI.getVRegDef(getAVLReg());
- MI && MI->getOpcode() == RISCV::ADDI &&
- MI->getOperand(1).isReg() && MI->getOperand(2).isImm() &&
- MI->getOperand(1).getReg() == RISCV::X0 &&
- MI->getOperand(2).getImm() != 0)
+ if (isADDIX0(getAVLReg(), MRI))
return true;
return false;
}
@@ -579,8 +589,11 @@ class VSETVLIInfo {
// Determine whether the vector instructions requirements represented by
// Require are compatible with the previous vsetvli instruction represented
// by this. MI is the instruction whose requirements we're considering.
+ // The instruction represented by Require should come after this, unless
+ // OrderReversed is true.
bool isCompatible(const DemandedFields &Used, const VSETVLIInfo &Require,
- const MachineRegisterInfo &MRI) const {
+ const MachineRegisterInfo &MRI,
+ bool OrderReversed = false) const {
assert(isValid() && Require.isValid() &&
"Can't compare invalid VSETVLIInfos");
assert(!Require.SEWLMULRatioOnly &&
@@ -593,11 +606,15 @@ class VSETVLIInfo {
if (SEWLMULRatioOnly)
return false;
- if (Used.VLAny && !hasSameAVL(Require))
- return false;
+ // If the VL will be preserved, then we don't need to check the AVL.
+ const uint8_t EndState = OrderReversed ? State : Require.State;
+ if (EndState != PreserveVL) {
+ if (Used.VLAny && !hasSameAVL(Require))
+ return false;
- if (Used.VLZeroness && !hasEquallyZeroAVL(Require, MRI))
- return false;
+ if (Used.VLZeroness && !hasEquallyZeroAVL(Require, MRI))
+ return false;
+ }
return hasCompatibleVTYPE(Used, Require);
}
@@ -849,9 +866,11 @@ static VSETVLIInfo getInfoForVSETVLI(const MachineInstr &MI) {
assert(MI.getOpcode() == RISCV::PseudoVSETVLI ||
MI.getOpcode() == RISCV::PseudoVSETVLIX0);
Register AVLReg = MI.getOperand(1).getReg();
- assert((AVLReg != RISCV::X0 || MI.getOperand(0).getReg() != RISCV::X0) &&
- "Can't handle X0, X0 vsetvli yet");
- NewInfo.setAVLReg(AVLReg);
+
+ if (AVLReg == RISCV::X0 && MI.getOperand(0).getReg() == RISCV::X0)
+ NewInfo.setPreserveVL();
+ else
+ NewInfo.setAVLReg(AVLReg);
}
NewInfo.setVTYPE(MI.getOperand(2).getImm());
@@ -1426,52 +1445,9 @@ static void doUnion(DemandedFields &A, DemandedFields B) {
A.MaskPolicy |= B.MaskPolicy;
}
-static bool isNonZeroAVL(const MachineOperand &MO) {
- if (MO.isReg())
- return RISCV::X0 == MO.getReg();
- assert(MO.isImm());
- return 0 != MO.getImm();
-}
-
-// Return true if we can mutate PrevMI to match MI without changing any the
-// fields which would be observed.
-static bool canMutatePriorConfig(const MachineInstr &PrevMI,
- const MachineInstr &MI,
- const DemandedFields &Used) {
- // If the VL values aren't equal, return false if either a) the former is
- // demanded, or b) we can't rewrite the former to be the later for
- // implementation reasons.
- if (!isVLPreservingConfig(MI)) {
- if (Used.VLAny)
- return false;
-
- // We don't bother to handle the equally zero case here as it's largely
- // uninteresting.
- if (Used.VLZeroness) {
- if (isVLPreservingConfig(PrevMI))
- return false;
- if (!isNonZeroAVL(MI.getOperand(1)) ||
- !isNonZeroAVL(PrevMI.getOperand(1)))
- return false;
- }
-
- // TODO: Track whether the register is defined between
- // PrevMI and MI.
- if (MI.getOperand(1).isReg() &&
- RISCV::X0 != MI.getOperand(1).getReg())
- return false;
- }
-
- if (!PrevMI.getOperand(2).isImm() || !MI.getOperand(2).isImm())
- return false;
-
- auto PriorVType = PrevMI.getOperand(2).getImm();
- auto VType = MI.getOperand(2).getImm();
- return areCompatibleVTYPEs(PriorVType, VType, Used);
-}
-
void RISCVInsertVSETVLI::doLocalPostpass(MachineBasicBlock &MBB) {
MachineInstr *NextMI = nullptr;
+ VSETVLIInfo NextInfo;
// We can have arbitrary code in successors, so VL and VTYPE
// must be considered demanded.
DemandedFields Used;
@@ -1482,6 +1458,7 @@ void RISCVInsertVSETVLI::doLocalPostpass(MachineBasicBlock &MBB) {
if (!isVectorConfigInstr(MI)) {
doUnion(Used, getDemanded(MI, MRI, ST));
+ transferAfter(NextInfo, MI);
continue;
}
@@ -1495,14 +1472,25 @@ void RISCVInsertVSETVLI::doLocalPostpass(MachineBasicBlock &MBB) {
ToDelete.push_back(&MI);
// Leave NextMI unchanged
continue;
- } else if (canMutatePriorConfig(MI, *NextMI, Used)) {
+ } else if (NextInfo.isCompatible(Used, getInfoForVSETVLI(MI), *MRI,
+ true)) {
if (!isVLPreservingConfig(*NextMI)) {
MI.getOperand(0).setReg(NextMI->getOperand(0).getReg());
MI.getOperand(0).setIsDead(false);
if (NextMI->getOperand(1).isImm())
MI.getOperand(1).ChangeToImmediate(NextMI->getOperand(1).getImm());
- else
- MI.getOperand(1).ChangeToRegister(NextMI->getOperand(1).getReg(), false);
+ else {
+ // NextMI may have an AVL (addi x0, imm) whilst MI might have a
+ // different non-zero AVL. But the AVLs may be considered
+ // compatible. So hoist it up to MI in case it's not already
+ // dominated by it. See hasNonZeroAVL.
+ if (MachineInstr *ADDI =
+ isADDIX0(NextMI->getOperand(1).getReg(), *MRI))
+ ADDI->moveBefore(&MI);
+
+ MI.getOperand(1).ChangeToRegister(NextMI->getOperand(1).getReg(),
+ false);
+ }
MI.setDesc(NextMI->getDesc());
}
MI.getOperand(2).setImm(NextMI->getOperand(2).getImm());
@@ -1511,6 +1499,7 @@ void RISCVInsertVSETVLI::doLocalPostpass(MachineBasicBlock &MBB) {
}
}
NextMI = &MI;
+ NextInfo = getInfoForVSETVLI(MI);
Used = getDemanded(MI, MRI, ST);
}
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-insert.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-insert.ll
index 3cc7371c1ce9ac4..38efe0d30ed9511 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-insert.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-insert.ll
@@ -51,11 +51,10 @@ define <32 x i32> @insertelt_v32i32_0(<32 x i32> %a, i32 %y) {
define <32 x i32> @insertelt_v32i32_4(<32 x i32> %a, i32 %y) {
; CHECK-LABEL: insertelt_v32i32_4:
; CHECK: # %bb.0:
-; CHECK-NEXT: li a1, 32
-; CHECK-NEXT: vsetvli zero, a1, e32, m1, ta, ma
-; CHECK-NEXT: vmv.s.x v16, a0
; CHECK-NEXT: vsetivli zero, 5, e32, m2, tu, ma
+; CHECK-NEXT: vmv.s.x v16, a0
; CHECK-NEXT: vslideup.vi v8, v16, 4
+; CHECK-NEXT: li a0, 32
; CHECK-NEXT: ret
%b = insertelement <32 x i32> %a, i32 %y, i32 4
ret <32 x i32> %b
@@ -65,9 +64,8 @@ define <32 x i32> @insertelt_v32i32_31(<32 x i32> %a, i32 %y) {
; CHECK-LABEL: insertelt_v32i32_31:
; CHECK: # %bb.0:
; CHECK-NEXT: li a1, 32
-; CHECK-NEXT: vsetvli zero, a1, e32, m1, ta, ma
-; CHECK-NEXT: vmv.s.x v16, a0
; CHECK-NEXT: vsetvli zero, a1, e32, m8, ta, ma
+; CHECK-NEXT: vmv.s.x v16, a0
; CHECK-NEXT: vslideup.vi v8, v16, 31
; CHECK-NEXT: ret
%b = insertelement <32 x i32> %a, i32 %y, i32 31
@@ -103,9 +101,8 @@ define <64 x i32> @insertelt_v64i32_63(<64 x i32> %a, i32 %y) {
; CHECK-LABEL: insertelt_v64i32_63:
; CHECK: # %bb.0:
; CHECK-NEXT: li a1, 32
-; CHECK-NEXT: vsetvli zero, a1, e32, m1, ta, ma
-; CHECK-NEXT: vmv.s.x v24, a0
; CHECK-NEXT: vsetvli zero, a1, e32, m8, ta, ma
+; CHECK-NEXT: vmv.s.x v24, a0
; CHECK-NEXT: vslideup.vi v16, v24, 31
; CHECK-NEXT: ret
%b = insertelement <64 x i32> %a, i32 %y, i32 63
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
index 728cf18e1a77d8a..2ab868835552ea3 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
@@ -12418,12 +12418,11 @@ define <32 x i8> @mgather_baseidx_v32i8(ptr %base, <32 x i8> %idxs, <32 x i1> %m
; RV64ZVE32F-NEXT: andi a2, a1, 1
; RV64ZVE32F-NEXT: beqz a2, .LBB98_2
; RV64ZVE32F-NEXT: # %bb.1: # %cond.load
-; RV64ZVE32F-NEXT: vsetvli zero, zero, e8, mf4, ta, ma
+; RV64ZVE32F-NEXT: li a2, 32
+; RV64ZVE32F-NEXT: vsetvli zero, a2, e8, mf4, tu, ma
; RV64ZVE32F-NEXT: vmv.x.s a2, v8
; RV64ZVE32F-NEXT: add a2, a0, a2
; RV64ZVE32F-NEXT: lbu a2, 0(a2)
-; RV64ZVE32F-NEXT: li a3, 32
-; RV64ZVE32F-NEXT: vsetvli zero, a3, e8, mf4, tu, ma
; RV64ZVE32F-NEXT: vmv.s.x v10, a2
; RV64ZVE32F-NEXT: .LBB98_2: # %else
; RV64ZVE32F-NEXT: andi a2, a1, 2
@@ -12452,14 +12451,12 @@ define <32 x i8> @mgather_baseidx_v32i8(ptr %base, <32 x i8> %idxs, <32 x i1> %m
; RV64ZVE32F-NEXT: andi a2, a1, 16
; RV64ZVE32F-NEXT: beqz a2, .LBB98_8
; RV64ZVE32F-NEXT: .LBB98_7: # %cond.load10
-; RV64ZVE32F-NEXT: vsetivli zero, 1, e8, mf4, ta, ma
+; RV64ZVE32F-NEXT: vsetivli zero, 5, e8, m1, tu, ma
; RV64ZVE32F-NEXT: vmv.x.s a2, v13
; RV64ZVE32F-NEXT: add a2, a0, a2
; RV64ZVE32F-NEXT: lbu a2, 0(a2)
; RV64ZVE32F-NEXT: li a3, 32
-; RV64ZVE32F-NEXT: vsetvli zero, a3, e8, mf4, ta, ma
; RV64ZVE32F-NEXT: vmv.s.x v12, a2
-; RV64ZVE32F-NEXT: vsetivli zero, 5, e8, m1, tu, ma
; RV64ZVE32F-NEXT: vslideup.vi v10, v12, 4
; RV64ZVE32F-NEXT: .LBB98_8: # %else11
; RV64ZVE32F-NEXT: vsetivli zero, 8, e8, m1, ta, ma
@@ -12592,14 +12589,12 @@ define <32 x i8> @mgather_baseidx_v32i8(ptr %base, <32 x i8> %idxs, <32 x i1> %m
; RV64ZVE32F-NEXT: slli a2, a1, 43
; RV64ZVE32F-NEXT: bgez a2, .LBB98_32
; RV64ZVE32F-NEXT: .LBB98_31: # %cond.load58
-; RV64ZVE32F-NEXT: vsetivli zero, 1, e8, mf4, ta, ma
+; RV64ZVE32F-NEXT: vsetivli zero, 21, e8, m2, tu, ma
; RV64ZVE32F-NEXT: vmv.x.s a2, v9
; RV64ZVE32F-NEXT: add a2, a0, a2
; RV64ZVE32F-NEXT: lbu a2, 0(a2)
; RV64ZVE32F-NEXT: li a3, 32
-; RV64ZVE32F-NEXT: vsetvli zero, a3, e8, mf4, ta, ma
; RV64ZVE32F-NEXT: vmv.s.x v12, a2
-; RV64ZVE32F-NEXT: vsetivli zero, 21, e8, m2, tu, ma
; RV64ZVE32F-NEXT: vslideup.vi v10, v12, 20
; RV64ZVE32F-NEXT: .LBB98_32: # %else59
; RV64ZVE32F-NEXT: vsetivli zero, 8, e8, m1, ta, ma
@@ -12742,14 +12737,12 @@ define <32 x i8> @mgather_baseidx_v32i8(ptr %base, <32 x i8> %idxs, <32 x i1> %m
; RV64ZVE32F-NEXT: andi a2, a1, 256
; RV64ZVE32F-NEXT: beqz a2, .LBB98_13
; RV64ZVE32F-NEXT: .LBB98_53: # %cond.load22
-; RV64ZVE32F-NEXT: vsetivli zero, 1, e8, mf4, ta, ma
+; RV64ZVE32F-NEXT: vsetivli zero, 9, e8, m1, tu, ma
; RV64ZVE32F-NEXT: vmv.x.s a2, v12
; RV64ZVE32F-NEXT: add a2, a0, a2
; RV64ZVE32F-NEXT: lbu a2, 0(a2)
; RV64ZVE32F-NEXT: li a3, 32
-; RV64ZVE32F-NEXT: vsetvli zero, a3, e8, mf4, ta, ma
; RV64ZVE32F-NEXT: vmv.s.x v13, a2
-; RV64ZVE32F-NEXT: vsetivli zero, 9, e8, m1, tu, ma
; RV64ZVE32F-NEXT: vslideup.vi v10, v13, 8
; RV64ZVE32F-NEXT: andi a2, a1, 512
; RV64ZVE32F-NEXT: bnez a2, .LBB98_14
@@ -12777,14 +12770,12 @@ define <32 x i8> @mgather_baseidx_v32i8(ptr %base, <32 x i8> %idxs, <32 x i1> %m
; RV64ZVE32F-NEXT: slli a2, a1, 47
; RV64ZVE32F-NEXT: bgez a2, .LBB98_26
; RV64ZVE32F-NEXT: .LBB98_56: # %cond.load46
-; RV64ZVE32F-NEXT: vsetivli zero, 1, e8, mf4, ta, ma
+; RV64ZVE32F-NEXT: vsetivli zero, 17, e8, m2, tu, ma
; RV64ZVE32F-NEXT: vmv.x.s a2, v8
; RV64ZVE32F-NEXT: add a2, a0, a2
; RV64ZVE32F-NEXT: lbu a2, 0(a2)
; RV64ZVE32F-NEXT: li a3, 32
-; RV64ZVE32F-NEXT: vsetvli zero, a3, e8, mf4, ta, ma
; RV64ZVE32F-NEXT: vmv.s.x v12, a2
-; RV64ZVE32F-NEXT: vsetivli zero, 17, e8, m2, tu, ma
; RV64ZVE32F-NEXT: vslideup.vi v10, v12, 16
; RV64ZVE32F-NEXT: slli a2, a1, 46
; RV64ZVE32F-NEXT: bltz a2, .LBB98_27
@@ -12835,14 +12826,12 @@ define <32 x i8> @mgather_baseidx_v32i8(ptr %base, <32 x i8> %idxs, <32 x i1> %m
; RV64ZVE32F-NEXT: slli a2, a1, 39
; RV64ZVE32F-NEXT: bgez a2, .LBB98_37
; RV64ZVE32F-NEXT: .LBB98_61: # %cond.load70
-; RV64ZVE32F-NEXT: vsetivli zero, 1, e8, mf4, ta, ma
+; RV64ZVE32F-NEXT: vsetivli zero, 25, e8, m2, tu, ma
; RV64ZVE32F-NEXT: vmv.x.s a2, v8
; RV64ZVE32F-NEXT: add a2, a0, a2
; RV64ZVE32F-NEXT: lbu a2, 0(a2)
; RV64ZVE32F-NEXT: li a3, 32
-; RV64ZVE32F-NEXT: vsetvli zero, a3, e8, mf4, ta, ma
; RV64ZVE32F-NEXT: vmv.s.x v12, a2
-; RV64ZVE32F-NEXT: vsetivli zero, 25, e8, m2, tu, ma
; RV64ZVE32F-NEXT: vslideup.vi v10, v12, 24
; RV64ZVE32F-NEXT: slli a2, a1, 38
; RV64ZVE32F-NEXT: bltz a2, .LBB98_38
@@ -12870,14 +12859,12 @@ define <32 x i8> @mgather_baseidx_v32i8(ptr %base, <32 x i8> %idxs, <32 x i1> %m
; RV64ZVE32F-NEXT: slli a2, a1, 35
; RV64ZVE32F-NEXT: bgez a2, .LBB98_42
; RV64ZVE32F-NEXT: .LBB98_64: # %cond.load82
-; RV64ZVE32F-NEXT: vsetivli zero, 1, e8, mf4, ta, ma
+; RV64ZVE32F-NEXT: vsetivli zero, 29, e8, m2, tu, ma
; RV64ZVE32F-NEXT: vmv.x.s a2, v9
; RV64ZVE32F-NEXT: add a2, a0, a2
; RV64ZVE32F-NEXT: lbu a2, 0(a2)
; RV64ZVE32F-NEXT: li a3, 32
-; RV64ZVE32F-NEXT: vsetvli zero, a3, e8, mf4, ta, ma
; RV64ZVE32F-NEXT: vmv.s.x v12, a2
-; RV64ZVE32F-NEXT: vsetivli zero, 29, e8, m2, tu, ma
; RV64ZVE32F-NEXT: vslideup.vi v10, v12, 28
; RV64ZVE32F-NEXT: slli a2, a1, 34
; RV64ZVE32F-NEXT: bltz a2, .LBB98_43
diff --git a/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert.ll b/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert.ll
index 4d0f640408dd2a9..5ddb5469c079ec5 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert.ll
@@ -329,9 +329,9 @@ entry:
define double @test17(i64 %avl, <vscale x 1 x double> %a, <vscale x 1 x double> %b) nounwind {
; CHECK-LABEL: test17:
; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetvli a0, a0, e64, m1, ta, ma
-; CHECK-NEXT: vfmv.f.s fa5, v8
+; CHECK-NEXT: vsetvli a0, a0, e32, mf2, ta, ma
; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, ma
+; CHECK-NEXT: vfmv.f.s fa5, v8
; CHECK-NEXT: vfadd.vv v8, v8, v9
; CHECK-NEXT: vfmv.f.s fa4, v8
; CHECK-NEXT: fadd.d fa0, fa5, fa4
``````````
</details>
https://github.com/llvm/llvm-project/pull/71657
More information about the llvm-commits
mailing list