[llvm] [RISCV] Move exact VLEN VLMAX transform to RISCVVectorPeephole (PR #100551)
Luke Lau via llvm-commits
llvm-commits at lists.llvm.org
Thu Jul 25 03:17:50 PDT 2024
https://github.com/lukel97 created https://github.com/llvm/llvm-project/pull/100551
We can teach RISCVVectorPeephole to detect when an AVL is equal to the VLMAX when the exact VLEN is known and use the VLMAX sentinel instead, and in doing so remove the need for getVLOp in RISCVISelLowering. This keeps all the VLMAX logic in one place.
>From 701fbc43dc67376782b0d6f5c614dc0dea4f3f4c Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Thu, 25 Jul 2024 18:08:17 +0800
Subject: [PATCH] [RISCV] Move exact VLEN VLMAX transform to
RISCVVectorPeephole
We can teach RISCVVectorPeephole to detect when an AVL is equal to the VLMAX when the exact VLEN is known and use the VLMAX sentinel instead, and in doing so remove the need for getVLOp in RISCVISelLowering. This keeps all the VLMAX logic in one place.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 35 +++----------
llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp | 52 ++++++++++++++-----
llvm/test/CodeGen/RISCV/rvv/pr83017.ll | 6 +--
llvm/test/CodeGen/RISCV/rvv/pr90559.ll | 6 +--
4 files changed, 52 insertions(+), 47 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index d40d4997d7614..0339b302fb218 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -2758,19 +2758,6 @@ static SDValue getAllOnesMask(MVT VecVT, SDValue VL, const SDLoc &DL,
return DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL);
}
-static SDValue getVLOp(uint64_t NumElts, MVT ContainerVT, const SDLoc &DL,
- SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
- // If we know the exact VLEN, and our VL is exactly equal to VLMAX,
- // canonicalize the representation. InsertVSETVLI will pick the immediate
- // encoding later if profitable.
- const auto [MinVLMAX, MaxVLMAX] =
- RISCVTargetLowering::computeVLMAXBounds(ContainerVT, Subtarget);
- if (MinVLMAX == MaxVLMAX && NumElts == MinVLMAX)
- return DAG.getRegister(RISCV::X0, Subtarget.getXLenVT());
-
- return DAG.getConstant(NumElts, DL, Subtarget.getXLenVT());
-}
-
static std::pair<SDValue, SDValue>
getDefaultScalableVLOps(MVT VecVT, const SDLoc &DL, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
@@ -2784,7 +2771,7 @@ static std::pair<SDValue, SDValue>
getDefaultVLOps(uint64_t NumElts, MVT ContainerVT, const SDLoc &DL,
SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
assert(ContainerVT.isScalableVector() && "Expecting scalable container type");
- SDValue VL = getVLOp(NumElts, ContainerVT, DL, DAG, Subtarget);
+ SDValue VL = DAG.getConstant(NumElts, DL, Subtarget.getXLenVT());
SDValue Mask = getAllOnesMask(ContainerVT, VL, DL, DAG);
return {Mask, VL};
}
@@ -9427,8 +9414,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
MVT VT = Op->getSimpleValueType(0);
MVT ContainerVT = getContainerForFixedLengthVector(VT);
- SDValue VL = getVLOp(VT.getVectorNumElements(), ContainerVT, DL, DAG,
- Subtarget);
+ SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
SDValue IntID = DAG.getTargetConstant(VlsegInts[NF - 2], DL, XLenVT);
auto *Load = cast<MemIntrinsicSDNode>(Op);
SmallVector<EVT, 9> ContainerVTs(NF, ContainerVT);
@@ -9507,8 +9493,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_VOID(SDValue Op,
MVT VT = Op->getOperand(2).getSimpleValueType();
MVT ContainerVT = getContainerForFixedLengthVector(VT);
- SDValue VL = getVLOp(VT.getVectorNumElements(), ContainerVT, DL, DAG,
- Subtarget);
+ SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
SDValue IntID = DAG.getTargetConstant(VssegInts[NF - 2], DL, XLenVT);
SDValue Ptr = Op->getOperand(NF + 2);
@@ -9974,7 +9959,7 @@ SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,
// Set the vector length to only the number of elements we care about. Note
// that for slideup this includes the offset.
unsigned EndIndex = OrigIdx + SubVecVT.getVectorNumElements();
- SDValue VL = getVLOp(EndIndex, ContainerVT, DL, DAG, Subtarget);
+ SDValue VL = DAG.getConstant(EndIndex, DL, XLenVT);
// Use tail agnostic policy if we're inserting over Vec's tail.
unsigned Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED;
@@ -10211,8 +10196,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op,
getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget).first;
// Set the vector length to only the number of elements we care about. This
// avoids sliding down elements we're going to discard straight away.
- SDValue VL = getVLOp(SubVecVT.getVectorNumElements(), ContainerVT, DL, DAG,
- Subtarget);
+ SDValue VL = DAG.getConstant(SubVecVT.getVectorNumElements(), DL, XLenVT);
SDValue SlidedownAmt = DAG.getConstant(OrigIdx, DL, XLenVT);
SDValue Slidedown =
getVSlidedown(DAG, Subtarget, DL, ContainerVT,
@@ -10287,8 +10271,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op,
SDValue SlidedownAmt = DAG.getElementCount(DL, XLenVT, RemIdx);
auto [Mask, VL] = getDefaultScalableVLOps(InterSubVT, DL, DAG, Subtarget);
if (SubVecVT.isFixedLengthVector())
- VL = getVLOp(SubVecVT.getVectorNumElements(), InterSubVT, DL, DAG,
- Subtarget);
+ VL = DAG.getConstant(SubVecVT.getVectorNumElements(), DL, XLenVT);
SDValue Slidedown =
getVSlidedown(DAG, Subtarget, DL, InterSubVT, DAG.getUNDEF(InterSubVT),
Vec, SlidedownAmt, Mask, VL);
@@ -10668,7 +10651,7 @@ RISCVTargetLowering::lowerFixedLengthVectorLoadToRVV(SDValue Op,
return DAG.getMergeValues({Result, NewLoad.getValue(1)}, DL);
}
- SDValue VL = getVLOp(VT.getVectorNumElements(), ContainerVT, DL, DAG, Subtarget);
+ SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
bool IsMaskOp = VT.getVectorElementType() == MVT::i1;
SDValue IntID = DAG.getTargetConstant(
@@ -10715,7 +10698,6 @@ RISCVTargetLowering::lowerFixedLengthVectorStoreToRVV(SDValue Op,
SDValue NewValue =
convertToScalableVector(ContainerVT, StoreVal, DAG, Subtarget);
-
// If we know the exact VLEN and our fixed length vector completely fills
// the container, use a whole register store instead.
const auto [MinVLMAX, MaxVLMAX] =
@@ -10728,8 +10710,7 @@ RISCVTargetLowering::lowerFixedLengthVectorStoreToRVV(SDValue Op,
MMO->getFlags(), MMO->getAAInfo());
}
- SDValue VL = getVLOp(VT.getVectorNumElements(), ContainerVT, DL, DAG,
- Subtarget);
+ SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
bool IsMaskOp = VT.getVectorElementType() == MVT::i1;
SDValue IntID = DAG.getTargetConstant(
diff --git a/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp b/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp
index b083e64cfc8d7..f328c55e1d3ba 100644
--- a/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp
+++ b/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp
@@ -47,6 +47,7 @@ class RISCVVectorPeephole : public MachineFunctionPass {
const TargetInstrInfo *TII;
MachineRegisterInfo *MRI;
const TargetRegisterInfo *TRI;
+ const RISCVSubtarget *ST;
RISCVVectorPeephole() : MachineFunctionPass(ID) {}
bool runOnMachineFunction(MachineFunction &MF) override;
@@ -64,6 +65,7 @@ class RISCVVectorPeephole : public MachineFunctionPass {
bool convertVMergeToVMv(MachineInstr &MI) const;
bool isAllOnesMask(const MachineInstr *MaskDef) const;
+ std::optional<unsigned> getConstant(const MachineOperand &VL) const;
/// Maps uses of V0 to the corresponding def of V0.
DenseMap<const MachineInstr *, const MachineInstr *> V0Defs;
@@ -76,13 +78,44 @@ char RISCVVectorPeephole::ID = 0;
INITIALIZE_PASS(RISCVVectorPeephole, DEBUG_TYPE, "RISC-V Fold Masks", false,
false)
-// If an AVL is a VLENB that's possibly scaled to be equal to VLMAX, convert it
-// to the VLMAX sentinel value.
+/// Check if an operand is an immediate or a materialized ADDI $x0, imm.
+std::optional<unsigned>
+RISCVVectorPeephole::getConstant(const MachineOperand &VL) const {
+ if (VL.isImm())
+ return VL.getImm();
+
+ MachineInstr *Def = MRI->getVRegDef(VL.getReg());
+ if (!Def || Def->getOpcode() != RISCV::ADDI ||
+ Def->getOperand(1).getReg() != RISCV::X0)
+ return std::nullopt;
+ return Def->getOperand(2).getImm();
+}
+
+/// Convert AVLs that are known to be VLMAX to the VLMAX sentinel.
bool RISCVVectorPeephole::convertToVLMAX(MachineInstr &MI) const {
if (!RISCVII::hasVLOp(MI.getDesc().TSFlags) ||
!RISCVII::hasSEWOp(MI.getDesc().TSFlags))
return false;
+
+ auto LMUL = RISCVVType::decodeVLMUL(RISCVII::getLMul(MI.getDesc().TSFlags));
+ // Fixed-point value, denominator=8
+ unsigned LMULFixed = LMUL.second ? (8 / LMUL.first) : 8 * LMUL.first;
+ unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
+ // A Log2SEW of 0 is an operation on mask registers only
+ unsigned SEW = Log2SEW ? 1 << Log2SEW : 8;
+ assert(RISCVVType::isValidSEW(SEW) && "Unexpected SEW");
+ assert(8 * LMULFixed / SEW > 0);
+
+ // If the exact VLEN is known then we know VLMAX, check if the AVL == VLMAX.
MachineOperand &VL = MI.getOperand(RISCVII::getVLOpNum(MI.getDesc()));
+ if (auto VLen = ST->getRealVLen(), AVL = getConstant(VL);
+ VLen && AVL && (*VLen * LMULFixed) / SEW == *AVL * 8) {
+ VL.ChangeToImmediate(RISCV::VLMaxSentinel);
+ return true;
+ }
+
+ // If an AVL is a VLENB that's possibly scaled to be equal to VLMAX, convert
+ // it to the VLMAX sentinel value.
if (!VL.isReg())
return false;
MachineInstr *Def = MRI->getVRegDef(VL.getReg());
@@ -105,15 +138,6 @@ bool RISCVVectorPeephole::convertToVLMAX(MachineInstr &MI) const {
if (!Def || Def->getOpcode() != RISCV::PseudoReadVLENB)
return false;
- auto LMUL = RISCVVType::decodeVLMUL(RISCVII::getLMul(MI.getDesc().TSFlags));
- // Fixed-point value, denominator=8
- unsigned LMULFixed = LMUL.second ? (8 / LMUL.first) : 8 * LMUL.first;
- unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
- // A Log2SEW of 0 is an operation on mask registers only
- unsigned SEW = Log2SEW ? 1 << Log2SEW : 8;
- assert(RISCVVType::isValidSEW(SEW) && "Unexpected SEW");
- assert(8 * LMULFixed / SEW > 0);
-
// AVL = (VLENB * Scale)
//
// VLMAX = (VLENB * 8 * LMUL) / SEW
@@ -302,11 +326,11 @@ bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) {
return false;
// Skip if the vector extension is not enabled.
- const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
- if (!ST.hasVInstructions())
+ ST = &MF.getSubtarget<RISCVSubtarget>();
+ if (!ST->hasVInstructions())
return false;
- TII = ST.getInstrInfo();
+ TII = ST->getInstrInfo();
MRI = &MF.getRegInfo();
TRI = MRI->getTargetRegisterInfo();
diff --git a/llvm/test/CodeGen/RISCV/rvv/pr83017.ll b/llvm/test/CodeGen/RISCV/rvv/pr83017.ll
index 3719a2ad994d6..beca480378a35 100644
--- a/llvm/test/CodeGen/RISCV/rvv/pr83017.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/pr83017.ll
@@ -35,11 +35,11 @@ define void @aliasing(ptr %p) {
; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma
; CHECK-NEXT: vmv.v.i v8, 0
; CHECK-NEXT: vs1r.v v8, (a2)
-; CHECK-NEXT: vsetvli a2, zero, e8, m4, ta, ma
-; CHECK-NEXT: vmv.v.i v12, 0
-; CHECK-NEXT: vs4r.v v12, (a0)
; CHECK-NEXT: addi a2, a0, 64
; CHECK-NEXT: vs1r.v v8, (a2)
+; CHECK-NEXT: vsetvli a2, zero, e8, m4, ta, ma
+; CHECK-NEXT: vmv.v.i v8, 0
+; CHECK-NEXT: vs4r.v v8, (a0)
; CHECK-NEXT: sw a1, 84(a0)
; CHECK-NEXT: ret
%q = getelementptr inbounds i8, ptr %p, i64 84
diff --git a/llvm/test/CodeGen/RISCV/rvv/pr90559.ll b/llvm/test/CodeGen/RISCV/rvv/pr90559.ll
index 8d330b12055ae..7e109f307c4a5 100644
--- a/llvm/test/CodeGen/RISCV/rvv/pr90559.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/pr90559.ll
@@ -32,11 +32,11 @@ define void @f(ptr %p) vscale_range(2,2) {
; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma
; CHECK-NEXT: vmv.v.i v8, 0
; CHECK-NEXT: vs1r.v v8, (a2)
-; CHECK-NEXT: vsetvli a2, zero, e8, m4, ta, ma
-; CHECK-NEXT: vmv.v.i v12, 0
-; CHECK-NEXT: vs4r.v v12, (a0)
; CHECK-NEXT: addi a2, a0, 64
; CHECK-NEXT: vs1r.v v8, (a2)
+; CHECK-NEXT: vsetvli a2, zero, e8, m4, ta, ma
+; CHECK-NEXT: vmv.v.i v8, 0
+; CHECK-NEXT: vs4r.v v8, (a0)
; CHECK-NEXT: sw a1, 84(a0)
; CHECK-NEXT: ret
%q = getelementptr inbounds i8, ptr %p, i64 84
More information about the llvm-commits
mailing list