[Mlir-commits] [mlir] 63b683a - [mlir][Vector] Add a vector.matrix_multiply op on 1-D vectors
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Mar 9 10:35:29 PDT 2020
Author: Nicolas Vasilache
Date: 2020-03-09T13:34:03-04:00
New Revision: 63b683a8168fa03bd8dfa7567c73ad507104f666
URL: https://github.com/llvm/llvm-project/commit/63b683a8168fa03bd8dfa7567c73ad507104f666
DIFF: https://github.com/llvm/llvm-project/commit/63b683a8168fa03bd8dfa7567c73ad507104f666.diff
LOG: [mlir][Vector] Add a vector.matrix_multiply op on 1-D vectors
Summary: This op mirrors the llvm.intr counterpart and allows lowering + type conversions in a progressive fashion.
Differential Revision: https://reviews.llvm.org/D75775
Added:
Modified:
mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/VectorOps/VectorOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Target/llvmir-intrinsics.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index 1a50a810a265..8c95d13a2922 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -15,6 +15,12 @@ class LLVMTypeConverter;
class ModuleOp;
template <typename T> class OpPassBase;
+/// Collect a set of patterns to convert from Vector contractions to LLVM Matrix
+/// Intrinsics. To lower to assembly, the LLVM flag -lower-matrix-intrinsics
+/// will be needed when invoking LLVM.
+void populateVectorToLLVMMatrixConversionPatterns(
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
+
/// Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
OwningRewritePatternList &patterns);
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 39d056d93999..1fb5bd92b7f8 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -836,12 +836,12 @@ def LLVM_MatrixMultiplyOp
: LLVM_OneResultOp<"intr.matrix.multiply">,
Arguments<(
ins LLVM_Type:$lhs, LLVM_Type:$rhs,
- I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_rows)> {
+ I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)> {
string llvmBuilder = [{
llvm::MatrixBuilder<decltype(builder)> mb(builder);
$res = mb.CreateMatrixMultiply(
$lhs, $rhs, $lhs_rows.getZExtValue(), $lhs_columns.getZExtValue(),
- $rhs_rows.getZExtValue());
+ $rhs_columns.getZExtValue());
}];
let assemblyFormat = "$lhs `,` $rhs attr-dict "
"`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)";
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
index 67a880d2e5d6..331a43429039 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
@@ -1336,4 +1336,65 @@ def Vector_PrintOp :
let assemblyFormat = "$source attr-dict `:` type($source)";
}
+//===----------------------------------------------------------------------===//
+// Ops used for supporting progressive lowering and conversion type changes.
+//===----------------------------------------------------------------------===//
+
+/// Vector dialect matrix multiplication op that operates on flattened 1-D
+/// MLIR vectors. This is the counterpart of llvm.matrix.multiply in MLIR.
+/// This may seem redundant with vector.contract but it serves the purposes of
+/// more progressive lowering and localized type conversion on the path:
+/// `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`.
+def Vector_MatmulOp : Vector_Op<"matrix_multiply", [NoSideEffect,
+ PredOpTrait<"lhs operand and result have same element type",
+ TCresVTEtIsSameAsOpBase<0, 0>>,
+ PredOpTrait<"rhs operand and result have same element type",
+ TCresVTEtIsSameAsOpBase<0, 1>>]>,
+ Arguments<(
+ // TODO(ntv, fhahn): tighten vector element types that make sense.
+ ins VectorOfRankAndType<[1],
+ [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$lhs,
+ VectorOfRankAndType<[1],
+ [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$rhs,
+ I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)>,
+ Results<(
+ outs VectorOfRankAndType<[1],
+ [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$res)>
+{
+ let summary = "Vector matrix multiplication op that operates on flattened 1-D"
+ " MLIR vectors";
+ let description = [{
+ This is the counterpart of llvm.matrix.multiply in MLIR. It serves the
+ purposes of more progressive lowering and localized type conversion.
+
+ The ‘vector.matrix_multiply’ op treats `lhs` as matrix with <lhs_rows> rows
+ and <lhs_columns> columns, `rhs` as matrix with <lhs_columns> rows and
+ <rhs_columns> and multiplies them. The result matrix is returned embedded in
+ the result vector.
+
+ Example:
+
+ ```
+ %C = vector.matrix_multiply %A, %B
+ { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
+ (vector<64xf64>, vector<48xf64>) -> vector<12xf64>
+ ```
+ }];
+ let builders = [
+ OpBuilder<"Builder *builder, OperationState &result, Value lhs, Value rhs, "
+ "unsigned lhsRows, unsigned lhsColumns, unsigned rhsColumns",
+ [{
+ result.addOperands({lhs, rhs});
+ result.addAttribute("lhs_rows", builder->getI32IntegerAttr(lhsRows));
+ result.addAttribute("lhs_columns", builder->getI32IntegerAttr(lhsColumns));
+ result.addAttribute("rhs_columns", builder->getI32IntegerAttr(rhsColumns));
+ result.addTypes(VectorType::get(lhsRows * lhsColumns,
+ lhs.getType().cast<VectorType>().getElementType()));
+ }]>,
+ ];
+ let verifier = ?;
+ let assemblyFormat = "$lhs `,` $rhs attr-dict "
+ "`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)";
+}
+
#endif // VECTOR_OPS
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a075848a9ac7..d2167c52a2d2 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -275,6 +275,28 @@ class VectorBroadcastOpConversion : public ConvertToLLVMPattern {
}
};
+/// Conversion pattern for a vector.matrix_multiply.
+/// This is lowered directly to the proper llvm.intr.matrix.multiply.
+class VectorMatmulOpConversion : public ConvertToLLVMPattern {
+public:
+ explicit VectorMatmulOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
+ typeConverter) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto matmulOp = cast<vector::MatmulOp>(op);
+ auto adaptor = vector::MatmulOpOperandAdaptor(operands);
+ rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
+ op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
+ adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
+ matmulOp.rhs_columns());
+ return matchSuccess();
+ }
+};
+
class VectorReductionOpConversion : public ConvertToLLVMPattern {
public:
explicit VectorReductionOpConversion(MLIRContext *context,
@@ -1141,6 +1163,12 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorPrintOpConversion>(ctx, converter);
}
+void mlir::populateVectorToLLVMMatrixConversionPatterns(
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+ MLIRContext *ctx = converter.getDialect()->getContext();
+ patterns.insert<VectorMatmulOpConversion>(ctx, converter);
+}
+
namespace {
struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> {
void runOnModule() override;
@@ -1160,6 +1188,7 @@ void LowerVectorToLLVMPass::runOnModule() {
// Convert to the LLVM IR dialect.
LLVMTypeConverter converter(&getContext());
OwningRewritePatternList patterns;
+ populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(converter, patterns);
populateStdToLLVMConversionPatterns(converter, patterns);
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 5159031339aa..f70fb0cac6da 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -701,3 +701,15 @@ func @reduce_i64(%arg0: vector<16xi64>) -> i64 {
// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.add"(%[[A]])
// CHECK: llvm.return %[[V]] : !llvm.i64
+
+// 4x16 16x3 4x3
+func @matrix_ops(%A: vector<64xf64>, %B: vector<48xf64>) -> vector<12xf64> {
+ %C = vector.matrix_multiply %A, %B
+ { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
+ (vector<64xf64>, vector<48xf64>) -> vector<12xf64>
+ return %C: vector<12xf64>
+}
+// CHECK-LABEL: llvm.func @matrix_ops
+// CHECK: llvm.intr.matrix.multiply %{{.*}}, %{{.*}} {
+// CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32
+// CHECK-SAME: } : (!llvm<"<64 x double>">, !llvm<"<48 x double>">) -> !llvm<"<12 x double>">
diff --git a/mlir/test/Target/llvmir-intrinsics.mlir b/mlir/test/Target/llvmir-intrinsics.mlir
index 5d6602d0ff76..f0f17966e0c3 100644
--- a/mlir/test/Target/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/llvmir-intrinsics.mlir
@@ -136,7 +136,7 @@ llvm.func @matrix_intrinsics(%A: !llvm<"<64 x float>">, %B: !llvm<"<48 x float>"
%ptr: !llvm<"float*">, %stride: !llvm.i32) {
// CHECK: call <12 x float> @llvm.matrix.multiply.v12f32.v64f32.v48f32(<64 x float> %0, <48 x float> %1, i32 4, i32 16, i32 3)
%C = llvm.intr.matrix.multiply %A, %B
- { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_rows = 3: i32} :
+ { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32} :
(!llvm<"<64 x float>">, !llvm<"<48 x float>">) -> !llvm<"<12 x float>">
// CHECK: call <48 x float> @llvm.matrix.transpose.v48f32(<48 x float> %1, i32 3, i32 16)
%D = llvm.intr.matrix.transpose %B { rows = 3: i32, columns = 16: i32} :
More information about the Mlir-commits
mailing list