[llvm] [RISCV][MachineCombiner] Add reassociation optimizations for RVV instructions (PR #88307)
Min-Yih Hsu via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 10 13:44:28 PDT 2024
================
@@ -1619,8 +1619,184 @@ static bool isFMUL(unsigned Opc) {
}
}
+bool RISCVInstrInfo::isVectorAssociativeAndCommutative(const MachineInstr &Inst,
+ bool Invert) const {
+#define OPCODE_LMUL_CASE(OPC) \
+ case RISCV::OPC##_M1: \
+ case RISCV::OPC##_M2: \
+ case RISCV::OPC##_M4: \
+ case RISCV::OPC##_M8: \
+ case RISCV::OPC##_MF2: \
+ case RISCV::OPC##_MF4: \
+ case RISCV::OPC##_MF8
+
+#define OPCODE_LMUL_MASK_CASE(OPC) \
+ case RISCV::OPC##_M1_MASK: \
+ case RISCV::OPC##_M2_MASK: \
+ case RISCV::OPC##_M4_MASK: \
+ case RISCV::OPC##_M8_MASK: \
+ case RISCV::OPC##_MF2_MASK: \
+ case RISCV::OPC##_MF4_MASK: \
+ case RISCV::OPC##_MF8_MASK
+
+ unsigned Opcode = Inst.getOpcode();
+ if (Invert) {
+ if (auto InvOpcode = getInverseOpcode(Opcode))
+ Opcode = *InvOpcode;
+ else
+ return false;
+ }
+
+ // clang-format off
+ switch (Opcode) {
+ default:
+ return false;
+ OPCODE_LMUL_CASE(PseudoVADD_VV):
+ OPCODE_LMUL_MASK_CASE(PseudoVADD_VV):
+ OPCODE_LMUL_CASE(PseudoVMUL_VV):
+ OPCODE_LMUL_MASK_CASE(PseudoVMUL_VV):
+ OPCODE_LMUL_CASE(PseudoVMULH_VV):
+ OPCODE_LMUL_MASK_CASE(PseudoVMULH_VV):
+ OPCODE_LMUL_CASE(PseudoVMULHU_VV):
+ OPCODE_LMUL_MASK_CASE(PseudoVMULHU_VV):
+ return true;
+ }
+ // clang-format on
+
+#undef OPCODE_LMUL_MASK_CASE
+#undef OPCODE_LMUL_CASE
+}
+
+bool RISCVInstrInfo::areRVVInstsReassociable(const MachineInstr &MI1,
+ const MachineInstr &MI2) const {
+ if (!areOpcodesEqualOrInverse(MI1.getOpcode(), MI2.getOpcode()))
+ return false;
+
+ // Make sure vtype operands are also the same.
+ const MCInstrDesc &Desc = get(MI1.getOpcode());
+ const uint64_t TSFlags = Desc.TSFlags;
+
+ auto checkImmOperand = [&](unsigned OpIdx) {
+ return MI1.getOperand(OpIdx).getImm() == MI2.getOperand(OpIdx).getImm();
+ };
+
+ auto checkRegOperand = [&](unsigned OpIdx) {
+ return MI1.getOperand(OpIdx).getReg() == MI2.getOperand(OpIdx).getReg();
+ };
+
+ // PassThru
+ if (!checkRegOperand(1))
+ return false;
+
+ // SEW
+ if (RISCVII::hasSEWOp(TSFlags) &&
+ !checkImmOperand(RISCVII::getSEWOpNum(Desc)))
+ return false;
+
+ // Mask
+ // There might be more sophisticated ways to check equality of masks, but
+ // right now we simply check if they're the same virtual register.
+ if (RISCVII::usesMaskPolicy(TSFlags) && !checkRegOperand(4))
+ return false;
+
+ // Tail / Mask policies
+ if (RISCVII::hasVecPolicyOp(TSFlags) &&
+ !checkImmOperand(RISCVII::getVecPolicyOpNum(Desc)))
+ return false;
+
+ // VL
+ if (RISCVII::hasVLOp(TSFlags)) {
+ unsigned OpIdx = RISCVII::getVLOpNum(Desc);
+ const MachineOperand &Op1 = MI1.getOperand(OpIdx);
+ const MachineOperand &Op2 = MI2.getOperand(OpIdx);
+ if (Op1.getType() != Op2.getType())
+ return false;
+ switch (Op1.getType()) {
+ case MachineOperand::MO_Register:
+ if (Op1.getReg() != Op2.getReg())
+ return false;
+ break;
+ case MachineOperand::MO_Immediate:
+ if (Op1.getImm() != Op2.getImm())
+ return false;
+ break;
+ default:
+ llvm_unreachable("Unrecognized VL operand type");
+ }
+ }
+
+ // Rounding modes
+ if (RISCVII::hasRoundModeOp(TSFlags) &&
+ !checkImmOperand(RISCVII::getVLOpNum(Desc) - 1))
+ return false;
+
+ return true;
+}
+
+// Most of our RVV pseudo has passthru operand, so the real operands
+// start from index = 2.
+bool RISCVInstrInfo::hasReassociableVectorSibling(const MachineInstr &Inst,
+ bool &Commuted) const {
+ const MachineBasicBlock *MBB = Inst.getParent();
+ const MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();
+ MachineInstr *MI1 = MRI.getUniqueVRegDef(Inst.getOperand(2).getReg());
+ MachineInstr *MI2 = MRI.getUniqueVRegDef(Inst.getOperand(3).getReg());
+
+ // If only one operand has the same or inverse opcode and it's the second
+ // source operand, the operands must be commuted.
+ Commuted = !areRVVInstsReassociable(Inst, *MI1) &&
+ areRVVInstsReassociable(Inst, *MI2);
+ if (Commuted)
+ std::swap(MI1, MI2);
+
+ return areRVVInstsReassociable(Inst, *MI1) &&
+ (isVectorAssociativeAndCommutative(*MI1) ||
+ isVectorAssociativeAndCommutative(*MI1, /* Invert */ true)) &&
+ hasReassociableOperands(*MI1, MBB) &&
+ MRI.hasOneNonDBGUse(MI1->getOperand(0).getReg());
+}
+
+bool RISCVInstrInfo::hasReassociableOperands(
+ const MachineInstr &Inst, const MachineBasicBlock *MBB) const {
+ if (!isVectorAssociativeAndCommutative(Inst) &&
+ !isVectorAssociativeAndCommutative(Inst, /*Invert=*/true))
+ return TargetInstrInfo::hasReassociableOperands(Inst, MBB);
+
+ const MachineOperand &Op1 = Inst.getOperand(2);
+ const MachineOperand &Op2 = Inst.getOperand(3);
+ const MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();
+
+ // We need virtual register definitions for the operands that we will
+ // reassociate.
+ MachineInstr *MI1 = nullptr;
+ MachineInstr *MI2 = nullptr;
+ if (Op1.isReg() && Op1.getReg().isVirtual())
+ MI1 = MRI.getUniqueVRegDef(Op1.getReg());
+ if (Op2.isReg() && Op2.getReg().isVirtual())
+ MI2 = MRI.getUniqueVRegDef(Op2.getReg());
+
+ // And at least one operand must be defined in MBB.
+ return MI1 && MI2 && (MI1->getParent() == MBB || MI2->getParent() == MBB);
+}
+
+void RISCVInstrInfo::getReassociateOperandIdx(
+ const MachineInstr &Root, MachineCombinerPattern Pattern,
+ std::array<unsigned, 5> &OperandIndices) const {
+ TargetInstrInfo::getReassociateOperandIdx(Root, Pattern, OperandIndices);
+ if (isVectorAssociativeAndCommutative(Root) ||
+ isVectorAssociativeAndCommutative(Root, /*Invert=*/true)) {
+ // Skip the passthrough operand, so add all indices by one.
----------------
mshockwave wrote:
Fixed.
https://github.com/llvm/llvm-project/pull/88307
More information about the llvm-commits
mailing list