[llvm] 016092d - Reapply "[X86][AMX] Try to hoist AMX shapes' def"
via llvm-commits
llvm-commits at lists.llvm.org
Mon Apr 26 19:28:17 PDT 2021
Author: Wang, Pengfei
Date: 2021-04-27T10:27:59+08:00
New Revision: 016092d786f226f403fce5b5d0888dfa939b3f21
URL: https://github.com/llvm/llvm-project/commit/016092d786f226f403fce5b5d0888dfa939b3f21
DIFF: https://github.com/llvm/llvm-project/commit/016092d786f226f403fce5b5d0888dfa939b3f21.diff
LOG: Reapply "[X86][AMX] Try to hoist AMX shapes' def"
We request no intersections between AMX instructions and their shapes'
def when we insert ldtilecfg. However, this is not always ture resulting
from not only users don't follow AMX API model, but also optimizations.
This patch adds a mechanism that tries to hoist AMX shapes' def as well.
It only hoists shapes inside a BB, we can improve it for cases across
BBs in future. Currently, it only hoists shapes of which all sources' def
above the first AMX instruction. We can improve for the case that only
source that moves an immediate value to a register below AMX instruction.
Reviewed By: xiangzhangllvm
Differential Revision: https://reviews.llvm.org/D101067
Added:
Modified:
llvm/lib/Target/X86/X86PreTileConfig.cpp
llvm/test/CodeGen/X86/AMX/amx-sched.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/X86/X86PreTileConfig.cpp b/llvm/lib/Target/X86/X86PreTileConfig.cpp
index 4ec9e792cfe7e..9164dfd59cf20 100644
--- a/llvm/lib/Target/X86/X86PreTileConfig.cpp
+++ b/llvm/lib/Target/X86/X86PreTileConfig.cpp
@@ -57,6 +57,9 @@ struct MIRef {
++I, ++Pos)
MI = &*I;
}
+ MIRef(MachineInstr *MI)
+ : MI(MI), MBB(MI->getParent()),
+ Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
MIRef(MachineInstr *MI, MachineBasicBlock *MBB)
: MI(MI), MBB(MBB),
Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
@@ -66,6 +69,7 @@ struct MIRef {
bool operator==(const MIRef &RHS) const {
return MI == RHS.MI && MBB == RHS.MBB;
}
+ bool operator!=(const MIRef &RHS) const { return !(*this == RHS); }
bool operator<(const MIRef &RHS) const {
return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos);
}
@@ -77,7 +81,7 @@ struct MIRef {
struct BBInfo {
MIRef FirstAMX;
MIRef LastCall;
- MIRef LastShape;
+ bool HasAMXRegLiveIn = false;
bool TileCfgForbidden = false;
bool NeedTileCfgLiveIn = false;
};
@@ -86,8 +90,8 @@ class X86PreTileConfig : public MachineFunctionPass {
MachineRegisterInfo *MRI;
const MachineLoopInfo *MLI;
SmallSet<MachineInstr *, 8> DefVisited;
- SmallSet<MachineBasicBlock *, 8> ShapeBBs;
DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo;
+ DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8>> ShapeBBs;
/// Check if the callee will clobber AMX registers.
bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) {
@@ -124,6 +128,32 @@ class X86PreTileConfig : public MachineFunctionPass {
/// Collect the shape def information for later use.
void collectShapeInfo(MachineInstr &MI);
+ /// Try to hoist shapes definded below AMX instructions.
+ bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) {
+ MIRef &FirstAMX = BBVisitedInfo[MBB].FirstAMX;
+ auto FirstShapeBelowAMX = llvm::lower_bound(Shapes, FirstAMX);
+ auto InsertPoint = FirstAMX.MI->getIterator();
+ for (auto I = FirstShapeBelowAMX, E = Shapes.end(); I != E; ++I) {
+ // Do not hoist instructions that access memory.
+ if (I->MI->mayLoadOrStore())
+ return false;
+ for (auto &MO : I->MI->operands()) {
+ if (MO.isDef())
+ continue;
+ // Do not hoist instructions if the sources' def under AMX instruction.
+ // TODO: We can handle isMoveImmediate MI here.
+ if (MO.isReg() && MIRef(MRI->getVRegDef(MO.getReg())) > FirstAMX)
+ return false;
+ // TODO: Maybe need more checks here.
+ }
+ MBB->insert(InsertPoint, I->MI->removeFromParent());
+ }
+ // We only need to mark the last shape in the BB now.
+ Shapes.clear();
+ Shapes.push_back(MIRef(&*--InsertPoint, MBB));
+ return true;
+ }
+
public:
X86PreTileConfig() : MachineFunctionPass(ID) {}
@@ -165,9 +195,9 @@ INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig",
void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) {
auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) {
MIRef MIR(MI, MBB);
- if (BBVisitedInfo[MBB].LastShape < MIR)
- BBVisitedInfo[MBB].LastShape = MIR;
- ShapeBBs.insert(MBB);
+ auto I = llvm::lower_bound(ShapeBBs[MBB], MIR);
+ if (I == ShapeBBs[MBB].end() || *I != MIR)
+ ShapeBBs[MBB].insert(I, MIR);
};
SmallVector<Register, 8> WorkList(
@@ -229,6 +259,10 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
else
CfgLiveInBBs.push_back(&MBB);
}
+ if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn)
+ for (auto *Succ : MBB.successors())
+ if (!isLoopBackEdge(Succ, &MBB))
+ BBVisitedInfo[Succ].HasAMXRegLiveIn = true;
}
// Update NeedTileCfgLiveIn for predecessors.
@@ -252,8 +286,17 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
return false;
// Avoid to insert ldtilecfg before any shape defs.
- SmallVector<MachineBasicBlock *, 8> WorkList(
- make_range(ShapeBBs.begin(), ShapeBBs.end()));
+ SmallVector<MachineBasicBlock *, 8> WorkList;
+ for (auto &I : ShapeBBs) {
+ // TODO: We can hoist shapes across BBs here.
+ if (BBVisitedInfo[I.first].HasAMXRegLiveIn)
+ REPORT_CONFIG_FAIL
+ if (BBVisitedInfo[I.first].FirstAMX &&
+ BBVisitedInfo[I.first].FirstAMX < I.second.back() &&
+ !hoistShapesInBB(I.first, I.second))
+ REPORT_CONFIG_FAIL
+ WorkList.push_back(I.first);
+ }
while (!WorkList.empty()) {
MachineBasicBlock *MBB = WorkList.pop_back_val();
for (auto *Pred : MBB->predecessors()) {
@@ -282,9 +325,6 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
} else {
// Avoid the BB to be multi visited.
VisitedOrInserted.insert(I);
- // We cannot sink it across any AMX instruction.
- if (BBVisitedInfo[I.MBB].FirstAMX)
- REPORT_CONFIG_FAIL;
// Sink the inserting point along the chain with NeedTileCfgLiveIn =
// true when MBB isn't all shapes reachable.
for (auto *Succ : I.MBB->successors())
@@ -296,14 +336,9 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
// A given point might be forked due to shape conditions are not met.
for (MIRef I : InsertPoints) {
- // Even MBB is all shapes reachable, we still need to check if there's
- // AMX that intersects with shapes in the same MBB.
- if (BBVisitedInfo[I.MBB].FirstAMX &&
- BBVisitedInfo[I.MBB].FirstAMX < BBVisitedInfo[I.MBB].LastShape)
- REPORT_CONFIG_FAIL;
// Make sure we insert ldtilecfg after the last shape def in MBB.
- if (I < BBVisitedInfo[I.MBB].LastShape)
- I = BBVisitedInfo[I.MBB].LastShape;
+ if (ShapeBBs.count(I.MBB) && I < ShapeBBs[I.MBB].back())
+ I = ShapeBBs[I.MBB].back();
// There're chances the MBB is sunk more than once. Record it to avoid
// multi insert.
if (VisitedOrInserted.insert(I).second) {
diff --git a/llvm/test/CodeGen/X86/AMX/amx-sched.ll b/llvm/test/CodeGen/X86/AMX/amx-sched.ll
index 7e704cf5bedff..790c6c94834cf 100644
--- a/llvm/test/CodeGen/X86/AMX/amx-sched.ll
+++ b/llvm/test/CodeGen/X86/AMX/amx-sched.ll
@@ -2,6 +2,7 @@
define <256 x i32> @test_shape_sched(i16 %m, i16 %n, i16 %k, <256 x i32> %c, <256 x i32> %a, <256 x i32> %b) nounwind {
; Just to make sure shape def is not scheduled across ldtilecfg.
+; CHECK-LABEL: test_shape_sched:
; CHECK: ldtilecfg
; CHECK-NOT: movw
%c1 = bitcast <256 x i32> %c to x86_amx
@@ -12,5 +13,19 @@ define <256 x i32> @test_shape_sched(i16 %m, i16 %n, i16 %k, <256 x i32> %c, <25
ret <256 x i32> %res
}
+define <256 x i32> @test_shape_sched2(i16 %m, i16 %n, i16 %k, i8* %c, i8* %a, i8* %b) nounwind {
+; Just to make sure shape def is not scheduled across ldtilecfg.
+; CHECK-LABEL: test_shape_sched2:
+; CHECK: ldtilecfg
+; CHECK-NOT: movw
+ %aa = lshr i16 %k, 2
+ %c1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %n, i8* %c, i64 64)
+ %a1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %k, i8* %a, i64 64)
+ %b1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %aa, i16 %n, i8* %b, i64 64)
+ %t = call x86_amx @llvm.x86.tdpbssd.internal(i16 %m, i16 %n, i16 %k, x86_amx %c1, x86_amx %a1, x86_amx %b1)
+ %res = bitcast x86_amx %t to <256 x i32>
+ ret <256 x i32> %res
+}
+declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
More information about the llvm-commits
mailing list