[llvm] 2a70db2 - [Matrix] Add IR MatrixBuilder.
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Wed Mar 4 03:19:06 PST 2020
Author: Florian Hahn
Date: 2020-03-04T11:14:20Z
New Revision: 2a70db245d6ea19e5614c7c22660de1a5dd8e898
URL: https://github.com/llvm/llvm-project/commit/2a70db245d6ea19e5614c7c22660de1a5dd8e898
DIFF: https://github.com/llvm/llvm-project/commit/2a70db245d6ea19e5614c7c22660de1a5dd8e898.diff
LOG: [Matrix] Add IR MatrixBuilder.
This builder provides a convenient way for targets to lower various matrix
operations to LLVM IR, making use of matrix intrinsics where available.
Reviewers: anemet, Gerolf, hfinkel, andrew.w.kaylor, LuoYuanke
Reviewed By: anemet
Differential Revision: https://reviews.llvm.org/D72280
Added:
llvm/include/llvm/IR/MatrixBuilder.h
Modified:
Removed:
################################################################################
diff --git a/llvm/include/llvm/IR/MatrixBuilder.h b/llvm/include/llvm/IR/MatrixBuilder.h
new file mode 100644
index 000000000000..ab92f86ebb4e
--- /dev/null
+++ b/llvm/include/llvm/IR/MatrixBuilder.h
@@ -0,0 +1,172 @@
+//===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the MatrixBuilder class, which is used as a convenient way
+// to lower matrix operations to LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_IR_MATRIXBUILDER_H
+#define LLVM_IR_MATRIXBUILDER_H
+
+#include "llvm/IR/Constant.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Value.h"
+
+namespace llvm {
+
+class Function;
+class Twine;
+class Module;
+
+template <class IRBuilderTy> class MatrixBuilder {
+ IRBuilderTy &B;
+ Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); }
+
+public:
+ MatrixBuilder(IRBuilderTy &Builder) : B(Builder) {}
+
+ /// Create a columnwise, strided matrix load.
+ /// \p DataPtr - Start address of the matrix read
+ /// \p Rows - Number of rows in matrix (must be a constant)
+ /// \p Columns - Number of columns in matrix (must be a constant)
+ /// \p Stride - Space between columns
+ CallInst *CreateMatrixColumnwiseLoad(Value *DataPtr, unsigned Rows,
+ unsigned Columns, Value *Stride,
+ const Twine &Name = "") {
+
+ // Deal with the pointer
+ PointerType *PtrTy = cast<PointerType>(DataPtr->getType());
+ Type *EltTy = PtrTy->getElementType();
+
+ Type *RetType = VectorType::get(EltTy, Rows * Columns);
+
+ Value *Ops[] = {DataPtr, Stride, B.getInt32(Rows), B.getInt32(Columns)};
+ Type *OverloadedTypes[] = {RetType, PtrTy};
+
+ Function *TheFn = Intrinsic::getDeclaration(
+ getModule(), Intrinsic::matrix_columnwise_load, OverloadedTypes);
+
+ return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
+ }
+
+ /// Create a columnwise, strided matrix store.
+ /// \p Matrix - Matrix to store
+ /// \p Ptr - Pointer to write back to
+ /// \p Stride - Space between columns
+ CallInst *CreateMatrixColumnwiseStore(Value *Matrix, Value *Ptr,
+ Value *Stride, unsigned Rows,
+ unsigned Columns,
+ const Twine &Name = "") {
+ Value *Ops[] = {Matrix, Ptr, Stride, B.getInt32(Rows), B.getInt32(Columns)};
+ Type *OverloadedTypes[] = {Matrix->getType(), Ptr->getType()};
+
+ Function *TheFn = Intrinsic::getDeclaration(
+ getModule(), Intrinsic::matrix_columnwise_store, OverloadedTypes);
+
+ return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
+ }
+
+ /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows
+ /// rows and \p Columns columns.
+ CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows,
+ unsigned Columns, const Twine &Name = "") {
+ auto *OpType = cast<VectorType>(Matrix->getType());
+ Type *ReturnType =
+ VectorType::get(OpType->getElementType(), Rows * Columns);
+
+ Type *OverloadedTypes[] = {ReturnType};
+ Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)};
+ Function *TheFn = Intrinsic::getDeclaration(
+ getModule(), Intrinsic::matrix_transpose, OverloadedTypes);
+
+ return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
+ }
+
+ /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p
+ /// RHS.
+ CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows,
+ unsigned LHSColumns, unsigned RHSColumns,
+ const Twine &Name = "") {
+ auto *LHSType = cast<VectorType>(LHS->getType());
+ auto *RHSType = cast<VectorType>(RHS->getType());
+
+ Type *ReturnType =
+ VectorType::get(LHSType->getElementType(), LHSRows * RHSColumns);
+
+ Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns),
+ B.getInt32(RHSColumns)};
+ Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType};
+
+ Function *TheFn = Intrinsic::getDeclaration(
+ getModule(), Intrinsic::matrix_multiply, OverloadedTypes);
+ return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
+ }
+
+ /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p
+ /// ColumnIdx).
+ Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx,
+ Value *ColumnIdx, unsigned NumRows) {
+ return B.CreateInsertElement(
+ Matrix, NewVal,
+ B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get(
+ ColumnIdx->getType(), NumRows)),
+ RowIdx));
+ }
+
+ /// Add matrixes \p LHS and \p RHS. Support both integer and floating point
+ /// matrixes.
+ Value *CreateAdd(Value *LHS, Value *RHS) {
+ return cast<VectorType>(LHS->getType())
+ ->getElementType()
+ ->isFloatingPointTy()
+ ? B.CreateFAdd(LHS, RHS)
+ : B.CreateAdd(LHS, RHS);
+ }
+
+ /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating
+ /// point matrixes.
+ Value *CreateSub(Value *LHS, Value *RHS) {
+ return cast<VectorType>(LHS->getType())
+ ->getElementType()
+ ->isFloatingPointTy()
+ ? B.CreateFSub(LHS, RHS)
+ : B.CreateSub(LHS, RHS);
+ }
+
+ /// Multiply matrix \p LHS with scalar \p RHS.
+ Value *CreateScalarMultiply(Value *LHS, Value *RHS) {
+ Value *ScalarVector =
+ B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getNumElements(),
+ RHS, "scalar.splat");
+ if (RHS->getType()->isFloatingPointTy())
+ return B.CreateFMul(LHS, ScalarVector);
+
+ return B.CreateMul(LHS, ScalarVector);
+ }
+
+ /// Extracts the element at (\p Row, \p Column) from \p Matrix.
+ Value *CreateExtractMatrix(Value *Matrix, Value *Row, Value *Column,
+ unsigned NumRows, Twine const &Name = "") {
+
+ return B.CreateExtractElement(
+ Matrix,
+ B.CreateAdd(
+ B.CreateMul(Column, ConstantInt::get(Column->getType(), NumRows)),
+ Row));
+ }
+};
+
+} // end namespace llvm
+
+#endif // LLVM_IR_MATRIXBUILDER_H
More information about the llvm-commits
mailing list