[llvm] 004bf17 - [AArch64] Emit FNMADD instead of FNEG(FMADD)

Matt Devereau via llvm-commits llvm-commits at lists.llvm.org
Wed May 10 05:47:51 PDT 2023


Author: Matt Devereau
Date: 2023-05-10T12:45:54Z
New Revision: 004bf170c6cbaa049601bcf92f86a9459aec2dc2

URL: https://github.com/llvm/llvm-project/commit/004bf170c6cbaa049601bcf92f86a9459aec2dc2
DIFF: https://github.com/llvm/llvm-project/commit/004bf170c6cbaa049601bcf92f86a9459aec2dc2.diff

LOG: [AArch64] Emit FNMADD instead of FNEG(FMADD)

Emit FNMADD instead of FNEG(FMADD) for optimization levels
above Oz when fast-math flags (nsz+contract) permit it.

Differential Revision: https://reviews.llvm.org/D149260

Added: 
    llvm/test/CodeGen/AArch64/aarch64_fnmadd.ll

Modified: 
    llvm/include/llvm/CodeGen/MachineCombinerPattern.h
    llvm/lib/Target/AArch64/AArch64InstrInfo.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/MachineCombinerPattern.h b/llvm/include/llvm/CodeGen/MachineCombinerPattern.h
index 5be436b69a5b9..89eed7463bd78 100644
--- a/llvm/include/llvm/CodeGen/MachineCombinerPattern.h
+++ b/llvm/include/llvm/CodeGen/MachineCombinerPattern.h
@@ -178,6 +178,8 @@ enum class MachineCombinerPattern {
 
   // X86 VNNI
   DPWSSD,
+
+  FNMADD,
 };
 
 } // end namespace llvm

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 2bd0c1d782c2f..b34831d5807db 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -5409,6 +5409,39 @@ static bool getFMULPatterns(MachineInstr &Root,
   return Found;
 }
 
+static bool getFNEGPatterns(MachineInstr &Root,
+                            SmallVectorImpl<MachineCombinerPattern> &Patterns) {
+  unsigned Opc = Root.getOpcode();
+  MachineBasicBlock &MBB = *Root.getParent();
+  MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
+
+  auto Match = [&](unsigned Opcode, MachineCombinerPattern Pattern) -> bool {
+    MachineOperand &MO = Root.getOperand(1);
+    MachineInstr *MI = MRI.getUniqueVRegDef(MO.getReg());
+    if (MI != nullptr && MRI.hasOneNonDBGUse(MI->getOperand(0).getReg()) &&
+        (MI->getOpcode() == Opcode) &&
+        Root.getFlag(MachineInstr::MIFlag::FmContract) &&
+        Root.getFlag(MachineInstr::MIFlag::FmNsz) &&
+        MI->getFlag(MachineInstr::MIFlag::FmContract) &&
+        MI->getFlag(MachineInstr::MIFlag::FmNsz)) {
+      Patterns.push_back(Pattern);
+      return true;
+    }
+    return false;
+  };
+
+  switch (Opc) {
+  default:
+    break;
+  case AArch64::FNEGDr:
+    return Match(AArch64::FMADDDrrr, MachineCombinerPattern::FNMADD);
+  case AArch64::FNEGSr:
+    return Match(AArch64::FMADDSrrr, MachineCombinerPattern::FNMADD);
+  }
+
+  return false;
+}
+
 /// Return true when a code sequence can improve throughput. It
 /// should be called only for instructions in loops.
 /// \param Pattern - combiner pattern
@@ -5578,6 +5611,8 @@ bool AArch64InstrInfo::getMachineCombinerPatterns(
     return true;
   if (getFMAPatterns(Root, Patterns))
     return true;
+  if (getFNEGPatterns(Root, Patterns))
+    return true;
 
   // Other patterns
   if (getMiscPatterns(Root, Patterns))
@@ -5668,6 +5703,47 @@ genFusedMultiply(MachineFunction &MF, MachineRegisterInfo &MRI,
   return MUL;
 }
 
+static MachineInstr *
+genFNegatedMAD(MachineFunction &MF, MachineRegisterInfo &MRI,
+               const TargetInstrInfo *TII, MachineInstr &Root,
+               SmallVectorImpl<MachineInstr *> &InsInstrs) {
+  MachineInstr *MAD = MRI.getUniqueVRegDef(Root.getOperand(1).getReg());
+
+  unsigned Opc = 0;
+  const TargetRegisterClass *RC = MRI.getRegClass(MAD->getOperand(0).getReg());
+  if (AArch64::FPR32RegClass.hasSubClassEq(RC))
+    Opc = AArch64::FNMADDSrrr;
+  else if (AArch64::FPR64RegClass.hasSubClassEq(RC))
+    Opc = AArch64::FNMADDDrrr;
+  else
+    return nullptr;
+
+  Register ResultReg = Root.getOperand(0).getReg();
+  Register SrcReg0 = MAD->getOperand(1).getReg();
+  Register SrcReg1 = MAD->getOperand(2).getReg();
+  Register SrcReg2 = MAD->getOperand(3).getReg();
+  bool Src0IsKill = MAD->getOperand(1).isKill();
+  bool Src1IsKill = MAD->getOperand(2).isKill();
+  bool Src2IsKill = MAD->getOperand(3).isKill();
+  if (ResultReg.isVirtual())
+    MRI.constrainRegClass(ResultReg, RC);
+  if (SrcReg0.isVirtual())
+    MRI.constrainRegClass(SrcReg0, RC);
+  if (SrcReg1.isVirtual())
+    MRI.constrainRegClass(SrcReg1, RC);
+  if (SrcReg2.isVirtual())
+    MRI.constrainRegClass(SrcReg2, RC);
+
+  MachineInstrBuilder MIB =
+      BuildMI(MF, MIMetadata(Root), TII->get(Opc), ResultReg)
+          .addReg(SrcReg0, getKillRegState(Src0IsKill))
+          .addReg(SrcReg1, getKillRegState(Src1IsKill))
+          .addReg(SrcReg2, getKillRegState(Src2IsKill));
+  InsInstrs.push_back(MIB);
+
+  return MAD;
+}
+
 /// Fold (FMUL x (DUP y lane)) into (FMUL_indexed x y lane)
 static MachineInstr *
 genIndexedMultiply(MachineInstr &Root,
@@ -6800,6 +6876,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
                        &AArch64::FPR128_loRegClass, MRI);
     break;
   }
+  case MachineCombinerPattern::FNMADD: {
+    MUL = genFNegatedMAD(MF, MRI, TII, Root, InsInstrs);
+    break;
+  }
+
   } // end switch (Pattern)
   // Record MUL and ADD/SUB for deletion
   if (MUL)

diff  --git a/llvm/test/CodeGen/AArch64/aarch64_fnmadd.ll b/llvm/test/CodeGen/AArch64/aarch64_fnmadd.ll
new file mode 100644
index 0000000000000..b47736907e1e2
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/aarch64_fnmadd.ll
@@ -0,0 +1,153 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 2
+; RUN: llc < %s -mtriple=aarch64-linux-gnu -O3 -verify-machineinstrs | FileCheck %s
+
+define void @fnmaddd(ptr %a, ptr %b, ptr %c) {
+; CHECK-LABEL: fnmaddd:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ldr d0, [x1]
+; CHECK-NEXT:    ldr d1, [x0]
+; CHECK-NEXT:    ldr d2, [x2]
+; CHECK-NEXT:    fnmadd d0, d0, d1, d2
+; CHECK-NEXT:    str d0, [x0]
+; CHECK-NEXT:    ret
+entry:
+  %0 = load double, ptr %a, align 8
+  %1 = load double, ptr %b, align 8
+  %mul = fmul fast double %1, %0
+  %2 = load double, ptr %c, align 8
+  %add = fadd fast double %mul, %2
+  %fneg = fneg fast double %add
+  store double %fneg, ptr %a, align 8
+  ret void
+}
+
+; Don't combine: No flags
+define void @fnmaddd_no_fast(ptr %a, ptr %b, ptr %c) {
+; CHECK-LABEL: fnmaddd_no_fast:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ldr d0, [x0]
+; CHECK-NEXT:    ldr d1, [x1]
+; CHECK-NEXT:    fmul d0, d1, d0
+; CHECK-NEXT:    ldr d1, [x2]
+; CHECK-NEXT:    fadd d0, d0, d1
+; CHECK-NEXT:    fneg d0, d0
+; CHECK-NEXT:    str d0, [x0]
+; CHECK-NEXT:    ret
+entry:
+  %0 = load double, ptr %a, align 8
+  %1 = load double, ptr %b, align 8
+  %mul = fmul double %1, %0
+  %2 = load double, ptr %c, align 8
+  %add = fadd double %mul, %2
+  %fneg = fneg double %add
+  store double %fneg, ptr %a, align 8
+  ret void
+}
+
+define void @fnmadds(ptr %a, ptr %b, ptr %c) {
+; CHECK-LABEL: fnmadds:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ldr s0, [x1]
+; CHECK-NEXT:    ldr s1, [x0]
+; CHECK-NEXT:    ldr s2, [x2]
+; CHECK-NEXT:    fnmadd s0, s0, s1, s2
+; CHECK-NEXT:    str s0, [x0]
+; CHECK-NEXT:    ret
+entry:
+  %0 = load float, ptr %a, align 4
+  %1 = load float, ptr %b, align 4
+  %mul = fmul fast float %1, %0
+  %2 = load float, ptr %c, align 4
+  %add = fadd fast float %mul, %2
+  %fneg = fneg fast float %add
+  store float %fneg, ptr %a, align 4
+  ret void
+}
+
+define void @fnmadds_nsz_contract(ptr %a, ptr %b, ptr %c) {
+; CHECK-LABEL: fnmadds_nsz_contract:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ldr s0, [x1]
+; CHECK-NEXT:    ldr s1, [x0]
+; CHECK-NEXT:    ldr s2, [x2]
+; CHECK-NEXT:    fnmadd s0, s0, s1, s2
+; CHECK-NEXT:    str s0, [x0]
+; CHECK-NEXT:    ret
+entry:
+  %0 = load float, ptr %a, align 4
+  %1 = load float, ptr %b, align 4
+  %mul = fmul contract nsz float %1, %0
+  %2 = load float, ptr %c, align 4
+  %add = fadd contract nsz float %mul, %2
+  %fneg = fneg contract nsz float %add
+  store float %fneg, ptr %a, align 4
+  ret void
+}
+
+; Don't combine: Missing nsz
+define void @fnmadds_contract(ptr %a, ptr %b, ptr %c) {
+; CHECK-LABEL: fnmadds_contract:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ldr s0, [x1]
+; CHECK-NEXT:    ldr s1, [x0]
+; CHECK-NEXT:    ldr s2, [x2]
+; CHECK-NEXT:    fmadd s0, s0, s1, s2
+; CHECK-NEXT:    fneg s0, s0
+; CHECK-NEXT:    str s0, [x0]
+; CHECK-NEXT:    ret
+entry:
+  %0 = load float, ptr %a, align 4
+  %1 = load float, ptr %b, align 4
+  %mul = fmul contract float %1, %0
+  %2 = load float, ptr %c, align 4
+  %add = fadd contract float %mul, %2
+  %fneg = fneg contract float %add
+  store float %fneg, ptr %a, align 4
+  ret void
+}
+
+; Don't combine: Missing contract
+define void @fnmadds_nsz(ptr %a, ptr %b, ptr %c) {
+; CHECK-LABEL: fnmadds_nsz:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ldr s0, [x0]
+; CHECK-NEXT:    ldr s1, [x1]
+; CHECK-NEXT:    fmul s0, s1, s0
+; CHECK-NEXT:    ldr s1, [x2]
+; CHECK-NEXT:    fadd s0, s0, s1
+; CHECK-NEXT:    fneg s0, s0
+; CHECK-NEXT:    str s0, [x0]
+; CHECK-NEXT:    ret
+entry:
+  %0 = load float, ptr %a, align 4
+  %1 = load float, ptr %b, align 4
+  %mul = fmul nsz float %1, %0
+  %2 = load float, ptr %c, align 4
+  %add = fadd nsz float %mul, %2
+  %fneg = fneg nsz float %add
+  store float %fneg, ptr %a, align 4
+  ret void
+}
+
+define void @fnmaddd_two_uses(ptr %a, ptr %b, ptr %c, ptr %d) {
+; CHECK-LABEL: fnmaddd_two_uses:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ldr d0, [x1]
+; CHECK-NEXT:    ldr d1, [x0]
+; CHECK-NEXT:    ldr d2, [x2]
+; CHECK-NEXT:    fmadd d0, d0, d1, d2
+; CHECK-NEXT:    fneg d1, d0
+; CHECK-NEXT:    str d1, [x0]
+; CHECK-NEXT:    str d0, [x3]
+; CHECK-NEXT:    ret
+entry:
+  %0 = load double, ptr %a, align 8
+  %1 = load double, ptr %b, align 8
+  %mul = fmul fast double %1, %0
+  %2 = load double, ptr %c, align 8
+  %add = fadd fast double %mul, %2
+  %fneg1 = fneg fast double %add
+  store double %fneg1, ptr %a, align 8
+  store double %add, ptr %d, align 8
+  ret void
+}


        


More information about the llvm-commits mailing list