[llvm] 0cc2d23 - [Matrix] Hoist load/store generation logic, add helpers for tiled access.

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 19 12:28:43 PDT 2020


Author: Florian Hahn
Date: 2020-03-19T19:28:21Z
New Revision: 0cc2d2375165b553ea1774a25b6b2b1c931dca67

URL: https://github.com/llvm/llvm-project/commit/0cc2d2375165b553ea1774a25b6b2b1c931dca67
DIFF: https://github.com/llvm/llvm-project/commit/0cc2d2375165b553ea1774a25b6b2b1c931dca67.diff

LOG: [Matrix] Hoist load/store generation logic, add helpers for tiled access.

This patch slightly generalizes the code to emit loads and stores of a
matrix and adds helpers to load/store a tile of a larger matrix.

This will be used in a follow-up patch introducing initial tiling.

Reviewers: anemet, Gerolf, hfinkel, andrew.w.kaylor, LuoYuanke

Reviewed By: anemet

Differential Revision: https://reviews.llvm.org/D75564

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 5efd3ffc2680..27ddb28aaa46 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -181,8 +181,8 @@ class LowerMatrixIntrinsics {
 
     void setColumn(unsigned i, Value *V) { Columns[i] = V; }
 
-    size_t getNumColumns() const { return Columns.size(); }
-    size_t getNumRows() const {
+    unsigned getNumColumns() const { return Columns.size(); }
+    unsigned getNumRows() const {
       assert(Columns.size() > 0 && "Cannot call getNumRows without columns");
       return cast<VectorType>(Columns[0]->getType())->getNumElements();
     }
@@ -634,10 +634,11 @@ class LowerMatrixIntrinsics {
     return true;
   }
 
-  void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride,
-                 ShapeInfo Shape) {
-    IRBuilder<> Builder(Inst);
-    auto VType = cast<VectorType>(Inst->getType());
+  /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
+  /// columns.
+  ColumnMatrixTy loadMatrix(Type *Ty, Value *Ptr, Value *Stride,
+                            ShapeInfo Shape, IRBuilder<> &Builder) {
+    auto VType = cast<VectorType>(Ty);
     Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
     ColumnMatrixTy Result;
     // Distance between start of one column and the start of the next
@@ -648,10 +649,41 @@ class LowerMatrixIntrinsics {
       Value *Column = createColumnLoad(GEP, VType->getElementType(), Builder);
       Result.addColumn(Column);
     }
+    return Result.addNumLoads(getNumOps(Result.getColumnTy()) *
+                              Result.getNumColumns());
+  }
+
+  /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
+  /// starting at \p MatrixPtr[I][J].
+  ColumnMatrixTy loadMatrix(Value *MatrixPtr, ShapeInfo MatrixShape, unsigned I,
+                            unsigned J, ShapeInfo ResultShape, Type *EltTy,
+                            IRBuilder<> &Builder) {
+
+    Value *Offset = Builder.CreateAdd(
+        Builder.CreateMul(Builder.getInt32(J),
+                          Builder.getInt32(MatrixShape.NumRows)),
+        Builder.getInt32(I));
+
+    unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
+    Value *EltPtr =
+        Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
+    Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
+    Type *TileTy =
+        VectorType::get(EltTy, ResultShape.NumRows * ResultShape.NumColumns);
+    Type *TilePtrTy = PointerType::get(TileTy, AS);
+    Value *TilePtr =
+        Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
+
+    return loadMatrix(TileTy, TilePtr, Builder.getInt32(ResultShape.NumRows),
+                      ResultShape, Builder);
+  }
 
+  /// Lower a load instruction with shape information.
+  void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride,
+                 ShapeInfo Shape) {
+    IRBuilder<> Builder(Inst);
     finalizeLowering(Inst,
-                     Result.addNumLoads(getNumOps(Result.getColumnTy()) *
-                                        Result.getNumColumns()),
+                     loadMatrix(Inst->getType(), Ptr, Stride, Shape, Builder),
                      Builder);
   }
 
@@ -665,22 +697,54 @@ class LowerMatrixIntrinsics {
               {Inst->getArgOperand(2), Inst->getArgOperand(3)});
   }
 
-  void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride,
-                  ShapeInfo Shape) {
-    IRBuilder<> Builder(Inst);
-    auto VType = cast<VectorType>(Matrix->getType());
+  /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
+  /// MatrixPtr[I][J].
+  void storeMatrix(const ColumnMatrixTy &StoreVal, Value *MatrixPtr,
+                   ShapeInfo MatrixShape, unsigned I, unsigned J, Type *EltTy,
+                   IRBuilder<> &Builder) {
+    Value *Offset = Builder.CreateAdd(
+        Builder.CreateMul(Builder.getInt32(J),
+                          Builder.getInt32(MatrixShape.NumRows)),
+        Builder.getInt32(I));
+
+    unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
+    Value *EltPtr =
+        Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
+    Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
+    Type *TileTy = VectorType::get(EltTy, StoreVal.getNumRows() *
+                                              StoreVal.getNumColumns());
+    Type *TilePtrTy = PointerType::get(TileTy, AS);
+    Value *TilePtr =
+        Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
+
+    storeMatrix(TileTy, StoreVal, TilePtr,
+                Builder.getInt32(StoreVal.getNumRows()), Builder);
+  }
+
+  /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
+  /// columns.
+  ColumnMatrixTy storeMatrix(Type *Ty, ColumnMatrixTy StoreVal, Value *Ptr,
+                             Value *Stride, IRBuilder<> &Builder) {
+    auto VType = cast<VectorType>(Ty);
     Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
-    auto LM = getMatrix(Matrix, Shape, Builder);
-    for (auto C : enumerate(LM.columns())) {
-      Value *GEP =
-          computeColumnAddr(EltPtr, Builder.getInt32(C.index()), Stride,
-                            Shape.NumRows, VType->getElementType(), Builder);
+    for (auto C : enumerate(StoreVal.columns())) {
+      Value *GEP = computeColumnAddr(EltPtr, Builder.getInt32(C.index()),
+                                     Stride, StoreVal.getNumRows(),
+                                     VType->getElementType(), Builder);
       createColumnStore(C.value(), GEP, VType->getElementType(), Builder);
     }
-    Inst2ColumnMatrix[Inst] = ColumnMatrixTy().addNumStores(
-        getNumOps(LM.getColumnTy()) * LM.getNumColumns());
+    return ColumnMatrixTy().addNumStores(getNumOps(StoreVal.getColumnTy()) *
+                                         StoreVal.getNumColumns());
+  }
 
-    ToRemove.push_back(Inst);
+  /// Lower a store instruction with shape information.
+  void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride,
+                  ShapeInfo Shape) {
+    IRBuilder<> Builder(Inst);
+    auto StoreVal = getMatrix(Matrix, Shape, Builder);
+    finalizeLowering(
+        Inst, storeMatrix(Matrix->getType(), StoreVal, Ptr, Stride, Builder),
+        Builder);
   }
 
   /// Lowers llvm.matrix.columnwise.store.


        


More information about the llvm-commits mailing list