[Mlir-commits] [mlir] bbf3ef8 - [mlir][Vector]Lower vector.contract to llvm.intr.matrix_multiply

Nicolas Vasilache llvmlistbot at llvm.org
Fri Mar 13 13:38:06 PDT 2020


Author: Nicolas Vasilache
Date: 2020-03-13T16:33:23-04:00
New Revision: bbf3ef85411660e2cebef8d1b633c35d9b179948

URL: https://github.com/llvm/llvm-project/commit/bbf3ef85411660e2cebef8d1b633c35d9b179948
DIFF: https://github.com/llvm/llvm-project/commit/bbf3ef85411660e2cebef8d1b633c35d9b179948.diff

LOG: [mlir][Vector]Lower vector.contract to llvm.intr.matrix_multiply

Summary:
This revision adds lowering of vector.contract to llvm.intr.matrix_multiply.
Note that there is currently a mismatch between the MLIR vector dialect which
expects row-major layout and the LLVM matrix intrinsics which expect column
major layout.

As a consequence, we currently only match a vector.contract with indexing maps
that express column-major matrix multiplication.
Other cases would require additional transposes and it is better to wait for
LLVM intrinsics to provide a per-operation attribute that would specify which
layout is expected.

A separate integration test, not submitted to MLIR core, has independently
verified that correct execution occurs on a 2x2x2 matrix multiplication.

Differential Revision: https://reviews.llvm.org/D76014

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
    mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
    mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index 782a60c4e024..d54791a65410 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -17,11 +17,33 @@
 #ifndef MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
 #define MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
 
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/StringRef.h"
 
 namespace mlir {
+
+inline bool isRowMajorMatmul(ArrayAttr indexingMaps) {
+  AffineExpr m, n, k;
+  bindDims(indexingMaps.getContext(), m, n, k);
+  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}));
+  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}));
+  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}));
+  auto maps = ArrayAttr::get({mapA, mapB, mapC}, indexingMaps.getContext());
+  return indexingMaps == maps;
+}
+
+inline bool isColumnMajorMatmul(ArrayAttr indexingMaps) {
+  AffineExpr m, n, k;
+  bindDims(indexingMaps.getContext(), m, n, k);
+  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}));
+  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}));
+  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m}));
+  auto maps = ArrayAttr::get({mapA, mapB, mapC}, indexingMaps.getContext());
+  return indexingMaps == maps;
+}
+
 /// Attribute name for the AffineArrayAttr which encodes the relationship
 /// between a structured op iterators' and its operands.
 constexpr StringRef getIndexingMapsAttrName() { return "indexing_maps"; }

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 828a964afa06..d295870d09ab 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1138,6 +1138,7 @@ void LowerVectorToLLVMPass::runOnModule() {
   OwningRewritePatternList patterns;
   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
   populateVectorToLLVMConversionPatterns(converter, patterns);
+  populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
   populateStdToLLVMConversionPatterns(converter, patterns);
 
   LLVMConversionTarget target(getContext());

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
index bc39b1aab0ff..d751b46059b4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/VectorOps/VectorOps.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Matchers.h"
@@ -144,17 +145,10 @@ static bool hasMultiplyAddBody(linalg::GenericOp op) {
 
 // TODO(ntv) should be Tablegen'd from a single source that generates the op
 // itself.
-static bool isMatmul(linalg::GenericOp genericOp) {
-  auto *ctx = genericOp.getContext();
-  auto m = getAffineDimExpr(0, ctx);
-  auto n = getAffineDimExpr(1, ctx);
-  auto k = getAffineDimExpr(2, ctx);
-  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}));
-  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}));
-  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}));
-  auto maps = ArrayAttr::get({mapA, mapB, mapC}, ctx);
+static bool isRowMajorMatmul(linalg::GenericOp genericOp) {
   return genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 &&
-         genericOp.indexing_maps() == maps && hasMultiplyAddBody(genericOp);
+         isRowMajorMatmul(genericOp.indexing_maps()) &&
+         hasMultiplyAddBody(genericOp);
 }
 
 // TODO(ntv, ataei): This is in fact much more general than just vectorization
@@ -172,7 +166,7 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
     return success();
 
   auto genericOp = dyn_cast<linalg::GenericOp>(op);
-  if (!genericOp || !isMatmul(genericOp))
+  if (!genericOp || !::isRowMajorMatmul(genericOp))
     return failure();
 
   // TODO(ntv): non-identity layout.

diff  --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
index b164df276c3b..4f6634d2be27 100644
--- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
@@ -42,6 +42,13 @@ using namespace mlir;
 using llvm::dbgs;
 using mlir::functional::zipMap;
 
+static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
+
+static llvm::cl::opt<bool> lowerToLLVMMatrixIntrinsics(
+    "vector-lower-matrix-intrinsics",
+    llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
+    llvm::cl::init(false), llvm::cl::cat(clOptionsCategory));
+
 /// Given a shape with sizes greater than 0 along all dimensions,
 /// returns the distance, in number of elements, between a slice in a dimension
 /// and the next slice in the same dimension.
@@ -935,6 +942,39 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
     if (llvm::size(op.masks()) != 0)
       return matchFailure();
 
+    // TODO(ntv, ajcbik): implement benefits, cost models, separate this out in
+    // a new pattern.
+    // TODO(ntv, fhahn): once row-major mode is available in LLVM's matrix
+    // intrinsics, use that.
+    if (lowerToLLVMMatrixIntrinsics &&
+        isColumnMajorMatmul(op.indexing_maps())) {
+      VectorType lhsType = op.getLhsType();
+      VectorType rhsType = op.getRhsType();
+      Type flattenedLHSType =
+          VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
+      Type flattenedRHSType =
+          VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
+      auto lhs = rewriter.create<vector::ShapeCastOp>(
+          op.getLoc(), flattenedLHSType, op.lhs());
+      auto rhs = rewriter.create<vector::ShapeCastOp>(
+          op.getLoc(), flattenedRHSType, op.rhs());
+
+      unsigned lhsRows = op.getLhsType().getShape()[0];
+      unsigned lhsColumns = op.getLhsType().getShape()[1];
+      unsigned rhsColumns = op.getRhsType().getShape()[1];
+      Value mul = rewriter.create<vector::MatmulOp>(
+          op.getLoc(), lhs, rhs, lhsRows, lhsColumns, rhsColumns);
+      mul = rewriter.create<vector::ShapeCastOp>(op.getLoc(),
+                                                 op.acc().getType(), mul);
+      Type elementType = op.getLhsType().getElementType();
+      assert(elementType.isIntOrFloat());
+      if (elementType.isa<IntegerType>())
+        rewriter.replaceOpWithNewOp<AddIOp>(op, op.acc(), mul);
+      else
+        rewriter.replaceOpWithNewOp<AddFOp>(op, op.acc(), mul);
+      return matchSuccess();
+    }
+
     // Find first batch dimension in LHS/RHS, and lower when found.
     std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
     if (!batchDimMap.empty()) {

diff  --git a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
index 8d6b0d8d9ccd..2440c7b4a566 100644
--- a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s
+// RUN: mlir-opt %s -test-vector-contraction-conversion -vector-lower-matrix-intrinsics | FileCheck %s --check-prefix=MATRIX
 
 #dotp_accesses = [
   affine_map<(i) -> (i)>,
@@ -333,3 +334,47 @@ func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) {
   // CHECK: return %[[add]], %[[res1]] : vector<4xf32>, vector<2x2xf32>
   return %r0, %1 : vector<4xf32>, vector<2x2xf32>
 }
+
+// MATRIX-LABEL: func @column_major_matmul
+// MATRIX-SAME: %[[A:[a-zA-Z0-9]*]]: vector<4x3xf32>,
+// MATRIX-SAME: %[[B:[a-zA-Z0-9]*]]: vector<2x4xf32>,
+// MATRIX-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
+//      MATRIX:  %[[vcst:.*]] = constant dense<0.000000e+00> : vector<12xf32>
+//      MATRIX:  %[[vcst_0:.*]] = constant dense<0.000000e+00> : vector<8xf32>
+//      MATRIX:  %[[vcst_1:.*]] = constant dense<0.000000e+00> : vector<3x2xf32>
+//      MATRIX:  %[[a0:.*]] = vector.extract %[[A]][0] : vector<4x3xf32>
+//      MATRIX:  %[[a1:.*]] = vector.insert_strided_slice %[[a0]], %[[vcst]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<12xf32>
+//      MATRIX:  %[[a2:.*]] = vector.extract %[[A]][1] : vector<4x3xf32>
+//      MATRIX:  %[[a3:.*]] = vector.insert_strided_slice %[[a2]], %[[a1]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<12xf32>
+//      MATRIX:  %[[a4:.*]] = vector.extract %[[A]][2] : vector<4x3xf32>
+//      MATRIX:  %[[a5:.*]] = vector.insert_strided_slice %[[a4]], %[[a3]] {offsets = [6], strides = [1]} : vector<3xf32> into vector<12xf32>
+//      MATRIX:  %[[a6:.*]] = vector.extract %[[A]][3] : vector<4x3xf32>
+//      MATRIX:  %[[a7:.*]] = vector.insert_strided_slice %[[a6]], %[[a5]] {offsets = [9], strides = [1]} : vector<3xf32> into vector<12xf32>
+//      MATRIX:  %[[b8:.*]] = vector.extract %[[B]][0] : vector<2x4xf32>
+//      MATRIX:  %[[b9:.*]] = vector.insert_strided_slice %[[b8]], %[[vcst_0]] {offsets = [0], strides = [1]} : vector<4xf32> into vector<8xf32>
+//      MATRIX:  %[[b10:.*]] = vector.extract %[[B]][1] : vector<2x4xf32>
+//      MATRIX:  %[[b11:.*]] = vector.insert_strided_slice %[[b10]], %[[b9]] {offsets = [4], strides = [1]} : vector<4xf32> into vector<8xf32>
+//      MATRIX:  %[[mm12:.*]] = vector.matrix_multiply %[[a7]], %[[b11]] {lhs_columns = 3 : i32, lhs_rows = 4 : i32, rhs_columns = 4 : i32} : (vector<12xf32>, vector<8xf32>) -> vector<12xf32>
+//      MATRIX:  %[[mm13:.*]] = vector.strided_slice %[[mm12]] {offsets = [0], sizes = [2], strides = [1]} : vector<12xf32> to vector<2xf32>
+//      MATRIX:  %[[mm14:.*]] = vector.insert %[[mm13]], %[[vcst_1]] [0] : vector<2xf32> into vector<3x2xf32>
+//      MATRIX:  %[[mm15:.*]] = vector.strided_slice %[[mm12]] {offsets = [2], sizes = [2], strides = [1]} : vector<12xf32> to vector<2xf32>
+//      MATRIX:  %[[mm16:.*]] = vector.insert %[[mm15]], %[[mm14]] [1] : vector<2xf32> into vector<3x2xf32>
+//      MATRIX:  %[[mm17:.*]] = vector.strided_slice %[[mm12]] {offsets = [4], sizes = [2], strides = [1]} : vector<12xf32> to vector<2xf32>
+//      MATRIX:  %[[mm18:.*]] = vector.insert %[[mm17]], %[[mm16]] [2] : vector<2xf32> into vector<3x2xf32>
+//      MATRIX:  %[[mm19:.*]] = addf %[[C]], %[[mm18]] : vector<3x2xf32>
+#column_major_matmat_accesses = [
+  affine_map<(i, j, k) -> (k, j)>,
+  affine_map<(i, j, k) -> (i, k)>,
+  affine_map<(i, j, k) -> (j, i)>
+]
+#column_major_matmat_trait = {
+  indexing_maps = #column_major_matmat_accesses,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+func @column_major_matmul(%arg0: vector<4x3xf32>,
+                          %arg1: vector<2x4xf32>,
+                          %arg2: vector<3x2xf32>) -> vector<3x2xf32> {
+  %0 = vector.contract #column_major_matmat_trait %arg0, %arg1, %arg2
+    : vector<4x3xf32>, vector<2x4xf32> into vector<3x2xf32>
+  return %0 : vector<3x2xf32>
+}


        


More information about the Mlir-commits mailing list