[llvm] 2c6e8b4 - [Matrix] Refactor tiled loops in a struct. NFC

Francis Visoiu Mistrih via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 26 11:03:04 PDT 2022


Author: Francis Visoiu Mistrih
Date: 2022-07-26T11:02:22-07:00
New Revision: 2c6e8b4636700f22a76eeda01e4a3258692b80f3

URL: https://github.com/llvm/llvm-project/commit/2c6e8b4636700f22a76eeda01e4a3258692b80f3
DIFF: https://github.com/llvm/llvm-project/commit/2c6e8b4636700f22a76eeda01e4a3258692b80f3.diff

LOG: [Matrix] Refactor tiled loops in a struct. NFC

The three loops have the same structure: index, header, latch.

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/Utils/MatrixUtils.h
    llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
    llvm/lib/Transforms/Utils/MatrixUtils.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Utils/MatrixUtils.h b/llvm/include/llvm/Transforms/Utils/MatrixUtils.h
index 39a0d4bf40ccb..ffad57002935e 100644
--- a/llvm/include/llvm/Transforms/Utils/MatrixUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/MatrixUtils.h
@@ -25,9 +25,9 @@ class IRBuilderBase;
 
 /// A helper struct to create IR loop nests for tiling in IR of the following
 /// form:
-///   for CurrentColumn = 0..NumColumns
-///     for CurrentRow = 0..NumRows
-///       for CurrentInner = 0..NumInner
+///   for ColumnLoop.Index = 0..NumColumns
+///     for RowLoop.Index = 0..NumRows
+///       for KLoop.Index = 0..NumInner
 struct TileInfo {
   /// Number of rows of the matrix.
   unsigned NumRows;
@@ -42,26 +42,21 @@ struct TileInfo {
   /// Number of rows/columns in a tile.
   unsigned TileSize = -1;
 
-  /// Start row of the current tile to compute.
-  Value *CurrentRow;
-
-  /// Start column of the current tile to compute.
-  Value *CurrentCol;
-
-  /// Current tile offset during the tile computation.
-  Value *CurrentK;
-
-  /// Header of the outermost loop iterating from 0..NumColumns.
-  BasicBlock *ColumnLoopHeader = nullptr;
-
-  /// Header of the second loop iterating from 0..NumRows.
-  BasicBlock *RowLoopHeader = nullptr;
-  /// Latch of the second loop iterating from 0..NumRows.
-  BasicBlock *RowLoopLatch = nullptr;
-  /// Header of the innermost loop iterating from 0..NumInner.
-  BasicBlock *InnerLoopHeader = nullptr;
-  /// Latch of the innermost loop iterating from 0..NumInner.
-  BasicBlock *InnerLoopLatch = nullptr;
+  /// Properties of a single loop used when generating the tiled loop nest.
+  struct MatrixLoop {
+    /// The index updated on every iteration.
+    Value *Index = nullptr;
+    /// The header and latch of the loop.
+    BasicBlock *Header = nullptr;
+    BasicBlock *Latch = nullptr;
+  };
+
+  /// The loop iterating on the rows.
+  MatrixLoop RowLoop;
+  /// The loop iterating on the columns.
+  MatrixLoop ColumnLoop;
+  /// The loop iterating on k (inner dimension).
+  MatrixLoop KLoop;
 
   TileInfo(unsigned NumRows, unsigned NumColumns, unsigned NumInner,
            unsigned TileSize)
@@ -72,9 +67,9 @@ struct TileInfo {
   /// for the inner loop body and sets {Column,Row,Inner}LoopHeader/Latch
   /// fields.
   ///
-  /// for CurrentColumn = 0..NumColumns
-  ///   for CurrentRow = 0..NumRows
-  ///     for CurrentInner = 0..NumInner
+  /// for ColumnLoop.Index = 0..NumColumns
+  ///   for RowLoop.Index = 0..NumRows
+  ///     for InnerLoop.Index = 0..NumInner
   BasicBlock *CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
                                IRBuilderBase &B, DomTreeUpdater &DTU,
                                LoopInfo &LI);

diff  --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index c05906649f167..73cd92d176ced 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1423,13 +1423,13 @@ class LowerMatrixIntrinsics {
         FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize);
     MatrixTy TileResult;
     // Insert in the inner loop header.
-    Builder.SetInsertPoint(TI.InnerLoopHeader->getTerminator());
+    Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());
     // Create PHI nodes for the result columns to accumulate across iterations.
     SmallVector<PHINode *, 4> ColumnPhis;
     for (unsigned I = 0; I < TileSize; I++) {
       auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));
       Phi->addIncoming(ConstantAggregateZero::get(TileVecTy),
-                       TI.RowLoopHeader->getSingleSuccessor());
+                       TI.RowLoop.Header->getSingleSuccessor());
       TileResult.addVector(Phi);
       ColumnPhis.push_back(Phi);
     }
@@ -1438,27 +1438,29 @@ class LowerMatrixIntrinsics {
     //   Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
     Builder.SetInsertPoint(InnerBody->getTerminator());
     // Load tiles of the operands.
-    MatrixTy A = loadMatrix(LPtr, {}, false, LShape, TI.CurrentRow, TI.CurrentK,
-                            {TileSize, TileSize}, EltType, Builder);
-    MatrixTy B = loadMatrix(RPtr, {}, false, RShape, TI.CurrentK, TI.CurrentCol,
-                            {TileSize, TileSize}, EltType, Builder);
+    MatrixTy A =
+        loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
+                   {TileSize, TileSize}, EltType, Builder);
+    MatrixTy B =
+        loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
+                   {TileSize, TileSize}, EltType, Builder);
     emitMatrixMultiply(TileResult, A, B, Builder, true, false,
                        getFastMathFlags(MatMul));
     // Store result after the inner loop is done.
-    Builder.SetInsertPoint(TI.RowLoopLatch->getTerminator());
+    Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());
     storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),
                 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
-                TI.CurrentRow, TI.CurrentCol, EltType, Builder);
+                TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
 
     for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
-      ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.InnerLoopLatch);
+      ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch);
 
     // Force unrolling of a few iterations of the inner loop, to make sure there
     // is enough work per iteration.
     // FIXME: The unroller should make this decision directly instead, but
     // currently the cost-model is not up to the task.
     unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);
-    addStringMetadataToLoop(LI->getLoopFor(TI.InnerLoopHeader),
+    addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header),
                             "llvm.loop.unroll.count", InnerLoopUnrollCount);
   }
 

diff  --git a/llvm/lib/Transforms/Utils/MatrixUtils.cpp b/llvm/lib/Transforms/Utils/MatrixUtils.cpp
index 6a137630deeb0..e218773cf5da1 100644
--- a/llvm/lib/Transforms/Utils/MatrixUtils.cpp
+++ b/llvm/lib/Transforms/Utils/MatrixUtils.cpp
@@ -70,35 +70,35 @@ BasicBlock *TileInfo::CreateLoop(BasicBlock *Preheader, BasicBlock *Exit,
 BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
                                        IRBuilderBase &B, DomTreeUpdater &DTU,
                                        LoopInfo &LI) {
-  Loop *ColLoop = LI.AllocateLoop();
-  Loop *RowLoop = LI.AllocateLoop();
-  Loop *InnerLoop = LI.AllocateLoop();
-  RowLoop->addChildLoop(InnerLoop);
-  ColLoop->addChildLoop(RowLoop);
+  Loop *ColumnLoopInfo = LI.AllocateLoop();
+  Loop *RowLoopInfo = LI.AllocateLoop();
+  Loop *KLoopInfo = LI.AllocateLoop();
+  RowLoopInfo->addChildLoop(KLoopInfo);
+  ColumnLoopInfo->addChildLoop(RowLoopInfo);
   if (Loop *ParentL = LI.getLoopFor(Start))
-    ParentL->addChildLoop(ColLoop);
+    ParentL->addChildLoop(ColumnLoopInfo);
   else
-    LI.addTopLevelLoop(ColLoop);
+    LI.addTopLevelLoop(ColumnLoopInfo);
 
   BasicBlock *ColBody =
       CreateLoop(Start, End, B.getInt64(NumColumns), B.getInt64(TileSize),
-                 "cols", B, DTU, ColLoop, LI);
-  BasicBlock *ColLatch = ColBody->getSingleSuccessor();
+                 "cols", B, DTU, ColumnLoopInfo, LI);
+  ColumnLoop.Latch = ColBody->getSingleSuccessor();
   BasicBlock *RowBody =
-      CreateLoop(ColBody, ColLatch, B.getInt64(NumRows), B.getInt64(TileSize),
-                 "rows", B, DTU, RowLoop, LI);
-  RowLoopLatch = RowBody->getSingleSuccessor();
+      CreateLoop(ColBody, ColumnLoop.Latch, B.getInt64(NumRows),
+                 B.getInt64(TileSize), "rows", B, DTU, RowLoopInfo, LI);
+  RowLoop.Latch = RowBody->getSingleSuccessor();
 
   BasicBlock *InnerBody =
-      CreateLoop(RowBody, RowLoopLatch, B.getInt64(NumInner),
-                 B.getInt64(TileSize), "inner", B, DTU, InnerLoop, LI);
-  InnerLoopLatch = InnerBody->getSingleSuccessor();
-  ColumnLoopHeader = ColBody->getSinglePredecessor();
-  RowLoopHeader = RowBody->getSinglePredecessor();
-  InnerLoopHeader = InnerBody->getSinglePredecessor();
-  CurrentRow = &*RowLoopHeader->begin();
-  CurrentCol = &*ColumnLoopHeader->begin();
-  CurrentK = &*InnerLoopHeader->begin();
+      CreateLoop(RowBody, RowLoop.Latch, B.getInt64(NumInner),
+                 B.getInt64(TileSize), "inner", B, DTU, KLoopInfo, LI);
+  KLoop.Latch = InnerBody->getSingleSuccessor();
+  ColumnLoop.Header = ColBody->getSinglePredecessor();
+  RowLoop.Header = RowBody->getSinglePredecessor();
+  KLoop.Header = InnerBody->getSinglePredecessor();
+  RowLoop.Index = &*RowLoop.Header->begin();
+  ColumnLoop.Index = &*ColumnLoop.Header->begin();
+  KLoop.Index = &*KLoop.Header->begin();
 
   return InnerBody;
 }


        


More information about the llvm-commits mailing list