[llvm] [RISCV][MachineCombiner] Add reassociation optimizations for RVV instructions (PR #88307)

Min-Yih Hsu via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 15 23:25:49 PDT 2024


================
@@ -1619,8 +1619,223 @@ static bool isFMUL(unsigned Opc) {
   }
 }
 
+bool RISCVInstrInfo::isVectorAssociativeAndCommutative(const MachineInstr &Inst,
+                                                       bool Invert) const {
+#define OPCODE_LMUL_CASE(OPC)                                                  \
+  case RISCV::OPC##_M1:                                                        \
+  case RISCV::OPC##_M2:                                                        \
+  case RISCV::OPC##_M4:                                                        \
+  case RISCV::OPC##_M8:                                                        \
+  case RISCV::OPC##_MF2:                                                       \
+  case RISCV::OPC##_MF4:                                                       \
+  case RISCV::OPC##_MF8
+
+#define OPCODE_LMUL_MASK_CASE(OPC)                                             \
+  case RISCV::OPC##_M1_MASK:                                                   \
+  case RISCV::OPC##_M2_MASK:                                                   \
+  case RISCV::OPC##_M4_MASK:                                                   \
+  case RISCV::OPC##_M8_MASK:                                                   \
+  case RISCV::OPC##_MF2_MASK:                                                  \
+  case RISCV::OPC##_MF4_MASK:                                                  \
+  case RISCV::OPC##_MF8_MASK
+
+  unsigned Opcode = Inst.getOpcode();
+  if (Invert) {
+    if (auto InvOpcode = getInverseOpcode(Opcode))
+      Opcode = *InvOpcode;
+    else
+      return false;
+  }
+
+  // clang-format off
+  switch (Opcode) {
+  default:
+    return false;
+  OPCODE_LMUL_CASE(PseudoVADD_VV):
+  OPCODE_LMUL_MASK_CASE(PseudoVADD_VV):
+  OPCODE_LMUL_CASE(PseudoVMUL_VV):
+  OPCODE_LMUL_MASK_CASE(PseudoVMUL_VV):
+    return true;
+  }
+  // clang-format on
+
+#undef OPCODE_LMUL_MASK_CASE
+#undef OPCODE_LMUL_CASE
+}
+
+bool RISCVInstrInfo::areRVVInstsReassociable(const MachineInstr &MI1,
+                                             const MachineInstr &MI2) const {
+  if (!areOpcodesEqualOrInverse(MI1.getOpcode(), MI2.getOpcode()))
+    return false;
+
+  assert(MI1.getMF() == MI2.getMF());
+  const MachineRegisterInfo *MRI = &MI1.getMF()->getRegInfo();
+  const TargetRegisterInfo *TRI = MRI->getTargetRegisterInfo();
+
+  // Make sure vtype operands are also the same.
+  const MCInstrDesc &Desc = get(MI1.getOpcode());
+  const uint64_t TSFlags = Desc.TSFlags;
+
+  auto checkImmOperand = [&](unsigned OpIdx) {
+    return MI1.getOperand(OpIdx).getImm() == MI2.getOperand(OpIdx).getImm();
+  };
+
+  auto checkRegOperand = [&](unsigned OpIdx) {
+    return MI1.getOperand(OpIdx).getReg() == MI2.getOperand(OpIdx).getReg();
+  };
+
+  // PassThru
+  if (!checkRegOperand(1))
+    return false;
+
+  // SEW
+  if (RISCVII::hasSEWOp(TSFlags) &&
+      !checkImmOperand(RISCVII::getSEWOpNum(Desc)))
+    return false;
+
+  // Mask
+  if (RISCVII::usesMaskPolicy(TSFlags)) {
+    const MachineBasicBlock *MBB = MI1.getParent();
+    const MachineBasicBlock::const_reverse_iterator It1(&MI1);
+    const MachineBasicBlock::const_reverse_iterator It2(&MI2);
+    Register MI1VReg;
+
+    bool SeenMI2 = false;
+    for (auto End = MBB->rend(), It = It1; It != End; ++It) {
+      if (It == It2) {
+        SeenMI2 = true;
+        if (!MI1VReg.isValid())
+          // There is no V0 def between MI1 and MI2; they're sharing the
+          // same V0.
+          break;
+      }
+
+      if (It->definesRegister(RISCV::V0, TRI)) {
+        Register SrcReg =
+            TRI->lookThruCopyLike(It->getOperand(1).getReg(), MRI);
+
+        if (!MI1VReg.isValid()) {
+          // This is the V0 def for MI1.
+          MI1VReg = SrcReg;
+          continue;
+        }
+
+        // Some random mask updates.
+        if (!SeenMI2)
+          continue;
+
+        // This is the V0 def for MI2; check if it's the same as that of
+        // MI1.
+        if (MI1VReg != SrcReg)
----------------
mshockwave wrote:

> we would need to check that nothing writes to it in between M1 and M2.

Absolutely, and I think I've done that: the loop starting from 1704 is traversing backward from M1, trying to find the last V0 def before M1. It continues and when traversing pass M2, it tried to find the last V0 def before M2, comparing it with that of before M1 which we have retrieved. 
For cases where there is no V0 def between M1 and M2, the check on line 1707 covers that.

https://github.com/llvm/llvm-project/pull/88307


More information about the llvm-commits mailing list