[llvm] 2c6e8b4 - [Matrix] Refactor tiled loops in a struct. NFC
Francis Visoiu Mistrih via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 26 11:03:04 PDT 2022
Author: Francis Visoiu Mistrih
Date: 2022-07-26T11:02:22-07:00
New Revision: 2c6e8b4636700f22a76eeda01e4a3258692b80f3
URL: https://github.com/llvm/llvm-project/commit/2c6e8b4636700f22a76eeda01e4a3258692b80f3
DIFF: https://github.com/llvm/llvm-project/commit/2c6e8b4636700f22a76eeda01e4a3258692b80f3.diff
LOG: [Matrix] Refactor tiled loops in a struct. NFC
The three loops have the same structure: index, header, latch.
Added:
Modified:
llvm/include/llvm/Transforms/Utils/MatrixUtils.h
llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
llvm/lib/Transforms/Utils/MatrixUtils.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Transforms/Utils/MatrixUtils.h b/llvm/include/llvm/Transforms/Utils/MatrixUtils.h
index 39a0d4bf40ccb..ffad57002935e 100644
--- a/llvm/include/llvm/Transforms/Utils/MatrixUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/MatrixUtils.h
@@ -25,9 +25,9 @@ class IRBuilderBase;
/// A helper struct to create IR loop nests for tiling in IR of the following
/// form:
-/// for CurrentColumn = 0..NumColumns
-/// for CurrentRow = 0..NumRows
-/// for CurrentInner = 0..NumInner
+/// for ColumnLoop.Index = 0..NumColumns
+/// for RowLoop.Index = 0..NumRows
+/// for KLoop.Index = 0..NumInner
struct TileInfo {
/// Number of rows of the matrix.
unsigned NumRows;
@@ -42,26 +42,21 @@ struct TileInfo {
/// Number of rows/columns in a tile.
unsigned TileSize = -1;
- /// Start row of the current tile to compute.
- Value *CurrentRow;
-
- /// Start column of the current tile to compute.
- Value *CurrentCol;
-
- /// Current tile offset during the tile computation.
- Value *CurrentK;
-
- /// Header of the outermost loop iterating from 0..NumColumns.
- BasicBlock *ColumnLoopHeader = nullptr;
-
- /// Header of the second loop iterating from 0..NumRows.
- BasicBlock *RowLoopHeader = nullptr;
- /// Latch of the second loop iterating from 0..NumRows.
- BasicBlock *RowLoopLatch = nullptr;
- /// Header of the innermost loop iterating from 0..NumInner.
- BasicBlock *InnerLoopHeader = nullptr;
- /// Latch of the innermost loop iterating from 0..NumInner.
- BasicBlock *InnerLoopLatch = nullptr;
+ /// Properties of a single loop used when generating the tiled loop nest.
+ struct MatrixLoop {
+ /// The index updated on every iteration.
+ Value *Index = nullptr;
+ /// The header and latch of the loop.
+ BasicBlock *Header = nullptr;
+ BasicBlock *Latch = nullptr;
+ };
+
+ /// The loop iterating on the rows.
+ MatrixLoop RowLoop;
+ /// The loop iterating on the columns.
+ MatrixLoop ColumnLoop;
+ /// The loop iterating on k (inner dimension).
+ MatrixLoop KLoop;
TileInfo(unsigned NumRows, unsigned NumColumns, unsigned NumInner,
unsigned TileSize)
@@ -72,9 +67,9 @@ struct TileInfo {
/// for the inner loop body and sets {Column,Row,Inner}LoopHeader/Latch
/// fields.
///
- /// for CurrentColumn = 0..NumColumns
- /// for CurrentRow = 0..NumRows
- /// for CurrentInner = 0..NumInner
+ /// for ColumnLoop.Index = 0..NumColumns
+ /// for RowLoop.Index = 0..NumRows
+ /// for InnerLoop.Index = 0..NumInner
BasicBlock *CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
IRBuilderBase &B, DomTreeUpdater &DTU,
LoopInfo &LI);
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index c05906649f167..73cd92d176ced 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1423,13 +1423,13 @@ class LowerMatrixIntrinsics {
FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize);
MatrixTy TileResult;
// Insert in the inner loop header.
- Builder.SetInsertPoint(TI.InnerLoopHeader->getTerminator());
+ Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());
// Create PHI nodes for the result columns to accumulate across iterations.
SmallVector<PHINode *, 4> ColumnPhis;
for (unsigned I = 0; I < TileSize; I++) {
auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));
Phi->addIncoming(ConstantAggregateZero::get(TileVecTy),
- TI.RowLoopHeader->getSingleSuccessor());
+ TI.RowLoop.Header->getSingleSuccessor());
TileResult.addVector(Phi);
ColumnPhis.push_back(Phi);
}
@@ -1438,27 +1438,29 @@ class LowerMatrixIntrinsics {
// Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
Builder.SetInsertPoint(InnerBody->getTerminator());
// Load tiles of the operands.
- MatrixTy A = loadMatrix(LPtr, {}, false, LShape, TI.CurrentRow, TI.CurrentK,
- {TileSize, TileSize}, EltType, Builder);
- MatrixTy B = loadMatrix(RPtr, {}, false, RShape, TI.CurrentK, TI.CurrentCol,
- {TileSize, TileSize}, EltType, Builder);
+ MatrixTy A =
+ loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
+ {TileSize, TileSize}, EltType, Builder);
+ MatrixTy B =
+ loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
+ {TileSize, TileSize}, EltType, Builder);
emitMatrixMultiply(TileResult, A, B, Builder, true, false,
getFastMathFlags(MatMul));
// Store result after the inner loop is done.
- Builder.SetInsertPoint(TI.RowLoopLatch->getTerminator());
+ Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());
storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),
Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
- TI.CurrentRow, TI.CurrentCol, EltType, Builder);
+ TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
- ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.InnerLoopLatch);
+ ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch);
// Force unrolling of a few iterations of the inner loop, to make sure there
// is enough work per iteration.
// FIXME: The unroller should make this decision directly instead, but
// currently the cost-model is not up to the task.
unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);
- addStringMetadataToLoop(LI->getLoopFor(TI.InnerLoopHeader),
+ addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header),
"llvm.loop.unroll.count", InnerLoopUnrollCount);
}
diff --git a/llvm/lib/Transforms/Utils/MatrixUtils.cpp b/llvm/lib/Transforms/Utils/MatrixUtils.cpp
index 6a137630deeb0..e218773cf5da1 100644
--- a/llvm/lib/Transforms/Utils/MatrixUtils.cpp
+++ b/llvm/lib/Transforms/Utils/MatrixUtils.cpp
@@ -70,35 +70,35 @@ BasicBlock *TileInfo::CreateLoop(BasicBlock *Preheader, BasicBlock *Exit,
BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
IRBuilderBase &B, DomTreeUpdater &DTU,
LoopInfo &LI) {
- Loop *ColLoop = LI.AllocateLoop();
- Loop *RowLoop = LI.AllocateLoop();
- Loop *InnerLoop = LI.AllocateLoop();
- RowLoop->addChildLoop(InnerLoop);
- ColLoop->addChildLoop(RowLoop);
+ Loop *ColumnLoopInfo = LI.AllocateLoop();
+ Loop *RowLoopInfo = LI.AllocateLoop();
+ Loop *KLoopInfo = LI.AllocateLoop();
+ RowLoopInfo->addChildLoop(KLoopInfo);
+ ColumnLoopInfo->addChildLoop(RowLoopInfo);
if (Loop *ParentL = LI.getLoopFor(Start))
- ParentL->addChildLoop(ColLoop);
+ ParentL->addChildLoop(ColumnLoopInfo);
else
- LI.addTopLevelLoop(ColLoop);
+ LI.addTopLevelLoop(ColumnLoopInfo);
BasicBlock *ColBody =
CreateLoop(Start, End, B.getInt64(NumColumns), B.getInt64(TileSize),
- "cols", B, DTU, ColLoop, LI);
- BasicBlock *ColLatch = ColBody->getSingleSuccessor();
+ "cols", B, DTU, ColumnLoopInfo, LI);
+ ColumnLoop.Latch = ColBody->getSingleSuccessor();
BasicBlock *RowBody =
- CreateLoop(ColBody, ColLatch, B.getInt64(NumRows), B.getInt64(TileSize),
- "rows", B, DTU, RowLoop, LI);
- RowLoopLatch = RowBody->getSingleSuccessor();
+ CreateLoop(ColBody, ColumnLoop.Latch, B.getInt64(NumRows),
+ B.getInt64(TileSize), "rows", B, DTU, RowLoopInfo, LI);
+ RowLoop.Latch = RowBody->getSingleSuccessor();
BasicBlock *InnerBody =
- CreateLoop(RowBody, RowLoopLatch, B.getInt64(NumInner),
- B.getInt64(TileSize), "inner", B, DTU, InnerLoop, LI);
- InnerLoopLatch = InnerBody->getSingleSuccessor();
- ColumnLoopHeader = ColBody->getSinglePredecessor();
- RowLoopHeader = RowBody->getSinglePredecessor();
- InnerLoopHeader = InnerBody->getSinglePredecessor();
- CurrentRow = &*RowLoopHeader->begin();
- CurrentCol = &*ColumnLoopHeader->begin();
- CurrentK = &*InnerLoopHeader->begin();
+ CreateLoop(RowBody, RowLoop.Latch, B.getInt64(NumInner),
+ B.getInt64(TileSize), "inner", B, DTU, KLoopInfo, LI);
+ KLoop.Latch = InnerBody->getSingleSuccessor();
+ ColumnLoop.Header = ColBody->getSinglePredecessor();
+ RowLoop.Header = RowBody->getSinglePredecessor();
+ KLoop.Header = InnerBody->getSinglePredecessor();
+ RowLoop.Index = &*RowLoop.Header->begin();
+ ColumnLoop.Index = &*ColumnLoop.Header->begin();
+ KLoop.Index = &*KLoop.Header->begin();
return InnerBody;
}
More information about the llvm-commits
mailing list