[Mlir-commits] [mlir] 63b683a - [mlir][Vector] Add a vector.matrix_multiply op on 1-D vectors

Nicolas Vasilache llvmlistbot at llvm.org
Mon Mar 9 10:35:29 PDT 2020


Author: Nicolas Vasilache
Date: 2020-03-09T13:34:03-04:00
New Revision: 63b683a8168fa03bd8dfa7567c73ad507104f666

URL: https://github.com/llvm/llvm-project/commit/63b683a8168fa03bd8dfa7567c73ad507104f666
DIFF: https://github.com/llvm/llvm-project/commit/63b683a8168fa03bd8dfa7567c73ad507104f666.diff

LOG: [mlir][Vector] Add a vector.matrix_multiply op on 1-D vectors

Summary: This op mirrors the llvm.intr counterpart and allows lowering + type conversions in a progressive fashion.

Differential Revision: https://reviews.llvm.org/D75775

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Dialect/VectorOps/VectorOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Target/llvmir-intrinsics.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index 1a50a810a265..8c95d13a2922 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -15,6 +15,12 @@ class LLVMTypeConverter;
 class ModuleOp;
 template <typename T> class OpPassBase;
 
+/// Collect a set of patterns to convert from Vector contractions to LLVM Matrix
+/// Intrinsics. To lower to assembly, the LLVM flag -lower-matrix-intrinsics
+/// will be needed when invoking LLVM.
+void populateVectorToLLVMMatrixConversionPatterns(
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
+
 /// Collect a set of patterns to convert from the Vector dialect to LLVM.
 void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                             OwningRewritePatternList &patterns);

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 39d056d93999..1fb5bd92b7f8 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -836,12 +836,12 @@ def LLVM_MatrixMultiplyOp
     : LLVM_OneResultOp<"intr.matrix.multiply">,
       Arguments<(
         ins LLVM_Type:$lhs, LLVM_Type:$rhs,
-            I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_rows)> {
+            I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)> {
   string llvmBuilder = [{
     llvm::MatrixBuilder<decltype(builder)> mb(builder);
     $res = mb.CreateMatrixMultiply(
       $lhs, $rhs, $lhs_rows.getZExtValue(), $lhs_columns.getZExtValue(),
-      $rhs_rows.getZExtValue());
+      $rhs_columns.getZExtValue());
   }];
   let assemblyFormat = "$lhs `,` $rhs attr-dict "
     "`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)";

diff  --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
index 67a880d2e5d6..331a43429039 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
@@ -1336,4 +1336,65 @@ def Vector_PrintOp :
   let assemblyFormat = "$source attr-dict `:` type($source)";
 }
 
+//===----------------------------------------------------------------------===//
+// Ops used for supporting progressive lowering and conversion type changes.
+//===----------------------------------------------------------------------===//
+
+/// Vector dialect matrix multiplication op that operates on flattened 1-D
+/// MLIR vectors. This is the counterpart of llvm.matrix.multiply in MLIR.
+/// This may seem redundant with vector.contract but it serves the purposes of
+/// more progressive lowering and localized type conversion on the path:
+///   `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`.
+def Vector_MatmulOp : Vector_Op<"matrix_multiply", [NoSideEffect,
+        PredOpTrait<"lhs operand and result have same element type",
+                    TCresVTEtIsSameAsOpBase<0, 0>>,
+        PredOpTrait<"rhs operand and result have same element type",
+                    TCresVTEtIsSameAsOpBase<0, 1>>]>,
+      Arguments<(
+        // TODO(ntv, fhahn): tighten vector element types that make sense.
+        ins VectorOfRankAndType<[1],
+              [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$lhs,
+            VectorOfRankAndType<[1],
+              [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$rhs,
+            I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)>,
+      Results<(
+        outs VectorOfRankAndType<[1],
+               [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$res)>
+{
+  let summary = "Vector matrix multiplication op that operates on flattened 1-D"
+    " MLIR vectors";
+  let description = [{
+    This is the counterpart of llvm.matrix.multiply in MLIR. It serves the
+    purposes of more progressive lowering and localized type conversion.
+
+    The ‘vector.matrix_multiply’ op treats `lhs` as matrix with <lhs_rows> rows
+    and <lhs_columns> columns, `rhs` as matrix with <lhs_columns> rows and
+    <rhs_columns> and multiplies them. The result matrix is returned embedded in
+    the result vector.
+
+    Example:
+
+    ```
+      %C = vector.matrix_multiply %A, %B
+        { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
+        (vector<64xf64>, vector<48xf64>) -> vector<12xf64>
+    ```
+  }];
+  let builders = [
+   OpBuilder<"Builder *builder, OperationState &result, Value lhs, Value rhs, "
+             "unsigned lhsRows, unsigned lhsColumns, unsigned rhsColumns",
+   [{
+     result.addOperands({lhs, rhs});
+     result.addAttribute("lhs_rows", builder->getI32IntegerAttr(lhsRows));
+     result.addAttribute("lhs_columns", builder->getI32IntegerAttr(lhsColumns));
+     result.addAttribute("rhs_columns", builder->getI32IntegerAttr(rhsColumns));
+     result.addTypes(VectorType::get(lhsRows * lhsColumns,
+       lhs.getType().cast<VectorType>().getElementType()));
+   }]>,
+  ];
+  let verifier = ?;
+  let assemblyFormat = "$lhs `,` $rhs attr-dict "
+    "`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)";
+}
+
 #endif // VECTOR_OPS

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a075848a9ac7..d2167c52a2d2 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -275,6 +275,28 @@ class VectorBroadcastOpConversion : public ConvertToLLVMPattern {
   }
 };
 
+/// Conversion pattern for a vector.matrix_multiply.
+/// This is lowered directly to the proper llvm.intr.matrix.multiply.
+class VectorMatmulOpConversion : public ConvertToLLVMPattern {
+public:
+  explicit VectorMatmulOpConversion(MLIRContext *context,
+                                    LLVMTypeConverter &typeConverter)
+      : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
+                             typeConverter) {}
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto matmulOp = cast<vector::MatmulOp>(op);
+    auto adaptor = vector::MatmulOpOperandAdaptor(operands);
+    rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
+        op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
+        adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
+        matmulOp.rhs_columns());
+    return matchSuccess();
+  }
+};
+
 class VectorReductionOpConversion : public ConvertToLLVMPattern {
 public:
   explicit VectorReductionOpConversion(MLIRContext *context,
@@ -1141,6 +1163,12 @@ void mlir::populateVectorToLLVMConversionPatterns(
                   VectorPrintOpConversion>(ctx, converter);
 }
 
+void mlir::populateVectorToLLVMMatrixConversionPatterns(
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+  MLIRContext *ctx = converter.getDialect()->getContext();
+  patterns.insert<VectorMatmulOpConversion>(ctx, converter);
+}
+
 namespace {
 struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> {
   void runOnModule() override;
@@ -1160,6 +1188,7 @@ void LowerVectorToLLVMPass::runOnModule() {
   // Convert to the LLVM IR dialect.
   LLVMTypeConverter converter(&getContext());
   OwningRewritePatternList patterns;
+  populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
   populateVectorToLLVMConversionPatterns(converter, patterns);
   populateStdToLLVMConversionPatterns(converter, patterns);
 

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 5159031339aa..f70fb0cac6da 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -701,3 +701,15 @@ func @reduce_i64(%arg0: vector<16xi64>) -> i64 {
 //      CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.add"(%[[A]])
 //      CHECK: llvm.return %[[V]] : !llvm.i64
 
+
+//                          4x16                16x3               4x3
+func @matrix_ops(%A: vector<64xf64>, %B: vector<48xf64>) -> vector<12xf64> {
+  %C = vector.matrix_multiply %A, %B
+    { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
+    (vector<64xf64>, vector<48xf64>) -> vector<12xf64>
+  return %C: vector<12xf64>
+}
+// CHECK-LABEL: llvm.func @matrix_ops
+//       CHECK:   llvm.intr.matrix.multiply %{{.*}}, %{{.*}} {
+//  CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32
+//  CHECK-SAME: } : (!llvm<"<64 x double>">, !llvm<"<48 x double>">) -> !llvm<"<12 x double>">

diff  --git a/mlir/test/Target/llvmir-intrinsics.mlir b/mlir/test/Target/llvmir-intrinsics.mlir
index 5d6602d0ff76..f0f17966e0c3 100644
--- a/mlir/test/Target/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/llvmir-intrinsics.mlir
@@ -136,7 +136,7 @@ llvm.func @matrix_intrinsics(%A: !llvm<"<64 x float>">, %B: !llvm<"<48 x float>"
                              %ptr: !llvm<"float*">, %stride: !llvm.i32) {
   // CHECK: call <12 x float> @llvm.matrix.multiply.v12f32.v64f32.v48f32(<64 x float> %0, <48 x float> %1, i32 4, i32 16, i32 3)
   %C = llvm.intr.matrix.multiply %A, %B
-    { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_rows = 3: i32} :
+    { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32} :
     (!llvm<"<64 x float>">, !llvm<"<48 x float>">) -> !llvm<"<12 x float>">
   // CHECK: call <48 x float> @llvm.matrix.transpose.v48f32(<48 x float> %1, i32 3, i32 16)
   %D = llvm.intr.matrix.transpose %B { rows = 3: i32, columns = 16: i32} :


        


More information about the Mlir-commits mailing list