[llvm] 0cc2d23 - [Matrix] Hoist load/store generation logic, add helpers for tiled access.
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Thu Mar 19 12:28:43 PDT 2020
Author: Florian Hahn
Date: 2020-03-19T19:28:21Z
New Revision: 0cc2d2375165b553ea1774a25b6b2b1c931dca67
URL: https://github.com/llvm/llvm-project/commit/0cc2d2375165b553ea1774a25b6b2b1c931dca67
DIFF: https://github.com/llvm/llvm-project/commit/0cc2d2375165b553ea1774a25b6b2b1c931dca67.diff
LOG: [Matrix] Hoist load/store generation logic, add helpers for tiled access.
This patch slightly generalizes the code to emit loads and stores of a
matrix and adds helpers to load/store a tile of a larger matrix.
This will be used in a follow-up patch introducing initial tiling.
Reviewers: anemet, Gerolf, hfinkel, andrew.w.kaylor, LuoYuanke
Reviewed By: anemet
Differential Revision: https://reviews.llvm.org/D75564
Added:
Modified:
llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 5efd3ffc2680..27ddb28aaa46 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -181,8 +181,8 @@ class LowerMatrixIntrinsics {
void setColumn(unsigned i, Value *V) { Columns[i] = V; }
- size_t getNumColumns() const { return Columns.size(); }
- size_t getNumRows() const {
+ unsigned getNumColumns() const { return Columns.size(); }
+ unsigned getNumRows() const {
assert(Columns.size() > 0 && "Cannot call getNumRows without columns");
return cast<VectorType>(Columns[0]->getType())->getNumElements();
}
@@ -634,10 +634,11 @@ class LowerMatrixIntrinsics {
return true;
}
- void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride,
- ShapeInfo Shape) {
- IRBuilder<> Builder(Inst);
- auto VType = cast<VectorType>(Inst->getType());
+ /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
+ /// columns.
+ ColumnMatrixTy loadMatrix(Type *Ty, Value *Ptr, Value *Stride,
+ ShapeInfo Shape, IRBuilder<> &Builder) {
+ auto VType = cast<VectorType>(Ty);
Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
ColumnMatrixTy Result;
// Distance between start of one column and the start of the next
@@ -648,10 +649,41 @@ class LowerMatrixIntrinsics {
Value *Column = createColumnLoad(GEP, VType->getElementType(), Builder);
Result.addColumn(Column);
}
+ return Result.addNumLoads(getNumOps(Result.getColumnTy()) *
+ Result.getNumColumns());
+ }
+
+ /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
+ /// starting at \p MatrixPtr[I][J].
+ ColumnMatrixTy loadMatrix(Value *MatrixPtr, ShapeInfo MatrixShape, unsigned I,
+ unsigned J, ShapeInfo ResultShape, Type *EltTy,
+ IRBuilder<> &Builder) {
+
+ Value *Offset = Builder.CreateAdd(
+ Builder.CreateMul(Builder.getInt32(J),
+ Builder.getInt32(MatrixShape.NumRows)),
+ Builder.getInt32(I));
+
+ unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
+ Value *EltPtr =
+ Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
+ Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
+ Type *TileTy =
+ VectorType::get(EltTy, ResultShape.NumRows * ResultShape.NumColumns);
+ Type *TilePtrTy = PointerType::get(TileTy, AS);
+ Value *TilePtr =
+ Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
+
+ return loadMatrix(TileTy, TilePtr, Builder.getInt32(ResultShape.NumRows),
+ ResultShape, Builder);
+ }
+ /// Lower a load instruction with shape information.
+ void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride,
+ ShapeInfo Shape) {
+ IRBuilder<> Builder(Inst);
finalizeLowering(Inst,
- Result.addNumLoads(getNumOps(Result.getColumnTy()) *
- Result.getNumColumns()),
+ loadMatrix(Inst->getType(), Ptr, Stride, Shape, Builder),
Builder);
}
@@ -665,22 +697,54 @@ class LowerMatrixIntrinsics {
{Inst->getArgOperand(2), Inst->getArgOperand(3)});
}
- void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride,
- ShapeInfo Shape) {
- IRBuilder<> Builder(Inst);
- auto VType = cast<VectorType>(Matrix->getType());
+ /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
+ /// MatrixPtr[I][J].
+ void storeMatrix(const ColumnMatrixTy &StoreVal, Value *MatrixPtr,
+ ShapeInfo MatrixShape, unsigned I, unsigned J, Type *EltTy,
+ IRBuilder<> &Builder) {
+ Value *Offset = Builder.CreateAdd(
+ Builder.CreateMul(Builder.getInt32(J),
+ Builder.getInt32(MatrixShape.NumRows)),
+ Builder.getInt32(I));
+
+ unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
+ Value *EltPtr =
+ Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
+ Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
+ Type *TileTy = VectorType::get(EltTy, StoreVal.getNumRows() *
+ StoreVal.getNumColumns());
+ Type *TilePtrTy = PointerType::get(TileTy, AS);
+ Value *TilePtr =
+ Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
+
+ storeMatrix(TileTy, StoreVal, TilePtr,
+ Builder.getInt32(StoreVal.getNumRows()), Builder);
+ }
+
+ /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
+ /// columns.
+ ColumnMatrixTy storeMatrix(Type *Ty, ColumnMatrixTy StoreVal, Value *Ptr,
+ Value *Stride, IRBuilder<> &Builder) {
+ auto VType = cast<VectorType>(Ty);
Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
- auto LM = getMatrix(Matrix, Shape, Builder);
- for (auto C : enumerate(LM.columns())) {
- Value *GEP =
- computeColumnAddr(EltPtr, Builder.getInt32(C.index()), Stride,
- Shape.NumRows, VType->getElementType(), Builder);
+ for (auto C : enumerate(StoreVal.columns())) {
+ Value *GEP = computeColumnAddr(EltPtr, Builder.getInt32(C.index()),
+ Stride, StoreVal.getNumRows(),
+ VType->getElementType(), Builder);
createColumnStore(C.value(), GEP, VType->getElementType(), Builder);
}
- Inst2ColumnMatrix[Inst] = ColumnMatrixTy().addNumStores(
- getNumOps(LM.getColumnTy()) * LM.getNumColumns());
+ return ColumnMatrixTy().addNumStores(getNumOps(StoreVal.getColumnTy()) *
+ StoreVal.getNumColumns());
+ }
- ToRemove.push_back(Inst);
+ /// Lower a store instruction with shape information.
+ void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride,
+ ShapeInfo Shape) {
+ IRBuilder<> Builder(Inst);
+ auto StoreVal = getMatrix(Matrix, Shape, Builder);
+ finalizeLowering(
+ Inst, storeMatrix(Matrix->getType(), StoreVal, Ptr, Stride, Builder),
+ Builder);
}
/// Lowers llvm.matrix.columnwise.store.
More information about the llvm-commits
mailing list