[llvm] [MacroFusion] Support commutable instructions (PR #82751)

via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 5 03:54:03 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Wang Pengcheng (wangpc-pp)

<details>
<summary>Changes</summary>

If the second instruction is commutable, we should be able to check
its commutable operands.

A simple RISCV fusion is contained in this PR to show the functionality
is correct, I may remove it when landing.

There are some other issues I should fix. For example, we should be
able to check that the destination register is tha same as the
commutable operands. But I post this PR here firstly to gather
feedbacks.

Fixes #<!-- -->82738

---
Full diff: https://github.com/llvm/llvm-project/pull/82751.diff


5 Files Affected:

- (modified) llvm/include/llvm/Target/TargetSchedule.td (+13-6) 
- (modified) llvm/lib/Target/RISCV/RISCVMacroFusion.td (+21) 
- (modified) llvm/test/CodeGen/RISCV/macro-fusions.mir (+29-1) 
- (modified) llvm/test/TableGen/MacroFusion.td (+42-9) 
- (modified) llvm/utils/TableGen/MacroFusionPredicatorEmitter.cpp (+40-9) 


``````````diff
diff --git a/llvm/include/llvm/Target/TargetSchedule.td b/llvm/include/llvm/Target/TargetSchedule.td
index 48c9387977c075..a872cc30a9b50c 100644
--- a/llvm/include/llvm/Target/TargetSchedule.td
+++ b/llvm/include/llvm/Target/TargetSchedule.td
@@ -617,16 +617,27 @@ class SecondFusionPredicateWithMCInstPredicate<MCInstPredicate pred>
   : FusionPredicateWithMCInstPredicate<second_fusion_target, pred>;
 // The pred will be applied on both firstMI and secondMI.
 class BothFusionPredicateWithMCInstPredicate<MCInstPredicate pred>
-  : FusionPredicateWithMCInstPredicate<second_fusion_target, pred>;
+  : FusionPredicateWithMCInstPredicate<both_fusion_target, pred>;
 
 // Tie firstOpIdx and secondOpIdx. The operand of `FirstMI` at position
 // `firstOpIdx` should be the same as the operand of `SecondMI` at position
 // `secondOpIdx`.
+// If the operand at `secondOpIdx` has commutable operand, then the commutable
+// operand will be checked too.
 class TieReg<int firstOpIdx, int secondOpIdx> : BothFusionPredicate {
   int FirstOpIdx = firstOpIdx;
   int SecondOpIdx = secondOpIdx;
 }
 
+// The operand of `SecondMI` at position `firstOpIdx` should be the same as the
+// operand at position `secondOpIdx`.
+// If the operand at `secondOpIdx` has commutable operand, then the commutable
+// operand will be checked too.
+class SameReg<int firstOpIdx, int secondOpIdx> : SecondFusionPredicate {
+  int FirstOpIdx = firstOpIdx;
+  int SecondOpIdx = secondOpIdx;
+}
+
 // A predicate for wildcard. The generated code will be like:
 // ```
 // if (!FirstMI)
@@ -688,11 +699,7 @@ class SimpleFusion<string name, string fieldName, string desc,
                 SecondFusionPredicateWithMCInstPredicate<secondPred>,
                 WildcardTrue,
                 FirstFusionPredicateWithMCInstPredicate<firstPred>,
-                SecondFusionPredicateWithMCInstPredicate<
-                  CheckAny<[
-                    CheckIsVRegOperand<0>,
-                    CheckSameRegOperand<0, 1>
-                  ]>>,
+                SameReg<0, 1>,
                 OneUse,
                 TieReg<0, 1>,
               ],
diff --git a/llvm/lib/Target/RISCV/RISCVMacroFusion.td b/llvm/lib/Target/RISCV/RISCVMacroFusion.td
index 875a93d09a2c64..14e8962f8ce110 100644
--- a/llvm/lib/Target/RISCV/RISCVMacroFusion.td
+++ b/llvm/lib/Target/RISCV/RISCVMacroFusion.td
@@ -91,3 +91,24 @@ def TuneLDADDFusion
                    CheckIsImmOperand<2>,
                    CheckImmOperand<2, 0>
                  ]>>;
+
+// These should be covered by Zba extension.
+// * shift left by one and add:
+//   slli r1, r0, 1
+//   add r1, r1, r2
+// * shift left by two and add:
+//   slli r1, r0, 2
+//   add r1, r1, r2
+// * shift left by three and add:
+//   slli r1, r0, 3
+//   add r1, r1, r2
+def ShiftNAddFusion
+  : SimpleFusion<"shift-n-add-fusion", "HasShiftNAddFusion",
+                 "Enable SLLI+ADD to be fused to shift left by 1/2/3 and add",
+                 CheckAll<[
+                   CheckOpcode<[SLLI]>,
+                   CheckAny<[CheckImmOperand<2, 1>,
+                             CheckImmOperand<2, 2>,
+                             CheckImmOperand<2, 3>]>
+                 ]>,
+                 CheckOpcode<[ADD]>>;
diff --git a/llvm/test/CodeGen/RISCV/macro-fusions.mir b/llvm/test/CodeGen/RISCV/macro-fusions.mir
index 13464141ce27e6..11bb456e0ca050 100644
--- a/llvm/test/CodeGen/RISCV/macro-fusions.mir
+++ b/llvm/test/CodeGen/RISCV/macro-fusions.mir
@@ -1,7 +1,7 @@
 # REQUIRES: asserts
 # RUN: llc -mtriple=riscv64-linux-gnu -x=mir < %s \
 # RUN:   -debug-only=machine-scheduler -start-before=machine-scheduler 2>&1 \
-# RUN:   -mattr=+lui-addi-fusion,+auipc-addi-fusion,+zexth-fusion,+zextw-fusion,+shifted-zextw-fusion,+ld-add-fusion \
+# RUN:   -mattr=+lui-addi-fusion,+auipc-addi-fusion,+zexth-fusion,+zextw-fusion,+shifted-zextw-fusion,+ld-add-fusion,+shift-n-add-fusion \
 # RUN:   | FileCheck %s
 
 # CHECK: lui_addi:%bb.0
@@ -174,3 +174,31 @@ body:             |
     $x11 = COPY %5
     PseudoRET
 ...
+
+# CHECK: shift_n_add:%bb.0
+# CHECK: Macro fuse: {{.*}}SLLI - ADD
+---
+name: shift_n_add
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $x10, $x11
+    $x10 = SLLI $x10, 1
+    $x12 = XORI $x11, 3
+    $x10 = ADD $x10, $x11
+    PseudoRET
+...
+
+# CHECK: shift_n_add_commutable:%bb.0
+# CHECK: Macro fuse: {{.*}}SLLI - ADD
+---
+name: shift_n_add_commutable
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $x10, $x11
+    $x10 = SLLI $x10, 1
+    $x12 = XORI $x11, 3
+    $x10 = ADD $x11, $x10
+    PseudoRET
+...
diff --git a/llvm/test/TableGen/MacroFusion.td b/llvm/test/TableGen/MacroFusion.td
index 4aa6c8d9acb273..c1cdd0fee78ccd 100644
--- a/llvm/test/TableGen/MacroFusion.td
+++ b/llvm/test/TableGen/MacroFusion.td
@@ -34,6 +34,11 @@ let Namespace = "Test" in {
 def Inst0 : TestInst<0>;
 def Inst1 : TestInst<1>;
 
+def BothFusionPredicate: BothFusionPredicateWithMCInstPredicate<CheckRegOperand<0, X0>>;
+def TestBothFusionPredicate: Fusion<"test-both-fusion-predicate", "HasBothFusionPredicate",
+                                    "Test BothFusionPredicate",
+                                    [BothFusionPredicate]>;
+
 def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
                              CheckOpcode<[Inst0]>,
                              CheckAll<[
@@ -45,6 +50,7 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
 // CHECK-PREDICATOR-NEXT:  #undef GET_Test_MACRO_FUSION_PRED_DECL
 // CHECK-PREDICATOR-EMPTY:
 // CHECK-PREDICATOR-NEXT:  namespace llvm {
+// CHECK-PREDICATOR-NEXT:  bool isTestBothFusionPredicate(const TargetInstrInfo &, const TargetSubtargetInfo &, const MachineInstr *, const MachineInstr &);
 // CHECK-PREDICATOR-NEXT:  bool isTestFusion(const TargetInstrInfo &, const TargetSubtargetInfo &, const MachineInstr *, const MachineInstr &);
 // CHECK-PREDICATOR-NEXT:  } // end namespace llvm
 // CHECK-PREDICATOR-EMPTY:
@@ -54,6 +60,24 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
 // CHECK-PREDICATOR-NEXT:  #undef GET_Test_MACRO_FUSION_PRED_IMPL
 // CHECK-PREDICATOR-EMPTY:
 // CHECK-PREDICATOR-NEXT:  namespace llvm {
+// CHECK-PREDICATOR-NEXT:  bool isTestBothFusionPredicate(
+// CHECK-PREDICATOR-NEXT:      const TargetInstrInfo &TII,
+// CHECK-PREDICATOR-NEXT:      const TargetSubtargetInfo &STI,
+// CHECK-PREDICATOR-NEXT:      const MachineInstr *FirstMI,
+// CHECK-PREDICATOR-NEXT:      const MachineInstr &SecondMI) {
+// CHECK-PREDICATOR-NEXT:    auto &MRI = SecondMI.getMF()->getRegInfo();
+// CHECK-PREDICATOR-NEXT:    {
+// CHECK-PREDICATOR-NEXT:      const MachineInstr *MI = FirstMI;
+// CHECK-PREDICATOR-NEXT:      if (MI->getOperand(0).getReg() != Test::X0)
+// CHECK-PREDICATOR-NEXT:        return false;
+// CHECK-PREDICATOR-NEXT:    }
+// CHECK-PREDICATOR-NEXT:    {
+// CHECK-PREDICATOR-NEXT:      const MachineInstr *MI = &SecondMI;
+// CHECK-PREDICATOR-NEXT:      if (MI->getOperand(0).getReg() != Test::X0)
+// CHECK-PREDICATOR-NEXT:        return false;
+// CHECK-PREDICATOR-NEXT:    }
+// CHECK-PREDICATOR-NEXT:    return true;
+// CHECK-PREDICATOR-NEXT:  }
 // CHECK-PREDICATOR-NEXT:  bool isTestFusion(
 // CHECK-PREDICATOR-NEXT:      const TargetInstrInfo &TII,
 // CHECK-PREDICATOR-NEXT:      const TargetSubtargetInfo &STI,
@@ -75,13 +99,15 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
 // CHECK-PREDICATOR-NEXT:      if (( MI->getOpcode() != Test::Inst0 ))
 // CHECK-PREDICATOR-NEXT:        return false;
 // CHECK-PREDICATOR-NEXT:    }
-// CHECK-PREDICATOR-NEXT:    {
-// CHECK-PREDICATOR-NEXT:      const MachineInstr *MI = &SecondMI;
-// CHECK-PREDICATOR-NEXT:      if (!(
-// CHECK-PREDICATOR-NEXT:          MI->getOperand(0).getReg().isVirtual()
-// CHECK-PREDICATOR-NEXT:          || MI->getOperand(0).getReg() == MI->getOperand(1).getReg()
-// CHECK-PREDICATOR-NEXT:        ))
-// CHECK-PREDICATOR-NEXT:        return false;
+// CHECK-PREDICATOR-NEXT:    if (!SecondMI.getOperand(0).getReg().isVirtual()) {                                                                                                                                                                                      
+// CHECK-PREDICATOR-NEXT:      if (SecondMI.getOperand(0).getReg() != SecondMI.getOperand(1).getReg()) {                                                                                                                                                              
+// CHECK-PREDICATOR-NEXT:        if (!SecondMI.getDesc().isCommutable())                                                                                                                                                                                              
+// CHECK-PREDICATOR-NEXT:          return false;
+// CHECK-PREDICATOR-NEXT:        unsigned SrcOpIdx1 = 1, SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;
+// CHECK-PREDICATOR-NEXT:        if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))
+// CHECK-PREDICATOR-NEXT:          if (SecondMI.getOperand(0).getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())
+// CHECK-PREDICATOR-NEXT:            return false;
+// CHECK-PREDICATOR-NEXT:      }
 // CHECK-PREDICATOR-NEXT:    }
 // CHECK-PREDICATOR-NEXT:    {
 // CHECK-PREDICATOR-NEXT:      Register FirstDest = FirstMI->getOperand(0).getReg();
@@ -90,8 +116,14 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
 // CHECK-PREDICATOR-NEXT:    }
 // CHECK-PREDICATOR-NEXT:    if (!(FirstMI->getOperand(0).isReg() &&
 // CHECK-PREDICATOR-NEXT:          SecondMI.getOperand(1).isReg() &&
-// CHECK-PREDICATOR-NEXT:          FirstMI->getOperand(0).getReg() == SecondMI.getOperand(1).getReg()))
-// CHECK-PREDICATOR-NEXT:      return false;
+// CHECK-PREDICATOR-NEXT:          FirstMI->getOperand(0).getReg() == SecondMI.getOperand(1).getReg())) {
+// CHECK-PREDICATOR-NEXT:      if (!SecondMI.getDesc().isCommutable())
+// CHECK-PREDICATOR-NEXT:        return false;
+// CHECK-PREDICATOR-NEXT:      unsigned SrcOpIdx1 = 1, SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;
+// CHECK-PREDICATOR-NEXT:      if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))
+// CHECK-PREDICATOR-NEXT:        if (FirstMI->getOperand(0).getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())
+// CHECK-PREDICATOR-NEXT:          return false;
+// CHECK-PREDICATOR-NEXT:    }
 // CHECK-PREDICATOR-NEXT:    return true;
 // CHECK-PREDICATOR-NEXT:  }
 // CHECK-PREDICATOR-NEXT:  } // end namespace llvm
@@ -106,6 +138,7 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
 
 // CHECK-SUBTARGET:      std::vector<MacroFusionPredTy> TestGenSubtargetInfo::getMacroFusions() const {
 // CHECK-SUBTARGET-NEXT:   std::vector<MacroFusionPredTy> Fusions;
+// CHECK-SUBTARGET-NEXT:   if (hasFeature(Test::TestBothFusionPredicate)) Fusions.push_back(llvm::isTestBothFusionPredicate); 
 // CHECK-SUBTARGET-NEXT:   if (hasFeature(Test::TestFusion)) Fusions.push_back(llvm::isTestFusion);
 // CHECK-SUBTARGET-NEXT:   return Fusions;
 // CHECK-SUBTARGET-NEXT: }
diff --git a/llvm/utils/TableGen/MacroFusionPredicatorEmitter.cpp b/llvm/utils/TableGen/MacroFusionPredicatorEmitter.cpp
index 78dcd4471ae747..1042dd9c2dbccf 100644
--- a/llvm/utils/TableGen/MacroFusionPredicatorEmitter.cpp
+++ b/llvm/utils/TableGen/MacroFusionPredicatorEmitter.cpp
@@ -152,8 +152,7 @@ void MacroFusionPredicatorEmitter::emitFirstPredicate(Record *Predicate,
         << "if (FirstDest.isVirtual() && !MRI.hasOneNonDBGUse(FirstDest))\n";
     OS.indent(4) << "  return false;\n";
     OS.indent(2) << "}\n";
-  } else if (Predicate->isSubClassOf(
-                 "FirstFusionPredicateWithMCInstPredicate")) {
+  } else if (Predicate->isSubClassOf("FusionPredicateWithMCInstPredicate")) {
     OS.indent(2) << "{\n";
     OS.indent(4) << "const MachineInstr *MI = FirstMI;\n";
     OS.indent(4) << "if (";
@@ -173,7 +172,7 @@ void MacroFusionPredicatorEmitter::emitFirstPredicate(Record *Predicate,
 void MacroFusionPredicatorEmitter::emitSecondPredicate(Record *Predicate,
                                                        PredicateExpander &PE,
                                                        raw_ostream &OS) {
-  if (Predicate->isSubClassOf("SecondFusionPredicateWithMCInstPredicate")) {
+  if (Predicate->isSubClassOf("FusionPredicateWithMCInstPredicate")) {
     OS.indent(2) << "{\n";
     OS.indent(4) << "const MachineInstr *MI = &SecondMI;\n";
     OS.indent(4) << "if (";
@@ -183,9 +182,31 @@ void MacroFusionPredicatorEmitter::emitSecondPredicate(Record *Predicate,
     OS << ")\n";
     OS.indent(4) << "  return false;\n";
     OS.indent(2) << "}\n";
+  } else if (Predicate->isSubClassOf("SameReg")) {
+    int FirstOpIdx = Predicate->getValueAsInt("FirstOpIdx");
+    int SecondOpIdx = Predicate->getValueAsInt("SecondOpIdx");
+
+    OS.indent(2) << "if (!SecondMI.getOperand(" << FirstOpIdx
+                 << ").getReg().isVirtual()) {\n";
+    OS.indent(4) << "if (SecondMI.getOperand(" << FirstOpIdx
+                 << ").getReg() != SecondMI.getOperand(" << SecondOpIdx
+                 << ").getReg()) {\n";
+
+    OS.indent(6) << "if (!SecondMI.getDesc().isCommutable())\n";
+    OS.indent(6) << "  return false;\n";
+
+    OS.indent(6) << "unsigned SrcOpIdx1 = " << SecondOpIdx
+                 << ", SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;\n";
+    OS.indent(6)
+        << "if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))\n";
+    OS.indent(6) << "  if (SecondMI.getOperand(" << FirstOpIdx
+                 << ").getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())\n";
+    OS.indent(6) << "    return false;\n";
+    OS.indent(4) << "}\n";
+    OS.indent(2) << "}\n";
   } else {
     PrintFatalError(Predicate->getLoc(),
-                    "Unsupported predicate for first instruction: " +
+                    "Unsupported predicate for second instruction: " +
                         Predicate->getType()->getAsString());
   }
 }
@@ -196,9 +217,8 @@ void MacroFusionPredicatorEmitter::emitBothPredicate(Record *Predicate,
   if (Predicate->isSubClassOf("FusionPredicateWithCode"))
     OS << Predicate->getValueAsString("Predicate");
   else if (Predicate->isSubClassOf("BothFusionPredicateWithMCInstPredicate")) {
-    Record *MCPred = Predicate->getValueAsDef("Predicate");
-    emitFirstPredicate(MCPred, PE, OS);
-    emitSecondPredicate(MCPred, PE, OS);
+    emitFirstPredicate(Predicate, PE, OS);
+    emitSecondPredicate(Predicate, PE, OS);
   } else if (Predicate->isSubClassOf("TieReg")) {
     int FirstOpIdx = Predicate->getValueAsInt("FirstOpIdx");
     int SecondOpIdx = Predicate->getValueAsInt("SecondOpIdx");
@@ -208,8 +228,19 @@ void MacroFusionPredicatorEmitter::emitBothPredicate(Record *Predicate,
                  << ").isReg() &&\n";
     OS.indent(2) << "      FirstMI->getOperand(" << FirstOpIdx
                  << ").getReg() == SecondMI.getOperand(" << SecondOpIdx
-                 << ").getReg()))\n";
-    OS.indent(2) << "  return false;\n";
+                 << ").getReg())) {\n";
+
+    OS.indent(4) << "if (!SecondMI.getDesc().isCommutable())\n";
+    OS.indent(4) << "  return false;\n";
+
+    OS.indent(4) << "unsigned SrcOpIdx1 = " << SecondOpIdx
+                 << ", SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;\n";
+    OS.indent(4)
+        << "if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))\n";
+    OS.indent(4) << "  if (FirstMI->getOperand(" << FirstOpIdx
+                 << ").getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())\n";
+    OS.indent(4) << "    return false;\n";
+    OS.indent(2) << "}\n";
   } else
     PrintFatalError(Predicate->getLoc(),
                     "Unsupported predicate for both instruction: " +

``````````

</details>


https://github.com/llvm/llvm-project/pull/82751


More information about the llvm-commits mailing list