[llvm] be86bc7 - [Matrix] Generalize ColumnMatrixTy to MatrixTy (NFC).

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 20 01:33:27 PDT 2020


Author: Florian Hahn
Date: 2020-03-20T08:32:13Z
New Revision: be86bc76f0c21c024ed15704f266eb3595088b02

URL: https://github.com/llvm/llvm-project/commit/be86bc76f0c21c024ed15704f266eb3595088b02
DIFF: https://github.com/llvm/llvm-project/commit/be86bc76f0c21c024ed15704f266eb3595088b02.diff

LOG: [Matrix] Generalize ColumnMatrixTy to MatrixTy (NFC).

This patch sets the stage for supporting both row and column major
layouts for matrixes. It renames ColumnMatrixTy to MatrixTy, adds
booleans indicating the underlying layout to both MatrixTy and ShapeInfo
and generalizes the methods of MatrixTy to support both row and column
major layouts.

Reviewers: Gerolf, anemet, andrew.w.kaylor, LuoYuanke

Reviewed By: anemet

Differential Revision: https://reviews.llvm.org/D76324

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 3b66aa9c6e41..671da3f53993 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -165,66 +165,85 @@ class LowerMatrixIntrinsics {
     }
   };
 
-  /// Wrapper class representing a matrix as a set of column vectors.
-  /// All column vectors must have the same vector type.
-  class ColumnMatrixTy {
-    SmallVector<Value *, 16> Columns;
+  /// Wrapper class representing a matrix as a set of vectors, either in row or
+  /// column major layout. All vectors must have the same vector type.
+  class MatrixTy {
+    SmallVector<Value *, 16> Vectors;
 
     OpInfoTy OpInfo;
 
+    bool IsColumnMajor = true;
+
   public:
-    ColumnMatrixTy() : Columns() {}
-    ColumnMatrixTy(ArrayRef<Value *> Cols)
-        : Columns(Cols.begin(), Cols.end()) {}
+    MatrixTy() : Vectors() {}
+    MatrixTy(ArrayRef<Value *> Vectors)
+        : Vectors(Vectors.begin(), Vectors.end()) {}
+
+    Value *getVector(unsigned i) const { return Vectors[i]; }
+    Value *getColumn(unsigned i) const {
+      assert(isColumnMajor() && "only supported for column-major matrixes");
+      return Vectors[i];
+    }
 
-    Value *getColumn(unsigned i) const { return Columns[i]; }
+    void setColumn(unsigned i, Value *V) { Vectors[i] = V; }
 
-    void setColumn(unsigned i, Value *V) { Columns[i] = V; }
+    Type *getElementType() { return getVectorTy()->getElementType(); }
 
-    Type *getElementType() {
-      return cast<VectorType>(Columns[0]->getType())->getElementType();
+    unsigned getNumColumns() const {
+      if (isColumnMajor())
+        return Vectors.size();
+      else {
+        assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
+        return cast<VectorType>(Vectors[0]->getType())->getNumElements();
+      }
     }
-
-    unsigned getNumColumns() const { return Columns.size(); }
     unsigned getNumRows() const {
-      assert(Columns.size() > 0 && "Cannot call getNumRows without columns");
-      return cast<VectorType>(Columns[0]->getType())->getNumElements();
+      if (isColumnMajor()) {
+        assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
+        return cast<VectorType>(Vectors[0]->getType())->getNumElements();
+      } else
+        return Vectors.size();
     }
 
-    const SmallVectorImpl<Value *> &getColumnVectors() const { return Columns; }
+    const SmallVectorImpl<Value *> &getColumnVectors() const { return Vectors; }
 
-    SmallVectorImpl<Value *> &getColumnVectors() { return Columns; }
+    SmallVectorImpl<Value *> &getColumnVectors() { return Vectors; }
 
-    void addColumn(Value *V) { Columns.push_back(V); }
+    void addColumn(Value *V) { Vectors.push_back(V); }
 
     VectorType *getColumnTy() {
-      return cast<VectorType>(Columns[0]->getType());
+      assert(isColumnMajor() && "only supported for column-major matrixes");
+      return getVectorTy();
+    }
+
+    VectorType *getVectorTy() {
+      return cast<VectorType>(Vectors[0]->getType());
     }
 
     iterator_range<SmallVector<Value *, 8>::iterator> columns() {
-      return make_range(Columns.begin(), Columns.end());
+      return make_range(Vectors.begin(), Vectors.end());
     }
 
     /// Embed the columns of the matrix into a flat vector by concatenating
     /// them.
     Value *embedInVector(IRBuilder<> &Builder) const {
-      return Columns.size() == 1 ? Columns[0]
-                                 : concatenateVectors(Builder, Columns);
+      return Vectors.size() == 1 ? Vectors[0]
+                                 : concatenateVectors(Builder, Vectors);
     }
 
-    ColumnMatrixTy &addNumLoads(unsigned N) {
+    MatrixTy &addNumLoads(unsigned N) {
       OpInfo.NumLoads += N;
       return *this;
     }
 
     void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
 
-    ColumnMatrixTy &addNumStores(unsigned N) {
+    MatrixTy &addNumStores(unsigned N) {
       OpInfo.NumStores += N;
       return *this;
     }
 
-    ColumnMatrixTy &addNumComputeOps(unsigned N) {
+    MatrixTy &addNumComputeOps(unsigned N) {
       OpInfo.NumComputeOps += N;
       return *this;
     }
@@ -234,6 +253,8 @@ class LowerMatrixIntrinsics {
     unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
 
     const OpInfoTy &getOpInfo() const { return OpInfo; }
+
+    bool isColumnMajor() const { return IsColumnMajor; }
   };
 
   struct ShapeInfo {
@@ -274,7 +295,7 @@ class LowerMatrixIntrinsics {
   SmallVector<Instruction *, 16> ToRemove;
 
   /// Map from instructions to their produced column matrix.
-  MapVector<Value *, ColumnMatrixTy> Inst2ColumnMatrix;
+  MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
 
 public:
   LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
@@ -300,8 +321,8 @@ class LowerMatrixIntrinsics {
   /// If we lowered \p MatrixVal, just return the cache result column matrix.
   /// Otherwie split the flat vector \p MatrixVal containing a matrix with
   /// shape \p SI into column vectors.
-  ColumnMatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
-                           IRBuilder<> &Builder) {
+  MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
+                     IRBuilder<> &Builder) {
     VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
     assert(VType && "MatrixVal must be a vector type");
     assert(VType->getNumElements() == SI.NumRows * SI.NumColumns &&
@@ -313,7 +334,7 @@ class LowerMatrixIntrinsics {
     // vector and split it later.
     auto Found = Inst2ColumnMatrix.find(MatrixVal);
     if (Found != Inst2ColumnMatrix.end()) {
-      ColumnMatrixTy &M = Found->second;
+      MatrixTy &M = Found->second;
       // Return the found matrix, if its shape matches the requested shape
       // information
       if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
@@ -640,11 +661,11 @@ class LowerMatrixIntrinsics {
 
   /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
   /// columns.
-  ColumnMatrixTy loadMatrix(Type *Ty, Value *Ptr, Value *Stride,
-                            ShapeInfo Shape, IRBuilder<> &Builder) {
+  MatrixTy loadMatrix(Type *Ty, Value *Ptr, Value *Stride, ShapeInfo Shape,
+                      IRBuilder<> &Builder) {
     auto VType = cast<VectorType>(Ty);
     Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
-    ColumnMatrixTy Result;
+    MatrixTy Result;
     // Distance between start of one column and the start of the next
     for (unsigned C = 0, E = Shape.NumColumns; C < E; ++C) {
       Value *GEP =
@@ -659,9 +680,9 @@ class LowerMatrixIntrinsics {
 
   /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
   /// starting at \p MatrixPtr[I][J].
-  ColumnMatrixTy loadMatrix(Value *MatrixPtr, ShapeInfo MatrixShape, unsigned I,
-                            unsigned J, ShapeInfo ResultShape, Type *EltTy,
-                            IRBuilder<> &Builder) {
+  MatrixTy loadMatrix(Value *MatrixPtr, ShapeInfo MatrixShape, unsigned I,
+                      unsigned J, ShapeInfo ResultShape, Type *EltTy,
+                      IRBuilder<> &Builder) {
 
     Value *Offset = Builder.CreateAdd(
         Builder.CreateMul(Builder.getInt32(J),
@@ -703,7 +724,7 @@ class LowerMatrixIntrinsics {
 
   /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
   /// MatrixPtr[I][J].
-  void storeMatrix(const ColumnMatrixTy &StoreVal, Value *MatrixPtr,
+  void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
                    ShapeInfo MatrixShape, unsigned I, unsigned J, Type *EltTy,
                    IRBuilder<> &Builder) {
     Value *Offset = Builder.CreateAdd(
@@ -727,8 +748,8 @@ class LowerMatrixIntrinsics {
 
   /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
   /// columns.
-  ColumnMatrixTy storeMatrix(Type *Ty, ColumnMatrixTy StoreVal, Value *Ptr,
-                             Value *Stride, IRBuilder<> &Builder) {
+  MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, Value *Stride,
+                       IRBuilder<> &Builder) {
     auto VType = cast<VectorType>(Ty);
     Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
     for (auto C : enumerate(StoreVal.columns())) {
@@ -737,8 +758,8 @@ class LowerMatrixIntrinsics {
                                      VType->getElementType(), Builder);
       createColumnStore(C.value(), GEP, VType->getElementType(), Builder);
     }
-    return ColumnMatrixTy().addNumStores(getNumOps(StoreVal.getColumnTy()) *
-                                         StoreVal.getNumColumns());
+    return MatrixTy().addNumStores(getNumOps(StoreVal.getColumnTy()) *
+                                   StoreVal.getNumColumns());
   }
 
   /// Lower a store instruction with shape information.
@@ -764,7 +785,7 @@ class LowerMatrixIntrinsics {
 
   /// Extract a column vector of \p NumElts starting at index (\p I, \p J) from
   /// the matrix \p LM represented as a vector of column vectors.
-  Value *extractVector(const ColumnMatrixTy &LM, unsigned I, unsigned J,
+  Value *extractVector(const MatrixTy &LM, unsigned I, unsigned J,
                        unsigned NumElts, IRBuilder<> &Builder) {
     Value *Col = LM.getColumn(J);
     Value *Undef = UndefValue::get(Col->getType());
@@ -836,7 +857,7 @@ class LowerMatrixIntrinsics {
   /// cached value when they are lowered. For other users, \p Matrix is
   /// flattened and the uses are updated to use it. Also marks \p Inst for
   /// deletion.
-  void finalizeLowering(Instruction *Inst, ColumnMatrixTy Matrix,
+  void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
                         IRBuilder<> &Builder) {
     Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
 
@@ -854,9 +875,8 @@ class LowerMatrixIntrinsics {
 
   /// Compute Res += A * B for tile-sized matrices with left-associating
   /// addition.
-  void emitChainedMatrixMultiply(ColumnMatrixTy &Result,
-                                 const ColumnMatrixTy &A,
-                                 const ColumnMatrixTy &B, bool AllowContraction,
+  void emitChainedMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
+                                 const MatrixTy &B, bool AllowContraction,
                                  IRBuilder<> &Builder, bool isTiled) {
     const unsigned VF = std::max<unsigned>(
         TTI.getRegisterBitWidth(true) /
@@ -902,17 +922,15 @@ class LowerMatrixIntrinsics {
     ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
     ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
 
-    const ColumnMatrixTy &Lhs =
-        getMatrix(MatMul->getArgOperand(0), LShape, Builder);
-    const ColumnMatrixTy &Rhs =
-        getMatrix(MatMul->getArgOperand(1), RShape, Builder);
+    const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
+    const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
 
     const unsigned R = LShape.NumRows;
     const unsigned C = RShape.NumColumns;
     assert(LShape.NumColumns == RShape.NumRows);
 
     // Initialize the output
-    ColumnMatrixTy Result;
+    MatrixTy Result;
     for (unsigned J = 0; J < C; ++J)
       Result.addColumn(UndefValue::get(VectorType::get(EltType, R)));
 
@@ -924,12 +942,12 @@ class LowerMatrixIntrinsics {
 
   /// Lowers llvm.matrix.transpose.
   void LowerTranspose(CallInst *Inst) {
-    ColumnMatrixTy Result;
+    MatrixTy Result;
     IRBuilder<> Builder(Inst);
     Value *InputVal = Inst->getArgOperand(0);
     VectorType *VectorTy = cast<VectorType>(InputVal->getType());
     ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
-    ColumnMatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
+    MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
 
     for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) {
       // Build a single column vector for this row. First initialize it.
@@ -989,11 +1007,11 @@ class LowerMatrixIntrinsics {
     IRBuilder<> Builder(Inst);
     ShapeInfo &Shape = I->second;
 
-    ColumnMatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder);
-    ColumnMatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder);
+    MatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder);
+    MatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder);
 
     // Add each column and store the result back into the opmapping
-    ColumnMatrixTy Result;
+    MatrixTy Result;
     auto BuildColumnOp = [&Builder, Inst](Value *LHS, Value *RHS) {
       switch (Inst->getOpcode()) {
       case Instruction::Add:
@@ -1035,7 +1053,7 @@ class LowerMatrixIntrinsics {
 
     /// Mapping from instructions to column matrixes. It is used to identify
     /// matrix instructions.
-    const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix;
+    const MapVector<Value *, MatrixTy> &Inst2ColumnMatrix;
 
     /// Mapping from values to the leaves of all expressions that the value is
     /// part of.
@@ -1052,7 +1070,7 @@ class LowerMatrixIntrinsics {
     SmallPtrSet<Value *, 8> ReusedExprs;
 
     ExprLinearizer(const DataLayout &DL,
-                   const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix,
+                   const MapVector<Value *, MatrixTy> &Inst2ColumnMatrix,
                    const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
                    const SmallSetVector<Value *, 32> &ExprsInSubprogram,
                    Value *Leaf)
@@ -1296,12 +1314,12 @@ class LowerMatrixIntrinsics {
   ///    that multiple leaves can share sub-expressions. Shared subexpressions
   ///    are explicitly marked as shared().
   struct RemarkGenerator {
-    const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix;
+    const MapVector<Value *, MatrixTy> &Inst2ColumnMatrix;
     OptimizationRemarkEmitter &ORE;
     Function &Func;
     const DataLayout &DL;
 
-    RemarkGenerator(const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix,
+    RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2ColumnMatrix,
                     OptimizationRemarkEmitter &ORE, Function &Func)
         : Inst2ColumnMatrix(Inst2ColumnMatrix), ORE(ORE), Func(Func),
           DL(Func.getParent()->getDataLayout()) {}


        


More information about the llvm-commits mailing list