[llvm] 5f67ce5 - [RISCV][MachineCombiner] Add reassociation optimizations for RVV instructions (#88307)

via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 25 16:36:14 PDT 2024


Author: Min-Yih Hsu
Date: 2024-04-25T16:36:11-07:00
New Revision: 5f67ce5611ba007ed363b6a78b9c4eac85b70837

URL: https://github.com/llvm/llvm-project/commit/5f67ce5611ba007ed363b6a78b9c4eac85b70837
DIFF: https://github.com/llvm/llvm-project/commit/5f67ce5611ba007ed363b6a78b9c4eac85b70837.diff

LOG: [RISCV][MachineCombiner] Add reassociation optimizations for RVV instructions (#88307)

This patch covers a really basic reassociation optimizations for VADD_VV and VMUL_VV.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
    llvm/lib/Target/RISCV/RISCVInstrInfo.h
    llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 5c1f154efa9911..3efd09aeae879d 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -1633,8 +1633,230 @@ static bool isFMUL(unsigned Opc) {
   }
 }
 
+bool RISCVInstrInfo::isVectorAssociativeAndCommutative(const MachineInstr &Inst,
+                                                       bool Invert) const {
+#define OPCODE_LMUL_CASE(OPC)                                                  \
+  case RISCV::OPC##_M1:                                                        \
+  case RISCV::OPC##_M2:                                                        \
+  case RISCV::OPC##_M4:                                                        \
+  case RISCV::OPC##_M8:                                                        \
+  case RISCV::OPC##_MF2:                                                       \
+  case RISCV::OPC##_MF4:                                                       \
+  case RISCV::OPC##_MF8
+
+#define OPCODE_LMUL_MASK_CASE(OPC)                                             \
+  case RISCV::OPC##_M1_MASK:                                                   \
+  case RISCV::OPC##_M2_MASK:                                                   \
+  case RISCV::OPC##_M4_MASK:                                                   \
+  case RISCV::OPC##_M8_MASK:                                                   \
+  case RISCV::OPC##_MF2_MASK:                                                  \
+  case RISCV::OPC##_MF4_MASK:                                                  \
+  case RISCV::OPC##_MF8_MASK
+
+  unsigned Opcode = Inst.getOpcode();
+  if (Invert) {
+    if (auto InvOpcode = getInverseOpcode(Opcode))
+      Opcode = *InvOpcode;
+    else
+      return false;
+  }
+
+  // clang-format off
+  switch (Opcode) {
+  default:
+    return false;
+  OPCODE_LMUL_CASE(PseudoVADD_VV):
+  OPCODE_LMUL_MASK_CASE(PseudoVADD_VV):
+  OPCODE_LMUL_CASE(PseudoVMUL_VV):
+  OPCODE_LMUL_MASK_CASE(PseudoVMUL_VV):
+    return true;
+  }
+  // clang-format on
+
+#undef OPCODE_LMUL_MASK_CASE
+#undef OPCODE_LMUL_CASE
+}
+
+bool RISCVInstrInfo::areRVVInstsReassociable(const MachineInstr &Root,
+                                             const MachineInstr &Prev) const {
+  if (!areOpcodesEqualOrInverse(Root.getOpcode(), Prev.getOpcode()))
+    return false;
+
+  assert(Root.getMF() == Prev.getMF());
+  const MachineRegisterInfo *MRI = &Root.getMF()->getRegInfo();
+  const TargetRegisterInfo *TRI = MRI->getTargetRegisterInfo();
+
+  // Make sure vtype operands are also the same.
+  const MCInstrDesc &Desc = get(Root.getOpcode());
+  const uint64_t TSFlags = Desc.TSFlags;
+
+  auto checkImmOperand = [&](unsigned OpIdx) {
+    return Root.getOperand(OpIdx).getImm() == Prev.getOperand(OpIdx).getImm();
+  };
+
+  auto checkRegOperand = [&](unsigned OpIdx) {
+    return Root.getOperand(OpIdx).getReg() == Prev.getOperand(OpIdx).getReg();
+  };
+
+  // PassThru
+  // TODO: Potentially we can loosen the condition to consider Root to be
+  // associable with Prev if Root has NoReg as passthru. In which case we
+  // also need to loosen the condition on vector policies between these.
+  if (!checkRegOperand(1))
+    return false;
+
+  // SEW
+  if (RISCVII::hasSEWOp(TSFlags) &&
+      !checkImmOperand(RISCVII::getSEWOpNum(Desc)))
+    return false;
+
+  // Mask
+  if (RISCVII::usesMaskPolicy(TSFlags)) {
+    const MachineBasicBlock *MBB = Root.getParent();
+    const MachineBasicBlock::const_reverse_iterator It1(&Root);
+    const MachineBasicBlock::const_reverse_iterator It2(&Prev);
+    Register MI1VReg;
+
+    bool SeenMI2 = false;
+    for (auto End = MBB->rend(), It = It1; It != End; ++It) {
+      if (It == It2) {
+        SeenMI2 = true;
+        if (!MI1VReg.isValid())
+          // There is no V0 def between Root and Prev; they're sharing the
+          // same V0.
+          break;
+      }
+
+      if (It->modifiesRegister(RISCV::V0, TRI)) {
+        Register SrcReg = It->getOperand(1).getReg();
+        // If it's not VReg it'll be more 
diff icult to track its defs, so
+        // bailing out here just to be safe.
+        if (!SrcReg.isVirtual())
+          return false;
+
+        if (!MI1VReg.isValid()) {
+          // This is the V0 def for Root.
+          MI1VReg = SrcReg;
+          continue;
+        }
+
+        // Some random mask updates.
+        if (!SeenMI2)
+          continue;
+
+        // This is the V0 def for Prev; check if it's the same as that of
+        // Root.
+        if (MI1VReg != SrcReg)
+          return false;
+        else
+          break;
+      }
+    }
+
+    // If we haven't encountered Prev, it's likely that this function was
+    // called in a wrong way (e.g. Root is before Prev).
+    assert(SeenMI2 && "Prev is expected to appear before Root");
+  }
+
+  // Tail / Mask policies
+  if (RISCVII::hasVecPolicyOp(TSFlags) &&
+      !checkImmOperand(RISCVII::getVecPolicyOpNum(Desc)))
+    return false;
+
+  // VL
+  if (RISCVII::hasVLOp(TSFlags)) {
+    unsigned OpIdx = RISCVII::getVLOpNum(Desc);
+    const MachineOperand &Op1 = Root.getOperand(OpIdx);
+    const MachineOperand &Op2 = Prev.getOperand(OpIdx);
+    if (Op1.getType() != Op2.getType())
+      return false;
+    switch (Op1.getType()) {
+    case MachineOperand::MO_Register:
+      if (Op1.getReg() != Op2.getReg())
+        return false;
+      break;
+    case MachineOperand::MO_Immediate:
+      if (Op1.getImm() != Op2.getImm())
+        return false;
+      break;
+    default:
+      llvm_unreachable("Unrecognized VL operand type");
+    }
+  }
+
+  // Rounding modes
+  if (RISCVII::hasRoundModeOp(TSFlags) &&
+      !checkImmOperand(RISCVII::getVLOpNum(Desc) - 1))
+    return false;
+
+  return true;
+}
+
+// Most of our RVV pseudos have passthru operand, so the real operands
+// start from index = 2.
+bool RISCVInstrInfo::hasReassociableVectorSibling(const MachineInstr &Inst,
+                                                  bool &Commuted) const {
+  const MachineBasicBlock *MBB = Inst.getParent();
+  const MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();
+  assert(RISCVII::isFirstDefTiedToFirstUse(get(Inst.getOpcode())) &&
+         "Expect the present of passthrough operand.");
+  MachineInstr *MI1 = MRI.getUniqueVRegDef(Inst.getOperand(2).getReg());
+  MachineInstr *MI2 = MRI.getUniqueVRegDef(Inst.getOperand(3).getReg());
+
+  // If only one operand has the same or inverse opcode and it's the second
+  // source operand, the operands must be commuted.
+  Commuted = !areRVVInstsReassociable(Inst, *MI1) &&
+             areRVVInstsReassociable(Inst, *MI2);
+  if (Commuted)
+    std::swap(MI1, MI2);
+
+  return areRVVInstsReassociable(Inst, *MI1) &&
+         (isVectorAssociativeAndCommutative(*MI1) ||
+          isVectorAssociativeAndCommutative(*MI1, /* Invert */ true)) &&
+         hasReassociableOperands(*MI1, MBB) &&
+         MRI.hasOneNonDBGUse(MI1->getOperand(0).getReg());
+}
+
+bool RISCVInstrInfo::hasReassociableOperands(
+    const MachineInstr &Inst, const MachineBasicBlock *MBB) const {
+  if (!isVectorAssociativeAndCommutative(Inst) &&
+      !isVectorAssociativeAndCommutative(Inst, /*Invert=*/true))
+    return TargetInstrInfo::hasReassociableOperands(Inst, MBB);
+
+  const MachineOperand &Op1 = Inst.getOperand(2);
+  const MachineOperand &Op2 = Inst.getOperand(3);
+  const MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();
+
+  // We need virtual register definitions for the operands that we will
+  // reassociate.
+  MachineInstr *MI1 = nullptr;
+  MachineInstr *MI2 = nullptr;
+  if (Op1.isReg() && Op1.getReg().isVirtual())
+    MI1 = MRI.getUniqueVRegDef(Op1.getReg());
+  if (Op2.isReg() && Op2.getReg().isVirtual())
+    MI2 = MRI.getUniqueVRegDef(Op2.getReg());
+
+  // And at least one operand must be defined in MBB.
+  return MI1 && MI2 && (MI1->getParent() == MBB || MI2->getParent() == MBB);
+}
+
+void RISCVInstrInfo::getReassociateOperandIndices(
+    const MachineInstr &Root, unsigned Pattern,
+    std::array<unsigned, 5> &OperandIndices) const {
+  TargetInstrInfo::getReassociateOperandIndices(Root, Pattern, OperandIndices);
+  if (RISCV::getRVVMCOpcode(Root.getOpcode())) {
+    // Skip the passthrough operand, so increment all indices by one.
+    for (unsigned I = 0; I < 5; ++I)
+      ++OperandIndices[I];
+  }
+}
+
 bool RISCVInstrInfo::hasReassociableSibling(const MachineInstr &Inst,
                                             bool &Commuted) const {
+  if (isVectorAssociativeAndCommutative(Inst) ||
+      isVectorAssociativeAndCommutative(Inst, /*Invert=*/true))
+    return hasReassociableVectorSibling(Inst, Commuted);
+
   if (!TargetInstrInfo::hasReassociableSibling(Inst, Commuted))
     return false;
 
@@ -1654,6 +1876,9 @@ bool RISCVInstrInfo::hasReassociableSibling(const MachineInstr &Inst,
 
 bool RISCVInstrInfo::isAssociativeAndCommutative(const MachineInstr &Inst,
                                                  bool Invert) const {
+  if (isVectorAssociativeAndCommutative(Inst, Invert))
+    return true;
+
   unsigned Opc = Inst.getOpcode();
   if (Invert) {
     auto InverseOpcode = getInverseOpcode(Opc);
@@ -1706,6 +1931,38 @@ bool RISCVInstrInfo::isAssociativeAndCommutative(const MachineInstr &Inst,
 
 std::optional<unsigned>
 RISCVInstrInfo::getInverseOpcode(unsigned Opcode) const {
+#define RVV_OPC_LMUL_CASE(OPC, INV)                                            \
+  case RISCV::OPC##_M1:                                                        \
+    return RISCV::INV##_M1;                                                    \
+  case RISCV::OPC##_M2:                                                        \
+    return RISCV::INV##_M2;                                                    \
+  case RISCV::OPC##_M4:                                                        \
+    return RISCV::INV##_M4;                                                    \
+  case RISCV::OPC##_M8:                                                        \
+    return RISCV::INV##_M8;                                                    \
+  case RISCV::OPC##_MF2:                                                       \
+    return RISCV::INV##_MF2;                                                   \
+  case RISCV::OPC##_MF4:                                                       \
+    return RISCV::INV##_MF4;                                                   \
+  case RISCV::OPC##_MF8:                                                       \
+    return RISCV::INV##_MF8
+
+#define RVV_OPC_LMUL_MASK_CASE(OPC, INV)                                       \
+  case RISCV::OPC##_M1_MASK:                                                   \
+    return RISCV::INV##_M1_MASK;                                               \
+  case RISCV::OPC##_M2_MASK:                                                   \
+    return RISCV::INV##_M2_MASK;                                               \
+  case RISCV::OPC##_M4_MASK:                                                   \
+    return RISCV::INV##_M4_MASK;                                               \
+  case RISCV::OPC##_M8_MASK:                                                   \
+    return RISCV::INV##_M8_MASK;                                               \
+  case RISCV::OPC##_MF2_MASK:                                                  \
+    return RISCV::INV##_MF2_MASK;                                              \
+  case RISCV::OPC##_MF4_MASK:                                                  \
+    return RISCV::INV##_MF4_MASK;                                              \
+  case RISCV::OPC##_MF8_MASK:                                                  \
+    return RISCV::INV##_MF8_MASK
+
   switch (Opcode) {
   default:
     return std::nullopt;
@@ -1729,7 +1986,16 @@ RISCVInstrInfo::getInverseOpcode(unsigned Opcode) const {
     return RISCV::SUBW;
   case RISCV::SUBW:
     return RISCV::ADDW;
+    // clang-format off
+  RVV_OPC_LMUL_CASE(PseudoVADD_VV, PseudoVSUB_VV);
+  RVV_OPC_LMUL_MASK_CASE(PseudoVADD_VV, PseudoVSUB_VV);
+  RVV_OPC_LMUL_CASE(PseudoVSUB_VV, PseudoVADD_VV);
+  RVV_OPC_LMUL_MASK_CASE(PseudoVSUB_VV, PseudoVADD_VV);
+    // clang-format on
   }
+
+#undef RVV_OPC_LMUL_MASK_CASE
+#undef RVV_OPC_LMUL_CASE
 }
 
 static bool canCombineFPFusedMultiply(const MachineInstr &Root,

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
index 3b03d5efde6ef5..170f813eb10d7d 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
@@ -266,6 +266,9 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
       SmallVectorImpl<MachineInstr *> &DelInstrs,
       DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const override;
 
+  bool hasReassociableOperands(const MachineInstr &Inst,
+                               const MachineBasicBlock *MBB) const override;
+
   bool hasReassociableSibling(const MachineInstr &Inst,
                               bool &Commuted) const override;
 
@@ -274,6 +277,10 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
 
   std::optional<unsigned> getInverseOpcode(unsigned Opcode) const override;
 
+  void getReassociateOperandIndices(
+      const MachineInstr &Root, unsigned Pattern,
+      std::array<unsigned, 5> &OperandIndices) const override;
+
   ArrayRef<std::pair<MachineMemOperand::Flags, const char *>>
   getSerializableMachineMemOperandTargetFlags() const override;
 
@@ -297,6 +304,13 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
 
 private:
   unsigned getInstBundleLength(const MachineInstr &MI) const;
+
+  bool isVectorAssociativeAndCommutative(const MachineInstr &MI,
+                                         bool Invert = false) const;
+  bool areRVVInstsReassociable(const MachineInstr &MI1,
+                               const MachineInstr &MI2) const;
+  bool hasReassociableVectorSibling(const MachineInstr &Inst,
+                                    bool &Commuted) const;
 };
 
 namespace RISCV {

diff  --git a/llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll b/llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll
index 3cb6f3c35286cf..6435c1c14e061e 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll
@@ -31,7 +31,7 @@ define <vscale x 1 x i8> @simple_vadd_vv(<vscale x 1 x i8> %0, <vscale x 1 x i8>
 ; CHECK:       # %bb.0: # %entry
 ; CHECK-NEXT:    vsetvli zero, a0, e8, mf8, ta, ma
 ; CHECK-NEXT:    vadd.vv v9, v8, v9
-; CHECK-NEXT:    vadd.vv v9, v8, v9
+; CHECK-NEXT:    vadd.vv v8, v8, v8
 ; CHECK-NEXT:    vadd.vv v8, v8, v9
 ; CHECK-NEXT:    ret
 entry:
@@ -61,7 +61,7 @@ define <vscale x 1 x i8> @simple_vadd_vsub_vv(<vscale x 1 x i8> %0, <vscale x 1
 ; CHECK:       # %bb.0: # %entry
 ; CHECK-NEXT:    vsetvli zero, a0, e8, mf8, ta, ma
 ; CHECK-NEXT:    vsub.vv v9, v8, v9
-; CHECK-NEXT:    vadd.vv v9, v8, v9
+; CHECK-NEXT:    vadd.vv v8, v8, v8
 ; CHECK-NEXT:    vadd.vv v8, v8, v9
 ; CHECK-NEXT:    ret
 entry:
@@ -91,7 +91,7 @@ define <vscale x 1 x i8> @simple_vmul_vv(<vscale x 1 x i8> %0, <vscale x 1 x i8>
 ; CHECK:       # %bb.0: # %entry
 ; CHECK-NEXT:    vsetvli zero, a0, e8, mf8, ta, ma
 ; CHECK-NEXT:    vmul.vv v9, v8, v9
-; CHECK-NEXT:    vmul.vv v9, v8, v9
+; CHECK-NEXT:    vmul.vv v8, v8, v8
 ; CHECK-NEXT:    vmul.vv v8, v8, v9
 ; CHECK-NEXT:    ret
 entry:
@@ -124,8 +124,8 @@ define <vscale x 1 x i8> @vadd_vv_passthru(<vscale x 1 x i8> %0, <vscale x 1 x i
 ; CHECK-NEXT:    vmv1r.v v10, v8
 ; CHECK-NEXT:    vadd.vv v10, v8, v9
 ; CHECK-NEXT:    vmv1r.v v9, v8
-; CHECK-NEXT:    vadd.vv v9, v8, v10
-; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    vadd.vv v9, v8, v8
+; CHECK-NEXT:    vadd.vv v8, v9, v10
 ; CHECK-NEXT:    ret
 entry:
   %a = call <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
@@ -187,8 +187,8 @@ define <vscale x 1 x i8> @vadd_vv_mask(<vscale x 1 x i8> %0, <vscale x 1 x i8> %
 ; CHECK-NEXT:    vmv1r.v v10, v8
 ; CHECK-NEXT:    vadd.vv v10, v8, v9, v0.t
 ; CHECK-NEXT:    vmv1r.v v9, v8
-; CHECK-NEXT:    vadd.vv v9, v8, v10, v0.t
-; CHECK-NEXT:    vadd.vv v8, v8, v9, v0.t
+; CHECK-NEXT:    vadd.vv v9, v8, v8, v0.t
+; CHECK-NEXT:    vadd.vv v8, v9, v10, v0.t
 ; CHECK-NEXT:    ret
 entry:
   %a = call <vscale x 1 x i8> @llvm.riscv.vadd.mask.nxv1i8.nxv1i8(
@@ -215,15 +215,16 @@ entry:
   ret <vscale x 1 x i8> %c
 }
 
-define <vscale x 1 x i8> @vadd_vv_mask_negative(<vscale x 1 x i8> %0, <vscale x 1 x i8> %1, i32 %2, <vscale x 1 x i1> %m) nounwind {
+define <vscale x 1 x i8> @vadd_vv_mask_negative(<vscale x 1 x i8> %0, <vscale x 1 x i8> %1, i32 %2, <vscale x 1 x i1> %m, <vscale x 1 x i1> %m2) nounwind {
 ; CHECK-LABEL: vadd_vv_mask_negative:
 ; CHECK:       # %bb.0: # %entry
 ; CHECK-NEXT:    vsetvli zero, a0, e8, mf8, ta, mu
-; CHECK-NEXT:    vmv1r.v v10, v8
-; CHECK-NEXT:    vadd.vv v10, v8, v9, v0.t
+; CHECK-NEXT:    vmv1r.v v11, v8
+; CHECK-NEXT:    vadd.vv v11, v8, v9, v0.t
 ; CHECK-NEXT:    vmv1r.v v9, v8
-; CHECK-NEXT:    vadd.vv v9, v8, v10, v0.t
-; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    vadd.vv v9, v8, v11, v0.t
+; CHECK-NEXT:    vmv1r.v v0, v10
+; CHECK-NEXT:    vadd.vv v8, v8, v9, v0.t
 ; CHECK-NEXT:    ret
 entry:
   %a = call <vscale x 1 x i8> @llvm.riscv.vadd.mask.nxv1i8.nxv1i8(
@@ -240,8 +241,6 @@ entry:
     <vscale x 1 x i1> %m,
     i32 %2, i32 1)
 
-  %splat = insertelement <vscale x 1 x i1> poison, i1 1, i32 0
-  %m2 = shufflevector <vscale x 1 x i1> %splat, <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer
   %c = call <vscale x 1 x i8> @llvm.riscv.vadd.mask.nxv1i8.nxv1i8(
     <vscale x 1 x i8> %0,
     <vscale x 1 x i8> %0,


        


More information about the llvm-commits mailing list