[llvm] e503fee - [AArch64] Fix MUL/SUB fusing

Sanne Wouda via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 5 10:10:45 PST 2019


Author: Sanne Wouda
Date: 2019-12-05T18:10:06Z
New Revision: e503fee904d8c17c089d27ab928bc72eeeece649

URL: https://github.com/llvm/llvm-project/commit/e503fee904d8c17c089d27ab928bc72eeeece649
DIFF: https://github.com/llvm/llvm-project/commit/e503fee904d8c17c089d27ab928bc72eeeece649.diff

LOG: [AArch64] Fix MUL/SUB fusing

Summary:
When MUL is the first operand to SUB, we can't use MLS because the accumulator
should be negated.  Emit a NEG of the accumulator and an MLA instead, similar to
what we do for FMUL / FSUB fusing.

Reviewers: dmgreen, SjoerdMeijer, fhahn, Gerolf, mstorsjo, asbirlea

Reviewed By: asbirlea

Subscribers: kristof.beyls, hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D71067

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
    llvm/test/CodeGen/AArch64/neon-mla-mls.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 714007f8aba8..15d908d0797e 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -4198,6 +4198,40 @@ static MachineInstr *genFusedMultiplyAcc(
                           FMAInstKind::Accumulator);
 }
 
+/// genNeg - Helper to generate an intermediate negation of the second operand
+/// of Root
+static Register genNeg(MachineFunction &MF, MachineRegisterInfo &MRI,
+                       const TargetInstrInfo *TII, MachineInstr &Root,
+                       SmallVectorImpl<MachineInstr *> &InsInstrs,
+                       DenseMap<unsigned, unsigned> &InstrIdxForVirtReg,
+                       unsigned MnegOpc, const TargetRegisterClass *RC) {
+  Register NewVR = MRI.createVirtualRegister(RC);
+  MachineInstrBuilder MIB =
+      BuildMI(MF, Root.getDebugLoc(), TII->get(MnegOpc), NewVR)
+          .add(Root.getOperand(2));
+  InsInstrs.push_back(MIB);
+
+  assert(InstrIdxForVirtReg.empty());
+  InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0));
+
+  return NewVR;
+}
+
+/// genFusedMultiplyAccNeg - Helper to generate fused multiply accumulate
+/// instructions with an additional negation of the accumulator
+static MachineInstr *genFusedMultiplyAccNeg(
+    MachineFunction &MF, MachineRegisterInfo &MRI, const TargetInstrInfo *TII,
+    MachineInstr &Root, SmallVectorImpl<MachineInstr *> &InsInstrs,
+    DenseMap<unsigned, unsigned> &InstrIdxForVirtReg, unsigned IdxMulOpd,
+    unsigned MaddOpc, unsigned MnegOpc, const TargetRegisterClass *RC) {
+  assert(IdxMulOpd == 1);
+
+  Register NewVR =
+      genNeg(MF, MRI, TII, Root, InsInstrs, InstrIdxForVirtReg, MnegOpc, RC);
+  return genFusedMultiply(MF, MRI, TII, Root, InsInstrs, IdxMulOpd, MaddOpc, RC,
+                          FMAInstKind::Accumulator, &NewVR);
+}
+
 /// genFusedMultiplyIdx - Helper to generate fused multiply accumulate
 /// instructions.
 ///
@@ -4210,6 +4244,22 @@ static MachineInstr *genFusedMultiplyIdx(
                           FMAInstKind::Indexed);
 }
 
+/// genFusedMultiplyAccNeg - Helper to generate fused multiply accumulate
+/// instructions with an additional negation of the accumulator
+static MachineInstr *genFusedMultiplyIdxNeg(
+    MachineFunction &MF, MachineRegisterInfo &MRI, const TargetInstrInfo *TII,
+    MachineInstr &Root, SmallVectorImpl<MachineInstr *> &InsInstrs,
+    DenseMap<unsigned, unsigned> &InstrIdxForVirtReg, unsigned IdxMulOpd,
+    unsigned MaddOpc, unsigned MnegOpc, const TargetRegisterClass *RC) {
+  assert(IdxMulOpd == 1);
+
+  Register NewVR =
+      genNeg(MF, MRI, TII, Root, InsInstrs, InstrIdxForVirtReg, MnegOpc, RC);
+
+  return genFusedMultiply(MF, MRI, TII, Root, InsInstrs, IdxMulOpd, MaddOpc, RC,
+                          FMAInstKind::Indexed, &NewVR);
+}
+
 /// genMaddR - Generate madd instruction and combine mul and add using
 /// an extra virtual register
 /// Example - an ADD intermediate needs to be stored in a register:
@@ -4512,9 +4562,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
     break;
 
   case MachineCombinerPattern::MULSUBv8i8_OP1:
-    Opc = AArch64::MLSv8i8;
+    Opc = AArch64::MLAv8i8;
     RC = &AArch64::FPR64RegClass;
-    MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
+    MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs,
+                                 InstrIdxForVirtReg, 1, Opc, AArch64::NEGv8i8,
+                                 RC);
     break;
   case MachineCombinerPattern::MULSUBv8i8_OP2:
     Opc = AArch64::MLSv8i8;
@@ -4522,9 +4574,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
     MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
     break;
   case MachineCombinerPattern::MULSUBv16i8_OP1:
-    Opc = AArch64::MLSv16i8;
+    Opc = AArch64::MLAv16i8;
     RC = &AArch64::FPR128RegClass;
-    MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
+    MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs,
+                                 InstrIdxForVirtReg, 1, Opc, AArch64::NEGv16i8,
+                                 RC);
     break;
   case MachineCombinerPattern::MULSUBv16i8_OP2:
     Opc = AArch64::MLSv16i8;
@@ -4532,9 +4586,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
     MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
     break;
   case MachineCombinerPattern::MULSUBv4i16_OP1:
-    Opc = AArch64::MLSv4i16;
+    Opc = AArch64::MLAv4i16;
     RC = &AArch64::FPR64RegClass;
-    MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
+    MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs,
+                                 InstrIdxForVirtReg, 1, Opc, AArch64::NEGv4i16,
+                                 RC);
     break;
   case MachineCombinerPattern::MULSUBv4i16_OP2:
     Opc = AArch64::MLSv4i16;
@@ -4542,9 +4598,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
     MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
     break;
   case MachineCombinerPattern::MULSUBv8i16_OP1:
-    Opc = AArch64::MLSv8i16;
+    Opc = AArch64::MLAv8i16;
     RC = &AArch64::FPR128RegClass;
-    MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
+    MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs,
+                                 InstrIdxForVirtReg, 1, Opc, AArch64::NEGv8i16,
+                                 RC);
     break;
   case MachineCombinerPattern::MULSUBv8i16_OP2:
     Opc = AArch64::MLSv8i16;
@@ -4552,9 +4610,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
     MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
     break;
   case MachineCombinerPattern::MULSUBv2i32_OP1:
-    Opc = AArch64::MLSv2i32;
+    Opc = AArch64::MLAv2i32;
     RC = &AArch64::FPR64RegClass;
-    MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
+    MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs,
+                                 InstrIdxForVirtReg, 1, Opc, AArch64::NEGv2i32,
+                                 RC);
     break;
   case MachineCombinerPattern::MULSUBv2i32_OP2:
     Opc = AArch64::MLSv2i32;
@@ -4562,9 +4622,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
     MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
     break;
   case MachineCombinerPattern::MULSUBv4i32_OP1:
-    Opc = AArch64::MLSv4i32;
+    Opc = AArch64::MLAv4i32;
     RC = &AArch64::FPR128RegClass;
-    MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
+    MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs,
+                                 InstrIdxForVirtReg, 1, Opc, AArch64::NEGv4i32,
+                                 RC);
     break;
   case MachineCombinerPattern::MULSUBv4i32_OP2:
     Opc = AArch64::MLSv4i32;
@@ -4614,9 +4676,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
     break;
 
   case MachineCombinerPattern::MULSUBv4i16_indexed_OP1:
-    Opc = AArch64::MLSv4i16_indexed;
+    Opc = AArch64::MLAv4i16_indexed;
     RC = &AArch64::FPR64RegClass;
-    MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
+    MUL = genFusedMultiplyIdxNeg(MF, MRI, TII, Root, InsInstrs,
+                                 InstrIdxForVirtReg, 1, Opc, AArch64::NEGv4i16,
+                                 RC);
     break;
   case MachineCombinerPattern::MULSUBv4i16_indexed_OP2:
     Opc = AArch64::MLSv4i16_indexed;
@@ -4624,9 +4688,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
     MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
     break;
   case MachineCombinerPattern::MULSUBv8i16_indexed_OP1:
-    Opc = AArch64::MLSv8i16_indexed;
+    Opc = AArch64::MLAv8i16_indexed;
     RC = &AArch64::FPR128RegClass;
-    MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
+    MUL = genFusedMultiplyIdxNeg(MF, MRI, TII, Root, InsInstrs,
+                                 InstrIdxForVirtReg, 1, Opc, AArch64::NEGv8i16,
+                                 RC);
     break;
   case MachineCombinerPattern::MULSUBv8i16_indexed_OP2:
     Opc = AArch64::MLSv8i16_indexed;
@@ -4634,9 +4700,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
     MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
     break;
   case MachineCombinerPattern::MULSUBv2i32_indexed_OP1:
-    Opc = AArch64::MLSv2i32_indexed;
+    Opc = AArch64::MLAv2i32_indexed;
     RC = &AArch64::FPR64RegClass;
-    MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
+    MUL = genFusedMultiplyIdxNeg(MF, MRI, TII, Root, InsInstrs,
+                                 InstrIdxForVirtReg, 1, Opc, AArch64::NEGv2i32,
+                                 RC);
     break;
   case MachineCombinerPattern::MULSUBv2i32_indexed_OP2:
     Opc = AArch64::MLSv2i32_indexed;
@@ -4644,9 +4712,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
     MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
     break;
   case MachineCombinerPattern::MULSUBv4i32_indexed_OP1:
-    Opc = AArch64::MLSv4i32_indexed;
+    Opc = AArch64::MLAv4i32_indexed;
     RC = &AArch64::FPR128RegClass;
-    MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
+    MUL = genFusedMultiplyIdxNeg(MF, MRI, TII, Root, InsInstrs,
+                                 InstrIdxForVirtReg, 1, Opc, AArch64::NEGv4i32,
+                                 RC);
     break;
   case MachineCombinerPattern::MULSUBv4i32_indexed_OP2:
     Opc = AArch64::MLSv4i32_indexed;

diff  --git a/llvm/test/CodeGen/AArch64/neon-mla-mls.ll b/llvm/test/CodeGen/AArch64/neon-mla-mls.ll
index a4b9ef8eff57..08fb8a5631a3 100644
--- a/llvm/test/CodeGen/AArch64/neon-mla-mls.ll
+++ b/llvm/test/CodeGen/AArch64/neon-mla-mls.ll
@@ -135,3 +135,75 @@ define <4 x i32> @mls4xi32(<4 x i32> %A, <4 x i32> %B, <4 x i32> %C) {
 }
 
 
+define <8 x i8> @mls2v8xi8(<8 x i8> %A, <8 x i8> %B, <8 x i8> %C) {
+; CHECK-LABEL: mls2v8xi8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    neg v2.8b, v2.8b
+; CHECK-NEXT:    mla v2.8b, v0.8b, v1.8b
+; CHECK-NEXT:    mov v0.16b, v2.16b
+; CHECK-NEXT:    ret
+	%tmp1 = mul <8 x i8> %A, %B;
+	%tmp2 = sub <8 x i8> %tmp1, %C;
+	ret <8 x i8> %tmp2
+}
+
+define <16 x i8> @mls2v16xi8(<16 x i8> %A, <16 x i8> %B, <16 x i8> %C) {
+; CHECK-LABEL: mls2v16xi8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    neg v2.16b, v2.16b
+; CHECK-NEXT:    mla v2.16b, v0.16b, v1.16b
+; CHECK-NEXT:    mov v0.16b, v2.16b
+; CHECK-NEXT:    ret
+	%tmp1 = mul <16 x i8> %A, %B;
+	%tmp2 = sub <16 x i8> %tmp1, %C;
+	ret <16 x i8> %tmp2
+}
+
+define <4 x i16> @mls2v4xi16(<4 x i16> %A, <4 x i16> %B, <4 x i16> %C) {
+; CHECK-LABEL: mls2v4xi16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    neg v2.4h, v2.4h
+; CHECK-NEXT:    mla v2.4h, v0.4h, v1.4h
+; CHECK-NEXT:    mov v0.16b, v2.16b
+; CHECK-NEXT:    ret
+	%tmp1 = mul <4 x i16> %A, %B;
+	%tmp2 = sub <4 x i16> %tmp1, %C;
+	ret <4 x i16> %tmp2
+}
+
+define <8 x i16> @mls2v8xi16(<8 x i16> %A, <8 x i16> %B, <8 x i16> %C) {
+; CHECK-LABEL: mls2v8xi16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    neg v2.8h, v2.8h
+; CHECK-NEXT:    mla v2.8h, v0.8h, v1.8h
+; CHECK-NEXT:    mov v0.16b, v2.16b
+; CHECK-NEXT:    ret
+	%tmp1 = mul <8 x i16> %A, %B;
+	%tmp2 = sub <8 x i16> %tmp1, %C;
+	ret <8 x i16> %tmp2
+}
+
+define <2 x i32> @mls2v2xi32(<2 x i32> %A, <2 x i32> %B, <2 x i32> %C) {
+; CHECK-LABEL: mls2v2xi32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    neg v2.2s, v2.2s
+; CHECK-NEXT:    mla v2.2s, v0.2s, v1.2s
+; CHECK-NEXT:    mov v0.16b, v2.16b
+; CHECK-NEXT:    ret
+	%tmp1 = mul <2 x i32> %A, %B;
+	%tmp2 = sub <2 x i32> %tmp1, %C;
+	ret <2 x i32> %tmp2
+}
+
+define <4 x i32> @mls2v4xi32(<4 x i32> %A, <4 x i32> %B, <4 x i32> %C) {
+; CHECK-LABEL: mls2v4xi32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    neg v2.4s, v2.4s
+; CHECK-NEXT:    mla v2.4s, v0.4s, v1.4s
+; CHECK-NEXT:    mov v0.16b, v2.16b
+; CHECK-NEXT:    ret
+	%tmp1 = mul <4 x i32> %A, %B;
+	%tmp2 = sub <4 x i32> %tmp1, %C;
+	ret <4 x i32> %tmp2
+}
+


        


More information about the llvm-commits mailing list