[llvm] [RISCV][MachineCombiner] Add reassociation optimizations for RVV instructions (PR #88307)

Min-Yih Hsu via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 15 13:43:48 PDT 2024


================
@@ -1619,8 +1619,223 @@ static bool isFMUL(unsigned Opc) {
   }
 }
 
+bool RISCVInstrInfo::isVectorAssociativeAndCommutative(const MachineInstr &Inst,
+                                                       bool Invert) const {
+#define OPCODE_LMUL_CASE(OPC)                                                  \
+  case RISCV::OPC##_M1:                                                        \
+  case RISCV::OPC##_M2:                                                        \
+  case RISCV::OPC##_M4:                                                        \
+  case RISCV::OPC##_M8:                                                        \
+  case RISCV::OPC##_MF2:                                                       \
+  case RISCV::OPC##_MF4:                                                       \
+  case RISCV::OPC##_MF8
+
+#define OPCODE_LMUL_MASK_CASE(OPC)                                             \
+  case RISCV::OPC##_M1_MASK:                                                   \
+  case RISCV::OPC##_M2_MASK:                                                   \
+  case RISCV::OPC##_M4_MASK:                                                   \
+  case RISCV::OPC##_M8_MASK:                                                   \
+  case RISCV::OPC##_MF2_MASK:                                                  \
+  case RISCV::OPC##_MF4_MASK:                                                  \
+  case RISCV::OPC##_MF8_MASK
+
+  unsigned Opcode = Inst.getOpcode();
+  if (Invert) {
+    if (auto InvOpcode = getInverseOpcode(Opcode))
+      Opcode = *InvOpcode;
+    else
+      return false;
+  }
+
+  // clang-format off
+  switch (Opcode) {
+  default:
+    return false;
+  OPCODE_LMUL_CASE(PseudoVADD_VV):
+  OPCODE_LMUL_MASK_CASE(PseudoVADD_VV):
+  OPCODE_LMUL_CASE(PseudoVMUL_VV):
+  OPCODE_LMUL_MASK_CASE(PseudoVMUL_VV):
+    return true;
+  }
+  // clang-format on
+
+#undef OPCODE_LMUL_MASK_CASE
+#undef OPCODE_LMUL_CASE
+}
+
+bool RISCVInstrInfo::areRVVInstsReassociable(const MachineInstr &MI1,
+                                             const MachineInstr &MI2) const {
+  if (!areOpcodesEqualOrInverse(MI1.getOpcode(), MI2.getOpcode()))
+    return false;
+
+  assert(MI1.getMF() == MI2.getMF());
+  const MachineRegisterInfo *MRI = &MI1.getMF()->getRegInfo();
+  const TargetRegisterInfo *TRI = MRI->getTargetRegisterInfo();
+
+  // Make sure vtype operands are also the same.
+  const MCInstrDesc &Desc = get(MI1.getOpcode());
+  const uint64_t TSFlags = Desc.TSFlags;
+
+  auto checkImmOperand = [&](unsigned OpIdx) {
+    return MI1.getOperand(OpIdx).getImm() == MI2.getOperand(OpIdx).getImm();
+  };
+
+  auto checkRegOperand = [&](unsigned OpIdx) {
+    return MI1.getOperand(OpIdx).getReg() == MI2.getOperand(OpIdx).getReg();
+  };
+
+  // PassThru
+  if (!checkRegOperand(1))
+    return false;
+
+  // SEW
+  if (RISCVII::hasSEWOp(TSFlags) &&
+      !checkImmOperand(RISCVII::getSEWOpNum(Desc)))
+    return false;
+
+  // Mask
+  if (RISCVII::usesMaskPolicy(TSFlags)) {
+    const MachineBasicBlock *MBB = MI1.getParent();
+    const MachineBasicBlock::const_reverse_iterator It1(&MI1);
+    const MachineBasicBlock::const_reverse_iterator It2(&MI2);
+    Register MI1VReg;
+
+    bool SeenMI2 = false;
+    for (auto End = MBB->rend(), It = It1; It != End; ++It) {
+      if (It == It2) {
+        SeenMI2 = true;
+        if (!MI1VReg.isValid())
+          // There is no V0 def between MI1 and MI2; they're sharing the
+          // same V0.
+          break;
+      }
+
+      if (It->definesRegister(RISCV::V0, TRI)) {
+        Register SrcReg =
+            TRI->lookThruCopyLike(It->getOperand(1).getReg(), MRI);
+
+        if (!MI1VReg.isValid()) {
+          // This is the V0 def for MI1.
+          MI1VReg = SrcReg;
+          continue;
+        }
+
+        // Some random mask updates.
+        if (!SeenMI2)
+          continue;
+
+        // This is the V0 def for MI2; check if it's the same as that of
+        // MI1.
+        if (MI1VReg != SrcReg)
----------------
mshockwave wrote:

I don't think SrcReg will necessarily be a virtual register. For example, if the mask is from a function argument, something like this will be generated:
```
    %1:vr = COPY $v0
    $v0 = COPY %1
    %3:vrnov0 = PseudoVSRL_VI_MF8_MASK $noreg, %0, 1, $v0, %2, 3 /* e8 */, 1 /* ta, mu */
    $v0 = COPY %1
    %5:vrnov0 = PseudoVAND_VX_MF8_MASK $noreg, killed %3, killed %4, $v0, %2, 3 /* e8 */, 1 /* ta, mu */
   ...
```
And SrcReg will be V0 in this case.

https://github.com/llvm/llvm-project/pull/88307


More information about the llvm-commits mailing list