[llvm] [AArch64][GlobalISel] Improve MULL generation (PR #112405)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 10 00:38:00 PDT 2025


================
@@ -438,6 +438,123 @@ void applyCombineMulCMLT(MachineInstr &MI, MachineRegisterInfo &MRI,
   MI.eraseFromParent();
 }
 
+// Match mul({z/s}ext , {z/s}ext) => {u/s}mull
+bool matchExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI,
+                       GISelKnownBits *KB,
+                       std::tuple<bool, Register, Register> &MatchInfo) {
+  // Get the instructions that defined the source operand
+  LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
+  MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
+  MachineInstr *I2 = getDefIgnoringCopies(MI.getOperand(2).getReg(), MRI);
+  unsigned I1Opc = I1->getOpcode();
+  unsigned I2Opc = I2->getOpcode();
+  unsigned EltSize = DstTy.getScalarSizeInBits();
+
+  if (!DstTy.isVector() || I1->getNumOperands() < 2 || I2->getNumOperands() < 2)
+    return false;
+
+  auto IsAtLeastDoubleExtend = [&](Register R) {
+    LLT Ty = MRI.getType(R);
+    return EltSize >= Ty.getScalarSizeInBits() * 2;
+  };
+
+  // If the source operands were EXTENDED before, then {U/S}MULL can be used
+  bool IsZExt1 =
+      I1Opc == TargetOpcode::G_ZEXT || I1Opc == TargetOpcode::G_ANYEXT;
+  bool IsZExt2 =
+      I2Opc == TargetOpcode::G_ZEXT || I2Opc == TargetOpcode::G_ANYEXT;
+  if (IsZExt1 && IsZExt2 && IsAtLeastDoubleExtend(I1->getOperand(1).getReg()) &&
+      IsAtLeastDoubleExtend(I2->getOperand(1).getReg())) {
+    get<0>(MatchInfo) = true;
+    get<1>(MatchInfo) = I1->getOperand(1).getReg();
+    get<2>(MatchInfo) = I2->getOperand(1).getReg();
+    return true;
+  }
+
+  bool IsSExt1 =
+      I1Opc == TargetOpcode::G_SEXT || I1Opc == TargetOpcode::G_ANYEXT;
+  bool IsSExt2 =
+      I2Opc == TargetOpcode::G_SEXT || I2Opc == TargetOpcode::G_ANYEXT;
+  if (IsSExt1 && IsSExt2 && IsAtLeastDoubleExtend(I1->getOperand(1).getReg()) &&
+      IsAtLeastDoubleExtend(I2->getOperand(1).getReg())) {
+    get<0>(MatchInfo) = false;
+    get<1>(MatchInfo) = I1->getOperand(1).getReg();
+    get<2>(MatchInfo) = I2->getOperand(1).getReg();
+    return true;
+  }
+
+  // Select UMULL if we can replace the other operand with an extend.
+  APInt Mask = APInt::getHighBitsSet(EltSize, EltSize / 2);
+  if (KB && (IsZExt1 || IsZExt2) &&
+      IsAtLeastDoubleExtend(IsZExt1 ? I1->getOperand(1).getReg()
+                                    : I2->getOperand(1).getReg())) {
+    Register ZExtOp =
+        IsZExt1 ? MI.getOperand(2).getReg() : MI.getOperand(1).getReg();
+    if (KB->maskedValueIsZero(ZExtOp, Mask)) {
+      get<0>(MatchInfo) = true;
+      get<1>(MatchInfo) = IsZExt1 ? I1->getOperand(1).getReg() : ZExtOp;
+      get<2>(MatchInfo) = IsZExt1 ? ZExtOp : I2->getOperand(1).getReg();
+      return true;
+    }
+  } else if (KB && DstTy == LLT::fixed_vector(2, 64) &&
+             KB->maskedValueIsZero(MI.getOperand(1).getReg(), Mask) &&
+             KB->maskedValueIsZero(MI.getOperand(2).getReg(), Mask)) {
+    get<0>(MatchInfo) = true;
+    get<1>(MatchInfo) = MI.getOperand(1).getReg();
+    get<2>(MatchInfo) = MI.getOperand(2).getReg();
+    return true;
+  }
+
+  if (KB && (IsSExt1 || IsSExt2) &&
+      IsAtLeastDoubleExtend(IsSExt1 ? I1->getOperand(1).getReg()
+                                    : I2->getOperand(1).getReg())) {
+    Register SExtOp =
+        IsSExt1 ? MI.getOperand(2).getReg() : MI.getOperand(1).getReg();
+    if (KB->computeNumSignBits(SExtOp) > EltSize / 2) {
+      get<0>(MatchInfo) = false;
+      get<1>(MatchInfo) = IsSExt1 ? I1->getOperand(1).getReg() : SExtOp;
+      get<2>(MatchInfo) = IsSExt1 ? SExtOp : I2->getOperand(1).getReg();
+      return true;
+    }
+  } else if (KB && DstTy == LLT::fixed_vector(2, 64) &&
+             KB->computeNumSignBits(MI.getOperand(1).getReg()) > EltSize / 2 &&
+             KB->computeNumSignBits(MI.getOperand(2).getReg()) > EltSize / 2) {
+    get<0>(MatchInfo) = false;
+    get<1>(MatchInfo) = MI.getOperand(1).getReg();
+    get<2>(MatchInfo) = MI.getOperand(2).getReg();
+    return true;
+  }
+
+  return false;
+}
+
+void applyExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI,
+                       MachineIRBuilder &B, GISelChangeObserver &Observer,
+                       std::tuple<bool, Register, Register> &MatchInfo) {
+  assert(MI.getOpcode() == TargetOpcode::G_MUL &&
+         "Expected a G_MUL instruction");
+
+  // Get the instructions that defined the source operand
+  LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
+  bool IsZExt = get<0>(MatchInfo);
+  Register Src1Reg = get<1>(MatchInfo);
+  Register Src2Reg = get<2>(MatchInfo);
+  LLT Src1Ty = MRI.getType(Src1Reg);
+  LLT Src2Ty = MRI.getType(Src2Reg);
+  LLT HalfDstTy = DstTy.changeElementSize(DstTy.getScalarSizeInBits() / 2);
+  unsigned ExtOpc = IsZExt ? TargetOpcode::G_ZEXT : TargetOpcode::G_SEXT;
+
+  if (Src1Ty.getScalarSizeInBits() * 2 != DstTy.getScalarSizeInBits())
+    Src1Reg = B.buildExtOrTrunc(ExtOpc, {HalfDstTy}, {Src1Reg}).getReg(0);
+  if (Src2Ty.getScalarSizeInBits() * 2 != DstTy.getScalarSizeInBits())
+    Src2Reg = B.buildExtOrTrunc(ExtOpc, {HalfDstTy}, {Src2Reg}).getReg(0);
+
+  B.setInstrAndDebugLoc(MI);
----------------
arsenm wrote:

I think this is pre-set before any apply action now 

https://github.com/llvm/llvm-project/pull/112405


More information about the llvm-commits mailing list