[llvm] b890a48 - [MacroFusion] Support commutable instructions (#82751)

via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 15 03:44:52 PDT 2024


Author: Wang Pengcheng
Date: 2024-03-15T18:44:49+08:00
New Revision: b890a48a12aa5c851185ae2fd6273cd853fe0bc5

URL: https://github.com/llvm/llvm-project/commit/b890a48a12aa5c851185ae2fd6273cd853fe0bc5
DIFF: https://github.com/llvm/llvm-project/commit/b890a48a12aa5c851185ae2fd6273cd853fe0bc5.diff

LOG: [MacroFusion] Support commutable instructions (#82751)

If the second instruction is commutable, we should be able to check
its commutable operands.

A simple RISCV fusion is contained in this PR to show the functionality
is correct, I may remove it when landing.

Fixes #82738

Added: 
    

Modified: 
    llvm/include/llvm/Target/TargetSchedule.td
    llvm/test/TableGen/MacroFusion.td
    llvm/utils/TableGen/MacroFusionPredicatorEmitter.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Target/TargetSchedule.td b/llvm/include/llvm/Target/TargetSchedule.td
index 069eb2900bfe68..d8158eb01ad45e 100644
--- a/llvm/include/llvm/Target/TargetSchedule.td
+++ b/llvm/include/llvm/Target/TargetSchedule.td
@@ -622,11 +622,22 @@ class BothFusionPredicateWithMCInstPredicate<MCInstPredicate pred>
 // Tie firstOpIdx and secondOpIdx. The operand of `FirstMI` at position
 // `firstOpIdx` should be the same as the operand of `SecondMI` at position
 // `secondOpIdx`.
+// If the fusion has `IsCommutable` being true and the operand at `secondOpIdx`
+// has commutable operand, then the commutable operand will be checked too.
 class TieReg<int firstOpIdx, int secondOpIdx> : BothFusionPredicate {
   int FirstOpIdx = firstOpIdx;
   int SecondOpIdx = secondOpIdx;
 }
 
+// The operand of `SecondMI` at position `firstOpIdx` should be the same as the
+// operand at position `secondOpIdx`.
+// If the fusion has `IsCommutable` being true and the operand at `secondOpIdx`
+// has commutable operand, then the commutable operand will be checked too.
+class SameReg<int firstOpIdx, int secondOpIdx> : SecondFusionPredicate {
+  int FirstOpIdx = firstOpIdx;
+  int SecondOpIdx = secondOpIdx;
+}
+
 // A predicate for wildcard. The generated code will be like:
 // ```
 // if (!FirstMI)
@@ -655,9 +666,12 @@ def OneUse : OneUsePred;
 //   return true;
 // }
 // ```
+//
+// `IsCommutable` means whether we should handle commutable operands.
 class Fusion<string name, string fieldName, string desc, list<FusionPredicate> predicates>
   : SubtargetFeature<name, fieldName, "true", desc> {
   list<FusionPredicate> Predicates = predicates;
+  bit IsCommutable = 0;
 }
 
 // The generated predicator will be like:
@@ -671,6 +685,7 @@ class Fusion<string name, string fieldName, string desc, list<FusionPredicate> p
 //   /* Predicate for `SecondMI` */
 //   /* Wildcard */
 //   /* Predicate for `FirstMI` */
+//   /* Check same registers */
 //   /* Check One Use */
 //   /* Tie registers */
 //   /* Epilog */
@@ -688,11 +703,7 @@ class SimpleFusion<string name, string fieldName, string desc,
                 SecondFusionPredicateWithMCInstPredicate<secondPred>,
                 WildcardTrue,
                 FirstFusionPredicateWithMCInstPredicate<firstPred>,
-                SecondFusionPredicateWithMCInstPredicate<
-                  CheckAny<[
-                    CheckIsVRegOperand<0>,
-                    CheckSameRegOperand<0, 1>
-                  ]>>,
+                SameReg<0, 1>,
                 OneUse,
                 TieReg<0, 1>,
               ],

diff  --git a/llvm/test/TableGen/MacroFusion.td b/llvm/test/TableGen/MacroFusion.td
index ce76e7f0f7fa64..05c970cbd22455 100644
--- a/llvm/test/TableGen/MacroFusion.td
+++ b/llvm/test/TableGen/MacroFusion.td
@@ -46,11 +46,21 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
                               CheckRegOperand<0, X0>
                              ]>>;
 
+let IsCommutable = 1 in
+def TestCommutableFusion: SimpleFusion<"test-commutable-fusion", "HasTestCommutableFusion",
+                                       "Test Commutable Fusion",
+                                       CheckOpcode<[Inst0]>,
+                                       CheckAll<[
+                                        CheckOpcode<[Inst1]>,
+                                        CheckRegOperand<0, X0>
+                                       ]>>;
+
 // CHECK-PREDICATOR:       #ifdef GET_Test_MACRO_FUSION_PRED_DECL
 // CHECK-PREDICATOR-NEXT:  #undef GET_Test_MACRO_FUSION_PRED_DECL
 // CHECK-PREDICATOR-EMPTY:
 // CHECK-PREDICATOR-NEXT:  namespace llvm {
 // CHECK-PREDICATOR-NEXT:  bool isTestBothFusionPredicate(const TargetInstrInfo &, const TargetSubtargetInfo &, const MachineInstr *, const MachineInstr &);
+// CHECK-PREDICATOR-NEXT:  bool isTestCommutableFusion(const TargetInstrInfo &, const TargetSubtargetInfo &, const MachineInstr *, const MachineInstr &);   
 // CHECK-PREDICATOR-NEXT:  bool isTestFusion(const TargetInstrInfo &, const TargetSubtargetInfo &, const MachineInstr *, const MachineInstr &);
 // CHECK-PREDICATOR-NEXT:  } // end namespace llvm
 // CHECK-PREDICATOR-EMPTY:
@@ -78,7 +88,7 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
 // CHECK-PREDICATOR-NEXT:    }
 // CHECK-PREDICATOR-NEXT:    return true;
 // CHECK-PREDICATOR-NEXT:  }
-// CHECK-PREDICATOR-NEXT:  bool isTestFusion(
+// CHECK-PREDICATOR-NEXT:  bool isTestCommutableFusion(
 // CHECK-PREDICATOR-NEXT:      const TargetInstrInfo &TII,
 // CHECK-PREDICATOR-NEXT:      const TargetSubtargetInfo &STI,
 // CHECK-PREDICATOR-NEXT:      const MachineInstr *FirstMI,
@@ -99,14 +109,58 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
 // CHECK-PREDICATOR-NEXT:      if (( MI->getOpcode() != Test::Inst0 ))
 // CHECK-PREDICATOR-NEXT:        return false;
 // CHECK-PREDICATOR-NEXT:    }
+// CHECK-PREDICATOR-NEXT:    if (!SecondMI.getOperand(0).getReg().isVirtual()) {
+// CHECK-PREDICATOR-NEXT:      if (SecondMI.getOperand(0).getReg() != SecondMI.getOperand(1).getReg()) {
+// CHECK-PREDICATOR-NEXT:        if (!SecondMI.getDesc().isCommutable())
+// CHECK-PREDICATOR-NEXT:          return false;
+// CHECK-PREDICATOR-NEXT:        unsigned SrcOpIdx1 = 1, SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;
+// CHECK-PREDICATOR-NEXT:        if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))
+// CHECK-PREDICATOR-NEXT:          if (SecondMI.getOperand(0).getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())
+// CHECK-PREDICATOR-NEXT:            return false;
+// CHECK-PREDICATOR-NEXT:      }
+// CHECK-PREDICATOR-NEXT:    }
+// CHECK-PREDICATOR-NEXT:    {
+// CHECK-PREDICATOR-NEXT:      Register FirstDest = FirstMI->getOperand(0).getReg();
+// CHECK-PREDICATOR-NEXT:      if (FirstDest.isVirtual() && !MRI.hasOneNonDBGUse(FirstDest))
+// CHECK-PREDICATOR-NEXT:        return false;
+// CHECK-PREDICATOR-NEXT:    }
+// CHECK-PREDICATOR-NEXT:    if (!(FirstMI->getOperand(0).isReg() &&
+// CHECK-PREDICATOR-NEXT:          SecondMI.getOperand(1).isReg() &&
+// CHECK-PREDICATOR-NEXT:          FirstMI->getOperand(0).getReg() == SecondMI.getOperand(1).getReg())) {
+// CHECK-PREDICATOR-NEXT:      if (!SecondMI.getDesc().isCommutable())
+// CHECK-PREDICATOR-NEXT:        return false;
+// CHECK-PREDICATOR-NEXT:      unsigned SrcOpIdx1 = 1, SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;
+// CHECK-PREDICATOR-NEXT:      if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))
+// CHECK-PREDICATOR-NEXT:        if (FirstMI->getOperand(0).getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())
+// CHECK-PREDICATOR-NEXT:          return false;
+// CHECK-PREDICATOR-NEXT:    }
+// CHECK-PREDICATOR-NEXT:    return true;
+// CHECK-PREDICATOR-NEXT:  }
+// CHECK-PREDICATOR-NEXT:  bool isTestFusion(
+// CHECK-PREDICATOR-NEXT:      const TargetInstrInfo &TII,
+// CHECK-PREDICATOR-NEXT:      const TargetSubtargetInfo &STI,
+// CHECK-PREDICATOR-NEXT:      const MachineInstr *FirstMI,
+// CHECK-PREDICATOR-NEXT:      const MachineInstr &SecondMI) {
+// CHECK-PREDICATOR-NEXT:    auto &MRI = SecondMI.getMF()->getRegInfo();
 // CHECK-PREDICATOR-NEXT:    {
 // CHECK-PREDICATOR-NEXT:      const MachineInstr *MI = &SecondMI;
 // CHECK-PREDICATOR-NEXT:      if (!(
-// CHECK-PREDICATOR-NEXT:          MI->getOperand(0).getReg().isVirtual()
-// CHECK-PREDICATOR-NEXT:          || MI->getOperand(0).getReg() == MI->getOperand(1).getReg()
+// CHECK-PREDICATOR-NEXT:          ( MI->getOpcode() == Test::Inst1 )
+// CHECK-PREDICATOR-NEXT:          && MI->getOperand(0).getReg() == Test::X0
 // CHECK-PREDICATOR-NEXT:        ))
 // CHECK-PREDICATOR-NEXT:        return false;
 // CHECK-PREDICATOR-NEXT:    }
+// CHECK-PREDICATOR-NEXT:    if (!FirstMI)
+// CHECK-PREDICATOR-NEXT:      return true;
+// CHECK-PREDICATOR-NEXT:    {
+// CHECK-PREDICATOR-NEXT:      const MachineInstr *MI = FirstMI;
+// CHECK-PREDICATOR-NEXT:      if (( MI->getOpcode() != Test::Inst0 ))
+// CHECK-PREDICATOR-NEXT:        return false;
+// CHECK-PREDICATOR-NEXT:    }
+// CHECK-PREDICATOR-NEXT:    if (!SecondMI.getOperand(0).getReg().isVirtual()) {
+// CHECK-PREDICATOR-NEXT:      if (SecondMI.getOperand(0).getReg() != SecondMI.getOperand(1).getReg())
+// CHECK-PREDICATOR-NEXT:        return false;
+// CHECK-PREDICATOR-NEXT:    }
 // CHECK-PREDICATOR-NEXT:    {
 // CHECK-PREDICATOR-NEXT:      Register FirstDest = FirstMI->getOperand(0).getReg();
 // CHECK-PREDICATOR-NEXT:      if (FirstDest.isVirtual() && !MRI.hasOneNonDBGUse(FirstDest))
@@ -131,6 +185,7 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
 // CHECK-SUBTARGET:      std::vector<MacroFusionPredTy> TestGenSubtargetInfo::getMacroFusions() const {
 // CHECK-SUBTARGET-NEXT:   std::vector<MacroFusionPredTy> Fusions;
 // CHECK-SUBTARGET-NEXT:   if (hasFeature(Test::TestBothFusionPredicate)) Fusions.push_back(llvm::isTestBothFusionPredicate); 
+// CHECK-SUBTARGET-NEXT:   if (hasFeature(Test::TestCommutableFusion)) Fusions.push_back(llvm::isTestCommutableFusion); 
 // CHECK-SUBTARGET-NEXT:   if (hasFeature(Test::TestFusion)) Fusions.push_back(llvm::isTestFusion);
 // CHECK-SUBTARGET-NEXT:   return Fusions;
 // CHECK-SUBTARGET-NEXT: }

diff  --git a/llvm/utils/TableGen/MacroFusionPredicatorEmitter.cpp b/llvm/utils/TableGen/MacroFusionPredicatorEmitter.cpp
index 7f494e532b1f44..91c3b0b4359cf0 100644
--- a/llvm/utils/TableGen/MacroFusionPredicatorEmitter.cpp
+++ b/llvm/utils/TableGen/MacroFusionPredicatorEmitter.cpp
@@ -40,12 +40,10 @@
 
 #include "CodeGenTarget.h"
 #include "PredicateExpander.h"
-#include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/TableGen/Error.h"
 #include "llvm/TableGen/Record.h"
 #include "llvm/TableGen/TableGenBackend.h"
-#include <set>
 #include <vector>
 
 using namespace llvm;
@@ -61,14 +59,14 @@ class MacroFusionPredicatorEmitter {
                            raw_ostream &OS);
   void emitMacroFusionImpl(std::vector<Record *> Fusions, PredicateExpander &PE,
                            raw_ostream &OS);
-  void emitPredicates(std::vector<Record *> &FirstPredicate,
+  void emitPredicates(std::vector<Record *> &FirstPredicate, bool IsCommutable,
                       PredicateExpander &PE, raw_ostream &OS);
-  void emitFirstPredicate(Record *SecondPredicate, PredicateExpander &PE,
-                          raw_ostream &OS);
-  void emitSecondPredicate(Record *SecondPredicate, PredicateExpander &PE,
-                           raw_ostream &OS);
-  void emitBothPredicate(Record *Predicates, PredicateExpander &PE,
-                         raw_ostream &OS);
+  void emitFirstPredicate(Record *SecondPredicate, bool IsCommutable,
+                          PredicateExpander &PE, raw_ostream &OS);
+  void emitSecondPredicate(Record *SecondPredicate, bool IsCommutable,
+                           PredicateExpander &PE, raw_ostream &OS);
+  void emitBothPredicate(Record *Predicates, bool IsCommutable,
+                         PredicateExpander &PE, raw_ostream &OS);
 
 public:
   MacroFusionPredicatorEmitter(RecordKeeper &R) : Records(R), Target(R) {}
@@ -103,6 +101,7 @@ void MacroFusionPredicatorEmitter::emitMacroFusionImpl(
   for (Record *Fusion : Fusions) {
     std::vector<Record *> Predicates =
         Fusion->getValueAsListOfDefs("Predicates");
+    bool IsCommutable = Fusion->getValueAsBit("IsCommutable");
 
     OS << "bool is" << Fusion->getName() << "(\n";
     OS.indent(4) << "const TargetInstrInfo &TII,\n";
@@ -111,7 +110,7 @@ void MacroFusionPredicatorEmitter::emitMacroFusionImpl(
     OS.indent(4) << "const MachineInstr &SecondMI) {\n";
     OS.indent(2) << "auto &MRI = SecondMI.getMF()->getRegInfo();\n";
 
-    emitPredicates(Predicates, PE, OS);
+    emitPredicates(Predicates, IsCommutable, PE, OS);
 
     OS.indent(2) << "return true;\n";
     OS << "}\n";
@@ -122,15 +121,16 @@ void MacroFusionPredicatorEmitter::emitMacroFusionImpl(
 }
 
 void MacroFusionPredicatorEmitter::emitPredicates(
-    std::vector<Record *> &Predicates, PredicateExpander &PE, raw_ostream &OS) {
+    std::vector<Record *> &Predicates, bool IsCommutable, PredicateExpander &PE,
+    raw_ostream &OS) {
   for (Record *Predicate : Predicates) {
     Record *Target = Predicate->getValueAsDef("Target");
     if (Target->getName() == "first_fusion_target")
-      emitFirstPredicate(Predicate, PE, OS);
+      emitFirstPredicate(Predicate, IsCommutable, PE, OS);
     else if (Target->getName() == "second_fusion_target")
-      emitSecondPredicate(Predicate, PE, OS);
+      emitSecondPredicate(Predicate, IsCommutable, PE, OS);
     else if (Target->getName() == "both_fusion_target")
-      emitBothPredicate(Predicate, PE, OS);
+      emitBothPredicate(Predicate, IsCommutable, PE, OS);
     else
       PrintFatalError(Target->getLoc(),
                       "Unsupported 'FusionTarget': " + Target->getName());
@@ -138,6 +138,7 @@ void MacroFusionPredicatorEmitter::emitPredicates(
 }
 
 void MacroFusionPredicatorEmitter::emitFirstPredicate(Record *Predicate,
+                                                      bool IsCommutable,
                                                       PredicateExpander &PE,
                                                       raw_ostream &OS) {
   if (Predicate->isSubClassOf("WildcardPred")) {
@@ -170,6 +171,7 @@ void MacroFusionPredicatorEmitter::emitFirstPredicate(Record *Predicate,
 }
 
 void MacroFusionPredicatorEmitter::emitSecondPredicate(Record *Predicate,
+                                                       bool IsCommutable,
                                                        PredicateExpander &PE,
                                                        raw_ostream &OS) {
   if (Predicate->isSubClassOf("FusionPredicateWithMCInstPredicate")) {
@@ -182,6 +184,36 @@ void MacroFusionPredicatorEmitter::emitSecondPredicate(Record *Predicate,
     OS << ")\n";
     OS.indent(4) << "  return false;\n";
     OS.indent(2) << "}\n";
+  } else if (Predicate->isSubClassOf("SameReg")) {
+    int FirstOpIdx = Predicate->getValueAsInt("FirstOpIdx");
+    int SecondOpIdx = Predicate->getValueAsInt("SecondOpIdx");
+
+    OS.indent(2) << "if (!SecondMI.getOperand(" << FirstOpIdx
+                 << ").getReg().isVirtual()) {\n";
+    OS.indent(4) << "if (SecondMI.getOperand(" << FirstOpIdx
+                 << ").getReg() != SecondMI.getOperand(" << SecondOpIdx
+                 << ").getReg())";
+
+    if (IsCommutable) {
+      OS << " {\n";
+      OS.indent(6) << "if (!SecondMI.getDesc().isCommutable())\n";
+      OS.indent(6) << "  return false;\n";
+
+      OS.indent(6)
+          << "unsigned SrcOpIdx1 = " << SecondOpIdx
+          << ", SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;\n";
+      OS.indent(6)
+          << "if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))\n";
+      OS.indent(6)
+          << "  if (SecondMI.getOperand(" << FirstOpIdx
+          << ").getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())\n";
+      OS.indent(6) << "    return false;\n";
+      OS.indent(4) << "}\n";
+    } else {
+      OS << "\n";
+      OS.indent(4) << "  return false;\n";
+    }
+    OS.indent(2) << "}\n";
   } else {
     PrintFatalError(Predicate->getLoc(),
                     "Unsupported predicate for second instruction: " +
@@ -190,13 +222,14 @@ void MacroFusionPredicatorEmitter::emitSecondPredicate(Record *Predicate,
 }
 
 void MacroFusionPredicatorEmitter::emitBothPredicate(Record *Predicate,
+                                                     bool IsCommutable,
                                                      PredicateExpander &PE,
                                                      raw_ostream &OS) {
   if (Predicate->isSubClassOf("FusionPredicateWithCode"))
     OS << Predicate->getValueAsString("Predicate");
   else if (Predicate->isSubClassOf("BothFusionPredicateWithMCInstPredicate")) {
-    emitFirstPredicate(Predicate, PE, OS);
-    emitSecondPredicate(Predicate, PE, OS);
+    emitFirstPredicate(Predicate, IsCommutable, PE, OS);
+    emitSecondPredicate(Predicate, IsCommutable, PE, OS);
   } else if (Predicate->isSubClassOf("TieReg")) {
     int FirstOpIdx = Predicate->getValueAsInt("FirstOpIdx");
     int SecondOpIdx = Predicate->getValueAsInt("SecondOpIdx");
@@ -206,8 +239,28 @@ void MacroFusionPredicatorEmitter::emitBothPredicate(Record *Predicate,
                  << ").isReg() &&\n";
     OS.indent(2) << "      FirstMI->getOperand(" << FirstOpIdx
                  << ").getReg() == SecondMI.getOperand(" << SecondOpIdx
-                 << ").getReg()))\n";
-    OS.indent(2) << "  return false;\n";
+                 << ").getReg()))";
+
+    if (IsCommutable) {
+      OS << " {\n";
+      OS.indent(4) << "if (!SecondMI.getDesc().isCommutable())\n";
+      OS.indent(4) << "  return false;\n";
+
+      OS.indent(4)
+          << "unsigned SrcOpIdx1 = " << SecondOpIdx
+          << ", SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;\n";
+      OS.indent(4)
+          << "if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))\n";
+      OS.indent(4)
+          << "  if (FirstMI->getOperand(" << FirstOpIdx
+          << ").getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())\n";
+      OS.indent(4) << "    return false;\n";
+      OS.indent(2) << "}";
+    } else {
+      OS << "\n";
+      OS.indent(2) << "  return false;";
+    }
+    OS << "\n";
   } else
     PrintFatalError(Predicate->getLoc(),
                     "Unsupported predicate for both instruction: " +


        


More information about the llvm-commits mailing list