[llvm] 9011856 - [X86][AMX] Try to hoist AMX shapes' def

via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 22 21:17:11 PDT 2021


Author: Wang, Pengfei
Date: 2021-04-23T12:17:00+08:00
New Revision: 90118563ad0f133c696e070ad72761fa0daa4517

URL: https://github.com/llvm/llvm-project/commit/90118563ad0f133c696e070ad72761fa0daa4517
DIFF: https://github.com/llvm/llvm-project/commit/90118563ad0f133c696e070ad72761fa0daa4517.diff

LOG: [X86][AMX] Try to hoist AMX shapes' def

We request no intersections between AMX instructions and their shapes'
def when we insert ldtilecfg. However, this is not always ture resulting
from not only users don't follow AMX API model, but also optimizations.

This patch adds a mechanism that tries to hoist AMX shapes' def as well.
It only hoists shapes inside a BB, we can improve it for cases across
BBs in future. Currently, it only hoists shapes of which all sources' def
above the first AMX instruction. We can improve for the case that only
source that moves an immediate value to a register below AMX instruction.

Differential Revision: https://reviews.llvm.org/D101067

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86PreTileConfig.cpp
    llvm/test/CodeGen/X86/AMX/amx-sched.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86PreTileConfig.cpp b/llvm/lib/Target/X86/X86PreTileConfig.cpp
index 0ffd43fbe80ae..48d68d72efe6f 100644
--- a/llvm/lib/Target/X86/X86PreTileConfig.cpp
+++ b/llvm/lib/Target/X86/X86PreTileConfig.cpp
@@ -57,6 +57,9 @@ struct MIRef {
          ++I, ++Pos)
       MI = &*I;
   }
+  MIRef(MachineInstr *MI)
+      : MI(MI), MBB(MI->getParent()),
+        Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
   MIRef(MachineInstr *MI, MachineBasicBlock *MBB)
       : MI(MI), MBB(MBB),
         Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
@@ -66,6 +69,7 @@ struct MIRef {
   bool operator==(const MIRef &RHS) const {
     return MI == RHS.MI && MBB == RHS.MBB;
   }
+  bool operator!=(const MIRef &RHS) const { return !(*this == RHS); }
   bool operator<(const MIRef &RHS) const {
     return (!MBB && RHS.MBB) || (MBB == RHS.MBB && Pos < RHS.Pos);
   }
@@ -77,7 +81,7 @@ struct MIRef {
 struct BBInfo {
   MIRef FirstAMX;
   MIRef LastCall;
-  MIRef LastShape;
+  bool HasAMXRegLiveIn = false;
   bool TileCfgForbidden = false;
   bool NeedTileCfgLiveIn = false;
 };
@@ -86,8 +90,8 @@ class X86PreTileConfig : public MachineFunctionPass {
   MachineRegisterInfo *MRI;
   const MachineLoopInfo *MLI;
   SmallSet<MachineInstr *, 8> DefVisited;
-  SmallSet<MachineBasicBlock *, 8> ShapeBBs;
   DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo;
+  DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8>> ShapeBBs;
 
   /// Check if the callee will clobber AMX registers.
   bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) {
@@ -124,6 +128,33 @@ class X86PreTileConfig : public MachineFunctionPass {
   /// Collect the shape def information for later use.
   void collectShapeInfo(MachineInstr &MI);
 
+  /// Try to hoist shapes definded below AMX instructions.
+  bool hoistShapesInBB(MachineBasicBlock *MBB) {
+    auto FirstShapeBelowAMX =
+        llvm::lower_bound(ShapeBBs[MBB], BBVisitedInfo[MBB].FirstAMX);
+    auto InsertPoint = BBVisitedInfo[MBB].FirstAMX.MI->getIterator();
+    for (auto I = FirstShapeBelowAMX, E = ShapeBBs[MBB].end(); I != E; ++I) {
+      // Do not hoist instructions that access memory.
+      if (I->MI->mayLoadOrStore())
+        return false;
+      for (auto &MO : I->MI->operands()) {
+        if (MO.isDef())
+          continue;
+        // Do not hoist instructions if the sources' def under AMX instruction.
+        // TODO: We can handle isMoveImmediate MI here.
+        if (MO.isReg() &&
+            MIRef(MRI->getVRegDef(MO.getReg())) > BBVisitedInfo[MBB].FirstAMX)
+          return false;
+        // TODO: Maybe need more checks here.
+      }
+      MBB->insert(InsertPoint, I->MI->removeFromParent());
+    }
+    // We only need to mark the last shape in the BB now.
+    ShapeBBs[MBB].clear();
+    ShapeBBs[MBB].push_back(MIRef(&*--InsertPoint, MBB));
+    return true;
+  }
+
 public:
   X86PreTileConfig() : MachineFunctionPass(ID) {}
 
@@ -165,9 +196,9 @@ INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig",
 void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) {
   auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) {
     MIRef MIR(MI, MBB);
-    if (BBVisitedInfo[MBB].LastShape < MIR)
-      BBVisitedInfo[MBB].LastShape = MIR;
-    ShapeBBs.insert(MBB);
+    auto I = llvm::lower_bound(ShapeBBs[MBB], MIR);
+    if (*I != MIR)
+      ShapeBBs[MBB].insert(I, MIR);
   };
 
   SmallVector<Register, 8> WorkList(
@@ -229,6 +260,10 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
       else
         CfgLiveInBBs.push_back(&MBB);
     }
+    if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn)
+      for (auto *Succ : MBB.successors())
+        if (!isLoopBackEdge(Succ, &MBB))
+          BBVisitedInfo[Succ].HasAMXRegLiveIn = true;
   }
 
   // Update NeedTileCfgLiveIn for predecessors.
@@ -252,8 +287,17 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
     return false;
 
   // Avoid to insert ldtilecfg before any shape defs.
-  SmallVector<MachineBasicBlock *, 8> WorkList(
-      make_range(ShapeBBs.begin(), ShapeBBs.end()));
+  SmallVector<MachineBasicBlock *, 8> WorkList;
+  for (auto &I : ShapeBBs) {
+    // TODO: We can hoist shapes across BBs here.
+    if (BBVisitedInfo[I.first].HasAMXRegLiveIn)
+      REPORT_CONFIG_FAIL
+    if (BBVisitedInfo[I.first].FirstAMX &&
+        BBVisitedInfo[I.first].FirstAMX < ShapeBBs[I.first].back() &&
+        !hoistShapesInBB(I.first))
+      REPORT_CONFIG_FAIL
+    WorkList.push_back(I.first);
+  }
   while (!WorkList.empty()) {
     MachineBasicBlock *MBB = WorkList.pop_back_val();
     for (auto *Pred : MBB->predecessors()) {
@@ -282,9 +326,6 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
         } else {
           // Avoid the BB to be multi visited.
           VisitedOrInserted.insert(I);
-          // We cannot sink it across any AMX instruction.
-          if (BBVisitedInfo[I.MBB].FirstAMX)
-            REPORT_CONFIG_FAIL;
           // Sink the inserting point along the chain with NeedTileCfgLiveIn =
           // true when MBB isn't all shapes reachable.
           for (auto *Succ : I.MBB->successors())
@@ -296,14 +337,9 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
 
     // A given point might be forked due to shape conditions are not met.
     for (MIRef I : InsertPoints) {
-      // Even MBB is all shapes reachable, we still need to check if there's
-      // AMX that intersects with shapes in the same MBB.
-      if (BBVisitedInfo[I.MBB].FirstAMX &&
-          BBVisitedInfo[I.MBB].FirstAMX < BBVisitedInfo[I.MBB].LastShape)
-        REPORT_CONFIG_FAIL;
       // Make sure we insert ldtilecfg after the last shape def in MBB.
-      if (I < BBVisitedInfo[I.MBB].LastShape)
-        I = BBVisitedInfo[I.MBB].LastShape;
+      if (ShapeBBs.count(I.MBB) && I < ShapeBBs[I.MBB].back())
+        I = ShapeBBs[I.MBB].back();
       // There're chances the MBB is sunk more than once. Record it to avoid
       // multi insert.
       if (VisitedOrInserted.insert(I).second) {

diff  --git a/llvm/test/CodeGen/X86/AMX/amx-sched.ll b/llvm/test/CodeGen/X86/AMX/amx-sched.ll
index 7e704cf5bedff..790c6c94834cf 100644
--- a/llvm/test/CodeGen/X86/AMX/amx-sched.ll
+++ b/llvm/test/CodeGen/X86/AMX/amx-sched.ll
@@ -2,6 +2,7 @@
 
 define <256 x i32> @test_shape_sched(i16 %m, i16 %n, i16 %k, <256 x i32> %c, <256 x i32> %a, <256 x i32> %b) nounwind {
 ; Just to make sure shape def is not scheduled across ldtilecfg.
+; CHECK-LABEL: test_shape_sched:
 ; CHECK:    ldtilecfg
 ; CHECK-NOT: movw
   %c1 = bitcast <256 x i32> %c to x86_amx
@@ -12,5 +13,19 @@ define <256 x i32> @test_shape_sched(i16 %m, i16 %n, i16 %k, <256 x i32> %c, <25
   ret <256 x i32> %res
 }
 
+define <256 x i32> @test_shape_sched2(i16 %m, i16 %n, i16 %k, i8* %c, i8* %a, i8* %b) nounwind {
+; Just to make sure shape def is not scheduled across ldtilecfg.
+; CHECK-LABEL: test_shape_sched2:
+; CHECK:    ldtilecfg
+; CHECK-NOT: movw
+  %aa = lshr i16 %k, 2
+  %c1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %n, i8* %c, i64 64)
+  %a1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %k, i8* %a, i64 64)
+  %b1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %aa, i16 %n, i8* %b, i64 64)
+  %t = call x86_amx @llvm.x86.tdpbssd.internal(i16 %m, i16 %n, i16 %k, x86_amx %c1, x86_amx %a1, x86_amx %b1)
+  %res = bitcast x86_amx %t to <256 x i32>
+  ret <256 x i32> %res
+}
 
+declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
 declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)


        


More information about the llvm-commits mailing list