[llvm] [AArch64][GlobalISel] Improve MULL generation (PR #112405)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 10 00:38:00 PDT 2025
================
@@ -438,6 +438,123 @@ void applyCombineMulCMLT(MachineInstr &MI, MachineRegisterInfo &MRI,
MI.eraseFromParent();
}
+// Match mul({z/s}ext , {z/s}ext) => {u/s}mull
+bool matchExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI,
+ GISelKnownBits *KB,
+ std::tuple<bool, Register, Register> &MatchInfo) {
+ // Get the instructions that defined the source operand
+ LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
+ MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
+ MachineInstr *I2 = getDefIgnoringCopies(MI.getOperand(2).getReg(), MRI);
+ unsigned I1Opc = I1->getOpcode();
+ unsigned I2Opc = I2->getOpcode();
+ unsigned EltSize = DstTy.getScalarSizeInBits();
+
+ if (!DstTy.isVector() || I1->getNumOperands() < 2 || I2->getNumOperands() < 2)
+ return false;
+
+ auto IsAtLeastDoubleExtend = [&](Register R) {
+ LLT Ty = MRI.getType(R);
+ return EltSize >= Ty.getScalarSizeInBits() * 2;
+ };
+
+ // If the source operands were EXTENDED before, then {U/S}MULL can be used
+ bool IsZExt1 =
+ I1Opc == TargetOpcode::G_ZEXT || I1Opc == TargetOpcode::G_ANYEXT;
+ bool IsZExt2 =
+ I2Opc == TargetOpcode::G_ZEXT || I2Opc == TargetOpcode::G_ANYEXT;
+ if (IsZExt1 && IsZExt2 && IsAtLeastDoubleExtend(I1->getOperand(1).getReg()) &&
+ IsAtLeastDoubleExtend(I2->getOperand(1).getReg())) {
+ get<0>(MatchInfo) = true;
+ get<1>(MatchInfo) = I1->getOperand(1).getReg();
+ get<2>(MatchInfo) = I2->getOperand(1).getReg();
+ return true;
+ }
+
+ bool IsSExt1 =
+ I1Opc == TargetOpcode::G_SEXT || I1Opc == TargetOpcode::G_ANYEXT;
+ bool IsSExt2 =
+ I2Opc == TargetOpcode::G_SEXT || I2Opc == TargetOpcode::G_ANYEXT;
+ if (IsSExt1 && IsSExt2 && IsAtLeastDoubleExtend(I1->getOperand(1).getReg()) &&
+ IsAtLeastDoubleExtend(I2->getOperand(1).getReg())) {
+ get<0>(MatchInfo) = false;
+ get<1>(MatchInfo) = I1->getOperand(1).getReg();
+ get<2>(MatchInfo) = I2->getOperand(1).getReg();
+ return true;
+ }
+
+ // Select UMULL if we can replace the other operand with an extend.
+ APInt Mask = APInt::getHighBitsSet(EltSize, EltSize / 2);
+ if (KB && (IsZExt1 || IsZExt2) &&
+ IsAtLeastDoubleExtend(IsZExt1 ? I1->getOperand(1).getReg()
+ : I2->getOperand(1).getReg())) {
+ Register ZExtOp =
+ IsZExt1 ? MI.getOperand(2).getReg() : MI.getOperand(1).getReg();
+ if (KB->maskedValueIsZero(ZExtOp, Mask)) {
+ get<0>(MatchInfo) = true;
+ get<1>(MatchInfo) = IsZExt1 ? I1->getOperand(1).getReg() : ZExtOp;
+ get<2>(MatchInfo) = IsZExt1 ? ZExtOp : I2->getOperand(1).getReg();
+ return true;
+ }
+ } else if (KB && DstTy == LLT::fixed_vector(2, 64) &&
+ KB->maskedValueIsZero(MI.getOperand(1).getReg(), Mask) &&
+ KB->maskedValueIsZero(MI.getOperand(2).getReg(), Mask)) {
+ get<0>(MatchInfo) = true;
+ get<1>(MatchInfo) = MI.getOperand(1).getReg();
+ get<2>(MatchInfo) = MI.getOperand(2).getReg();
+ return true;
+ }
+
+ if (KB && (IsSExt1 || IsSExt2) &&
+ IsAtLeastDoubleExtend(IsSExt1 ? I1->getOperand(1).getReg()
+ : I2->getOperand(1).getReg())) {
+ Register SExtOp =
+ IsSExt1 ? MI.getOperand(2).getReg() : MI.getOperand(1).getReg();
+ if (KB->computeNumSignBits(SExtOp) > EltSize / 2) {
+ get<0>(MatchInfo) = false;
+ get<1>(MatchInfo) = IsSExt1 ? I1->getOperand(1).getReg() : SExtOp;
+ get<2>(MatchInfo) = IsSExt1 ? SExtOp : I2->getOperand(1).getReg();
+ return true;
+ }
+ } else if (KB && DstTy == LLT::fixed_vector(2, 64) &&
+ KB->computeNumSignBits(MI.getOperand(1).getReg()) > EltSize / 2 &&
+ KB->computeNumSignBits(MI.getOperand(2).getReg()) > EltSize / 2) {
+ get<0>(MatchInfo) = false;
+ get<1>(MatchInfo) = MI.getOperand(1).getReg();
+ get<2>(MatchInfo) = MI.getOperand(2).getReg();
+ return true;
+ }
+
+ return false;
+}
+
+void applyExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI,
+ MachineIRBuilder &B, GISelChangeObserver &Observer,
+ std::tuple<bool, Register, Register> &MatchInfo) {
+ assert(MI.getOpcode() == TargetOpcode::G_MUL &&
+ "Expected a G_MUL instruction");
+
+ // Get the instructions that defined the source operand
+ LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
+ bool IsZExt = get<0>(MatchInfo);
+ Register Src1Reg = get<1>(MatchInfo);
+ Register Src2Reg = get<2>(MatchInfo);
+ LLT Src1Ty = MRI.getType(Src1Reg);
+ LLT Src2Ty = MRI.getType(Src2Reg);
+ LLT HalfDstTy = DstTy.changeElementSize(DstTy.getScalarSizeInBits() / 2);
+ unsigned ExtOpc = IsZExt ? TargetOpcode::G_ZEXT : TargetOpcode::G_SEXT;
+
+ if (Src1Ty.getScalarSizeInBits() * 2 != DstTy.getScalarSizeInBits())
+ Src1Reg = B.buildExtOrTrunc(ExtOpc, {HalfDstTy}, {Src1Reg}).getReg(0);
+ if (Src2Ty.getScalarSizeInBits() * 2 != DstTy.getScalarSizeInBits())
+ Src2Reg = B.buildExtOrTrunc(ExtOpc, {HalfDstTy}, {Src2Reg}).getReg(0);
+
+ B.setInstrAndDebugLoc(MI);
----------------
arsenm wrote:
I think this is pre-set before any apply action now
https://github.com/llvm/llvm-project/pull/112405
More information about the llvm-commits
mailing list