[llvm] r371321 - [aarch64] Add combine patterns for fp16 fmla

Sebastian Pop via llvm-commits llvm-commits at lists.llvm.org
Sat Sep 7 13:24:51 PDT 2019


Author: spop
Date: Sat Sep  7 13:24:51 2019
New Revision: 371321

URL: http://llvm.org/viewvc/llvm-project?rev=371321&view=rev
Log:
[aarch64] Add combine patterns for fp16 fmla

This patch enables generation of fused multiply add/sub for instructions operating on fp16.
Tested on aarch64-linux.

Differential Revision: https://reviews.llvm.org/D67297

Added:
    llvm/trunk/test/CodeGen/AArch64/fp16-fmla.ll
Modified:
    llvm/trunk/include/llvm/CodeGen/MachineCombinerPattern.h
    llvm/trunk/lib/Target/AArch64/AArch64InstrInfo.cpp

Modified: llvm/trunk/include/llvm/CodeGen/MachineCombinerPattern.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/CodeGen/MachineCombinerPattern.h?rev=371321&r1=371320&r2=371321&view=diff
==============================================================================
--- llvm/trunk/include/llvm/CodeGen/MachineCombinerPattern.h (original)
+++ llvm/trunk/include/llvm/CodeGen/MachineCombinerPattern.h Sat Sep  7 13:24:51 2019
@@ -39,6 +39,10 @@ enum class MachineCombinerPattern {
   MULADDXI_OP1,
   MULSUBXI_OP1,
   // Floating Point
+  FMULADDH_OP1,
+  FMULADDH_OP2,
+  FMULSUBH_OP1,
+  FMULSUBH_OP2,
   FMULADDS_OP1,
   FMULADDS_OP2,
   FMULSUBS_OP1,
@@ -47,16 +51,25 @@ enum class MachineCombinerPattern {
   FMULADDD_OP2,
   FMULSUBD_OP1,
   FMULSUBD_OP2,
+  FNMULSUBH_OP1,
   FNMULSUBS_OP1,
   FNMULSUBD_OP1,
   FMLAv1i32_indexed_OP1,
   FMLAv1i32_indexed_OP2,
   FMLAv1i64_indexed_OP1,
   FMLAv1i64_indexed_OP2,
+  FMLAv4f16_OP1,
+  FMLAv4f16_OP2,
+  FMLAv8f16_OP1,
+  FMLAv8f16_OP2,
   FMLAv2f32_OP2,
   FMLAv2f32_OP1,
   FMLAv2f64_OP1,
   FMLAv2f64_OP2,
+  FMLAv4i16_indexed_OP1,
+  FMLAv4i16_indexed_OP2,
+  FMLAv8i16_indexed_OP1,
+  FMLAv8i16_indexed_OP2,
   FMLAv2i32_indexed_OP1,
   FMLAv2i32_indexed_OP2,
   FMLAv2i64_indexed_OP1,
@@ -67,10 +80,16 @@ enum class MachineCombinerPattern {
   FMLAv4i32_indexed_OP2,
   FMLSv1i32_indexed_OP2,
   FMLSv1i64_indexed_OP2,
+  FMLSv4f16_OP2,
+  FMLSv8f16_OP1,
+  FMLSv8f16_OP2,
   FMLSv2f32_OP1,
   FMLSv2f32_OP2,
   FMLSv2f64_OP1,
   FMLSv2f64_OP2,
+  FMLSv4i16_indexed_OP2,
+  FMLSv8i16_indexed_OP1,
+  FMLSv8i16_indexed_OP2,
   FMLSv2i32_indexed_OP1,
   FMLSv2i32_indexed_OP2,
   FMLSv2i64_indexed_OP1,

Modified: llvm/trunk/lib/Target/AArch64/AArch64InstrInfo.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/AArch64/AArch64InstrInfo.cpp?rev=371321&r1=371320&r2=371321&view=diff
==============================================================================
--- llvm/trunk/lib/Target/AArch64/AArch64InstrInfo.cpp (original)
+++ llvm/trunk/lib/Target/AArch64/AArch64InstrInfo.cpp Sat Sep  7 13:24:51 2019
@@ -3466,13 +3466,19 @@ static bool isCombineInstrCandidateFP(co
   switch (Inst.getOpcode()) {
   default:
     break;
+  case AArch64::FADDHrr:
   case AArch64::FADDSrr:
   case AArch64::FADDDrr:
+  case AArch64::FADDv4f16:
+  case AArch64::FADDv8f16:
   case AArch64::FADDv2f32:
   case AArch64::FADDv2f64:
   case AArch64::FADDv4f32:
+  case AArch64::FSUBHrr:
   case AArch64::FSUBSrr:
   case AArch64::FSUBDrr:
+  case AArch64::FSUBv4f16:
+  case AArch64::FSUBv8f16:
   case AArch64::FSUBv2f32:
   case AArch64::FSUBv2f64:
   case AArch64::FSUBv4f32:
@@ -3682,9 +3688,21 @@ static bool getFMAPatterns(MachineInstr
   default:
     assert(false && "Unsupported FP instruction in combiner\n");
     break;
+  case AArch64::FADDHrr:
+    assert(Root.getOperand(1).isReg() && Root.getOperand(2).isReg() &&
+           "FADDHrr does not have register operands");
+    if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULHrr)) {
+      Patterns.push_back(MachineCombinerPattern::FMULADDH_OP1);
+      Found = true;
+    }
+    if (canCombineWithFMUL(MBB, Root.getOperand(2), AArch64::FMULHrr)) {
+      Patterns.push_back(MachineCombinerPattern::FMULADDH_OP2);
+      Found = true;
+    }
+    break;
   case AArch64::FADDSrr:
     assert(Root.getOperand(1).isReg() && Root.getOperand(2).isReg() &&
-           "FADDWrr does not have register operands");
+           "FADDSrr does not have register operands");
     if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULSrr)) {
       Patterns.push_back(MachineCombinerPattern::FMULADDS_OP1);
       Found = true;
@@ -3720,6 +3738,46 @@ static bool getFMAPatterns(MachineInstr
       Found = true;
     }
     break;
+  case AArch64::FADDv4f16:
+    if (canCombineWithFMUL(MBB, Root.getOperand(1),
+                           AArch64::FMULv4i16_indexed)) {
+      Patterns.push_back(MachineCombinerPattern::FMLAv4i16_indexed_OP1);
+      Found = true;
+    } else if (canCombineWithFMUL(MBB, Root.getOperand(1),
+                                  AArch64::FMULv4f16)) {
+      Patterns.push_back(MachineCombinerPattern::FMLAv4f16_OP1);
+      Found = true;
+    }
+    if (canCombineWithFMUL(MBB, Root.getOperand(2),
+                           AArch64::FMULv4i16_indexed)) {
+      Patterns.push_back(MachineCombinerPattern::FMLAv4i16_indexed_OP2);
+      Found = true;
+    } else if (canCombineWithFMUL(MBB, Root.getOperand(2),
+                                  AArch64::FMULv4f16)) {
+      Patterns.push_back(MachineCombinerPattern::FMLAv4f16_OP2);
+      Found = true;
+    }
+    break;
+  case AArch64::FADDv8f16:
+    if (canCombineWithFMUL(MBB, Root.getOperand(1),
+                           AArch64::FMULv8i16_indexed)) {
+      Patterns.push_back(MachineCombinerPattern::FMLAv8i16_indexed_OP1);
+      Found = true;
+    } else if (canCombineWithFMUL(MBB, Root.getOperand(1),
+                                  AArch64::FMULv8f16)) {
+      Patterns.push_back(MachineCombinerPattern::FMLAv8f16_OP1);
+      Found = true;
+    }
+    if (canCombineWithFMUL(MBB, Root.getOperand(2),
+                           AArch64::FMULv8i16_indexed)) {
+      Patterns.push_back(MachineCombinerPattern::FMLAv8i16_indexed_OP2);
+      Found = true;
+    } else if (canCombineWithFMUL(MBB, Root.getOperand(2),
+                                  AArch64::FMULv8f16)) {
+      Patterns.push_back(MachineCombinerPattern::FMLAv8f16_OP2);
+      Found = true;
+    }
+    break;
   case AArch64::FADDv2f32:
     if (canCombineWithFMUL(MBB, Root.getOperand(1),
                            AArch64::FMULv2i32_indexed)) {
@@ -3781,6 +3839,20 @@ static bool getFMAPatterns(MachineInstr
     }
     break;
 
+  case AArch64::FSUBHrr:
+    if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULHrr)) {
+      Patterns.push_back(MachineCombinerPattern::FMULSUBH_OP1);
+      Found = true;
+    }
+    if (canCombineWithFMUL(MBB, Root.getOperand(2), AArch64::FMULHrr)) {
+      Patterns.push_back(MachineCombinerPattern::FMULSUBH_OP2);
+      Found = true;
+    }
+    if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FNMULHrr)) {
+      Patterns.push_back(MachineCombinerPattern::FNMULSUBH_OP1);
+      Found = true;
+    }
+    break;
   case AArch64::FSUBSrr:
     if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULSrr)) {
       Patterns.push_back(MachineCombinerPattern::FMULSUBS_OP1);
@@ -3817,6 +3889,46 @@ static bool getFMAPatterns(MachineInstr
       Found = true;
     }
     break;
+  case AArch64::FSUBv4f16:
+    if (canCombineWithFMUL(MBB, Root.getOperand(2),
+                           AArch64::FMULv4i16_indexed)) {
+      Patterns.push_back(MachineCombinerPattern::FMLSv4i16_indexed_OP2);
+      Found = true;
+    } else if (canCombineWithFMUL(MBB, Root.getOperand(2),
+                                  AArch64::FMULv4f16)) {
+      Patterns.push_back(MachineCombinerPattern::FMLSv4f16_OP2);
+      Found = true;
+    }
+    if (canCombineWithFMUL(MBB, Root.getOperand(1),
+                           AArch64::FMULv4i16_indexed)) {
+      Patterns.push_back(MachineCombinerPattern::FMLSv2i32_indexed_OP1);
+      Found = true;
+    } else if (canCombineWithFMUL(MBB, Root.getOperand(1),
+                                  AArch64::FMULv4f16)) {
+      Patterns.push_back(MachineCombinerPattern::FMLSv2f32_OP1);
+      Found = true;
+    }
+    break;
+  case AArch64::FSUBv8f16:
+    if (canCombineWithFMUL(MBB, Root.getOperand(2),
+                           AArch64::FMULv8i16_indexed)) {
+      Patterns.push_back(MachineCombinerPattern::FMLSv8i16_indexed_OP2);
+      Found = true;
+    } else if (canCombineWithFMUL(MBB, Root.getOperand(2),
+                                  AArch64::FMULv8f16)) {
+      Patterns.push_back(MachineCombinerPattern::FMLSv8f16_OP2);
+      Found = true;
+    }
+    if (canCombineWithFMUL(MBB, Root.getOperand(1),
+                           AArch64::FMULv8i16_indexed)) {
+      Patterns.push_back(MachineCombinerPattern::FMLSv8i16_indexed_OP1);
+      Found = true;
+    } else if (canCombineWithFMUL(MBB, Root.getOperand(1),
+                                  AArch64::FMULv8f16)) {
+      Patterns.push_back(MachineCombinerPattern::FMLSv8f16_OP1);
+      Found = true;
+    }
+    break;
   case AArch64::FSUBv2f32:
     if (canCombineWithFMUL(MBB, Root.getOperand(2),
                            AArch64::FMULv2i32_indexed)) {
@@ -3889,6 +4001,10 @@ bool AArch64InstrInfo::isThroughputPatte
   switch (Pattern) {
   default:
     break;
+  case MachineCombinerPattern::FMULADDH_OP1:
+  case MachineCombinerPattern::FMULADDH_OP2:
+  case MachineCombinerPattern::FMULSUBH_OP1:
+  case MachineCombinerPattern::FMULSUBH_OP2:
   case MachineCombinerPattern::FMULADDS_OP1:
   case MachineCombinerPattern::FMULADDS_OP2:
   case MachineCombinerPattern::FMULSUBS_OP1:
@@ -3897,12 +4013,21 @@ bool AArch64InstrInfo::isThroughputPatte
   case MachineCombinerPattern::FMULADDD_OP2:
   case MachineCombinerPattern::FMULSUBD_OP1:
   case MachineCombinerPattern::FMULSUBD_OP2:
+  case MachineCombinerPattern::FNMULSUBH_OP1:
   case MachineCombinerPattern::FNMULSUBS_OP1:
   case MachineCombinerPattern::FNMULSUBD_OP1:
+  case MachineCombinerPattern::FMLAv4i16_indexed_OP1:
+  case MachineCombinerPattern::FMLAv4i16_indexed_OP2:
+  case MachineCombinerPattern::FMLAv8i16_indexed_OP1:
+  case MachineCombinerPattern::FMLAv8i16_indexed_OP2:
   case MachineCombinerPattern::FMLAv1i32_indexed_OP1:
   case MachineCombinerPattern::FMLAv1i32_indexed_OP2:
   case MachineCombinerPattern::FMLAv1i64_indexed_OP1:
   case MachineCombinerPattern::FMLAv1i64_indexed_OP2:
+  case MachineCombinerPattern::FMLAv4f16_OP2:
+  case MachineCombinerPattern::FMLAv4f16_OP1:
+  case MachineCombinerPattern::FMLAv8f16_OP1:
+  case MachineCombinerPattern::FMLAv8f16_OP2:
   case MachineCombinerPattern::FMLAv2f32_OP2:
   case MachineCombinerPattern::FMLAv2f32_OP1:
   case MachineCombinerPattern::FMLAv2f64_OP1:
@@ -3915,10 +4040,16 @@ bool AArch64InstrInfo::isThroughputPatte
   case MachineCombinerPattern::FMLAv4f32_OP2:
   case MachineCombinerPattern::FMLAv4i32_indexed_OP1:
   case MachineCombinerPattern::FMLAv4i32_indexed_OP2:
+  case MachineCombinerPattern::FMLSv4i16_indexed_OP2:
+  case MachineCombinerPattern::FMLSv8i16_indexed_OP1:
+  case MachineCombinerPattern::FMLSv8i16_indexed_OP2:
   case MachineCombinerPattern::FMLSv1i32_indexed_OP2:
   case MachineCombinerPattern::FMLSv1i64_indexed_OP2:
   case MachineCombinerPattern::FMLSv2i32_indexed_OP2:
   case MachineCombinerPattern::FMLSv2i64_indexed_OP2:
+  case MachineCombinerPattern::FMLSv4f16_OP2:
+  case MachineCombinerPattern::FMLSv8f16_OP1:
+  case MachineCombinerPattern::FMLSv8f16_OP2:
   case MachineCombinerPattern::FMLSv2f32_OP2:
   case MachineCombinerPattern::FMLSv2f64_OP2:
   case MachineCombinerPattern::FMLSv4i32_indexed_OP2:
@@ -4266,34 +4397,35 @@ void AArch64InstrInfo::genAlternativeCod
     break;
   }
   // Floating Point Support
+  case MachineCombinerPattern::FMULADDH_OP1:
+    Opc = AArch64::FMADDHrrr;
+    RC = &AArch64::FPR16RegClass;
+    MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
+    break;
   case MachineCombinerPattern::FMULADDS_OP1:
+    Opc = AArch64::FMADDSrrr;
+    RC = &AArch64::FPR32RegClass;
+    MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
+    break;
   case MachineCombinerPattern::FMULADDD_OP1:
-    // MUL I=A,B,0
-    // ADD R,I,C
-    // ==> MADD R,A,B,C
-    // --- Create(MADD);
-    if (Pattern == MachineCombinerPattern::FMULADDS_OP1) {
-      Opc = AArch64::FMADDSrrr;
-      RC = &AArch64::FPR32RegClass;
-    } else {
-      Opc = AArch64::FMADDDrrr;
-      RC = &AArch64::FPR64RegClass;
-    }
+    Opc = AArch64::FMADDDrrr;
+    RC = &AArch64::FPR64RegClass;
     MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
     break;
+
+  case MachineCombinerPattern::FMULADDH_OP2:
+    Opc = AArch64::FMADDHrrr;
+    RC = &AArch64::FPR16RegClass;
+    MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
+    break;
   case MachineCombinerPattern::FMULADDS_OP2:
+    Opc = AArch64::FMADDSrrr;
+    RC = &AArch64::FPR32RegClass;
+    MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
+    break;
   case MachineCombinerPattern::FMULADDD_OP2:
-    // FMUL I=A,B,0
-    // FADD R,C,I
-    // ==> FMADD R,A,B,C
-    // --- Create(FMADD);
-    if (Pattern == MachineCombinerPattern::FMULADDS_OP2) {
-      Opc = AArch64::FMADDSrrr;
-      RC = &AArch64::FPR32RegClass;
-    } else {
-      Opc = AArch64::FMADDDrrr;
-      RC = &AArch64::FPR64RegClass;
-    }
+    Opc = AArch64::FMADDDrrr;
+    RC = &AArch64::FPR64RegClass;
     MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
     break;
 
@@ -4323,6 +4455,31 @@ void AArch64InstrInfo::genAlternativeCod
                            FMAInstKind::Indexed);
     break;
 
+  case MachineCombinerPattern::FMLAv4i16_indexed_OP1:
+    RC = &AArch64::FPR64RegClass;
+    Opc = AArch64::FMLAv4i16_indexed;
+    MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC,
+                           FMAInstKind::Indexed);
+    break;
+  case MachineCombinerPattern::FMLAv4f16_OP1:
+    RC = &AArch64::FPR64RegClass;
+    Opc = AArch64::FMLAv4f16;
+    MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC,
+                           FMAInstKind::Accumulator);
+    break;
+  case MachineCombinerPattern::FMLAv4i16_indexed_OP2:
+    RC = &AArch64::FPR64RegClass;
+    Opc = AArch64::FMLAv4i16_indexed;
+    MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC,
+                           FMAInstKind::Indexed);
+    break;
+  case MachineCombinerPattern::FMLAv4f16_OP2:
+    RC = &AArch64::FPR64RegClass;
+    Opc = AArch64::FMLAv4f16;
+    MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC,
+                           FMAInstKind::Accumulator);
+    break;
+
   case MachineCombinerPattern::FMLAv2i32_indexed_OP1:
   case MachineCombinerPattern::FMLAv2f32_OP1:
     RC = &AArch64::FPR64RegClass;
@@ -4350,6 +4507,31 @@ void AArch64InstrInfo::genAlternativeCod
     }
     break;
 
+  case MachineCombinerPattern::FMLAv8i16_indexed_OP1:
+    RC = &AArch64::FPR128RegClass;
+    Opc = AArch64::FMLAv8i16_indexed;
+    MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC,
+                           FMAInstKind::Indexed);
+    break;
+  case MachineCombinerPattern::FMLAv8f16_OP1:
+    RC = &AArch64::FPR128RegClass;
+    Opc = AArch64::FMLAv8f16;
+    MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC,
+                           FMAInstKind::Accumulator);
+    break;
+  case MachineCombinerPattern::FMLAv8i16_indexed_OP2:
+    RC = &AArch64::FPR128RegClass;
+    Opc = AArch64::FMLAv8i16_indexed;
+    MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC,
+                           FMAInstKind::Indexed);
+    break;
+  case MachineCombinerPattern::FMLAv8f16_OP2:
+    RC = &AArch64::FPR128RegClass;
+    Opc = AArch64::FMLAv8f16;
+    MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC,
+                           FMAInstKind::Accumulator);
+    break;
+
   case MachineCombinerPattern::FMLAv2i64_indexed_OP1:
   case MachineCombinerPattern::FMLAv2f64_OP1:
     RC = &AArch64::FPR128RegClass;
@@ -4405,56 +4587,53 @@ void AArch64InstrInfo::genAlternativeCod
     }
     break;
 
+  case MachineCombinerPattern::FMULSUBH_OP1:
+    Opc = AArch64::FNMSUBHrrr;
+    RC = &AArch64::FPR16RegClass;
+    MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
+    break;
   case MachineCombinerPattern::FMULSUBS_OP1:
-  case MachineCombinerPattern::FMULSUBD_OP1: {
-    // FMUL I=A,B,0
-    // FSUB R,I,C
-    // ==> FNMSUB R,A,B,C // = -C + A*B
-    // --- Create(FNMSUB);
-    if (Pattern == MachineCombinerPattern::FMULSUBS_OP1) {
-      Opc = AArch64::FNMSUBSrrr;
-      RC = &AArch64::FPR32RegClass;
-    } else {
-      Opc = AArch64::FNMSUBDrrr;
-      RC = &AArch64::FPR64RegClass;
-    }
+    Opc = AArch64::FNMSUBSrrr;
+    RC = &AArch64::FPR32RegClass;
+    MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
+    break;
+  case MachineCombinerPattern::FMULSUBD_OP1:
+    Opc = AArch64::FNMSUBDrrr;
+    RC = &AArch64::FPR64RegClass;
     MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
     break;
-  }
 
+  case MachineCombinerPattern::FNMULSUBH_OP1:
+    Opc = AArch64::FNMADDHrrr;
+    RC = &AArch64::FPR16RegClass;
+    MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
+    break;
   case MachineCombinerPattern::FNMULSUBS_OP1:
-  case MachineCombinerPattern::FNMULSUBD_OP1: {
-    // FNMUL I=A,B,0
-    // FSUB R,I,C
-    // ==> FNMADD R,A,B,C // = -A*B - C
-    // --- Create(FNMADD);
-    if (Pattern == MachineCombinerPattern::FNMULSUBS_OP1) {
-      Opc = AArch64::FNMADDSrrr;
-      RC = &AArch64::FPR32RegClass;
-    } else {
-      Opc = AArch64::FNMADDDrrr;
-      RC = &AArch64::FPR64RegClass;
-    }
+    Opc = AArch64::FNMADDSrrr;
+    RC = &AArch64::FPR32RegClass;
+    MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
+    break;
+  case MachineCombinerPattern::FNMULSUBD_OP1:
+    Opc = AArch64::FNMADDDrrr;
+    RC = &AArch64::FPR64RegClass;
     MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
     break;
-  }
 
+  case MachineCombinerPattern::FMULSUBH_OP2:
+    Opc = AArch64::FMSUBHrrr;
+    RC = &AArch64::FPR16RegClass;
+    MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
+    break;
   case MachineCombinerPattern::FMULSUBS_OP2:
-  case MachineCombinerPattern::FMULSUBD_OP2: {
-    // FMUL I=A,B,0
-    // FSUB R,C,I
-    // ==> FMSUB R,A,B,C (computes C - A*B)
-    // --- Create(FMSUB);
-    if (Pattern == MachineCombinerPattern::FMULSUBS_OP2) {
-      Opc = AArch64::FMSUBSrrr;
-      RC = &AArch64::FPR32RegClass;
-    } else {
-      Opc = AArch64::FMSUBDrrr;
-      RC = &AArch64::FPR64RegClass;
-    }
+    Opc = AArch64::FMSUBSrrr;
+    RC = &AArch64::FPR32RegClass;
+    MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
+    break;
+  case MachineCombinerPattern::FMULSUBD_OP2:
+    Opc = AArch64::FMSUBDrrr;
+    RC = &AArch64::FPR64RegClass;
     MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
     break;
-  }
 
   case MachineCombinerPattern::FMLSv1i32_indexed_OP2:
     Opc = AArch64::FMLSv1i32_indexed;
@@ -4470,6 +4649,19 @@ void AArch64InstrInfo::genAlternativeCod
                            FMAInstKind::Indexed);
     break;
 
+  case MachineCombinerPattern::FMLSv4f16_OP2:
+    RC = &AArch64::FPR64RegClass;
+    Opc = AArch64::FMLSv4f16;
+    MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC,
+                           FMAInstKind::Accumulator);
+    break;
+  case MachineCombinerPattern::FMLSv4i16_indexed_OP2:
+    RC = &AArch64::FPR64RegClass;
+    Opc = AArch64::FMLSv4i16_indexed;
+    MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC,
+                           FMAInstKind::Indexed);
+    break;
+
   case MachineCombinerPattern::FMLSv2f32_OP2:
   case MachineCombinerPattern::FMLSv2i32_indexed_OP2:
     RC = &AArch64::FPR64RegClass;
@@ -4484,6 +4676,32 @@ void AArch64InstrInfo::genAlternativeCod
     }
     break;
 
+  case MachineCombinerPattern::FMLSv8f16_OP1:
+    RC = &AArch64::FPR128RegClass;
+    Opc = AArch64::FMLSv8f16;
+    MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC,
+                           FMAInstKind::Accumulator);
+    break;
+  case MachineCombinerPattern::FMLSv8i16_indexed_OP1:
+    RC = &AArch64::FPR128RegClass;
+    Opc = AArch64::FMLSv8i16_indexed;
+    MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC,
+                           FMAInstKind::Indexed);
+    break;
+
+  case MachineCombinerPattern::FMLSv8f16_OP2:
+    RC = &AArch64::FPR128RegClass;
+    Opc = AArch64::FMLSv8f16;
+    MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC,
+                           FMAInstKind::Accumulator);
+    break;
+  case MachineCombinerPattern::FMLSv8i16_indexed_OP2:
+    RC = &AArch64::FPR128RegClass;
+    Opc = AArch64::FMLSv8i16_indexed;
+    MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC,
+                           FMAInstKind::Indexed);
+    break;
+
   case MachineCombinerPattern::FMLSv2f64_OP2:
   case MachineCombinerPattern::FMLSv2i64_indexed_OP2:
     RC = &AArch64::FPR128RegClass;

Added: llvm/trunk/test/CodeGen/AArch64/fp16-fmla.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/AArch64/fp16-fmla.ll?rev=371321&view=auto
==============================================================================
--- llvm/trunk/test/CodeGen/AArch64/fp16-fmla.ll (added)
+++ llvm/trunk/test/CodeGen/AArch64/fp16-fmla.ll Sat Sep  7 13:24:51 2019
@@ -0,0 +1,208 @@
+; RUN: llc < %s -mtriple=aarch64-none-linux-gnu -mattr=+v8.2a,+fullfp16 -fp-contract=fast  | FileCheck %s
+
+define half @test_FMULADDH_OP1(half %a, half %b, half %c) {
+; CHECK-LABEL: test_FMULADDH_OP1:
+; CHECK: fmadd    {{h[0-9]+}}, {{h[0-9]+}}, {{h[0-9]+}}
+entry:
+  %mul = fmul fast half %c, %b
+  %add = fadd fast half %mul, %a
+  ret half %add
+}
+
+define half @test_FMULADDH_OP2(half %a, half %b, half %c) {
+; CHECK-LABEL: test_FMULADDH_OP2:
+; CHECK: fmadd    {{h[0-9]+}}, {{h[0-9]+}}, {{h[0-9]+}}
+entry:
+  %mul = fmul fast half %c, %b
+  %add = fadd fast half %a, %mul
+  ret half %add
+}
+
+define half @test_FMULSUBH_OP1(half %a, half %b, half %c) {
+; CHECK-LABEL: test_FMULSUBH_OP1:
+; CHECK: fnmsub    {{h[0-9]+}}, {{h[0-9]+}}, {{h[0-9]+}}
+entry:
+  %mul = fmul fast half %c, %b
+  %sub = fsub fast half %mul, %a
+  ret half %sub
+}
+
+define half @test_FMULSUBH_OP2(half %a, half %b, half %c) {
+; CHECK-LABEL: test_FMULSUBH_OP2:
+; CHECK: fmsub    {{h[0-9]+}}, {{h[0-9]+}}, {{h[0-9]+}}
+entry:
+  %mul = fmul fast half %c, %b
+  %add = fsub fast half %a, %mul
+  ret half %add
+}
+
+define half @test_FNMULSUBH_OP1(half %a, half %b, half %c) {
+; CHECK-LABEL: test_FNMULSUBH_OP1:
+; CHECK: fnmadd    {{h[0-9]+}}, {{h[0-9]+}}, {{h[0-9]+}}
+entry:
+  %mul = fmul fast half %c, %b
+  %neg = fsub fast half -0.0, %mul
+  %add = fsub fast half %neg, %a
+  ret half %add
+}
+
+define <4 x half> @test_FMLAv4f16_OP1(<4 x half> %a, <4 x half> %b, <4 x half> %c) {
+; CHECK-LABEL: test_FMLAv4f16_OP1:
+; CHECK: fmla    {{v[0-9]+}}.4h, {{v[0-9]+}}.4h, {{v[0-9]+}}.4h
+entry:
+  %mul = fmul fast <4 x half> %c, %b
+  %add = fadd fast <4 x half> %mul, %a
+  ret <4 x half> %add
+}
+
+define <4 x half> @test_FMLAv4f16_OP2(<4 x half> %a, <4 x half> %b, <4 x half> %c) {
+; CHECK-LABEL: test_FMLAv4f16_OP2:
+; CHECK: fmla    {{v[0-9]+}}.4h, {{v[0-9]+}}.4h, {{v[0-9]+}}.4h
+entry:
+  %mul = fmul fast <4 x half> %c, %b
+  %add = fadd fast <4 x half> %a, %mul
+  ret <4 x half> %add
+}
+
+define <8 x half> @test_FMLAv8f16_OP1(<8 x half> %a, <8 x half> %b, <8 x half> %c) {
+; CHECK-LABEL: test_FMLAv8f16_OP1:
+; CHECK: fmla    {{v[0-9]+}}.8h, {{v[0-9]+}}.8h, {{v[0-9]+}}.8h
+entry:
+  %mul = fmul fast <8 x half> %c, %b
+  %add = fadd fast <8 x half> %mul, %a
+  ret <8 x half> %add
+}
+
+define <8 x half> @test_FMLAv8f16_OP2(<8 x half> %a, <8 x half> %b, <8 x half> %c) {
+; CHECK-LABEL: test_FMLAv8f16_OP2:
+; CHECK: fmla    {{v[0-9]+}}.8h, {{v[0-9]+}}.8h, {{v[0-9]+}}.8h
+entry:
+  %mul = fmul fast <8 x half> %c, %b
+  %add = fadd fast <8 x half> %a, %mul
+  ret <8 x half> %add
+}
+
+define <4 x half> @test_FMLAv4i16_indexed_OP1(<4 x half> %a, <4 x i16> %b, <4 x i16> %c) {
+; CHECK-LABEL: test_FMLAv4i16_indexed_OP1:
+; CHECK-FIXME: Currently LLVM produces inefficient code:
+; CHECK: mul
+; CHECK: fadd
+; CHECK-FIXME: It should instead produce the following instruction:
+; CHECK-FIXME: fmla    {{v[0-9]+}}.4h, {{v[0-9]+}}.4h, {{v[0-9]+}}.4h
+entry:
+  %mul = mul <4 x i16> %c, %b
+  %m = bitcast <4 x i16> %mul to <4 x half>
+  %add = fadd fast <4 x half> %m, %a
+  ret <4 x half> %add
+}
+
+define <4 x half> @test_FMLAv4i16_indexed_OP2(<4 x half> %a, <4 x i16> %b, <4 x i16> %c) {
+; CHECK-LABEL: test_FMLAv4i16_indexed_OP2:
+; CHECK-FIXME: Currently LLVM produces inefficient code:
+; CHECK: mul
+; CHECK: fadd
+; CHECK-FIXME: It should instead produce the following instruction:
+; CHECK-FIXME: fmla    {{v[0-9]+}}.4h, {{v[0-9]+}}.4h, {{v[0-9]+}}.4h
+entry:
+  %mul = mul <4 x i16> %c, %b
+  %m = bitcast <4 x i16> %mul to <4 x half>
+  %add = fadd fast <4 x half> %a, %m
+  ret <4 x half> %add
+}
+
+define <8 x half> @test_FMLAv8i16_indexed_OP1(<8 x half> %a, <8 x i16> %b, <8 x i16> %c) {
+; CHECK-LABEL: test_FMLAv8i16_indexed_OP1:
+; CHECK-FIXME: Currently LLVM produces inefficient code:
+; CHECK: mul
+; CHECK: fadd
+; CHECK-FIXME: It should instead produce the following instruction:
+; CHECK-FIXME: fmla    {{v[0-9]+}}.8h, {{v[0-9]+}}.8h, {{v[0-9]+}}.8h
+entry:
+  %mul = mul <8 x i16> %c, %b
+  %m = bitcast <8 x i16> %mul to <8 x half>
+  %add = fadd fast <8 x half> %m, %a
+  ret <8 x half> %add
+}
+
+define <8 x half> @test_FMLAv8i16_indexed_OP2(<8 x half> %a, <8 x i16> %b, <8 x i16> %c) {
+; CHECK-LABEL: test_FMLAv8i16_indexed_OP2:
+; CHECK-FIXME: Currently LLVM produces inefficient code:
+; CHECK: mul
+; CHECK: fadd
+; CHECK-FIXME: It should instead produce the following instruction:
+; CHECK-FIXME: fmla    {{v[0-9]+}}.8h, {{v[0-9]+}}.8h, {{v[0-9]+}}.8h
+entry:
+  %mul = mul <8 x i16> %c, %b
+  %m = bitcast <8 x i16> %mul to <8 x half>
+  %add = fadd fast <8 x half> %a, %m
+  ret <8 x half> %add
+}
+
+define <4 x half> @test_FMLSv4f16_OP2(<4 x half> %a, <4 x half> %b, <4 x half> %c) {
+; CHECK-LABEL: test_FMLSv4f16_OP2:
+; CHECK: fmls    {{v[0-9]+}}.4h, {{v[0-9]+}}.4h, {{v[0-9]+}}.4h
+entry:
+  %mul = fmul fast <4 x half> %c, %b
+  %sub = fsub fast <4 x half> %a, %mul
+  ret <4 x half> %sub
+}
+
+define <8 x half> @test_FMLSv8f16_OP1(<8 x half> %a, <8 x half> %b, <8 x half> %c) {
+; CHECK-LABEL: test_FMLSv8f16_OP1:
+; CHECK: fmls    {{v[0-9]+}}.8h, {{v[0-9]+}}.8h, {{v[0-9]+}}.8h
+entry:
+  %mul = fmul fast <8 x half> %c, %b
+  %sub = fsub fast <8 x half> %mul, %a
+  ret <8 x half> %sub
+}
+
+define <8 x half> @test_FMLSv8f16_OP2(<8 x half> %a, <8 x half> %b, <8 x half> %c) {
+; CHECK-LABEL: test_FMLSv8f16_OP2:
+; CHECK: fmls    {{v[0-9]+}}.8h, {{v[0-9]+}}.8h, {{v[0-9]+}}.8h
+entry:
+  %mul = fmul fast <8 x half> %c, %b
+  %sub = fsub fast <8 x half> %a, %mul
+  ret <8 x half> %sub
+}
+
+define <4 x half> @test_FMLSv4i16_indexed_OP2(<4 x half> %a, <4 x i16> %b, <4 x i16> %c) {
+; CHECK-LABEL: test_FMLSv4i16_indexed_OP2:
+; CHECK-FIXME: Currently LLVM produces inefficient code:
+; CHECK: mul
+; CHECK: fsub
+; CHECK-FIXME: It should instead produce the following instruction:
+; CHECK-FIXME: fmls    {{v[0-9]+}}.4h, {{v[0-9]+}}.4h, {{v[0-9]+}}.4h
+entry:
+  %mul = mul <4 x i16> %c, %b
+  %m = bitcast <4 x i16> %mul to <4 x half>
+  %sub = fsub fast <4 x half> %a, %m
+  ret <4 x half> %sub
+}
+
+define <8 x half> @test_FMLSv8i16_indexed_OP1(<8 x half> %a, <8 x i16> %b, <8 x i16> %c) {
+; CHECK-LABEL: test_FMLSv8i16_indexed_OP1:
+; CHECK-FIXME: Currently LLVM produces inefficient code:
+; CHECK: mul
+; CHECK: fsub
+; CHECK-FIXME: It should instead produce the following instruction:
+; CHECK-FIXME: fmls    {{v[0-9]+}}.8h, {{v[0-9]+}}.8h, {{v[0-9]+}}.8h
+entry:
+  %mul = mul <8 x i16> %c, %b
+  %m = bitcast <8 x i16> %mul to <8 x half>
+  %sub = fsub fast <8 x half> %m, %a
+  ret <8 x half> %sub
+}
+
+define <8 x half> @test_FMLSv8i16_indexed_OP2(<8 x half> %a, <8 x i16> %b, <8 x i16> %c) {
+; CHECK-LABEL: test_FMLSv8i16_indexed_OP2:
+; CHECK-FIXME: Currently LLVM produces inefficient code:
+; CHECK: mul
+; CHECK: fsub
+; CHECK-FIXME: It should instead produce the following instruction:
+; CHECK-FIXME: fmls    {{v[0-9]+}}.8h, {{v[0-9]+}}.8h, {{v[0-9]+}}.8h
+entry:
+  %mul = mul <8 x i16> %c, %b
+  %m = bitcast <8 x i16> %mul to <8 x half>
+  %sub = fsub fast <8 x half> %a, %m
+  ret <8 x half> %sub
+}




More information about the llvm-commits mailing list