[llvm] ddc9e88 - [MachineCombiner, AArch64] Add a new pattern A-(B+C) => (A-B)-C to reduce latency

Guozhi Wei via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 28 14:48:54 PDT 2022


Author: Guozhi Wei
Date: 2022-06-28T21:42:51Z
New Revision: ddc9e8861ccf8ac3cf45e9d8cda58bdb2b0be63b

URL: https://github.com/llvm/llvm-project/commit/ddc9e8861ccf8ac3cf45e9d8cda58bdb2b0be63b
DIFF: https://github.com/llvm/llvm-project/commit/ddc9e8861ccf8ac3cf45e9d8cda58bdb2b0be63b.diff

LOG: [MachineCombiner, AArch64] Add a new pattern A-(B+C) => (A-B)-C to reduce latency

Add a new pattern A - (B + C) ==> (A - B) - C to give machine combiner a chance
to evaluate which instruction sequence has lower latency.

Differential Revision: https://reviews.llvm.org/D124564

Added: 
    llvm/test/CodeGen/AArch64/machine-combiner-subadd.ll
    llvm/test/CodeGen/AArch64/machine-combiner-subadd2.mir

Modified: 
    llvm/include/llvm/CodeGen/MachineCombinerPattern.h
    llvm/lib/CodeGen/MachineCombiner.cpp
    llvm/lib/Target/AArch64/AArch64InstrInfo.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/MachineCombinerPattern.h b/llvm/include/llvm/CodeGen/MachineCombinerPattern.h
index 67544779f34c6..68c95679d4667 100644
--- a/llvm/include/llvm/CodeGen/MachineCombinerPattern.h
+++ b/llvm/include/llvm/CodeGen/MachineCombinerPattern.h
@@ -34,6 +34,10 @@ enum class MachineCombinerPattern {
   REASSOC_XY_BCA,
   REASSOC_XY_BAC,
 
+  // These are patterns used to reduce the length of dependence chain.
+  SUBADD_OP1,
+  SUBADD_OP2,
+
   // These are multiply-add patterns matched by the AArch64 machine combiner.
   MULADDW_OP1,
   MULADDW_OP2,

diff  --git a/llvm/lib/CodeGen/MachineCombiner.cpp b/llvm/lib/CodeGen/MachineCombiner.cpp
index 05e4a1e65d391..722a709af2408 100644
--- a/llvm/lib/CodeGen/MachineCombiner.cpp
+++ b/llvm/lib/CodeGen/MachineCombiner.cpp
@@ -277,6 +277,8 @@ static CombinerObjective getCombinerObjective(MachineCombinerPattern P) {
   case MachineCombinerPattern::REASSOC_XA_YB:
   case MachineCombinerPattern::REASSOC_XY_AMM_BMM:
   case MachineCombinerPattern::REASSOC_XMM_AMM_BMM:
+  case MachineCombinerPattern::SUBADD_OP1:
+  case MachineCombinerPattern::SUBADD_OP2:
     return CombinerObjective::MustReduceDepth;
   case MachineCombinerPattern::REASSOC_XY_BCA:
   case MachineCombinerPattern::REASSOC_XY_BAC:

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 05fd190a6e54d..835a7b6cc81d9 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -4866,6 +4866,10 @@ static bool canCombine(MachineBasicBlock &MBB, MachineOperand &MO,
       return false;
   }
 
+  if (isCombineInstrSettingFlag(CombineOpc) &&
+      MI->findRegisterDefOperandIdx(AArch64::NZCV, true) == -1)
+    return false;
+
   return true;
 }
 
@@ -5366,6 +5370,42 @@ bool AArch64InstrInfo::isThroughputPattern(
   } // end switch (Pattern)
   return false;
 }
+
+/// Find other MI combine patterns.
+static bool getMiscPatterns(MachineInstr &Root,
+                            SmallVectorImpl<MachineCombinerPattern> &Patterns)
+{
+  // A - (B + C)  ==>   (A - B) - C  or  (A - C) - B
+  unsigned Opc = Root.getOpcode();
+  MachineBasicBlock &MBB = *Root.getParent();
+
+  switch (Opc) {
+  case AArch64::SUBWrr:
+  case AArch64::SUBSWrr:
+  case AArch64::SUBXrr:
+  case AArch64::SUBSXrr:
+    // Found candidate root.
+    break;
+  default:
+    return false;
+  }
+
+  if (isCombineInstrSettingFlag(Opc) &&
+      Root.findRegisterDefOperandIdx(AArch64::NZCV, true) == -1)
+    return false;
+
+  if (canCombine(MBB, Root.getOperand(2), AArch64::ADDWrr) ||
+      canCombine(MBB, Root.getOperand(2), AArch64::ADDSWrr) ||
+      canCombine(MBB, Root.getOperand(2), AArch64::ADDXrr) ||
+      canCombine(MBB, Root.getOperand(2), AArch64::ADDSXrr)) {
+    Patterns.push_back(MachineCombinerPattern::SUBADD_OP1);
+    Patterns.push_back(MachineCombinerPattern::SUBADD_OP2);
+    return true;
+  }
+
+  return false;
+}
+
 /// Return true when there is potentially a faster code sequence for an
 /// instruction chain ending in \p Root. All potential patterns are listed in
 /// the \p Pattern vector. Pattern should be sorted in priority order since the
@@ -5383,6 +5423,10 @@ bool AArch64InstrInfo::getMachineCombinerPatterns(
   if (getFMAPatterns(Root, Patterns))
     return true;
 
+  // Other patterns
+  if (getMiscPatterns(Root, Patterns))
+    return true;
+
   return TargetInstrInfo::getMachineCombinerPatterns(Root, Patterns,
                                                      DoRegPressureReduce);
 }
@@ -5633,6 +5677,53 @@ static MachineInstr *genMaddR(MachineFunction &MF, MachineRegisterInfo &MRI,
   return MUL;
 }
 
+/// Do the following transformation
+/// A - (B + C)  ==>   (A - B) - C
+/// A - (B + C)  ==>   (A - C) - B
+static void
+genSubAdd2SubSub(MachineFunction &MF, MachineRegisterInfo &MRI,
+                 const TargetInstrInfo *TII, MachineInstr &Root,
+                 SmallVectorImpl<MachineInstr *> &InsInstrs,
+                 SmallVectorImpl<MachineInstr *> &DelInstrs,
+                 unsigned IdxOpd1,
+                 DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) {
+  assert(IdxOpd1 == 1 || IdxOpd1 == 2);
+  unsigned IdxOtherOpd = IdxOpd1 == 1 ? 2 : 1;
+  MachineInstr *AddMI = MRI.getUniqueVRegDef(Root.getOperand(2).getReg());
+
+  Register ResultReg = Root.getOperand(0).getReg();
+  Register RegA = Root.getOperand(1).getReg();
+  bool RegAIsKill = Root.getOperand(1).isKill();
+  Register RegB = AddMI->getOperand(IdxOpd1).getReg();
+  bool RegBIsKill = AddMI->getOperand(IdxOpd1).isKill();
+  Register RegC = AddMI->getOperand(IdxOtherOpd).getReg();
+  bool RegCIsKill = AddMI->getOperand(IdxOtherOpd).isKill();
+  Register NewVR = MRI.createVirtualRegister(MRI.getRegClass(RegA));
+
+  unsigned Opcode = Root.getOpcode();
+  if (Opcode == AArch64::SUBSWrr)
+    Opcode = AArch64::SUBWrr;
+  else if (Opcode == AArch64::SUBSXrr)
+    Opcode = AArch64::SUBXrr;
+  else
+    assert((Opcode == AArch64::SUBWrr || Opcode == AArch64::SUBXrr) &&
+           "Unexpected instruction opcode.");
+
+  MachineInstrBuilder MIB1 =
+      BuildMI(MF, Root.getDebugLoc(), TII->get(Opcode), NewVR)
+          .addReg(RegA, getKillRegState(RegAIsKill))
+          .addReg(RegB, getKillRegState(RegBIsKill));
+  MachineInstrBuilder MIB2 =
+      BuildMI(MF, Root.getDebugLoc(), TII->get(Opcode), ResultReg)
+          .addReg(NewVR, getKillRegState(true))
+          .addReg(RegC, getKillRegState(RegCIsKill));
+
+  InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0));
+  InsInstrs.push_back(MIB1);
+  InsInstrs.push_back(MIB2);
+  DelInstrs.push_back(AddMI);
+}
+
 /// When getMachineCombinerPatterns() finds potential patterns,
 /// this function generates the instructions that could replace the
 /// original code sequence
@@ -5655,6 +5746,18 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
     TargetInstrInfo::genAlternativeCodeSequence(Root, Pattern, InsInstrs,
                                                 DelInstrs, InstrIdxForVirtReg);
     return;
+  case MachineCombinerPattern::SUBADD_OP1:
+    // A - (B + C)
+    // ==> (A - B) - C
+    genSubAdd2SubSub(MF, MRI, TII, Root, InsInstrs, DelInstrs, 1,
+                     InstrIdxForVirtReg);
+    break;
+  case MachineCombinerPattern::SUBADD_OP2:
+    // A - (B + C)
+    // ==> (A - C) - B
+    genSubAdd2SubSub(MF, MRI, TII, Root, InsInstrs, DelInstrs, 2,
+                     InstrIdxForVirtReg);
+    break;
   case MachineCombinerPattern::MULADDW_OP1:
   case MachineCombinerPattern::MULADDX_OP1:
     // MUL I=A,B,0

diff  --git a/llvm/test/CodeGen/AArch64/machine-combiner-subadd.ll b/llvm/test/CodeGen/AArch64/machine-combiner-subadd.ll
new file mode 100644
index 0000000000000..77c3d4e4df2df
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/machine-combiner-subadd.ll
@@ -0,0 +1,76 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=aarch64-linux-gnu %s -o - | FileCheck %s
+
+; The test cases in this file check following transformation if the right form
+; can reduce latency.
+;     A - (B + C)  ==>   (A - B) - C
+
+; 32 bit version.
+define i32 @test1(i32 %a, i32 %b, i32 %c) {
+; CHECK-LABEL: test1:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    add w9, w0, #100
+; CHECK-NEXT:    orr w8, w2, #0x80
+; CHECK-NEXT:    sub w8, w8, w9
+; CHECK-NEXT:    eor w9, w1, w9, lsl #8
+; CHECK-NEXT:    sub w8, w8, w9
+; CHECK-NEXT:    eor w0, w8, w9, asr #13
+; CHECK-NEXT:    ret
+entry:
+  %c1  = or  i32 %c, 128
+  %a1  = add i32 %a, 100
+  %shl = shl i32 %a1, 8
+  %xor = xor i32 %shl, %b
+  %add = add i32 %xor, %a1
+  %sub = sub i32 %c1, %add
+  %shr = ashr i32 %xor, 13
+  %xor2 = xor i32 %sub, %shr
+  ret i32 %xor2
+}
+
+; 64 bit version.
+define i64 @test2(i64 %a, i64 %b, i64 %c) {
+; CHECK-LABEL: test2:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    add x9, x0, #100
+; CHECK-NEXT:    orr x8, x2, #0x80
+; CHECK-NEXT:    sub x8, x8, x9
+; CHECK-NEXT:    eor x9, x1, x9, lsl #8
+; CHECK-NEXT:    sub x8, x8, x9
+; CHECK-NEXT:    eor x0, x8, x9, asr #13
+; CHECK-NEXT:    ret
+entry:
+  %c1  = or  i64 %c, 128
+  %a1  = add i64 %a, 100
+  %shl = shl i64 %a1, 8
+  %xor = xor i64 %shl, %b
+  %add = add i64 %xor, %a1
+  %sub = sub i64 %c1, %add
+  %shr = ashr i64 %xor, 13
+  %xor2 = xor i64 %sub, %shr
+  ret i64 %xor2
+}
+
+; Negative test. The right form can't reduce latency.
+define i32 @test3(i32 %a, i32 %b, i32 %c) {
+; CHECK-LABEL: test3:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    add w9, w0, #100
+; CHECK-NEXT:    orr w8, w2, #0x80
+; CHECK-NEXT:    add w8, w8, w9
+; CHECK-NEXT:    eor w9, w1, w9, lsl #8
+; CHECK-NEXT:    sub w8, w9, w8
+; CHECK-NEXT:    eor w0, w8, w9, asr #13
+; CHECK-NEXT:    ret
+entry:
+  %c1  = or  i32 %c, 128
+  %a1  = add i32 %a, 100
+  %shl = shl i32 %a1, 8
+  %xor = xor i32 %shl, %b
+  %add = add i32 %c1, %a1
+  %sub = sub i32 %xor, %add
+  %shr = ashr i32 %xor, 13
+  %xor2 = xor i32 %sub, %shr
+  ret i32 %xor2
+}
+

diff  --git a/llvm/test/CodeGen/AArch64/machine-combiner-subadd2.mir b/llvm/test/CodeGen/AArch64/machine-combiner-subadd2.mir
new file mode 100644
index 0000000000000..395fa3f023e19
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/machine-combiner-subadd2.mir
@@ -0,0 +1,239 @@
+# RUN: llc -mtriple=aarch64-linux-gnu -run-pass machine-combiner -o - %s | FileCheck %s
+
+# The test cases in this file check following transformation if the right form
+# can reduce latency.
+#     A - (B + C)  ==>   (A - B) - C
+
+---
+# 32 bit.
+
+# CHECK-LABEL: name: test1
+# CHECK:       %10:gpr32common = SUBWrr killed %3, %4
+# CHECK-NEXT:  %7:gpr32 = SUBWrr killed %10, %5
+
+name:            test1
+registers:
+  - { id: 0, class: gpr32common }
+  - { id: 1, class: gpr32 }
+  - { id: 2, class: gpr32 }
+  - { id: 3, class: gpr32common }
+  - { id: 4, class: gpr32common }
+  - { id: 5, class: gpr32 }
+  - { id: 6, class: gpr32 }
+  - { id: 7, class: gpr32 }
+  - { id: 8, class: gpr32 }
+body:              |
+  bb.0:
+    %2:gpr32 = COPY $w2
+    %1:gpr32 = COPY $w1
+    %0:gpr32common = COPY $w0
+    %3:gpr32common = ORRWri %2:gpr32, 1600
+    %4:gpr32common = ADDWri %0:gpr32common, 100, 0
+    %5:gpr32 = EORWrs %1:gpr32, %4:gpr32common, 8
+    %6:gpr32 = ADDWrr %5:gpr32, %4:gpr32common
+    %7:gpr32 = SUBWrr killed %3:gpr32common, killed %6:gpr32
+    %8:gpr32 = EORWrs killed %7:gpr32, %5:gpr32, 141
+    $w0 = COPY %8:gpr32
+    RET_ReallyLR implicit $w0
+
+...
+---
+# 64 bit.
+
+# CHECK-LABEL: name: test2
+# CHECK:       %10:gpr64common = SUBXrr killed %3, %4
+# CHECK-NEXT:  %7:gpr64 = SUBXrr killed %10, %5
+
+name:            test2
+registers:
+  - { id: 0, class: gpr64common }
+  - { id: 1, class: gpr64 }
+  - { id: 2, class: gpr64 }
+  - { id: 3, class: gpr64common }
+  - { id: 4, class: gpr64common }
+  - { id: 5, class: gpr64 }
+  - { id: 6, class: gpr64 }
+  - { id: 7, class: gpr64 }
+  - { id: 8, class: gpr64 }
+body:              |
+  bb.0:
+    %2:gpr64 = COPY $x2
+    %1:gpr64 = COPY $x1
+    %0:gpr64common = COPY $x0
+    %3:gpr64common = ORRXri %2:gpr64, 1600
+    %4:gpr64common = ADDXri %0:gpr64common, 100, 0
+    %5:gpr64 = EORXrs %1:gpr64, %4:gpr64common, 8
+    %6:gpr64 = ADDXrr %5:gpr64, %4:gpr64common
+    %7:gpr64 = SUBXrr killed %3:gpr64common, killed %6:gpr64
+    %8:gpr64 = EORXrs killed %7:gpr64, %5:gpr64, 141
+    $x0 = COPY %8:gpr64
+    RET_ReallyLR implicit $x0
+
+...
+---
+# Negative test. The right form can't reduce latency.
+
+# CHECK-LABEL: name: test3
+# CHECK:       %6:gpr32 = ADDWrr killed %3, %4
+# CHECK-NEXT:  %7:gpr32 = SUBWrr %5, killed %6
+
+name:           test3
+registers:
+  - { id: 0, class: gpr32common }
+  - { id: 1, class: gpr32 }
+  - { id: 2, class: gpr32 }
+  - { id: 3, class: gpr32common }
+  - { id: 4, class: gpr32common }
+  - { id: 5, class: gpr32 }
+  - { id: 6, class: gpr32 }
+  - { id: 7, class: gpr32 }
+  - { id: 8, class: gpr32 }
+body:              |
+  bb.0:
+    %2:gpr32 = COPY $w2
+    %1:gpr32 = COPY $w1
+    %0:gpr32common = COPY $w0
+    %3:gpr32common = ORRWri %2:gpr32, 1600
+    %4:gpr32common = ADDWri %0:gpr32common, 100, 0
+    %5:gpr32 = EORWrs %1:gpr32, %4:gpr32common, 8
+    %6:gpr32 = ADDWrr killed %3:gpr32common, %4:gpr32common
+    %7:gpr32 = SUBWrr %5:gpr32, killed %6:gpr32
+    %8:gpr32 = EORWrs killed %7:gpr32, %5:gpr32, 141
+    $w0 = COPY %8:gpr32
+    RET_ReallyLR implicit $w0
+
+...
+---
+# Dead define of flag registers should not block transformation.
+
+# CHECK-LABEL: name: test4
+# CHECK:       %10:gpr64common = SUBXrr killed %3, %4
+# CHECK-NEXT:  %7:gpr64 = SUBXrr killed %10, %5
+
+name:            test4
+registers:
+  - { id: 0, class: gpr64common }
+  - { id: 1, class: gpr64 }
+  - { id: 2, class: gpr64 }
+  - { id: 3, class: gpr64common }
+  - { id: 4, class: gpr64common }
+  - { id: 5, class: gpr64 }
+  - { id: 6, class: gpr64 }
+  - { id: 7, class: gpr64 }
+  - { id: 8, class: gpr64 }
+body:              |
+  bb.0:
+    %2:gpr64 = COPY $x2
+    %1:gpr64 = COPY $x1
+    %0:gpr64common = COPY $x0
+    %3:gpr64common = ORRXri %2:gpr64, 1600
+    %4:gpr64common = ADDXri %0:gpr64common, 100, 0
+    %5:gpr64 = EORXrs %1:gpr64, %4:gpr64common, 8
+    %6:gpr64 = ADDSXrr %5:gpr64, %4:gpr64common, implicit-def dead $nzcv
+    %7:gpr64 = SUBSXrr killed %3:gpr64common, killed %6:gpr64, implicit-def dead $nzcv
+    %8:gpr64 = EORXrs killed %7:gpr64, %5:gpr64, 141
+    $x0 = COPY %8:gpr64
+    RET_ReallyLR implicit $x0
+
+...
+---
+# Non dead define of flag register in SUB can block the transformation.
+
+# CHECK-LABEL: name: test5
+# CHECK:       %6:gpr32 = ADDWrr %5, %4
+# CHECK-NEXT:  %7:gpr32 = SUBSWrr killed %3, killed %6, implicit-def $nzcv
+
+name:            test5
+registers:
+  - { id: 0, class: gpr32common }
+  - { id: 1, class: gpr32 }
+  - { id: 2, class: gpr32 }
+  - { id: 3, class: gpr32common }
+  - { id: 4, class: gpr32common }
+  - { id: 5, class: gpr32 }
+  - { id: 6, class: gpr32 }
+  - { id: 7, class: gpr32 }
+  - { id: 8, class: gpr32 }
+body:              |
+  bb.0:
+    %2:gpr32 = COPY $w2
+    %1:gpr32 = COPY $w1
+    %0:gpr32common = COPY $w0
+    %3:gpr32common = ORRWri %2:gpr32, 1600
+    %4:gpr32common = ADDWri %0:gpr32common, 100, 0
+    %5:gpr32 = EORWrs %1:gpr32, %4:gpr32common, 8
+    %6:gpr32 = ADDWrr %5:gpr32, %4:gpr32common
+    %7:gpr32 = SUBSWrr killed %3:gpr32common, killed %6:gpr32, implicit-def $nzcv
+    %8:gpr32 = EORWrs killed %7:gpr32, %5:gpr32, 141
+    $w0 = COPY %8:gpr32
+    RET_ReallyLR implicit $w0
+
+...
+---
+# Non dead define of flag register in ADD can block the transformation.
+
+# CHECK-LABEL: name: test6
+# CHECK:       %6:gpr64 = ADDSXrr %5, %4, implicit-def $nzcv
+# CHECK-NEXT:  %7:gpr64 = SUBXrr killed %3, killed %6
+
+name:            test6
+registers:
+  - { id: 0, class: gpr64common }
+  - { id: 1, class: gpr64 }
+  - { id: 2, class: gpr64 }
+  - { id: 3, class: gpr64common }
+  - { id: 4, class: gpr64common }
+  - { id: 5, class: gpr64 }
+  - { id: 6, class: gpr64 }
+  - { id: 7, class: gpr64 }
+  - { id: 8, class: gpr64 }
+body:              |
+  bb.0:
+    %2:gpr64 = COPY $x2
+    %1:gpr64 = COPY $x1
+    %0:gpr64common = COPY $x0
+    %3:gpr64common = ORRXri %2:gpr64, 1600
+    %4:gpr64common = ADDXri %0:gpr64common, 100, 0
+    %5:gpr64 = EORXrs %1:gpr64, %4:gpr64common, 8
+    %6:gpr64 = ADDSXrr %5:gpr64, %4:gpr64common, implicit-def $nzcv
+    %7:gpr64 = SUBXrr killed %3:gpr64common, killed %6:gpr64
+    %8:gpr64 = EORXrs killed %7:gpr64, %5:gpr64, 141
+    $x0 = COPY %8:gpr64
+    RET_ReallyLR implicit $x0
+
+...
+---
+# ADD has multiple uses, so it is always required, we should not transform it.
+
+# CHECK-LABEL: name: test7
+# CHECK:       %6:gpr32 = ADDWrr %5, %4
+# CHECK-NEXT:  %7:gpr32 = SUBWrr killed %3, %6
+
+name:            test7
+registers:
+  - { id: 0, class: gpr32common }
+  - { id: 1, class: gpr32 }
+  - { id: 2, class: gpr32 }
+  - { id: 3, class: gpr32common }
+  - { id: 4, class: gpr32common }
+  - { id: 5, class: gpr32 }
+  - { id: 6, class: gpr32 }
+  - { id: 7, class: gpr32 }
+  - { id: 8, class: gpr32 }
+  - { id: 9, class: gpr32 }
+body:              |
+  bb.0:
+    %2:gpr32 = COPY $w2
+    %1:gpr32 = COPY $w1
+    %0:gpr32common = COPY $w0
+    %3:gpr32common = ORRWri %2:gpr32, 1600
+    %4:gpr32common = ADDWri %0:gpr32common, 100, 0
+    %5:gpr32 = EORWrs %1:gpr32, %4:gpr32common, 8
+    %6:gpr32 = ADDWrr %5:gpr32, %4:gpr32common
+    %7:gpr32 = SUBWrr killed %3:gpr32common, %6:gpr32
+    %8:gpr32 = EORWrs killed %7:gpr32, %5:gpr32, 141
+    %9:gpr32 = ADDWrr %8:gpr32, %6:gpr32
+    $w0 = COPY %9:gpr32
+    RET_ReallyLR implicit $w0
+
+...


        


More information about the llvm-commits mailing list