[llvm] [RISCV] Convert AVLs with vlenb to VLMAX where possible (PR #97800)
Luke Lau via llvm-commits
llvm-commits at lists.llvm.org
Fri Jul 5 02:18:22 PDT 2024
https://github.com/lukel97 updated https://github.com/llvm/llvm-project/pull/97800
>From 8f2d8b70c5328c8418db8ead7b7b7a2d1422b632 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Fri, 5 Jul 2024 17:01:24 +0800
Subject: [PATCH 1/2] [RISCV] Convert AVLs with vlenb to VLMAX where possible
Given an AVL that's computed from vlenb, if it's equal to VLMAX then we can replace it with the VLMAX sentinel value.
The main motiviation is to be able to express an EVL of VLMAX in VP intrinsics whilst emitting vsetvli a0, zero, so that we can replace llvm.riscv.masked.strided.{load,store} with their VP counterparts.
This is done in RISCVFoldMasks instead of SelectionDAG since there are multiple places places where VP nodes are lowered that would have need to have been handled.
This also avoids doing it in RISCVInsertVSETVLI as it's much harder to lookup the value of the AVL, and in RISCVFoldMasks we can take advantage of DeadMachineInstrElim to remove any leftover PseudoReadVLENBs.
---
llvm/lib/Target/RISCV/RISCVFoldMasks.cpp | 56 +++++++++++++++++++
.../CodeGen/RISCV/rvv/insert-subvector.ll | 12 ++--
llvm/test/CodeGen/RISCV/rvv/vadd-vp.ll | 13 ++---
llvm/test/CodeGen/RISCV/rvv/vmax-vp.ll | 13 ++---
llvm/test/CodeGen/RISCV/rvv/vmaxu-vp.ll | 13 ++---
llvm/test/CodeGen/RISCV/rvv/vmin-vp.ll | 13 ++---
llvm/test/CodeGen/RISCV/rvv/vminu-vp.ll | 13 ++---
7 files changed, 87 insertions(+), 46 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp b/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp
index 2089f5dda6fe5..f65127beaaa2b 100644
--- a/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp
+++ b/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp
@@ -14,6 +14,12 @@
// ->
// PseudoVMV_V_V %false, %true, %vl, %sew
//
+// It also converts AVLs to VLMAX where possible
+// %vl = VLENB * something
+// PseudoVADD_V_V %a, %b, %vl
+// ->
+// PseudoVADD_V_V %a, %b, -1
+//
//===---------------------------------------------------------------------===//
#include "RISCV.h"
@@ -47,6 +53,7 @@ class RISCVFoldMasks : public MachineFunctionPass {
StringRef getPassName() const override { return "RISC-V Fold Masks"; }
private:
+ bool convertToVLMAX(MachineInstr &MI) const;
bool convertToUnmasked(MachineInstr &MI) const;
bool convertVMergeToVMv(MachineInstr &MI) const;
@@ -62,6 +69,54 @@ char RISCVFoldMasks::ID = 0;
INITIALIZE_PASS(RISCVFoldMasks, DEBUG_TYPE, "RISC-V Fold Masks", false, false)
+// If an AVL is a VLENB that's possibly scaled to be equal to VLMAX, convert it
+// to the VLMAX sentinel value.
+bool RISCVFoldMasks::convertToVLMAX(MachineInstr &MI) const {
+ if (!RISCVII::hasVLOp(MI.getDesc().TSFlags) ||
+ !RISCVII::hasSEWOp(MI.getDesc().TSFlags))
+ return false;
+ MachineOperand &VL = MI.getOperand(RISCVII::getVLOpNum(MI.getDesc()));
+ if (!VL.isReg())
+ return false;
+ MachineInstr *Def = MRI->getVRegDef(VL.getReg());
+ if (!Def)
+ return false;
+
+ // Fixed-point value, denumerator=8
+ unsigned ScaleFixed = 8;
+ // Check if the VLENB was scaled for a possible slli/srli
+ if (Def->getOpcode() == RISCV::SLLI) {
+ ScaleFixed <<= Def->getOperand(2).getImm();
+ Def = MRI->getVRegDef(Def->getOperand(1).getReg());
+ } else if (Def->getOpcode() == RISCV::SRLI) {
+ ScaleFixed >>= Def->getOperand(2).getImm();
+ Def = MRI->getVRegDef(Def->getOperand(1).getReg());
+ }
+
+ if (!Def || Def->getOpcode() != RISCV::PseudoReadVLENB)
+ return false;
+
+ auto LMUL = RISCVVType::decodeVLMUL(RISCVII::getLMul(MI.getDesc().TSFlags));
+ // Fixed-point value, denumerator=8
+ unsigned LMULFixed = LMUL.second ? (8 / LMUL.first) : 8 * LMUL.first;
+ unsigned SEW =
+ 1 << MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
+
+ // AVL = (VLENB * Scale)
+ //
+ // VLMAX = (VLENB * 8 * LMUL) / SEW
+ //
+ // AVL == VLMAX
+ // -> VLENB * Scale == (VLENB * 8 * LMUL) / SEW
+ // -> Scale == (8 * LMUL) / SEW
+ if (ScaleFixed != 8 * LMULFixed / SEW)
+ return false;
+
+ VL.ChangeToImmediate(RISCV::VLMaxSentinel);
+
+ return true;
+}
+
bool RISCVFoldMasks::isAllOnesMask(const MachineInstr *MaskDef) const {
assert(MaskDef && MaskDef->isCopy() &&
MaskDef->getOperand(0).getReg() == RISCV::V0);
@@ -213,6 +268,7 @@ bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) {
for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : MBB) {
+ Changed |= convertToVLMAX(MI);
Changed |= convertToUnmasked(MI);
Changed |= convertVMergeToVMv(MI);
}
diff --git a/llvm/test/CodeGen/RISCV/rvv/insert-subvector.ll b/llvm/test/CodeGen/RISCV/rvv/insert-subvector.ll
index 0cd4f423a9df6..8d917f286720a 100644
--- a/llvm/test/CodeGen/RISCV/rvv/insert-subvector.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/insert-subvector.ll
@@ -305,9 +305,9 @@ define <vscale x 16 x i8> @insert_nxv16i8_nxv1i8_7(<vscale x 16 x i8> %vec, <vsc
; CHECK: # %bb.0:
; CHECK-NEXT: csrr a0, vlenb
; CHECK-NEXT: srli a1, a0, 3
-; CHECK-NEXT: sub a1, a0, a1
-; CHECK-NEXT: vsetvli zero, a0, e8, m1, ta, ma
-; CHECK-NEXT: vslideup.vx v8, v10, a1
+; CHECK-NEXT: sub a0, a0, a1
+; CHECK-NEXT: vsetvli a1, zero, e8, m1, ta, ma
+; CHECK-NEXT: vslideup.vx v8, v10, a0
; CHECK-NEXT: ret
%v = call <vscale x 16 x i8> @llvm.vector.insert.nxv1i8.nxv16i8(<vscale x 16 x i8> %vec, <vscale x 1 x i8> %subvec, i64 7)
ret <vscale x 16 x i8> %v
@@ -318,9 +318,9 @@ define <vscale x 16 x i8> @insert_nxv16i8_nxv1i8_15(<vscale x 16 x i8> %vec, <vs
; CHECK: # %bb.0:
; CHECK-NEXT: csrr a0, vlenb
; CHECK-NEXT: srli a1, a0, 3
-; CHECK-NEXT: sub a1, a0, a1
-; CHECK-NEXT: vsetvli zero, a0, e8, m1, ta, ma
-; CHECK-NEXT: vslideup.vx v9, v10, a1
+; CHECK-NEXT: sub a0, a0, a1
+; CHECK-NEXT: vsetvli a1, zero, e8, m1, ta, ma
+; CHECK-NEXT: vslideup.vx v9, v10, a0
; CHECK-NEXT: ret
%v = call <vscale x 16 x i8> @llvm.vector.insert.nxv1i8.nxv16i8(<vscale x 16 x i8> %vec, <vscale x 1 x i8> %subvec, i64 15)
ret <vscale x 16 x i8> %v
diff --git a/llvm/test/CodeGen/RISCV/rvv/vadd-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vadd-vp.ll
index ede395f4df8e1..2a4fbb248cd9c 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vadd-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vadd-vp.ll
@@ -1436,20 +1436,17 @@ define <vscale x 32 x i32> @vadd_vi_nxv32i32_evl_nx8(<vscale x 32 x i32> %va, <v
define <vscale x 32 x i32> @vadd_vi_nxv32i32_evl_nx16(<vscale x 32 x i32> %va, <vscale x 32 x i1> %m) {
; RV32-LABEL: vadd_vi_nxv32i32_evl_nx16:
; RV32: # %bb.0:
-; RV32-NEXT: csrr a0, vlenb
-; RV32-NEXT: slli a0, a0, 1
-; RV32-NEXT: vsetvli zero, a0, e32, m8, ta, ma
+; RV32-NEXT: vsetvli a0, zero, e32, m8, ta, ma
; RV32-NEXT: vadd.vi v8, v8, -1, v0.t
; RV32-NEXT: ret
;
; RV64-LABEL: vadd_vi_nxv32i32_evl_nx16:
; RV64: # %bb.0:
; RV64-NEXT: csrr a0, vlenb
-; RV64-NEXT: srli a1, a0, 2
-; RV64-NEXT: vsetvli a2, zero, e8, mf2, ta, ma
-; RV64-NEXT: vslidedown.vx v24, v0, a1
-; RV64-NEXT: slli a0, a0, 1
-; RV64-NEXT: vsetvli zero, a0, e32, m8, ta, ma
+; RV64-NEXT: srli a0, a0, 2
+; RV64-NEXT: vsetvli a1, zero, e8, mf2, ta, ma
+; RV64-NEXT: vslidedown.vx v24, v0, a0
+; RV64-NEXT: vsetvli a0, zero, e32, m8, ta, ma
; RV64-NEXT: vadd.vi v8, v8, -1, v0.t
; RV64-NEXT: vmv1r.v v0, v24
; RV64-NEXT: vsetivli zero, 0, e32, m8, ta, ma
diff --git a/llvm/test/CodeGen/RISCV/rvv/vmax-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vmax-vp.ll
index c15caa31bb098..5fdfb332da7cf 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vmax-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vmax-vp.ll
@@ -1073,20 +1073,17 @@ define <vscale x 32 x i32> @vmax_vx_nxv32i32_evl_nx8(<vscale x 32 x i32> %va, i3
define <vscale x 32 x i32> @vmax_vx_nxv32i32_evl_nx16(<vscale x 32 x i32> %va, i32 %b, <vscale x 32 x i1> %m) {
; RV32-LABEL: vmax_vx_nxv32i32_evl_nx16:
; RV32: # %bb.0:
-; RV32-NEXT: csrr a1, vlenb
-; RV32-NEXT: slli a1, a1, 1
-; RV32-NEXT: vsetvli zero, a1, e32, m8, ta, ma
+; RV32-NEXT: vsetvli a1, zero, e32, m8, ta, ma
; RV32-NEXT: vmax.vx v8, v8, a0, v0.t
; RV32-NEXT: ret
;
; RV64-LABEL: vmax_vx_nxv32i32_evl_nx16:
; RV64: # %bb.0:
; RV64-NEXT: csrr a1, vlenb
-; RV64-NEXT: srli a2, a1, 2
-; RV64-NEXT: vsetvli a3, zero, e8, mf2, ta, ma
-; RV64-NEXT: vslidedown.vx v24, v0, a2
-; RV64-NEXT: slli a1, a1, 1
-; RV64-NEXT: vsetvli zero, a1, e32, m8, ta, ma
+; RV64-NEXT: srli a1, a1, 2
+; RV64-NEXT: vsetvli a2, zero, e8, mf2, ta, ma
+; RV64-NEXT: vslidedown.vx v24, v0, a1
+; RV64-NEXT: vsetvli a1, zero, e32, m8, ta, ma
; RV64-NEXT: vmax.vx v8, v8, a0, v0.t
; RV64-NEXT: vmv1r.v v0, v24
; RV64-NEXT: vsetivli zero, 0, e32, m8, ta, ma
diff --git a/llvm/test/CodeGen/RISCV/rvv/vmaxu-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vmaxu-vp.ll
index df494f8af7387..7d678950b7a3c 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vmaxu-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vmaxu-vp.ll
@@ -1072,20 +1072,17 @@ define <vscale x 32 x i32> @vmaxu_vx_nxv32i32_evl_nx8(<vscale x 32 x i32> %va, i
define <vscale x 32 x i32> @vmaxu_vx_nxv32i32_evl_nx16(<vscale x 32 x i32> %va, i32 %b, <vscale x 32 x i1> %m) {
; RV32-LABEL: vmaxu_vx_nxv32i32_evl_nx16:
; RV32: # %bb.0:
-; RV32-NEXT: csrr a1, vlenb
-; RV32-NEXT: slli a1, a1, 1
-; RV32-NEXT: vsetvli zero, a1, e32, m8, ta, ma
+; RV32-NEXT: vsetvli a1, zero, e32, m8, ta, ma
; RV32-NEXT: vmaxu.vx v8, v8, a0, v0.t
; RV32-NEXT: ret
;
; RV64-LABEL: vmaxu_vx_nxv32i32_evl_nx16:
; RV64: # %bb.0:
; RV64-NEXT: csrr a1, vlenb
-; RV64-NEXT: srli a2, a1, 2
-; RV64-NEXT: vsetvli a3, zero, e8, mf2, ta, ma
-; RV64-NEXT: vslidedown.vx v24, v0, a2
-; RV64-NEXT: slli a1, a1, 1
-; RV64-NEXT: vsetvli zero, a1, e32, m8, ta, ma
+; RV64-NEXT: srli a1, a1, 2
+; RV64-NEXT: vsetvli a2, zero, e8, mf2, ta, ma
+; RV64-NEXT: vslidedown.vx v24, v0, a1
+; RV64-NEXT: vsetvli a1, zero, e32, m8, ta, ma
; RV64-NEXT: vmaxu.vx v8, v8, a0, v0.t
; RV64-NEXT: vmv1r.v v0, v24
; RV64-NEXT: vsetivli zero, 0, e32, m8, ta, ma
diff --git a/llvm/test/CodeGen/RISCV/rvv/vmin-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vmin-vp.ll
index 794a21c7c6aba..98a288ed68b9a 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vmin-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vmin-vp.ll
@@ -1073,20 +1073,17 @@ define <vscale x 32 x i32> @vmin_vx_nxv32i32_evl_nx8(<vscale x 32 x i32> %va, i3
define <vscale x 32 x i32> @vmin_vx_nxv32i32_evl_nx16(<vscale x 32 x i32> %va, i32 %b, <vscale x 32 x i1> %m) {
; RV32-LABEL: vmin_vx_nxv32i32_evl_nx16:
; RV32: # %bb.0:
-; RV32-NEXT: csrr a1, vlenb
-; RV32-NEXT: slli a1, a1, 1
-; RV32-NEXT: vsetvli zero, a1, e32, m8, ta, ma
+; RV32-NEXT: vsetvli a1, zero, e32, m8, ta, ma
; RV32-NEXT: vmin.vx v8, v8, a0, v0.t
; RV32-NEXT: ret
;
; RV64-LABEL: vmin_vx_nxv32i32_evl_nx16:
; RV64: # %bb.0:
; RV64-NEXT: csrr a1, vlenb
-; RV64-NEXT: srli a2, a1, 2
-; RV64-NEXT: vsetvli a3, zero, e8, mf2, ta, ma
-; RV64-NEXT: vslidedown.vx v24, v0, a2
-; RV64-NEXT: slli a1, a1, 1
-; RV64-NEXT: vsetvli zero, a1, e32, m8, ta, ma
+; RV64-NEXT: srli a1, a1, 2
+; RV64-NEXT: vsetvli a2, zero, e8, mf2, ta, ma
+; RV64-NEXT: vslidedown.vx v24, v0, a1
+; RV64-NEXT: vsetvli a1, zero, e32, m8, ta, ma
; RV64-NEXT: vmin.vx v8, v8, a0, v0.t
; RV64-NEXT: vmv1r.v v0, v24
; RV64-NEXT: vsetivli zero, 0, e32, m8, ta, ma
diff --git a/llvm/test/CodeGen/RISCV/rvv/vminu-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vminu-vp.ll
index d54de281a7fd2..34b554b7ff514 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vminu-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vminu-vp.ll
@@ -1072,20 +1072,17 @@ define <vscale x 32 x i32> @vminu_vx_nxv32i32_evl_nx8(<vscale x 32 x i32> %va, i
define <vscale x 32 x i32> @vminu_vx_nxv32i32_evl_nx16(<vscale x 32 x i32> %va, i32 %b, <vscale x 32 x i1> %m) {
; RV32-LABEL: vminu_vx_nxv32i32_evl_nx16:
; RV32: # %bb.0:
-; RV32-NEXT: csrr a1, vlenb
-; RV32-NEXT: slli a1, a1, 1
-; RV32-NEXT: vsetvli zero, a1, e32, m8, ta, ma
+; RV32-NEXT: vsetvli a1, zero, e32, m8, ta, ma
; RV32-NEXT: vminu.vx v8, v8, a0, v0.t
; RV32-NEXT: ret
;
; RV64-LABEL: vminu_vx_nxv32i32_evl_nx16:
; RV64: # %bb.0:
; RV64-NEXT: csrr a1, vlenb
-; RV64-NEXT: srli a2, a1, 2
-; RV64-NEXT: vsetvli a3, zero, e8, mf2, ta, ma
-; RV64-NEXT: vslidedown.vx v24, v0, a2
-; RV64-NEXT: slli a1, a1, 1
-; RV64-NEXT: vsetvli zero, a1, e32, m8, ta, ma
+; RV64-NEXT: srli a1, a1, 2
+; RV64-NEXT: vsetvli a2, zero, e8, mf2, ta, ma
+; RV64-NEXT: vslidedown.vx v24, v0, a1
+; RV64-NEXT: vsetvli a1, zero, e32, m8, ta, ma
; RV64-NEXT: vminu.vx v8, v8, a0, v0.t
; RV64-NEXT: vmv1r.v v0, v24
; RV64-NEXT: vsetivli zero, 0, e32, m8, ta, ma
>From 06cecabdc93dbef20ec74f44dc09e7e8224aa08d Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Fri, 5 Jul 2024 17:18:06 +0800
Subject: [PATCH 2/2] Fix typos
---
llvm/lib/Target/RISCV/RISCVFoldMasks.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp b/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp
index f65127beaaa2b..88ee220e01d96 100644
--- a/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp
+++ b/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp
@@ -82,9 +82,9 @@ bool RISCVFoldMasks::convertToVLMAX(MachineInstr &MI) const {
if (!Def)
return false;
- // Fixed-point value, denumerator=8
+ // Fixed-point value, denominator=8
unsigned ScaleFixed = 8;
- // Check if the VLENB was scaled for a possible slli/srli
+ // Check if the VLENB was potentially scaled with slli/srli
if (Def->getOpcode() == RISCV::SLLI) {
ScaleFixed <<= Def->getOperand(2).getImm();
Def = MRI->getVRegDef(Def->getOperand(1).getReg());
@@ -97,7 +97,7 @@ bool RISCVFoldMasks::convertToVLMAX(MachineInstr &MI) const {
return false;
auto LMUL = RISCVVType::decodeVLMUL(RISCVII::getLMul(MI.getDesc().TSFlags));
- // Fixed-point value, denumerator=8
+ // Fixed-point value, denominator=8
unsigned LMULFixed = LMUL.second ? (8 / LMUL.first) : 8 * LMUL.first;
unsigned SEW =
1 << MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
More information about the llvm-commits
mailing list