[llvm] [RISCV][MachineCombiner] Add reassociation optimizations for RVV instructions (PR #88307)
Min-Yih Hsu via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 10 11:33:49 PDT 2024
https://github.com/mshockwave created https://github.com/llvm/llvm-project/pull/88307
This patch covers VADD_VV, VMUL_VV, VMULU_VV, and VMULUH_VV.
----
This PR is stacked on top of #88306 (specifically, 8e67ee297d4e050b6fc0ac7cc6d2c14d514c8997)
I also put pre-commit test in a separate commit so that it's easier to grasp the differences. I will either squash it or push it before merge.
>From 8e67ee297d4e050b6fc0ac7cc6d2c14d514c8997 Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Mon, 8 Apr 2024 11:41:27 -0700
Subject: [PATCH 1/3] [CodeGen][TII] Allow reassociation on custom operand
indices
This opens up a door for reusing reassociation optimizations on target-specific
binary operations with non-standard operand list.
This is effectively a NFC.
---
llvm/include/llvm/CodeGen/TargetInstrInfo.h | 11 ++
llvm/lib/CodeGen/TargetInstrInfo.cpp | 145 ++++++++++++++------
llvm/lib/Target/RISCV/RISCVInstrInfo.cpp | 8 +-
3 files changed, 115 insertions(+), 49 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/TargetInstrInfo.h b/llvm/include/llvm/CodeGen/TargetInstrInfo.h
index 9fd0ebe6956fbe..82c952b227557d 100644
--- a/llvm/include/llvm/CodeGen/TargetInstrInfo.h
+++ b/llvm/include/llvm/CodeGen/TargetInstrInfo.h
@@ -30,6 +30,7 @@
#include "llvm/MC/MCInstrInfo.h"
#include "llvm/Support/BranchProbability.h"
#include "llvm/Support/ErrorHandling.h"
+#include <array>
#include <cassert>
#include <cstddef>
#include <cstdint>
@@ -1268,12 +1269,22 @@ class TargetInstrInfo : public MCInstrInfo {
return true;
}
+ /// The returned array encodes the operand index for each parameter because
+ /// the operands may be commuted; the operand indices for associative
+ /// operations might also be target-specific. Each element specifies the index
+ /// of {Prev, A, B, X, Y}.
+ virtual void
+ getReassociateOperandIdx(const MachineInstr &Root,
+ MachineCombinerPattern Pattern,
+ std::array<unsigned, 5> &OperandIndices) const;
+
/// Attempt to reassociate \P Root and \P Prev according to \P Pattern to
/// reduce critical path length.
void reassociateOps(MachineInstr &Root, MachineInstr &Prev,
MachineCombinerPattern Pattern,
SmallVectorImpl<MachineInstr *> &InsInstrs,
SmallVectorImpl<MachineInstr *> &DelInstrs,
+ ArrayRef<unsigned> OperandIndices,
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const;
/// Reassociation of some instructions requires inverse operations (e.g.
diff --git a/llvm/lib/CodeGen/TargetInstrInfo.cpp b/llvm/lib/CodeGen/TargetInstrInfo.cpp
index 9fbd516acea8e1..488922e3c1b720 100644
--- a/llvm/lib/CodeGen/TargetInstrInfo.cpp
+++ b/llvm/lib/CodeGen/TargetInstrInfo.cpp
@@ -1051,13 +1051,34 @@ static std::pair<bool, bool> mustSwapOperands(MachineCombinerPattern Pattern) {
}
}
+void TargetInstrInfo::getReassociateOperandIdx(
+ const MachineInstr &Root, MachineCombinerPattern Pattern,
+ std::array<unsigned, 5> &OperandIndices) const {
+ switch (Pattern) {
+ case MachineCombinerPattern::REASSOC_AX_BY:
+ OperandIndices = {1, 1, 1, 2, 2};
+ break;
+ case MachineCombinerPattern::REASSOC_AX_YB:
+ OperandIndices = {2, 1, 2, 2, 1};
+ break;
+ case MachineCombinerPattern::REASSOC_XA_BY:
+ OperandIndices = {1, 2, 1, 1, 2};
+ break;
+ case MachineCombinerPattern::REASSOC_XA_YB:
+ OperandIndices = {2, 2, 2, 1, 1};
+ break;
+ default:
+ llvm_unreachable("unexpected MachineCombinerPattern");
+ }
+}
+
/// Attempt the reassociation transformation to reduce critical path length.
/// See the above comments before getMachineCombinerPatterns().
void TargetInstrInfo::reassociateOps(
- MachineInstr &Root, MachineInstr &Prev,
- MachineCombinerPattern Pattern,
+ MachineInstr &Root, MachineInstr &Prev, MachineCombinerPattern Pattern,
SmallVectorImpl<MachineInstr *> &InsInstrs,
SmallVectorImpl<MachineInstr *> &DelInstrs,
+ ArrayRef<unsigned> OperandIndices,
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const {
MachineFunction *MF = Root.getMF();
MachineRegisterInfo &MRI = MF->getRegInfo();
@@ -1065,29 +1086,10 @@ void TargetInstrInfo::reassociateOps(
const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();
const TargetRegisterClass *RC = Root.getRegClassConstraint(0, TII, TRI);
- // This array encodes the operand index for each parameter because the
- // operands may be commuted. Each row corresponds to a pattern value,
- // and each column specifies the index of A, B, X, Y.
- unsigned OpIdx[4][4] = {
- { 1, 1, 2, 2 },
- { 1, 2, 2, 1 },
- { 2, 1, 1, 2 },
- { 2, 2, 1, 1 }
- };
-
- int Row;
- switch (Pattern) {
- case MachineCombinerPattern::REASSOC_AX_BY: Row = 0; break;
- case MachineCombinerPattern::REASSOC_AX_YB: Row = 1; break;
- case MachineCombinerPattern::REASSOC_XA_BY: Row = 2; break;
- case MachineCombinerPattern::REASSOC_XA_YB: Row = 3; break;
- default: llvm_unreachable("unexpected MachineCombinerPattern");
- }
-
- MachineOperand &OpA = Prev.getOperand(OpIdx[Row][0]);
- MachineOperand &OpB = Root.getOperand(OpIdx[Row][1]);
- MachineOperand &OpX = Prev.getOperand(OpIdx[Row][2]);
- MachineOperand &OpY = Root.getOperand(OpIdx[Row][3]);
+ MachineOperand &OpA = Prev.getOperand(OperandIndices[1]);
+ MachineOperand &OpB = Root.getOperand(OperandIndices[2]);
+ MachineOperand &OpX = Prev.getOperand(OperandIndices[3]);
+ MachineOperand &OpY = Root.getOperand(OperandIndices[4]);
MachineOperand &OpC = Root.getOperand(0);
Register RegA = OpA.getReg();
@@ -1126,11 +1128,62 @@ void TargetInstrInfo::reassociateOps(
std::swap(KillX, KillY);
}
+ unsigned PrevFirstOpIdx, PrevSecondOpIdx;
+ unsigned RootFirstOpIdx, RootSecondOpIdx;
+ switch (Pattern) {
+ case MachineCombinerPattern::REASSOC_AX_BY:
+ PrevFirstOpIdx = OperandIndices[1];
+ PrevSecondOpIdx = OperandIndices[3];
+ RootFirstOpIdx = OperandIndices[2];
+ RootSecondOpIdx = OperandIndices[4];
+ break;
+ case MachineCombinerPattern::REASSOC_AX_YB:
+ PrevFirstOpIdx = OperandIndices[1];
+ PrevSecondOpIdx = OperandIndices[3];
+ RootFirstOpIdx = OperandIndices[4];
+ RootSecondOpIdx = OperandIndices[2];
+ break;
+ case MachineCombinerPattern::REASSOC_XA_BY:
+ PrevFirstOpIdx = OperandIndices[3];
+ PrevSecondOpIdx = OperandIndices[1];
+ RootFirstOpIdx = OperandIndices[2];
+ RootSecondOpIdx = OperandIndices[4];
+ break;
+ case MachineCombinerPattern::REASSOC_XA_YB:
+ PrevFirstOpIdx = OperandIndices[3];
+ PrevSecondOpIdx = OperandIndices[1];
+ RootFirstOpIdx = OperandIndices[4];
+ RootSecondOpIdx = OperandIndices[2];
+ break;
+ default:
+ llvm_unreachable("unexpected MachineCombinerPattern");
+ }
+
+ // Basically BuildMI but doesn't add implicit operands by default.
+ auto buildMINoImplicit = [](MachineFunction &MF, const MIMetadata &MIMD,
+ const MCInstrDesc &MCID, Register DestReg) {
+ return MachineInstrBuilder(
+ MF, MF.CreateMachineInstr(MCID, MIMD.getDL(), /*NoImpl=*/true))
+ .setPCSections(MIMD.getPCSections())
+ .addReg(DestReg, RegState::Define);
+ };
+
// Create new instructions for insertion.
MachineInstrBuilder MIB1 =
- BuildMI(*MF, MIMetadata(Prev), TII->get(NewPrevOpc), NewVR)
- .addReg(RegX, getKillRegState(KillX))
- .addReg(RegY, getKillRegState(KillY));
+ buildMINoImplicit(*MF, MIMetadata(Prev), TII->get(NewPrevOpc), NewVR);
+ for (const auto &MO : Prev.explicit_operands()) {
+ unsigned Idx = MO.getOperandNo();
+ // Skip the result operand we'd already added.
+ if (Idx == 0)
+ continue;
+ if (Idx == PrevFirstOpIdx)
+ MIB1.addReg(RegX, getKillRegState(KillX));
+ else if (Idx == PrevSecondOpIdx)
+ MIB1.addReg(RegY, getKillRegState(KillY));
+ else
+ MIB1.add(MO);
+ }
+ MIB1.copyImplicitOps(Prev);
if (SwapRootOperands) {
std::swap(RegA, NewVR);
@@ -1138,9 +1191,20 @@ void TargetInstrInfo::reassociateOps(
}
MachineInstrBuilder MIB2 =
- BuildMI(*MF, MIMetadata(Root), TII->get(NewRootOpc), RegC)
- .addReg(RegA, getKillRegState(KillA))
- .addReg(NewVR, getKillRegState(KillNewVR));
+ buildMINoImplicit(*MF, MIMetadata(Root), TII->get(NewRootOpc), RegC);
+ for (const auto &MO : Root.explicit_operands()) {
+ unsigned Idx = MO.getOperandNo();
+ // Skip the result operand.
+ if (Idx == 0)
+ continue;
+ if (Idx == RootFirstOpIdx)
+ MIB2 = MIB2.addReg(RegA, getKillRegState(KillA));
+ else if (Idx == RootSecondOpIdx)
+ MIB2 = MIB2.addReg(NewVR, getKillRegState(KillNewVR));
+ else
+ MIB2 = MIB2.add(MO);
+ }
+ MIB2.copyImplicitOps(Root);
// Propagate FP flags from the original instructions.
// But clear poison-generating flags because those may not be valid now.
@@ -1184,25 +1248,16 @@ void TargetInstrInfo::genAlternativeCodeSequence(
MachineRegisterInfo &MRI = Root.getMF()->getRegInfo();
// Select the previous instruction in the sequence based on the input pattern.
- MachineInstr *Prev = nullptr;
- switch (Pattern) {
- case MachineCombinerPattern::REASSOC_AX_BY:
- case MachineCombinerPattern::REASSOC_XA_BY:
- Prev = MRI.getUniqueVRegDef(Root.getOperand(1).getReg());
- break;
- case MachineCombinerPattern::REASSOC_AX_YB:
- case MachineCombinerPattern::REASSOC_XA_YB:
- Prev = MRI.getUniqueVRegDef(Root.getOperand(2).getReg());
- break;
- default:
- llvm_unreachable("Unknown pattern for machine combiner");
- }
+ std::array<unsigned, 5> OpIdx;
+ getReassociateOperandIdx(Root, Pattern, OpIdx);
+ MachineInstr *Prev = MRI.getUniqueVRegDef(Root.getOperand(OpIdx[0]).getReg());
// Don't reassociate if Prev and Root are in different blocks.
if (Prev->getParent() != Root.getParent())
return;
- reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, InstIdxForVirtReg);
+ reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, OpIdx,
+ InstIdxForVirtReg);
}
MachineTraceStrategy TargetInstrInfo::getMachineCombinerTraceStrategy() const {
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 6b75efe684d913..5eeb0d7c27cb98 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -1575,10 +1575,10 @@ void RISCVInstrInfo::finalizeInsInstrs(
MachineFunction &MF = *Root.getMF();
for (auto *NewMI : InsInstrs) {
- assert(static_cast<unsigned>(RISCV::getNamedOperandIdx(
- NewMI->getOpcode(), RISCV::OpName::frm)) ==
- NewMI->getNumOperands() &&
- "Instruction has unexpected number of operands");
+ // We'd already added the FRM operand.
+ if (static_cast<unsigned>(RISCV::getNamedOperandIdx(
+ NewMI->getOpcode(), RISCV::OpName::frm)) != NewMI->getNumOperands())
+ continue;
MachineInstrBuilder MIB(MF, NewMI);
MIB.add(FRM);
if (FRM.getImm() == RISCVFPRndMode::DYN)
>From 7c67d617e3bb2b46b528b115439f3e00f7f68790 Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Mon, 8 Apr 2024 12:23:52 -0700
Subject: [PATCH 2/3] [RISCV][MachineCombiner] Pre-commit test for RVV
reassociations
---
.../RISCV/rvv/vector-reassociations.ll | 254 ++++++++++++++++++
1 file changed, 254 insertions(+)
create mode 100644 llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll
diff --git a/llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll b/llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll
new file mode 100644
index 00000000000000..3cb6f3c35286cf
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll
@@ -0,0 +1,254 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=riscv32 -mattr='+v' -O3 %s -o - | FileCheck %s
+
+declare <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
+ <vscale x 1 x i8>,
+ <vscale x 1 x i8>,
+ <vscale x 1 x i8>,
+ i32)
+
+declare <vscale x 1 x i8> @llvm.riscv.vadd.mask.nxv1i8.nxv1i8(
+ <vscale x 1 x i8>,
+ <vscale x 1 x i8>,
+ <vscale x 1 x i8>,
+ <vscale x 1 x i1>,
+ i32, i32)
+
+declare <vscale x 1 x i8> @llvm.riscv.vsub.nxv1i8.nxv1i8(
+ <vscale x 1 x i8>,
+ <vscale x 1 x i8>,
+ <vscale x 1 x i8>,
+ i32)
+
+declare <vscale x 1 x i8> @llvm.riscv.vmul.nxv1i8.nxv1i8(
+ <vscale x 1 x i8>,
+ <vscale x 1 x i8>,
+ <vscale x 1 x i8>,
+ i32)
+
+define <vscale x 1 x i8> @simple_vadd_vv(<vscale x 1 x i8> %0, <vscale x 1 x i8> %1, i32 %2) nounwind {
+; CHECK-LABEL: simple_vadd_vv:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetvli zero, a0, e8, mf8, ta, ma
+; CHECK-NEXT: vadd.vv v9, v8, v9
+; CHECK-NEXT: vadd.vv v9, v8, v9
+; CHECK-NEXT: vadd.vv v8, v8, v9
+; CHECK-NEXT: ret
+entry:
+ %a = call <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
+ <vscale x 1 x i8> undef,
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %1,
+ i32 %2)
+
+ %b = call <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
+ <vscale x 1 x i8> undef,
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %a,
+ i32 %2)
+
+ %c = call <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
+ <vscale x 1 x i8> undef,
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %b,
+ i32 %2)
+
+ ret <vscale x 1 x i8> %c
+}
+
+define <vscale x 1 x i8> @simple_vadd_vsub_vv(<vscale x 1 x i8> %0, <vscale x 1 x i8> %1, i32 %2) nounwind {
+; CHECK-LABEL: simple_vadd_vsub_vv:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetvli zero, a0, e8, mf8, ta, ma
+; CHECK-NEXT: vsub.vv v9, v8, v9
+; CHECK-NEXT: vadd.vv v9, v8, v9
+; CHECK-NEXT: vadd.vv v8, v8, v9
+; CHECK-NEXT: ret
+entry:
+ %a = call <vscale x 1 x i8> @llvm.riscv.vsub.nxv1i8.nxv1i8(
+ <vscale x 1 x i8> undef,
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %1,
+ i32 %2)
+
+ %b = call <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
+ <vscale x 1 x i8> undef,
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %a,
+ i32 %2)
+
+ %c = call <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
+ <vscale x 1 x i8> undef,
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %b,
+ i32 %2)
+
+ ret <vscale x 1 x i8> %c
+}
+
+define <vscale x 1 x i8> @simple_vmul_vv(<vscale x 1 x i8> %0, <vscale x 1 x i8> %1, i32 %2) nounwind {
+; CHECK-LABEL: simple_vmul_vv:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetvli zero, a0, e8, mf8, ta, ma
+; CHECK-NEXT: vmul.vv v9, v8, v9
+; CHECK-NEXT: vmul.vv v9, v8, v9
+; CHECK-NEXT: vmul.vv v8, v8, v9
+; CHECK-NEXT: ret
+entry:
+ %a = call <vscale x 1 x i8> @llvm.riscv.vmul.nxv1i8.nxv1i8(
+ <vscale x 1 x i8> undef,
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %1,
+ i32 %2)
+
+ %b = call <vscale x 1 x i8> @llvm.riscv.vmul.nxv1i8.nxv1i8(
+ <vscale x 1 x i8> undef,
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %a,
+ i32 %2)
+
+ %c = call <vscale x 1 x i8> @llvm.riscv.vmul.nxv1i8.nxv1i8(
+ <vscale x 1 x i8> undef,
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %b,
+ i32 %2)
+
+ ret <vscale x 1 x i8> %c
+}
+
+; With passthru and masks.
+define <vscale x 1 x i8> @vadd_vv_passthru(<vscale x 1 x i8> %0, <vscale x 1 x i8> %1, i32 %2) nounwind {
+; CHECK-LABEL: vadd_vv_passthru:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetvli zero, a0, e8, mf8, tu, ma
+; CHECK-NEXT: vmv1r.v v10, v8
+; CHECK-NEXT: vadd.vv v10, v8, v9
+; CHECK-NEXT: vmv1r.v v9, v8
+; CHECK-NEXT: vadd.vv v9, v8, v10
+; CHECK-NEXT: vadd.vv v8, v8, v9
+; CHECK-NEXT: ret
+entry:
+ %a = call <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %1,
+ i32 %2)
+
+ %b = call <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %a,
+ i32 %2)
+
+ %c = call <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %b,
+ i32 %2)
+
+ ret <vscale x 1 x i8> %c
+}
+
+define <vscale x 1 x i8> @vadd_vv_passthru_negative(<vscale x 1 x i8> %0, <vscale x 1 x i8> %1, i32 %2) nounwind {
+; CHECK-LABEL: vadd_vv_passthru_negative:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetvli zero, a0, e8, mf8, tu, ma
+; CHECK-NEXT: vmv1r.v v10, v8
+; CHECK-NEXT: vadd.vv v10, v8, v9
+; CHECK-NEXT: vadd.vv v9, v8, v10
+; CHECK-NEXT: vadd.vv v8, v8, v9
+; CHECK-NEXT: ret
+entry:
+ %a = call <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %1,
+ i32 %2)
+
+ %b = call <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
+ <vscale x 1 x i8> %1,
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %a,
+ i32 %2)
+
+ %c = call <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %b,
+ i32 %2)
+
+ ret <vscale x 1 x i8> %c
+}
+
+define <vscale x 1 x i8> @vadd_vv_mask(<vscale x 1 x i8> %0, <vscale x 1 x i8> %1, i32 %2, <vscale x 1 x i1> %m) nounwind {
+; CHECK-LABEL: vadd_vv_mask:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetvli zero, a0, e8, mf8, ta, mu
+; CHECK-NEXT: vmv1r.v v10, v8
+; CHECK-NEXT: vadd.vv v10, v8, v9, v0.t
+; CHECK-NEXT: vmv1r.v v9, v8
+; CHECK-NEXT: vadd.vv v9, v8, v10, v0.t
+; CHECK-NEXT: vadd.vv v8, v8, v9, v0.t
+; CHECK-NEXT: ret
+entry:
+ %a = call <vscale x 1 x i8> @llvm.riscv.vadd.mask.nxv1i8.nxv1i8(
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %1,
+ <vscale x 1 x i1> %m,
+ i32 %2, i32 1)
+
+ %b = call <vscale x 1 x i8> @llvm.riscv.vadd.mask.nxv1i8.nxv1i8(
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %a,
+ <vscale x 1 x i1> %m,
+ i32 %2, i32 1)
+
+ %c = call <vscale x 1 x i8> @llvm.riscv.vadd.mask.nxv1i8.nxv1i8(
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %b,
+ <vscale x 1 x i1> %m,
+ i32 %2, i32 1)
+
+ ret <vscale x 1 x i8> %c
+}
+
+define <vscale x 1 x i8> @vadd_vv_mask_negative(<vscale x 1 x i8> %0, <vscale x 1 x i8> %1, i32 %2, <vscale x 1 x i1> %m) nounwind {
+; CHECK-LABEL: vadd_vv_mask_negative:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetvli zero, a0, e8, mf8, ta, mu
+; CHECK-NEXT: vmv1r.v v10, v8
+; CHECK-NEXT: vadd.vv v10, v8, v9, v0.t
+; CHECK-NEXT: vmv1r.v v9, v8
+; CHECK-NEXT: vadd.vv v9, v8, v10, v0.t
+; CHECK-NEXT: vadd.vv v8, v8, v9
+; CHECK-NEXT: ret
+entry:
+ %a = call <vscale x 1 x i8> @llvm.riscv.vadd.mask.nxv1i8.nxv1i8(
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %1,
+ <vscale x 1 x i1> %m,
+ i32 %2, i32 1)
+
+ %b = call <vscale x 1 x i8> @llvm.riscv.vadd.mask.nxv1i8.nxv1i8(
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %a,
+ <vscale x 1 x i1> %m,
+ i32 %2, i32 1)
+
+ %splat = insertelement <vscale x 1 x i1> poison, i1 1, i32 0
+ %m2 = shufflevector <vscale x 1 x i1> %splat, <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer
+ %c = call <vscale x 1 x i8> @llvm.riscv.vadd.mask.nxv1i8.nxv1i8(
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %0,
+ <vscale x 1 x i8> %b,
+ <vscale x 1 x i1> %m2,
+ i32 %2, i32 1)
+
+ ret <vscale x 1 x i8> %c
+}
+
>From 1d5d09f9722697bae8bb8067d66b79ae6f56edab Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Wed, 10 Apr 2024 11:10:26 -0700
Subject: [PATCH 3/3] [RISCV][MachineCombiner] Add reassociation optimizations
for RVV instructions
This patch covers VADD_VV, VMUL_VV, VMULU_VV, and VMULUH_VV.
---
llvm/lib/Target/RISCV/RISCVInstrInfo.cpp | 220 ++++++++++++++++++
llvm/lib/Target/RISCV/RISCVInstrInfo.h | 14 ++
.../RISCV/rvv/vector-reassociations.ll | 14 +-
3 files changed, 241 insertions(+), 7 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 5eeb0d7c27cb98..d427842317881c 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -1619,8 +1619,184 @@ static bool isFMUL(unsigned Opc) {
}
}
+bool RISCVInstrInfo::isVectorAssociativeAndCommutative(const MachineInstr &Inst,
+ bool Invert) const {
+#define OPCODE_LMUL_CASE(OPC) \
+ case RISCV::OPC##_M1: \
+ case RISCV::OPC##_M2: \
+ case RISCV::OPC##_M4: \
+ case RISCV::OPC##_M8: \
+ case RISCV::OPC##_MF2: \
+ case RISCV::OPC##_MF4: \
+ case RISCV::OPC##_MF8
+
+#define OPCODE_LMUL_MASK_CASE(OPC) \
+ case RISCV::OPC##_M1_MASK: \
+ case RISCV::OPC##_M2_MASK: \
+ case RISCV::OPC##_M4_MASK: \
+ case RISCV::OPC##_M8_MASK: \
+ case RISCV::OPC##_MF2_MASK: \
+ case RISCV::OPC##_MF4_MASK: \
+ case RISCV::OPC##_MF8_MASK
+
+ unsigned Opcode = Inst.getOpcode();
+ if (Invert) {
+ if (auto InvOpcode = getInverseOpcode(Opcode))
+ Opcode = *InvOpcode;
+ else
+ return false;
+ }
+
+ // clang-format off
+ switch (Opcode) {
+ default:
+ return false;
+ OPCODE_LMUL_CASE(PseudoVADD_VV):
+ OPCODE_LMUL_MASK_CASE(PseudoVADD_VV):
+ OPCODE_LMUL_CASE(PseudoVMUL_VV):
+ OPCODE_LMUL_MASK_CASE(PseudoVMUL_VV):
+ OPCODE_LMUL_CASE(PseudoVMULH_VV):
+ OPCODE_LMUL_MASK_CASE(PseudoVMULH_VV):
+ OPCODE_LMUL_CASE(PseudoVMULHU_VV):
+ OPCODE_LMUL_MASK_CASE(PseudoVMULHU_VV):
+ return true;
+ }
+ // clang-format on
+
+#undef OPCODE_LMUL_MASK_CASE
+#undef OPCODE_LMUL_CASE
+}
+
+bool RISCVInstrInfo::areRVVInstsReassociable(const MachineInstr &MI1,
+ const MachineInstr &MI2) const {
+ if (!areOpcodesEqualOrInverse(MI1.getOpcode(), MI2.getOpcode()))
+ return false;
+
+ // Make sure vtype operands are also the same.
+ const MCInstrDesc &Desc = get(MI1.getOpcode());
+ const uint64_t TSFlags = Desc.TSFlags;
+
+ auto checkImmOperand = [&](unsigned OpIdx) {
+ return MI1.getOperand(OpIdx).getImm() == MI2.getOperand(OpIdx).getImm();
+ };
+
+ auto checkRegOperand = [&](unsigned OpIdx) {
+ return MI1.getOperand(OpIdx).getReg() == MI2.getOperand(OpIdx).getReg();
+ };
+
+ // PassThru
+ if (!checkRegOperand(1))
+ return false;
+
+ // SEW
+ if (RISCVII::hasSEWOp(TSFlags) &&
+ !checkImmOperand(RISCVII::getSEWOpNum(Desc)))
+ return false;
+
+ // Mask
+ // There might be more sophisticated ways to check equality of masks, but
+ // right now we simply check if they're the same virtual register.
+ if (RISCVII::usesMaskPolicy(TSFlags) && !checkRegOperand(4))
+ return false;
+
+ // Tail / Mask policies
+ if (RISCVII::hasVecPolicyOp(TSFlags) &&
+ !checkImmOperand(RISCVII::getVecPolicyOpNum(Desc)))
+ return false;
+
+ // VL
+ if (RISCVII::hasVLOp(TSFlags)) {
+ unsigned OpIdx = RISCVII::getVLOpNum(Desc);
+ const MachineOperand &Op1 = MI1.getOperand(OpIdx);
+ const MachineOperand &Op2 = MI2.getOperand(OpIdx);
+ if (Op1.getType() != Op2.getType())
+ return false;
+ switch (Op1.getType()) {
+ case MachineOperand::MO_Register:
+ if (Op1.getReg() != Op2.getReg())
+ return false;
+ break;
+ case MachineOperand::MO_Immediate:
+ if (Op1.getImm() != Op2.getImm())
+ return false;
+ break;
+ default:
+ llvm_unreachable("Unrecognized VL operand type");
+ }
+ }
+
+ // Rounding modes
+ if (RISCVII::hasRoundModeOp(TSFlags) &&
+ !checkImmOperand(RISCVII::getVLOpNum(Desc) - 1))
+ return false;
+
+ return true;
+}
+
+// Most of our RVV pseudo has passthru operand, so the real operands
+// start from index = 2.
+bool RISCVInstrInfo::hasReassociableVectorSibling(const MachineInstr &Inst,
+ bool &Commuted) const {
+ const MachineBasicBlock *MBB = Inst.getParent();
+ const MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();
+ MachineInstr *MI1 = MRI.getUniqueVRegDef(Inst.getOperand(2).getReg());
+ MachineInstr *MI2 = MRI.getUniqueVRegDef(Inst.getOperand(3).getReg());
+
+ // If only one operand has the same or inverse opcode and it's the second
+ // source operand, the operands must be commuted.
+ Commuted = !areRVVInstsReassociable(Inst, *MI1) &&
+ areRVVInstsReassociable(Inst, *MI2);
+ if (Commuted)
+ std::swap(MI1, MI2);
+
+ return areRVVInstsReassociable(Inst, *MI1) &&
+ (isVectorAssociativeAndCommutative(*MI1) ||
+ isVectorAssociativeAndCommutative(*MI1, /* Invert */ true)) &&
+ hasReassociableOperands(*MI1, MBB) &&
+ MRI.hasOneNonDBGUse(MI1->getOperand(0).getReg());
+}
+
+bool RISCVInstrInfo::hasReassociableOperands(
+ const MachineInstr &Inst, const MachineBasicBlock *MBB) const {
+ if (!isVectorAssociativeAndCommutative(Inst) &&
+ !isVectorAssociativeAndCommutative(Inst, /*Invert=*/true))
+ return TargetInstrInfo::hasReassociableOperands(Inst, MBB);
+
+ const MachineOperand &Op1 = Inst.getOperand(2);
+ const MachineOperand &Op2 = Inst.getOperand(3);
+ const MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();
+
+ // We need virtual register definitions for the operands that we will
+ // reassociate.
+ MachineInstr *MI1 = nullptr;
+ MachineInstr *MI2 = nullptr;
+ if (Op1.isReg() && Op1.getReg().isVirtual())
+ MI1 = MRI.getUniqueVRegDef(Op1.getReg());
+ if (Op2.isReg() && Op2.getReg().isVirtual())
+ MI2 = MRI.getUniqueVRegDef(Op2.getReg());
+
+ // And at least one operand must be defined in MBB.
+ return MI1 && MI2 && (MI1->getParent() == MBB || MI2->getParent() == MBB);
+}
+
+void RISCVInstrInfo::getReassociateOperandIdx(
+ const MachineInstr &Root, MachineCombinerPattern Pattern,
+ std::array<unsigned, 5> &OperandIndices) const {
+ TargetInstrInfo::getReassociateOperandIdx(Root, Pattern, OperandIndices);
+ if (isVectorAssociativeAndCommutative(Root) ||
+ isVectorAssociativeAndCommutative(Root, /*Invert=*/true)) {
+ // Skip the passthrough operand, so add all indices by one.
+ for (unsigned I = 0; I < 5; ++I)
+ ++OperandIndices[I];
+ }
+}
+
bool RISCVInstrInfo::hasReassociableSibling(const MachineInstr &Inst,
bool &Commuted) const {
+ if (isVectorAssociativeAndCommutative(Inst) ||
+ isVectorAssociativeAndCommutative(Inst, /*Invert=*/true))
+ return hasReassociableVectorSibling(Inst, Commuted);
+
if (!TargetInstrInfo::hasReassociableSibling(Inst, Commuted))
return false;
@@ -1640,6 +1816,9 @@ bool RISCVInstrInfo::hasReassociableSibling(const MachineInstr &Inst,
bool RISCVInstrInfo::isAssociativeAndCommutative(const MachineInstr &Inst,
bool Invert) const {
+ if (isVectorAssociativeAndCommutative(Inst, Invert))
+ return true;
+
unsigned Opc = Inst.getOpcode();
if (Invert) {
auto InverseOpcode = getInverseOpcode(Opc);
@@ -1692,6 +1871,38 @@ bool RISCVInstrInfo::isAssociativeAndCommutative(const MachineInstr &Inst,
std::optional<unsigned>
RISCVInstrInfo::getInverseOpcode(unsigned Opcode) const {
+#define RVV_OPC_LMUL_CASE(OPC, INV) \
+ case RISCV::OPC##_M1: \
+ return RISCV::INV##_M1; \
+ case RISCV::OPC##_M2: \
+ return RISCV::INV##_M2; \
+ case RISCV::OPC##_M4: \
+ return RISCV::INV##_M4; \
+ case RISCV::OPC##_M8: \
+ return RISCV::INV##_M8; \
+ case RISCV::OPC##_MF2: \
+ return RISCV::INV##_MF2; \
+ case RISCV::OPC##_MF4: \
+ return RISCV::INV##_MF4; \
+ case RISCV::OPC##_MF8: \
+ return RISCV::INV##_MF8
+
+#define RVV_OPC_LMUL_MASK_CASE(OPC, INV) \
+ case RISCV::OPC##_M1_MASK: \
+ return RISCV::INV##_M1_MASK; \
+ case RISCV::OPC##_M2_MASK: \
+ return RISCV::INV##_M2_MASK; \
+ case RISCV::OPC##_M4_MASK: \
+ return RISCV::INV##_M4_MASK; \
+ case RISCV::OPC##_M8_MASK: \
+ return RISCV::INV##_M8_MASK; \
+ case RISCV::OPC##_MF2_MASK: \
+ return RISCV::INV##_MF2_MASK; \
+ case RISCV::OPC##_MF4_MASK: \
+ return RISCV::INV##_MF4_MASK; \
+ case RISCV::OPC##_MF8_MASK: \
+ return RISCV::INV##_MF8_MASK
+
switch (Opcode) {
default:
return std::nullopt;
@@ -1715,7 +1926,16 @@ RISCVInstrInfo::getInverseOpcode(unsigned Opcode) const {
return RISCV::SUBW;
case RISCV::SUBW:
return RISCV::ADDW;
+ // clang-format off
+ RVV_OPC_LMUL_CASE(PseudoVADD_VV, PseudoVSUB_VV);
+ RVV_OPC_LMUL_MASK_CASE(PseudoVADD_VV, PseudoVSUB_VV);
+ RVV_OPC_LMUL_CASE(PseudoVSUB_VV, PseudoVADD_VV);
+ RVV_OPC_LMUL_MASK_CASE(PseudoVSUB_VV, PseudoVADD_VV);
+ // clang-format on
}
+
+#undef RVV_OPC_LMUL_MASK_CASE
+#undef RVV_OPC_LMUL_CASE
}
static bool canCombineFPFusedMultiply(const MachineInstr &Root,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
index 81d9c9db783c02..ecb5628f13d8ac 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
@@ -255,6 +255,9 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
SmallVectorImpl<MachineInstr *> &DelInstrs,
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const override;
+ bool hasReassociableOperands(const MachineInstr &Inst,
+ const MachineBasicBlock *MBB) const override;
+
bool hasReassociableSibling(const MachineInstr &Inst,
bool &Commuted) const override;
@@ -263,6 +266,10 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
std::optional<unsigned> getInverseOpcode(unsigned Opcode) const override;
+ void getReassociateOperandIdx(
+ const MachineInstr &Root, MachineCombinerPattern Pattern,
+ std::array<unsigned, 5> &OperandIndices) const override;
+
ArrayRef<std::pair<MachineMemOperand::Flags, const char *>>
getSerializableMachineMemOperandTargetFlags() const override;
@@ -286,6 +293,13 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
private:
unsigned getInstBundleLength(const MachineInstr &MI) const;
+
+ bool isVectorAssociativeAndCommutative(const MachineInstr &MI,
+ bool Invert = false) const;
+ bool areRVVInstsReassociable(const MachineInstr &MI1,
+ const MachineInstr &MI2) const;
+ bool hasReassociableVectorSibling(const MachineInstr &Inst,
+ bool &Commuted) const;
};
namespace RISCV {
diff --git a/llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll b/llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll
index 3cb6f3c35286cf..7c3d48c3e48a73 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll
@@ -31,7 +31,7 @@ define <vscale x 1 x i8> @simple_vadd_vv(<vscale x 1 x i8> %0, <vscale x 1 x i8>
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vsetvli zero, a0, e8, mf8, ta, ma
; CHECK-NEXT: vadd.vv v9, v8, v9
-; CHECK-NEXT: vadd.vv v9, v8, v9
+; CHECK-NEXT: vadd.vv v8, v8, v8
; CHECK-NEXT: vadd.vv v8, v8, v9
; CHECK-NEXT: ret
entry:
@@ -61,7 +61,7 @@ define <vscale x 1 x i8> @simple_vadd_vsub_vv(<vscale x 1 x i8> %0, <vscale x 1
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vsetvli zero, a0, e8, mf8, ta, ma
; CHECK-NEXT: vsub.vv v9, v8, v9
-; CHECK-NEXT: vadd.vv v9, v8, v9
+; CHECK-NEXT: vadd.vv v8, v8, v8
; CHECK-NEXT: vadd.vv v8, v8, v9
; CHECK-NEXT: ret
entry:
@@ -91,7 +91,7 @@ define <vscale x 1 x i8> @simple_vmul_vv(<vscale x 1 x i8> %0, <vscale x 1 x i8>
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vsetvli zero, a0, e8, mf8, ta, ma
; CHECK-NEXT: vmul.vv v9, v8, v9
-; CHECK-NEXT: vmul.vv v9, v8, v9
+; CHECK-NEXT: vmul.vv v8, v8, v8
; CHECK-NEXT: vmul.vv v8, v8, v9
; CHECK-NEXT: ret
entry:
@@ -124,8 +124,8 @@ define <vscale x 1 x i8> @vadd_vv_passthru(<vscale x 1 x i8> %0, <vscale x 1 x i
; CHECK-NEXT: vmv1r.v v10, v8
; CHECK-NEXT: vadd.vv v10, v8, v9
; CHECK-NEXT: vmv1r.v v9, v8
-; CHECK-NEXT: vadd.vv v9, v8, v10
-; CHECK-NEXT: vadd.vv v8, v8, v9
+; CHECK-NEXT: vadd.vv v9, v8, v8
+; CHECK-NEXT: vadd.vv v8, v9, v10
; CHECK-NEXT: ret
entry:
%a = call <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
@@ -187,8 +187,8 @@ define <vscale x 1 x i8> @vadd_vv_mask(<vscale x 1 x i8> %0, <vscale x 1 x i8> %
; CHECK-NEXT: vmv1r.v v10, v8
; CHECK-NEXT: vadd.vv v10, v8, v9, v0.t
; CHECK-NEXT: vmv1r.v v9, v8
-; CHECK-NEXT: vadd.vv v9, v8, v10, v0.t
-; CHECK-NEXT: vadd.vv v8, v8, v9, v0.t
+; CHECK-NEXT: vadd.vv v9, v8, v8, v0.t
+; CHECK-NEXT: vadd.vv v8, v9, v10, v0.t
; CHECK-NEXT: ret
entry:
%a = call <vscale x 1 x i8> @llvm.riscv.vadd.mask.nxv1i8.nxv1i8(
More information about the llvm-commits
mailing list