[llvm] caea37b - Revert "[X86][AMX] Try to hoist AMX shapes' def"

Mitch Phillips via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 23 10:43:06 PDT 2021


Author: Mitch Phillips
Date: 2021-04-23T10:42:26-07:00
New Revision: caea37b37e6aa8b0c1bb21526ad2d216b46a4b10

URL: https://github.com/llvm/llvm-project/commit/caea37b37e6aa8b0c1bb21526ad2d216b46a4b10
DIFF: https://github.com/llvm/llvm-project/commit/caea37b37e6aa8b0c1bb21526ad2d216b46a4b10.diff

LOG: Revert "[X86][AMX] Try to hoist AMX shapes' def"

This reverts commit 90118563ad0f133c696e070ad72761fa0daa4517.

Reason: Broke the MSan buildbots.
https://lab.llvm.org/buildbot/#/builders/5/builds/6967/steps/9/logs/stdio

More details can be found in the original phabricator review:
https://reviews.llvm.org/D101067

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86PreTileConfig.cpp
    llvm/test/CodeGen/X86/AMX/amx-sched.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86PreTileConfig.cpp b/llvm/lib/Target/X86/X86PreTileConfig.cpp
index a5057d093f6ce..4ec9e792cfe7e 100644
--- a/llvm/lib/Target/X86/X86PreTileConfig.cpp
+++ b/llvm/lib/Target/X86/X86PreTileConfig.cpp
@@ -57,9 +57,6 @@ struct MIRef {
          ++I, ++Pos)
       MI = &*I;
   }
-  MIRef(MachineInstr *MI)
-      : MI(MI), MBB(MI->getParent()),
-        Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
   MIRef(MachineInstr *MI, MachineBasicBlock *MBB)
       : MI(MI), MBB(MBB),
         Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
@@ -69,7 +66,6 @@ struct MIRef {
   bool operator==(const MIRef &RHS) const {
     return MI == RHS.MI && MBB == RHS.MBB;
   }
-  bool operator!=(const MIRef &RHS) const { return !(*this == RHS); }
   bool operator<(const MIRef &RHS) const {
     return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos);
   }
@@ -81,7 +77,7 @@ struct MIRef {
 struct BBInfo {
   MIRef FirstAMX;
   MIRef LastCall;
-  bool HasAMXRegLiveIn = false;
+  MIRef LastShape;
   bool TileCfgForbidden = false;
   bool NeedTileCfgLiveIn = false;
 };
@@ -90,8 +86,8 @@ class X86PreTileConfig : public MachineFunctionPass {
   MachineRegisterInfo *MRI;
   const MachineLoopInfo *MLI;
   SmallSet<MachineInstr *, 8> DefVisited;
+  SmallSet<MachineBasicBlock *, 8> ShapeBBs;
   DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo;
-  DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8>> ShapeBBs;
 
   /// Check if the callee will clobber AMX registers.
   bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) {
@@ -128,33 +124,6 @@ class X86PreTileConfig : public MachineFunctionPass {
   /// Collect the shape def information for later use.
   void collectShapeInfo(MachineInstr &MI);
 
-  /// Try to hoist shapes definded below AMX instructions.
-  bool hoistShapesInBB(MachineBasicBlock *MBB) {
-    auto FirstShapeBelowAMX =
-        llvm::lower_bound(ShapeBBs[MBB], BBVisitedInfo[MBB].FirstAMX);
-    auto InsertPoint = BBVisitedInfo[MBB].FirstAMX.MI->getIterator();
-    for (auto I = FirstShapeBelowAMX, E = ShapeBBs[MBB].end(); I != E; ++I) {
-      // Do not hoist instructions that access memory.
-      if (I->MI->mayLoadOrStore())
-        return false;
-      for (auto &MO : I->MI->operands()) {
-        if (MO.isDef())
-          continue;
-        // Do not hoist instructions if the sources' def under AMX instruction.
-        // TODO: We can handle isMoveImmediate MI here.
-        if (MO.isReg() &&
-            MIRef(MRI->getVRegDef(MO.getReg())) > BBVisitedInfo[MBB].FirstAMX)
-          return false;
-        // TODO: Maybe need more checks here.
-      }
-      MBB->insert(InsertPoint, I->MI->removeFromParent());
-    }
-    // We only need to mark the last shape in the BB now.
-    ShapeBBs[MBB].clear();
-    ShapeBBs[MBB].push_back(MIRef(&*--InsertPoint, MBB));
-    return true;
-  }
-
 public:
   X86PreTileConfig() : MachineFunctionPass(ID) {}
 
@@ -196,9 +165,9 @@ INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig",
 void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) {
   auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) {
     MIRef MIR(MI, MBB);
-    auto I = llvm::lower_bound(ShapeBBs[MBB], MIR);
-    if (*I != MIR)
-      ShapeBBs[MBB].insert(I, MIR);
+    if (BBVisitedInfo[MBB].LastShape < MIR)
+      BBVisitedInfo[MBB].LastShape = MIR;
+    ShapeBBs.insert(MBB);
   };
 
   SmallVector<Register, 8> WorkList(
@@ -260,10 +229,6 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
       else
         CfgLiveInBBs.push_back(&MBB);
     }
-    if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn)
-      for (auto *Succ : MBB.successors())
-        if (!isLoopBackEdge(Succ, &MBB))
-          BBVisitedInfo[Succ].HasAMXRegLiveIn = true;
   }
 
   // Update NeedTileCfgLiveIn for predecessors.
@@ -287,17 +252,8 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
     return false;
 
   // Avoid to insert ldtilecfg before any shape defs.
-  SmallVector<MachineBasicBlock *, 8> WorkList;
-  for (auto &I : ShapeBBs) {
-    // TODO: We can hoist shapes across BBs here.
-    if (BBVisitedInfo[I.first].HasAMXRegLiveIn)
-      REPORT_CONFIG_FAIL
-    if (BBVisitedInfo[I.first].FirstAMX &&
-        BBVisitedInfo[I.first].FirstAMX < ShapeBBs[I.first].back() &&
-        !hoistShapesInBB(I.first))
-      REPORT_CONFIG_FAIL
-    WorkList.push_back(I.first);
-  }
+  SmallVector<MachineBasicBlock *, 8> WorkList(
+      make_range(ShapeBBs.begin(), ShapeBBs.end()));
   while (!WorkList.empty()) {
     MachineBasicBlock *MBB = WorkList.pop_back_val();
     for (auto *Pred : MBB->predecessors()) {
@@ -326,6 +282,9 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
         } else {
           // Avoid the BB to be multi visited.
           VisitedOrInserted.insert(I);
+          // We cannot sink it across any AMX instruction.
+          if (BBVisitedInfo[I.MBB].FirstAMX)
+            REPORT_CONFIG_FAIL;
           // Sink the inserting point along the chain with NeedTileCfgLiveIn =
           // true when MBB isn't all shapes reachable.
           for (auto *Succ : I.MBB->successors())
@@ -337,9 +296,14 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
 
     // A given point might be forked due to shape conditions are not met.
     for (MIRef I : InsertPoints) {
+      // Even MBB is all shapes reachable, we still need to check if there's
+      // AMX that intersects with shapes in the same MBB.
+      if (BBVisitedInfo[I.MBB].FirstAMX &&
+          BBVisitedInfo[I.MBB].FirstAMX < BBVisitedInfo[I.MBB].LastShape)
+        REPORT_CONFIG_FAIL;
       // Make sure we insert ldtilecfg after the last shape def in MBB.
-      if (ShapeBBs.count(I.MBB) && I < ShapeBBs[I.MBB].back())
-        I = ShapeBBs[I.MBB].back();
+      if (I < BBVisitedInfo[I.MBB].LastShape)
+        I = BBVisitedInfo[I.MBB].LastShape;
       // There're chances the MBB is sunk more than once. Record it to avoid
       // multi insert.
       if (VisitedOrInserted.insert(I).second) {

diff  --git a/llvm/test/CodeGen/X86/AMX/amx-sched.ll b/llvm/test/CodeGen/X86/AMX/amx-sched.ll
index 790c6c94834cf..7e704cf5bedff 100644
--- a/llvm/test/CodeGen/X86/AMX/amx-sched.ll
+++ b/llvm/test/CodeGen/X86/AMX/amx-sched.ll
@@ -2,7 +2,6 @@
 
 define <256 x i32> @test_shape_sched(i16 %m, i16 %n, i16 %k, <256 x i32> %c, <256 x i32> %a, <256 x i32> %b) nounwind {
 ; Just to make sure shape def is not scheduled across ldtilecfg.
-; CHECK-LABEL: test_shape_sched:
 ; CHECK:    ldtilecfg
 ; CHECK-NOT: movw
   %c1 = bitcast <256 x i32> %c to x86_amx
@@ -13,19 +12,5 @@ define <256 x i32> @test_shape_sched(i16 %m, i16 %n, i16 %k, <256 x i32> %c, <25
   ret <256 x i32> %res
 }
 
-define <256 x i32> @test_shape_sched2(i16 %m, i16 %n, i16 %k, i8* %c, i8* %a, i8* %b) nounwind {
-; Just to make sure shape def is not scheduled across ldtilecfg.
-; CHECK-LABEL: test_shape_sched2:
-; CHECK:    ldtilecfg
-; CHECK-NOT: movw
-  %aa = lshr i16 %k, 2
-  %c1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %n, i8* %c, i64 64)
-  %a1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %k, i8* %a, i64 64)
-  %b1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %aa, i16 %n, i8* %b, i64 64)
-  %t = call x86_amx @llvm.x86.tdpbssd.internal(i16 %m, i16 %n, i16 %k, x86_amx %c1, x86_amx %a1, x86_amx %b1)
-  %res = bitcast x86_amx %t to <256 x i32>
-  ret <256 x i32> %res
-}
 
-declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
 declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)


        


More information about the llvm-commits mailing list