[llvm] [RISCV][MachineCombiner] Add reassociation optimizations for RVV instructions (PR #88307)

Min-Yih Hsu via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 11 11:36:27 PDT 2024


https://github.com/mshockwave updated https://github.com/llvm/llvm-project/pull/88307

>From 8e67ee297d4e050b6fc0ac7cc6d2c14d514c8997 Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Mon, 8 Apr 2024 11:41:27 -0700
Subject: [PATCH 1/5] [CodeGen][TII] Allow reassociation on custom operand
 indices

This opens up a door for reusing reassociation optimizations on target-specific
binary operations with non-standard operand list.

This is effectively a NFC.
---
 llvm/include/llvm/CodeGen/TargetInstrInfo.h |  11 ++
 llvm/lib/CodeGen/TargetInstrInfo.cpp        | 145 ++++++++++++++------
 llvm/lib/Target/RISCV/RISCVInstrInfo.cpp    |   8 +-
 3 files changed, 115 insertions(+), 49 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/TargetInstrInfo.h b/llvm/include/llvm/CodeGen/TargetInstrInfo.h
index 9fd0ebe6956fbe..82c952b227557d 100644
--- a/llvm/include/llvm/CodeGen/TargetInstrInfo.h
+++ b/llvm/include/llvm/CodeGen/TargetInstrInfo.h
@@ -30,6 +30,7 @@
 #include "llvm/MC/MCInstrInfo.h"
 #include "llvm/Support/BranchProbability.h"
 #include "llvm/Support/ErrorHandling.h"
+#include <array>
 #include <cassert>
 #include <cstddef>
 #include <cstdint>
@@ -1268,12 +1269,22 @@ class TargetInstrInfo : public MCInstrInfo {
     return true;
   }
 
+  /// The returned array encodes the operand index for each parameter because
+  /// the operands may be commuted; the operand indices for associative
+  /// operations might also be target-specific. Each element specifies the index
+  /// of {Prev, A, B, X, Y}.
+  virtual void
+  getReassociateOperandIdx(const MachineInstr &Root,
+                           MachineCombinerPattern Pattern,
+                           std::array<unsigned, 5> &OperandIndices) const;
+
   /// Attempt to reassociate \P Root and \P Prev according to \P Pattern to
   /// reduce critical path length.
   void reassociateOps(MachineInstr &Root, MachineInstr &Prev,
                       MachineCombinerPattern Pattern,
                       SmallVectorImpl<MachineInstr *> &InsInstrs,
                       SmallVectorImpl<MachineInstr *> &DelInstrs,
+                      ArrayRef<unsigned> OperandIndices,
                       DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const;
 
   /// Reassociation of some instructions requires inverse operations (e.g.
diff --git a/llvm/lib/CodeGen/TargetInstrInfo.cpp b/llvm/lib/CodeGen/TargetInstrInfo.cpp
index 9fbd516acea8e1..488922e3c1b720 100644
--- a/llvm/lib/CodeGen/TargetInstrInfo.cpp
+++ b/llvm/lib/CodeGen/TargetInstrInfo.cpp
@@ -1051,13 +1051,34 @@ static std::pair<bool, bool> mustSwapOperands(MachineCombinerPattern Pattern) {
   }
 }
 
+void TargetInstrInfo::getReassociateOperandIdx(
+    const MachineInstr &Root, MachineCombinerPattern Pattern,
+    std::array<unsigned, 5> &OperandIndices) const {
+  switch (Pattern) {
+  case MachineCombinerPattern::REASSOC_AX_BY:
+    OperandIndices = {1, 1, 1, 2, 2};
+    break;
+  case MachineCombinerPattern::REASSOC_AX_YB:
+    OperandIndices = {2, 1, 2, 2, 1};
+    break;
+  case MachineCombinerPattern::REASSOC_XA_BY:
+    OperandIndices = {1, 2, 1, 1, 2};
+    break;
+  case MachineCombinerPattern::REASSOC_XA_YB:
+    OperandIndices = {2, 2, 2, 1, 1};
+    break;
+  default:
+    llvm_unreachable("unexpected MachineCombinerPattern");
+  }
+}
+
 /// Attempt the reassociation transformation to reduce critical path length.
 /// See the above comments before getMachineCombinerPatterns().
 void TargetInstrInfo::reassociateOps(
-    MachineInstr &Root, MachineInstr &Prev,
-    MachineCombinerPattern Pattern,
+    MachineInstr &Root, MachineInstr &Prev, MachineCombinerPattern Pattern,
     SmallVectorImpl<MachineInstr *> &InsInstrs,
     SmallVectorImpl<MachineInstr *> &DelInstrs,
+    ArrayRef<unsigned> OperandIndices,
     DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const {
   MachineFunction *MF = Root.getMF();
   MachineRegisterInfo &MRI = MF->getRegInfo();
@@ -1065,29 +1086,10 @@ void TargetInstrInfo::reassociateOps(
   const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();
   const TargetRegisterClass *RC = Root.getRegClassConstraint(0, TII, TRI);
 
-  // This array encodes the operand index for each parameter because the
-  // operands may be commuted. Each row corresponds to a pattern value,
-  // and each column specifies the index of A, B, X, Y.
-  unsigned OpIdx[4][4] = {
-    { 1, 1, 2, 2 },
-    { 1, 2, 2, 1 },
-    { 2, 1, 1, 2 },
-    { 2, 2, 1, 1 }
-  };
-
-  int Row;
-  switch (Pattern) {
-  case MachineCombinerPattern::REASSOC_AX_BY: Row = 0; break;
-  case MachineCombinerPattern::REASSOC_AX_YB: Row = 1; break;
-  case MachineCombinerPattern::REASSOC_XA_BY: Row = 2; break;
-  case MachineCombinerPattern::REASSOC_XA_YB: Row = 3; break;
-  default: llvm_unreachable("unexpected MachineCombinerPattern");
-  }
-
-  MachineOperand &OpA = Prev.getOperand(OpIdx[Row][0]);
-  MachineOperand &OpB = Root.getOperand(OpIdx[Row][1]);
-  MachineOperand &OpX = Prev.getOperand(OpIdx[Row][2]);
-  MachineOperand &OpY = Root.getOperand(OpIdx[Row][3]);
+  MachineOperand &OpA = Prev.getOperand(OperandIndices[1]);
+  MachineOperand &OpB = Root.getOperand(OperandIndices[2]);
+  MachineOperand &OpX = Prev.getOperand(OperandIndices[3]);
+  MachineOperand &OpY = Root.getOperand(OperandIndices[4]);
   MachineOperand &OpC = Root.getOperand(0);
 
   Register RegA = OpA.getReg();
@@ -1126,11 +1128,62 @@ void TargetInstrInfo::reassociateOps(
     std::swap(KillX, KillY);
   }
 
+  unsigned PrevFirstOpIdx, PrevSecondOpIdx;
+  unsigned RootFirstOpIdx, RootSecondOpIdx;
+  switch (Pattern) {
+  case MachineCombinerPattern::REASSOC_AX_BY:
+    PrevFirstOpIdx = OperandIndices[1];
+    PrevSecondOpIdx = OperandIndices[3];
+    RootFirstOpIdx = OperandIndices[2];
+    RootSecondOpIdx = OperandIndices[4];
+    break;
+  case MachineCombinerPattern::REASSOC_AX_YB:
+    PrevFirstOpIdx = OperandIndices[1];
+    PrevSecondOpIdx = OperandIndices[3];
+    RootFirstOpIdx = OperandIndices[4];
+    RootSecondOpIdx = OperandIndices[2];
+    break;
+  case MachineCombinerPattern::REASSOC_XA_BY:
+    PrevFirstOpIdx = OperandIndices[3];
+    PrevSecondOpIdx = OperandIndices[1];
+    RootFirstOpIdx = OperandIndices[2];
+    RootSecondOpIdx = OperandIndices[4];
+    break;
+  case MachineCombinerPattern::REASSOC_XA_YB:
+    PrevFirstOpIdx = OperandIndices[3];
+    PrevSecondOpIdx = OperandIndices[1];
+    RootFirstOpIdx = OperandIndices[4];
+    RootSecondOpIdx = OperandIndices[2];
+    break;
+  default:
+    llvm_unreachable("unexpected MachineCombinerPattern");
+  }
+
+  // Basically BuildMI but doesn't add implicit operands by default.
+  auto buildMINoImplicit = [](MachineFunction &MF, const MIMetadata &MIMD,
+                              const MCInstrDesc &MCID, Register DestReg) {
+    return MachineInstrBuilder(
+               MF, MF.CreateMachineInstr(MCID, MIMD.getDL(), /*NoImpl=*/true))
+        .setPCSections(MIMD.getPCSections())
+        .addReg(DestReg, RegState::Define);
+  };
+
   // Create new instructions for insertion.
   MachineInstrBuilder MIB1 =
-      BuildMI(*MF, MIMetadata(Prev), TII->get(NewPrevOpc), NewVR)
-          .addReg(RegX, getKillRegState(KillX))
-          .addReg(RegY, getKillRegState(KillY));
+      buildMINoImplicit(*MF, MIMetadata(Prev), TII->get(NewPrevOpc), NewVR);
+  for (const auto &MO : Prev.explicit_operands()) {
+    unsigned Idx = MO.getOperandNo();
+    // Skip the result operand we'd already added.
+    if (Idx == 0)
+      continue;
+    if (Idx == PrevFirstOpIdx)
+      MIB1.addReg(RegX, getKillRegState(KillX));
+    else if (Idx == PrevSecondOpIdx)
+      MIB1.addReg(RegY, getKillRegState(KillY));
+    else
+      MIB1.add(MO);
+  }
+  MIB1.copyImplicitOps(Prev);
 
   if (SwapRootOperands) {
     std::swap(RegA, NewVR);
@@ -1138,9 +1191,20 @@ void TargetInstrInfo::reassociateOps(
   }
 
   MachineInstrBuilder MIB2 =
-      BuildMI(*MF, MIMetadata(Root), TII->get(NewRootOpc), RegC)
-          .addReg(RegA, getKillRegState(KillA))
-          .addReg(NewVR, getKillRegState(KillNewVR));
+      buildMINoImplicit(*MF, MIMetadata(Root), TII->get(NewRootOpc), RegC);
+  for (const auto &MO : Root.explicit_operands()) {
+    unsigned Idx = MO.getOperandNo();
+    // Skip the result operand.
+    if (Idx == 0)
+      continue;
+    if (Idx == RootFirstOpIdx)
+      MIB2 = MIB2.addReg(RegA, getKillRegState(KillA));
+    else if (Idx == RootSecondOpIdx)
+      MIB2 = MIB2.addReg(NewVR, getKillRegState(KillNewVR));
+    else
+      MIB2 = MIB2.add(MO);
+  }
+  MIB2.copyImplicitOps(Root);
 
   // Propagate FP flags from the original instructions.
   // But clear poison-generating flags because those may not be valid now.
@@ -1184,25 +1248,16 @@ void TargetInstrInfo::genAlternativeCodeSequence(
   MachineRegisterInfo &MRI = Root.getMF()->getRegInfo();
 
   // Select the previous instruction in the sequence based on the input pattern.
-  MachineInstr *Prev = nullptr;
-  switch (Pattern) {
-  case MachineCombinerPattern::REASSOC_AX_BY:
-  case MachineCombinerPattern::REASSOC_XA_BY:
-    Prev = MRI.getUniqueVRegDef(Root.getOperand(1).getReg());
-    break;
-  case MachineCombinerPattern::REASSOC_AX_YB:
-  case MachineCombinerPattern::REASSOC_XA_YB:
-    Prev = MRI.getUniqueVRegDef(Root.getOperand(2).getReg());
-    break;
-  default:
-    llvm_unreachable("Unknown pattern for machine combiner");
-  }
+  std::array<unsigned, 5> OpIdx;
+  getReassociateOperandIdx(Root, Pattern, OpIdx);
+  MachineInstr *Prev = MRI.getUniqueVRegDef(Root.getOperand(OpIdx[0]).getReg());
 
   // Don't reassociate if Prev and Root are in different blocks.
   if (Prev->getParent() != Root.getParent())
     return;
 
-  reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, InstIdxForVirtReg);
+  reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, OpIdx,
+                 InstIdxForVirtReg);
 }
 
 MachineTraceStrategy TargetInstrInfo::getMachineCombinerTraceStrategy() const {
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 6b75efe684d913..5eeb0d7c27cb98 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -1575,10 +1575,10 @@ void RISCVInstrInfo::finalizeInsInstrs(
   MachineFunction &MF = *Root.getMF();
 
   for (auto *NewMI : InsInstrs) {
-    assert(static_cast<unsigned>(RISCV::getNamedOperandIdx(
-               NewMI->getOpcode(), RISCV::OpName::frm)) ==
-               NewMI->getNumOperands() &&
-           "Instruction has unexpected number of operands");
+    // We'd already added the FRM operand.
+    if (static_cast<unsigned>(RISCV::getNamedOperandIdx(
+            NewMI->getOpcode(), RISCV::OpName::frm)) != NewMI->getNumOperands())
+      continue;
     MachineInstrBuilder MIB(MF, NewMI);
     MIB.add(FRM);
     if (FRM.getImm() == RISCVFPRndMode::DYN)

>From 7c67d617e3bb2b46b528b115439f3e00f7f68790 Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Mon, 8 Apr 2024 12:23:52 -0700
Subject: [PATCH 2/5] [RISCV][MachineCombiner] Pre-commit test for RVV
 reassociations

---
 .../RISCV/rvv/vector-reassociations.ll        | 254 ++++++++++++++++++
 1 file changed, 254 insertions(+)
 create mode 100644 llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll

diff --git a/llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll b/llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll
new file mode 100644
index 00000000000000..3cb6f3c35286cf
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll
@@ -0,0 +1,254 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=riscv32 -mattr='+v' -O3 %s -o - | FileCheck %s
+
+declare <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
+  <vscale x 1 x i8>,
+  <vscale x 1 x i8>,
+  <vscale x 1 x i8>,
+  i32)
+
+declare <vscale x 1 x i8> @llvm.riscv.vadd.mask.nxv1i8.nxv1i8(
+  <vscale x 1 x i8>,
+  <vscale x 1 x i8>,
+  <vscale x 1 x i8>,
+  <vscale x 1 x i1>,
+  i32, i32)
+
+declare <vscale x 1 x i8> @llvm.riscv.vsub.nxv1i8.nxv1i8(
+  <vscale x 1 x i8>,
+  <vscale x 1 x i8>,
+  <vscale x 1 x i8>,
+  i32)
+
+declare <vscale x 1 x i8> @llvm.riscv.vmul.nxv1i8.nxv1i8(
+  <vscale x 1 x i8>,
+  <vscale x 1 x i8>,
+  <vscale x 1 x i8>,
+  i32)
+
+define <vscale x 1 x i8> @simple_vadd_vv(<vscale x 1 x i8> %0, <vscale x 1 x i8> %1, i32 %2) nounwind {
+; CHECK-LABEL: simple_vadd_vv:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli zero, a0, e8, mf8, ta, ma
+; CHECK-NEXT:    vadd.vv v9, v8, v9
+; CHECK-NEXT:    vadd.vv v9, v8, v9
+; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    ret
+entry:
+  %a = call <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
+    <vscale x 1 x i8> undef,
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %1,
+    i32 %2)
+
+  %b = call <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
+    <vscale x 1 x i8> undef,
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %a,
+    i32 %2)
+
+  %c = call <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
+    <vscale x 1 x i8> undef,
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %b,
+    i32 %2)
+
+  ret <vscale x 1 x i8> %c
+}
+
+define <vscale x 1 x i8> @simple_vadd_vsub_vv(<vscale x 1 x i8> %0, <vscale x 1 x i8> %1, i32 %2) nounwind {
+; CHECK-LABEL: simple_vadd_vsub_vv:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli zero, a0, e8, mf8, ta, ma
+; CHECK-NEXT:    vsub.vv v9, v8, v9
+; CHECK-NEXT:    vadd.vv v9, v8, v9
+; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    ret
+entry:
+  %a = call <vscale x 1 x i8> @llvm.riscv.vsub.nxv1i8.nxv1i8(
+    <vscale x 1 x i8> undef,
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %1,
+    i32 %2)
+
+  %b = call <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
+    <vscale x 1 x i8> undef,
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %a,
+    i32 %2)
+
+  %c = call <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
+    <vscale x 1 x i8> undef,
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %b,
+    i32 %2)
+
+  ret <vscale x 1 x i8> %c
+}
+
+define <vscale x 1 x i8> @simple_vmul_vv(<vscale x 1 x i8> %0, <vscale x 1 x i8> %1, i32 %2) nounwind {
+; CHECK-LABEL: simple_vmul_vv:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli zero, a0, e8, mf8, ta, ma
+; CHECK-NEXT:    vmul.vv v9, v8, v9
+; CHECK-NEXT:    vmul.vv v9, v8, v9
+; CHECK-NEXT:    vmul.vv v8, v8, v9
+; CHECK-NEXT:    ret
+entry:
+  %a = call <vscale x 1 x i8> @llvm.riscv.vmul.nxv1i8.nxv1i8(
+    <vscale x 1 x i8> undef,
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %1,
+    i32 %2)
+
+  %b = call <vscale x 1 x i8> @llvm.riscv.vmul.nxv1i8.nxv1i8(
+    <vscale x 1 x i8> undef,
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %a,
+    i32 %2)
+
+  %c = call <vscale x 1 x i8> @llvm.riscv.vmul.nxv1i8.nxv1i8(
+    <vscale x 1 x i8> undef,
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %b,
+    i32 %2)
+
+  ret <vscale x 1 x i8> %c
+}
+
+; With passthru and masks.
+define <vscale x 1 x i8> @vadd_vv_passthru(<vscale x 1 x i8> %0, <vscale x 1 x i8> %1, i32 %2) nounwind {
+; CHECK-LABEL: vadd_vv_passthru:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli zero, a0, e8, mf8, tu, ma
+; CHECK-NEXT:    vmv1r.v v10, v8
+; CHECK-NEXT:    vadd.vv v10, v8, v9
+; CHECK-NEXT:    vmv1r.v v9, v8
+; CHECK-NEXT:    vadd.vv v9, v8, v10
+; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    ret
+entry:
+  %a = call <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %1,
+    i32 %2)
+
+  %b = call <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %a,
+    i32 %2)
+
+  %c = call <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %b,
+    i32 %2)
+
+  ret <vscale x 1 x i8> %c
+}
+
+define <vscale x 1 x i8> @vadd_vv_passthru_negative(<vscale x 1 x i8> %0, <vscale x 1 x i8> %1, i32 %2) nounwind {
+; CHECK-LABEL: vadd_vv_passthru_negative:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli zero, a0, e8, mf8, tu, ma
+; CHECK-NEXT:    vmv1r.v v10, v8
+; CHECK-NEXT:    vadd.vv v10, v8, v9
+; CHECK-NEXT:    vadd.vv v9, v8, v10
+; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    ret
+entry:
+  %a = call <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %1,
+    i32 %2)
+
+  %b = call <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
+    <vscale x 1 x i8> %1,
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %a,
+    i32 %2)
+
+  %c = call <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %b,
+    i32 %2)
+
+  ret <vscale x 1 x i8> %c
+}
+
+define <vscale x 1 x i8> @vadd_vv_mask(<vscale x 1 x i8> %0, <vscale x 1 x i8> %1, i32 %2, <vscale x 1 x i1> %m) nounwind {
+; CHECK-LABEL: vadd_vv_mask:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli zero, a0, e8, mf8, ta, mu
+; CHECK-NEXT:    vmv1r.v v10, v8
+; CHECK-NEXT:    vadd.vv v10, v8, v9, v0.t
+; CHECK-NEXT:    vmv1r.v v9, v8
+; CHECK-NEXT:    vadd.vv v9, v8, v10, v0.t
+; CHECK-NEXT:    vadd.vv v8, v8, v9, v0.t
+; CHECK-NEXT:    ret
+entry:
+  %a = call <vscale x 1 x i8> @llvm.riscv.vadd.mask.nxv1i8.nxv1i8(
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %1,
+    <vscale x 1 x i1> %m,
+    i32 %2, i32 1)
+
+  %b = call <vscale x 1 x i8> @llvm.riscv.vadd.mask.nxv1i8.nxv1i8(
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %a,
+    <vscale x 1 x i1> %m,
+    i32 %2, i32 1)
+
+  %c = call <vscale x 1 x i8> @llvm.riscv.vadd.mask.nxv1i8.nxv1i8(
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %b,
+    <vscale x 1 x i1> %m,
+    i32 %2, i32 1)
+
+  ret <vscale x 1 x i8> %c
+}
+
+define <vscale x 1 x i8> @vadd_vv_mask_negative(<vscale x 1 x i8> %0, <vscale x 1 x i8> %1, i32 %2, <vscale x 1 x i1> %m) nounwind {
+; CHECK-LABEL: vadd_vv_mask_negative:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli zero, a0, e8, mf8, ta, mu
+; CHECK-NEXT:    vmv1r.v v10, v8
+; CHECK-NEXT:    vadd.vv v10, v8, v9, v0.t
+; CHECK-NEXT:    vmv1r.v v9, v8
+; CHECK-NEXT:    vadd.vv v9, v8, v10, v0.t
+; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    ret
+entry:
+  %a = call <vscale x 1 x i8> @llvm.riscv.vadd.mask.nxv1i8.nxv1i8(
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %1,
+    <vscale x 1 x i1> %m,
+    i32 %2, i32 1)
+
+  %b = call <vscale x 1 x i8> @llvm.riscv.vadd.mask.nxv1i8.nxv1i8(
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %a,
+    <vscale x 1 x i1> %m,
+    i32 %2, i32 1)
+
+  %splat = insertelement <vscale x 1 x i1> poison, i1 1, i32 0
+  %m2 = shufflevector <vscale x 1 x i1> %splat, <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer
+  %c = call <vscale x 1 x i8> @llvm.riscv.vadd.mask.nxv1i8.nxv1i8(
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %0,
+    <vscale x 1 x i8> %b,
+    <vscale x 1 x i1> %m2,
+    i32 %2, i32 1)
+
+  ret <vscale x 1 x i8> %c
+}
+

>From 1d5d09f9722697bae8bb8067d66b79ae6f56edab Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Wed, 10 Apr 2024 11:10:26 -0700
Subject: [PATCH 3/5] [RISCV][MachineCombiner] Add reassociation optimizations
 for RVV instructions

This patch covers VADD_VV, VMUL_VV, VMULU_VV, and VMULUH_VV.
---
 llvm/lib/Target/RISCV/RISCVInstrInfo.cpp      | 220 ++++++++++++++++++
 llvm/lib/Target/RISCV/RISCVInstrInfo.h        |  14 ++
 .../RISCV/rvv/vector-reassociations.ll        |  14 +-
 3 files changed, 241 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 5eeb0d7c27cb98..d427842317881c 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -1619,8 +1619,184 @@ static bool isFMUL(unsigned Opc) {
   }
 }
 
+bool RISCVInstrInfo::isVectorAssociativeAndCommutative(const MachineInstr &Inst,
+                                                       bool Invert) const {
+#define OPCODE_LMUL_CASE(OPC)                                                  \
+  case RISCV::OPC##_M1:                                                        \
+  case RISCV::OPC##_M2:                                                        \
+  case RISCV::OPC##_M4:                                                        \
+  case RISCV::OPC##_M8:                                                        \
+  case RISCV::OPC##_MF2:                                                       \
+  case RISCV::OPC##_MF4:                                                       \
+  case RISCV::OPC##_MF8
+
+#define OPCODE_LMUL_MASK_CASE(OPC)                                             \
+  case RISCV::OPC##_M1_MASK:                                                   \
+  case RISCV::OPC##_M2_MASK:                                                   \
+  case RISCV::OPC##_M4_MASK:                                                   \
+  case RISCV::OPC##_M8_MASK:                                                   \
+  case RISCV::OPC##_MF2_MASK:                                                  \
+  case RISCV::OPC##_MF4_MASK:                                                  \
+  case RISCV::OPC##_MF8_MASK
+
+  unsigned Opcode = Inst.getOpcode();
+  if (Invert) {
+    if (auto InvOpcode = getInverseOpcode(Opcode))
+      Opcode = *InvOpcode;
+    else
+      return false;
+  }
+
+  // clang-format off
+  switch (Opcode) {
+  default:
+    return false;
+  OPCODE_LMUL_CASE(PseudoVADD_VV):
+  OPCODE_LMUL_MASK_CASE(PseudoVADD_VV):
+  OPCODE_LMUL_CASE(PseudoVMUL_VV):
+  OPCODE_LMUL_MASK_CASE(PseudoVMUL_VV):
+  OPCODE_LMUL_CASE(PseudoVMULH_VV):
+  OPCODE_LMUL_MASK_CASE(PseudoVMULH_VV):
+  OPCODE_LMUL_CASE(PseudoVMULHU_VV):
+  OPCODE_LMUL_MASK_CASE(PseudoVMULHU_VV):
+    return true;
+  }
+  // clang-format on
+
+#undef OPCODE_LMUL_MASK_CASE
+#undef OPCODE_LMUL_CASE
+}
+
+bool RISCVInstrInfo::areRVVInstsReassociable(const MachineInstr &MI1,
+                                             const MachineInstr &MI2) const {
+  if (!areOpcodesEqualOrInverse(MI1.getOpcode(), MI2.getOpcode()))
+    return false;
+
+  // Make sure vtype operands are also the same.
+  const MCInstrDesc &Desc = get(MI1.getOpcode());
+  const uint64_t TSFlags = Desc.TSFlags;
+
+  auto checkImmOperand = [&](unsigned OpIdx) {
+    return MI1.getOperand(OpIdx).getImm() == MI2.getOperand(OpIdx).getImm();
+  };
+
+  auto checkRegOperand = [&](unsigned OpIdx) {
+    return MI1.getOperand(OpIdx).getReg() == MI2.getOperand(OpIdx).getReg();
+  };
+
+  // PassThru
+  if (!checkRegOperand(1))
+    return false;
+
+  // SEW
+  if (RISCVII::hasSEWOp(TSFlags) &&
+      !checkImmOperand(RISCVII::getSEWOpNum(Desc)))
+    return false;
+
+  // Mask
+  // There might be more sophisticated ways to check equality of masks, but
+  // right now we simply check if they're the same virtual register.
+  if (RISCVII::usesMaskPolicy(TSFlags) && !checkRegOperand(4))
+    return false;
+
+  // Tail / Mask policies
+  if (RISCVII::hasVecPolicyOp(TSFlags) &&
+      !checkImmOperand(RISCVII::getVecPolicyOpNum(Desc)))
+    return false;
+
+  // VL
+  if (RISCVII::hasVLOp(TSFlags)) {
+    unsigned OpIdx = RISCVII::getVLOpNum(Desc);
+    const MachineOperand &Op1 = MI1.getOperand(OpIdx);
+    const MachineOperand &Op2 = MI2.getOperand(OpIdx);
+    if (Op1.getType() != Op2.getType())
+      return false;
+    switch (Op1.getType()) {
+    case MachineOperand::MO_Register:
+      if (Op1.getReg() != Op2.getReg())
+        return false;
+      break;
+    case MachineOperand::MO_Immediate:
+      if (Op1.getImm() != Op2.getImm())
+        return false;
+      break;
+    default:
+      llvm_unreachable("Unrecognized VL operand type");
+    }
+  }
+
+  // Rounding modes
+  if (RISCVII::hasRoundModeOp(TSFlags) &&
+      !checkImmOperand(RISCVII::getVLOpNum(Desc) - 1))
+    return false;
+
+  return true;
+}
+
+// Most of our RVV pseudo has passthru operand, so the real operands
+// start from index = 2.
+bool RISCVInstrInfo::hasReassociableVectorSibling(const MachineInstr &Inst,
+                                                  bool &Commuted) const {
+  const MachineBasicBlock *MBB = Inst.getParent();
+  const MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();
+  MachineInstr *MI1 = MRI.getUniqueVRegDef(Inst.getOperand(2).getReg());
+  MachineInstr *MI2 = MRI.getUniqueVRegDef(Inst.getOperand(3).getReg());
+
+  // If only one operand has the same or inverse opcode and it's the second
+  // source operand, the operands must be commuted.
+  Commuted = !areRVVInstsReassociable(Inst, *MI1) &&
+             areRVVInstsReassociable(Inst, *MI2);
+  if (Commuted)
+    std::swap(MI1, MI2);
+
+  return areRVVInstsReassociable(Inst, *MI1) &&
+         (isVectorAssociativeAndCommutative(*MI1) ||
+          isVectorAssociativeAndCommutative(*MI1, /* Invert */ true)) &&
+         hasReassociableOperands(*MI1, MBB) &&
+         MRI.hasOneNonDBGUse(MI1->getOperand(0).getReg());
+}
+
+bool RISCVInstrInfo::hasReassociableOperands(
+    const MachineInstr &Inst, const MachineBasicBlock *MBB) const {
+  if (!isVectorAssociativeAndCommutative(Inst) &&
+      !isVectorAssociativeAndCommutative(Inst, /*Invert=*/true))
+    return TargetInstrInfo::hasReassociableOperands(Inst, MBB);
+
+  const MachineOperand &Op1 = Inst.getOperand(2);
+  const MachineOperand &Op2 = Inst.getOperand(3);
+  const MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();
+
+  // We need virtual register definitions for the operands that we will
+  // reassociate.
+  MachineInstr *MI1 = nullptr;
+  MachineInstr *MI2 = nullptr;
+  if (Op1.isReg() && Op1.getReg().isVirtual())
+    MI1 = MRI.getUniqueVRegDef(Op1.getReg());
+  if (Op2.isReg() && Op2.getReg().isVirtual())
+    MI2 = MRI.getUniqueVRegDef(Op2.getReg());
+
+  // And at least one operand must be defined in MBB.
+  return MI1 && MI2 && (MI1->getParent() == MBB || MI2->getParent() == MBB);
+}
+
+void RISCVInstrInfo::getReassociateOperandIdx(
+    const MachineInstr &Root, MachineCombinerPattern Pattern,
+    std::array<unsigned, 5> &OperandIndices) const {
+  TargetInstrInfo::getReassociateOperandIdx(Root, Pattern, OperandIndices);
+  if (isVectorAssociativeAndCommutative(Root) ||
+      isVectorAssociativeAndCommutative(Root, /*Invert=*/true)) {
+    // Skip the passthrough operand, so add all indices by one.
+    for (unsigned I = 0; I < 5; ++I)
+      ++OperandIndices[I];
+  }
+}
+
 bool RISCVInstrInfo::hasReassociableSibling(const MachineInstr &Inst,
                                             bool &Commuted) const {
+  if (isVectorAssociativeAndCommutative(Inst) ||
+      isVectorAssociativeAndCommutative(Inst, /*Invert=*/true))
+    return hasReassociableVectorSibling(Inst, Commuted);
+
   if (!TargetInstrInfo::hasReassociableSibling(Inst, Commuted))
     return false;
 
@@ -1640,6 +1816,9 @@ bool RISCVInstrInfo::hasReassociableSibling(const MachineInstr &Inst,
 
 bool RISCVInstrInfo::isAssociativeAndCommutative(const MachineInstr &Inst,
                                                  bool Invert) const {
+  if (isVectorAssociativeAndCommutative(Inst, Invert))
+    return true;
+
   unsigned Opc = Inst.getOpcode();
   if (Invert) {
     auto InverseOpcode = getInverseOpcode(Opc);
@@ -1692,6 +1871,38 @@ bool RISCVInstrInfo::isAssociativeAndCommutative(const MachineInstr &Inst,
 
 std::optional<unsigned>
 RISCVInstrInfo::getInverseOpcode(unsigned Opcode) const {
+#define RVV_OPC_LMUL_CASE(OPC, INV)                                            \
+  case RISCV::OPC##_M1:                                                        \
+    return RISCV::INV##_M1;                                                    \
+  case RISCV::OPC##_M2:                                                        \
+    return RISCV::INV##_M2;                                                    \
+  case RISCV::OPC##_M4:                                                        \
+    return RISCV::INV##_M4;                                                    \
+  case RISCV::OPC##_M8:                                                        \
+    return RISCV::INV##_M8;                                                    \
+  case RISCV::OPC##_MF2:                                                       \
+    return RISCV::INV##_MF2;                                                   \
+  case RISCV::OPC##_MF4:                                                       \
+    return RISCV::INV##_MF4;                                                   \
+  case RISCV::OPC##_MF8:                                                       \
+    return RISCV::INV##_MF8
+
+#define RVV_OPC_LMUL_MASK_CASE(OPC, INV)                                       \
+  case RISCV::OPC##_M1_MASK:                                                   \
+    return RISCV::INV##_M1_MASK;                                               \
+  case RISCV::OPC##_M2_MASK:                                                   \
+    return RISCV::INV##_M2_MASK;                                               \
+  case RISCV::OPC##_M4_MASK:                                                   \
+    return RISCV::INV##_M4_MASK;                                               \
+  case RISCV::OPC##_M8_MASK:                                                   \
+    return RISCV::INV##_M8_MASK;                                               \
+  case RISCV::OPC##_MF2_MASK:                                                  \
+    return RISCV::INV##_MF2_MASK;                                              \
+  case RISCV::OPC##_MF4_MASK:                                                  \
+    return RISCV::INV##_MF4_MASK;                                              \
+  case RISCV::OPC##_MF8_MASK:                                                  \
+    return RISCV::INV##_MF8_MASK
+
   switch (Opcode) {
   default:
     return std::nullopt;
@@ -1715,7 +1926,16 @@ RISCVInstrInfo::getInverseOpcode(unsigned Opcode) const {
     return RISCV::SUBW;
   case RISCV::SUBW:
     return RISCV::ADDW;
+    // clang-format off
+  RVV_OPC_LMUL_CASE(PseudoVADD_VV, PseudoVSUB_VV);
+  RVV_OPC_LMUL_MASK_CASE(PseudoVADD_VV, PseudoVSUB_VV);
+  RVV_OPC_LMUL_CASE(PseudoVSUB_VV, PseudoVADD_VV);
+  RVV_OPC_LMUL_MASK_CASE(PseudoVSUB_VV, PseudoVADD_VV);
+    // clang-format on
   }
+
+#undef RVV_OPC_LMUL_MASK_CASE
+#undef RVV_OPC_LMUL_CASE
 }
 
 static bool canCombineFPFusedMultiply(const MachineInstr &Root,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
index 81d9c9db783c02..ecb5628f13d8ac 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
@@ -255,6 +255,9 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
       SmallVectorImpl<MachineInstr *> &DelInstrs,
       DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const override;
 
+  bool hasReassociableOperands(const MachineInstr &Inst,
+                               const MachineBasicBlock *MBB) const override;
+
   bool hasReassociableSibling(const MachineInstr &Inst,
                               bool &Commuted) const override;
 
@@ -263,6 +266,10 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
 
   std::optional<unsigned> getInverseOpcode(unsigned Opcode) const override;
 
+  void getReassociateOperandIdx(
+      const MachineInstr &Root, MachineCombinerPattern Pattern,
+      std::array<unsigned, 5> &OperandIndices) const override;
+
   ArrayRef<std::pair<MachineMemOperand::Flags, const char *>>
   getSerializableMachineMemOperandTargetFlags() const override;
 
@@ -286,6 +293,13 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
 
 private:
   unsigned getInstBundleLength(const MachineInstr &MI) const;
+
+  bool isVectorAssociativeAndCommutative(const MachineInstr &MI,
+                                         bool Invert = false) const;
+  bool areRVVInstsReassociable(const MachineInstr &MI1,
+                               const MachineInstr &MI2) const;
+  bool hasReassociableVectorSibling(const MachineInstr &Inst,
+                                    bool &Commuted) const;
 };
 
 namespace RISCV {
diff --git a/llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll b/llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll
index 3cb6f3c35286cf..7c3d48c3e48a73 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll
@@ -31,7 +31,7 @@ define <vscale x 1 x i8> @simple_vadd_vv(<vscale x 1 x i8> %0, <vscale x 1 x i8>
 ; CHECK:       # %bb.0: # %entry
 ; CHECK-NEXT:    vsetvli zero, a0, e8, mf8, ta, ma
 ; CHECK-NEXT:    vadd.vv v9, v8, v9
-; CHECK-NEXT:    vadd.vv v9, v8, v9
+; CHECK-NEXT:    vadd.vv v8, v8, v8
 ; CHECK-NEXT:    vadd.vv v8, v8, v9
 ; CHECK-NEXT:    ret
 entry:
@@ -61,7 +61,7 @@ define <vscale x 1 x i8> @simple_vadd_vsub_vv(<vscale x 1 x i8> %0, <vscale x 1
 ; CHECK:       # %bb.0: # %entry
 ; CHECK-NEXT:    vsetvli zero, a0, e8, mf8, ta, ma
 ; CHECK-NEXT:    vsub.vv v9, v8, v9
-; CHECK-NEXT:    vadd.vv v9, v8, v9
+; CHECK-NEXT:    vadd.vv v8, v8, v8
 ; CHECK-NEXT:    vadd.vv v8, v8, v9
 ; CHECK-NEXT:    ret
 entry:
@@ -91,7 +91,7 @@ define <vscale x 1 x i8> @simple_vmul_vv(<vscale x 1 x i8> %0, <vscale x 1 x i8>
 ; CHECK:       # %bb.0: # %entry
 ; CHECK-NEXT:    vsetvli zero, a0, e8, mf8, ta, ma
 ; CHECK-NEXT:    vmul.vv v9, v8, v9
-; CHECK-NEXT:    vmul.vv v9, v8, v9
+; CHECK-NEXT:    vmul.vv v8, v8, v8
 ; CHECK-NEXT:    vmul.vv v8, v8, v9
 ; CHECK-NEXT:    ret
 entry:
@@ -124,8 +124,8 @@ define <vscale x 1 x i8> @vadd_vv_passthru(<vscale x 1 x i8> %0, <vscale x 1 x i
 ; CHECK-NEXT:    vmv1r.v v10, v8
 ; CHECK-NEXT:    vadd.vv v10, v8, v9
 ; CHECK-NEXT:    vmv1r.v v9, v8
-; CHECK-NEXT:    vadd.vv v9, v8, v10
-; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    vadd.vv v9, v8, v8
+; CHECK-NEXT:    vadd.vv v8, v9, v10
 ; CHECK-NEXT:    ret
 entry:
   %a = call <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
@@ -187,8 +187,8 @@ define <vscale x 1 x i8> @vadd_vv_mask(<vscale x 1 x i8> %0, <vscale x 1 x i8> %
 ; CHECK-NEXT:    vmv1r.v v10, v8
 ; CHECK-NEXT:    vadd.vv v10, v8, v9, v0.t
 ; CHECK-NEXT:    vmv1r.v v9, v8
-; CHECK-NEXT:    vadd.vv v9, v8, v10, v0.t
-; CHECK-NEXT:    vadd.vv v8, v8, v9, v0.t
+; CHECK-NEXT:    vadd.vv v9, v8, v8, v0.t
+; CHECK-NEXT:    vadd.vv v8, v9, v10, v0.t
 ; CHECK-NEXT:    ret
 entry:
   %a = call <vscale x 1 x i8> @llvm.riscv.vadd.mask.nxv1i8.nxv1i8(

>From ae6ecdd626381316a74d182104ad47375247c4c7 Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Wed, 10 Apr 2024 13:43:51 -0700
Subject: [PATCH 4/5] fixup! [RISCV][MachineCombiner] Add reassociation
 optimizations for RVV instructions

---
 llvm/lib/Target/RISCV/RISCVInstrInfo.cpp | 8 ++------
 1 file changed, 2 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index d427842317881c..e9b896ce5b23a6 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -1655,10 +1655,6 @@ bool RISCVInstrInfo::isVectorAssociativeAndCommutative(const MachineInstr &Inst,
   OPCODE_LMUL_MASK_CASE(PseudoVADD_VV):
   OPCODE_LMUL_CASE(PseudoVMUL_VV):
   OPCODE_LMUL_MASK_CASE(PseudoVMUL_VV):
-  OPCODE_LMUL_CASE(PseudoVMULH_VV):
-  OPCODE_LMUL_MASK_CASE(PseudoVMULH_VV):
-  OPCODE_LMUL_CASE(PseudoVMULHU_VV):
-  OPCODE_LMUL_MASK_CASE(PseudoVMULHU_VV):
     return true;
   }
   // clang-format on
@@ -1733,7 +1729,7 @@ bool RISCVInstrInfo::areRVVInstsReassociable(const MachineInstr &MI1,
   return true;
 }
 
-// Most of our RVV pseudo has passthru operand, so the real operands
+// Most of our RVV pseudos have passthru operand, so the real operands
 // start from index = 2.
 bool RISCVInstrInfo::hasReassociableVectorSibling(const MachineInstr &Inst,
                                                   bool &Commuted) const {
@@ -1785,7 +1781,7 @@ void RISCVInstrInfo::getReassociateOperandIdx(
   TargetInstrInfo::getReassociateOperandIdx(Root, Pattern, OperandIndices);
   if (isVectorAssociativeAndCommutative(Root) ||
       isVectorAssociativeAndCommutative(Root, /*Invert=*/true)) {
-    // Skip the passthrough operand, so add all indices by one.
+    // Skip the passthrough operand, so increment all indices by one.
     for (unsigned I = 0; I < 5; ++I)
       ++OperandIndices[I];
   }

>From 6279eaf078548a95ede1ed61eae5d090e693abaf Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Thu, 11 Apr 2024 11:35:41 -0700
Subject: [PATCH 5/5] Check the definition of mask operand (i.e. V0)

---
 llvm/lib/Target/RISCV/RISCVInstrInfo.cpp      | 51 +++++++++++++++++--
 .../RISCV/rvv/vector-reassociations.ll        | 13 +++--
 2 files changed, 53 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index e9b896ce5b23a6..bf6ea9d1f56ef7 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -1668,6 +1668,10 @@ bool RISCVInstrInfo::areRVVInstsReassociable(const MachineInstr &MI1,
   if (!areOpcodesEqualOrInverse(MI1.getOpcode(), MI2.getOpcode()))
     return false;
 
+  assert(MI1.getMF() == MI2.getMF());
+  const MachineRegisterInfo *MRI = &MI1.getMF()->getRegInfo();
+  const TargetRegisterInfo *TRI = MRI->getTargetRegisterInfo();
+
   // Make sure vtype operands are also the same.
   const MCInstrDesc &Desc = get(MI1.getOpcode());
   const uint64_t TSFlags = Desc.TSFlags;
@@ -1690,10 +1694,49 @@ bool RISCVInstrInfo::areRVVInstsReassociable(const MachineInstr &MI1,
     return false;
 
   // Mask
-  // There might be more sophisticated ways to check equality of masks, but
-  // right now we simply check if they're the same virtual register.
-  if (RISCVII::usesMaskPolicy(TSFlags) && !checkRegOperand(4))
-    return false;
+  if (RISCVII::usesMaskPolicy(TSFlags)) {
+    const MachineBasicBlock *MBB = MI1.getParent();
+    const MachineBasicBlock::const_reverse_iterator It1(&MI1);
+    const MachineBasicBlock::const_reverse_iterator It2(&MI2);
+    Register MI1VReg;
+
+    bool SeenMI2 = false;
+    for (auto End = MBB->rend(), It = It1; It != End; ++It) {
+      if (It == It2) {
+        SeenMI2 = true;
+        if (!MI1VReg.isValid())
+          // There is no V0 def between MI1 and MI2; they're sharing the
+          // same V0.
+          break;
+      }
+
+      if (It->definesRegister(RISCV::V0, TRI)) {
+        Register SrcReg =
+            TRI->lookThruCopyLike(It->getOperand(1).getReg(), MRI);
+
+        if (!MI1VReg.isValid()) {
+          // This is the V0 def for MI1.
+          MI1VReg = SrcReg;
+          continue;
+        }
+
+        // Some random mask updates.
+        if (!SeenMI2)
+          continue;
+
+        // This is the V0 def for MI2; check if it's the same as that of
+        // MI1.
+        if (MI1VReg != SrcReg)
+          return false;
+        else
+          break;
+      }
+    }
+
+    // If we haven't encountered MI2, it's likely that this function was
+    // called in a wrong way (e.g. MI1 is before MI2).
+    assert(SeenMI2 && "MI2 is expected to appear before MI1");
+  }
 
   // Tail / Mask policies
   if (RISCVII::hasVecPolicyOp(TSFlags) &&
diff --git a/llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll b/llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll
index 7c3d48c3e48a73..6435c1c14e061e 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll
@@ -215,15 +215,16 @@ entry:
   ret <vscale x 1 x i8> %c
 }
 
-define <vscale x 1 x i8> @vadd_vv_mask_negative(<vscale x 1 x i8> %0, <vscale x 1 x i8> %1, i32 %2, <vscale x 1 x i1> %m) nounwind {
+define <vscale x 1 x i8> @vadd_vv_mask_negative(<vscale x 1 x i8> %0, <vscale x 1 x i8> %1, i32 %2, <vscale x 1 x i1> %m, <vscale x 1 x i1> %m2) nounwind {
 ; CHECK-LABEL: vadd_vv_mask_negative:
 ; CHECK:       # %bb.0: # %entry
 ; CHECK-NEXT:    vsetvli zero, a0, e8, mf8, ta, mu
-; CHECK-NEXT:    vmv1r.v v10, v8
-; CHECK-NEXT:    vadd.vv v10, v8, v9, v0.t
+; CHECK-NEXT:    vmv1r.v v11, v8
+; CHECK-NEXT:    vadd.vv v11, v8, v9, v0.t
 ; CHECK-NEXT:    vmv1r.v v9, v8
-; CHECK-NEXT:    vadd.vv v9, v8, v10, v0.t
-; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    vadd.vv v9, v8, v11, v0.t
+; CHECK-NEXT:    vmv1r.v v0, v10
+; CHECK-NEXT:    vadd.vv v8, v8, v9, v0.t
 ; CHECK-NEXT:    ret
 entry:
   %a = call <vscale x 1 x i8> @llvm.riscv.vadd.mask.nxv1i8.nxv1i8(
@@ -240,8 +241,6 @@ entry:
     <vscale x 1 x i1> %m,
     i32 %2, i32 1)
 
-  %splat = insertelement <vscale x 1 x i1> poison, i1 1, i32 0
-  %m2 = shufflevector <vscale x 1 x i1> %splat, <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer
   %c = call <vscale x 1 x i8> @llvm.riscv.vadd.mask.nxv1i8.nxv1i8(
     <vscale x 1 x i8> %0,
     <vscale x 1 x i8> %0,



More information about the llvm-commits mailing list