[Mlir-commits] [mlir] ee01c7a - [mlir] [VectorOps] Add choice between dot and axpy lowering of vector.contract

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 2 13:21:26 PDT 2020


Author: aartbik
Date: 2020-07-02T13:21:17-07:00
New Revision: ee01c7a7406345a50176216216ca384fb60e0267

URL: https://github.com/llvm/llvm-project/commit/ee01c7a7406345a50176216216ca384fb60e0267
DIFF: https://github.com/llvm/llvm-project/commit/ee01c7a7406345a50176216216ca384fb60e0267.diff

LOG: [mlir] [VectorOps] Add choice between dot and axpy lowering of vector.contract

Default vector.contract lowering essentially yields a series of sdot/ddot
operations. However, for some layouts a series of saxpy/daxpy operations,
chained through fma are more efficient. This CL introduces a choice between
the two lowering paths. A default heuristic is to follow.

Some preliminary avx2 performance numbers for matrix-times-vector.
Here, dot performs best for 64x64 A x b and saxpy for 64x64 A^T x b.

```
------------------------------------------------------------
            A x b                          A^T x b
------------------------------------------------------------
GFLOPS    sdot (reassoc)    saxpy    sdot (reassoc)    saxpy
------------------------------------------------------------
1x1        0.6               0.9       0.6             0.9
2x2        2.5               3.2       2.4             3.5
4x4        6.4               8.4       4.9             11.8
8x8       11.7               6.1       5.0             29.6
16x16     20.7              10.8       7.3             43.3
32x32     29.3               7.9       6.4             51.8
64x64     38.9                                         79.3
128x128   32.4                                         40.7
------------------------------------------------------------
```

Reviewed By: nicolasvasilache, ftynse

Differential Revision: https://reviews.llvm.org/D83012

Added: 
    mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/include/mlir/Dialect/Vector/VectorTransforms.h
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index def0d24adcf5..dd79b2986963 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -46,12 +46,14 @@ void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns,
 
 /// Enum to control the lowering of `vector.contract` operations.
 enum class VectorContractLowering {
-  /// Progressively lower to finer grained `vector.contract` and `vector.fma`.
-  FMA = 0,
+  /// Progressively lower to finer grained `vector.contract` and dot-products.
+  Dot = 0,
   /// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics.
   Matmul = 1,
   /// Lower to `vector.outerproduct`.
   OuterProduct = 2,
+  /// Lower to series of AXPY chained through FMA.
+  AXPY = 3,
 };
 /// Enum to control the lowering of `vector.transpose` operations.
 enum class VectorTransposeLowering {
@@ -63,7 +65,7 @@ enum class VectorTransposeLowering {
 };
 /// Structure to control the behavior of vector transform patterns.
 struct VectorTransformsOptions {
-  VectorContractLowering vectorContractLowering = VectorContractLowering::FMA;
+  VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
   VectorTransposeLowering vectorTransposeLowering =
       VectorTransposeLowering::EltWise;
   VectorTransformsOptions &

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index 08aa579d651b..1864d45ac552 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -135,6 +135,33 @@ class ContractionOpToOuterProductOpLowering
   vector::VectorTransformsOptions vectorTransformsOptions;
 };
 
+/// Progressive lowering of a `vector.contract %a, %b, %c` with
+/// matvec semantics to series of AXPY operations that are chained
+/// through FMA operations.
+///
+/// This only kicks in when VectorTransformsOptions is set to AXPY.
+//
+// TODO (ajcbik): this is very similar, but not quite the same as
+//                the outerproduct lowering above; merge the two?
+class ContractionOpToAXPYLowering
+    : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+  ContractionOpToAXPYLowering(
+      vector::VectorTransformsOptions vectorTransformsOptions,
+      MLIRContext *context)
+      : OpRewritePattern<vector::ContractionOp>(context),
+        vectorTransformsOptions(vectorTransformsOptions) {}
+
+  LogicalResult match(vector::ContractionOp op) const override;
+  void rewrite(vector::ContractionOp op,
+               PatternRewriter &rewriter) const override;
+
+private:
+  /// Options to control the vector patterns.
+  vector::VectorTransformsOptions vectorTransformsOptions;
+};
+
 /// Progressive lowering of ContractionOp.
 ///
 /// One:
@@ -145,10 +172,10 @@ class ContractionOpToOuterProductOpLowering
 ///   ..
 ///   %x = combine %a %b ..
 /// until a pure contraction is reached (no free/batch dimensions),
-/// which is replaced by a fma/reduction op.
+/// which is replaced by a dot-product.
 ///
-/// This only kicks in when either VectorTransformsOptions is set to FMA or when
-/// other contraction patterns fail.
+/// This only kicks in when either VectorTransformsOptions is set
+/// to Dot or when other contraction patterns fail.
 class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
 public:
   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index d8cd5a7fe0e8..b841580433f9 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1560,15 +1560,17 @@ ContractionOpToMatmulOpLowering::match(vector::ContractionOp op) const {
   if (llvm::size(op.masks()) != 0)
     return failure();
 
+  if (vectorTransformsOptions.vectorContractLowering !=
+      vector::VectorContractLowering::Matmul)
+    return failure();
+
   auto iteratorTypes = op.iterator_types().getValue();
   if (!isParallelIterator(iteratorTypes[0]) ||
       !isParallelIterator(iteratorTypes[1]) ||
       !isReductionIterator(iteratorTypes[2]))
     return failure();
 
-  if (vectorTransformsOptions.vectorContractLowering !=
-          vector::VectorContractLowering::Matmul ||
-      !isRowMajorMatmul(op.indexing_maps()))
+  if (!isRowMajorMatmul(op.indexing_maps()))
     return failure();
 
   return success();
@@ -1578,9 +1580,9 @@ void ContractionOpToMatmulOpLowering::rewrite(vector::ContractionOp op,
                                               PatternRewriter &rewriter) const {
   VectorType lhsType = op.getLhsType();
   VectorType rhsType = op.getRhsType();
-  unsigned lhsRows = op.getLhsType().getShape()[0];
-  unsigned lhsColumns = op.getLhsType().getShape()[1];
-  unsigned rhsColumns = op.getRhsType().getShape()[1];
+  int64_t lhsRows = lhsType.getDimSize(0);
+  int64_t lhsColumns = lhsType.getDimSize(1);
+  int64_t rhsColumns = rhsType.getDimSize(1);
 
   Type flattenedLHSType =
       VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
@@ -1657,7 +1659,7 @@ ContractionOpToOuterProductOpLowering ::match(vector::ContractionOp op) const {
 void ContractionOpToOuterProductOpLowering::rewrite(
     vector::ContractionOp op, PatternRewriter &rewriter) const {
   Location loc = op.getLoc();
-  unsigned reductionSize = 0;
+  int64_t reductionSize = 0;
   VectorType lhsType = op.getLhsType();
   Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
 
@@ -1673,41 +1675,41 @@ void ContractionOpToOuterProductOpLowering::rewrite(
   // First batch of cases, no need to output permute.
   if (maps == infer({{m, k}, {k, n}, {m, n}})) {
     // This is the classical row-major matmul. Just permute the lhs.
-    reductionSize = lhsType.getShape()[1];
+    reductionSize = lhsType.getDimSize(1);
     lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
   } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
-    reductionSize = lhsType.getShape()[1];
+    reductionSize = lhsType.getDimSize(1);
     lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
     rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
   } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
     // No need to permute anything.
-    reductionSize = lhsType.getShape()[0];
+    reductionSize = lhsType.getDimSize(0);
   } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
     // Just permute the rhs.
-    reductionSize = lhsType.getShape()[0];
+    reductionSize = lhsType.getDimSize(0);
     rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
   }
   // Second batch of cases, reshuffle to avoid output permute.
   else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
     // This is the classical row-major matmul. Just permute the lhs.
-    reductionSize = lhsType.getShape()[1];
+    reductionSize = lhsType.getDimSize(1);
     Value tmp = rhs;
     rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
     lhs = tmp;
   } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
-    reductionSize = lhsType.getShape()[1];
+    reductionSize = lhsType.getDimSize(1);
     Value tmp = rhs;
     rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
     lhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
   } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
     // No need to permute anything, but still swap lhs and rhs.
-    reductionSize = lhsType.getShape()[0];
+    reductionSize = lhsType.getDimSize(0);
     std::swap(lhs, rhs);
   } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
     // Just permute the rhs.
-    reductionSize = lhsType.getShape()[0];
+    reductionSize = lhsType.getDimSize(0);
     Value tmp = lhs;
     lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
     rhs = tmp;
@@ -1723,6 +1725,88 @@ void ContractionOpToOuterProductOpLowering::rewrite(
   rewriter.replaceOp(op, res);
 }
 
+/// Progressive lowering of a `vector.contract %a, %b, %c` with
+/// matvec semantics to series of AXPY operations that are chained
+/// through FMA operations.
+///
+/// This only kicks in when VectorTransformsOptions is set to AXPY.
+//
+// TODO (ajcbik): this is very similar, but not quite the same as
+//                the outerproduct lowering above; merge the two?
+LogicalResult
+ContractionOpToAXPYLowering::match(vector::ContractionOp op) const {
+  // TODO(ajcbik): implement masks
+  if (llvm::size(op.masks()) != 0)
+    return failure();
+
+  if (vectorTransformsOptions.vectorContractLowering !=
+      vector::VectorContractLowering::AXPY)
+    return failure();
+
+  auto iteratorTypes = op.iterator_types().getValue();
+  if (!isParallelIterator(iteratorTypes[0]) ||
+      !isReductionIterator(iteratorTypes[1]))
+    return failure();
+
+  // See if a series of AXPY operations chained through FMA operations
+  // could replace the default DOT implementation.
+  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
+  AffineExpr m, n;
+  bindDims(op.getContext(), m, n);
+  SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
+  if (maps != infer({{m, n}, {n}, {m}}) && // mat-vec
+      maps != infer({{n, m}, {n}, {m}}) && // mat-trans-vec
+      maps != infer({{n}, {m, n}, {m}}) && // vec-mat
+      maps != infer({{n}, {n, m}, {m}}))   // vec-mat-trans
+    return failure();
+  return success();
+}
+
+void ContractionOpToAXPYLowering::rewrite(vector::ContractionOp op,
+                                          PatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  VectorType lhsType = op.getLhsType();
+  Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
+
+  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
+  AffineExpr m, n;
+  bindDims(op.getContext(), m, n);
+  SmallVector<int64_t, 2> perm{1, 0};
+  //
+  SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
+  int64_t reductionSize = 0;
+  if (maps == infer({{m, n}, {n}, {m}})) {
+    // Case mat-vec: transpose.
+    reductionSize = lhsType.getDimSize(1);
+    lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+  } else if (maps == infer({{n, m}, {n}, {m}})) {
+    // Case mat-trans-vec: ready to go.
+    reductionSize = lhsType.getDimSize(0);
+  } else if (maps == infer({{n}, {m, n}, {m}})) {
+    // Case vec-mat: swap and transpose.
+    reductionSize = lhsType.getDimSize(0);
+    std::swap(lhs, rhs);
+    lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+  } else if (maps == infer({{n}, {n, m}, {m}})) {
+    // Case vec-mat-trans: swap and ready to go.
+    reductionSize = lhsType.getDimSize(0);
+    std::swap(lhs, rhs);
+  }
+  assert(reductionSize > 0);
+
+  // A direct series of AXPY operations, chained through FMA.
+  Type resType = op.getResultType();
+  for (int64_t k = 0; k < reductionSize; ++k) {
+    Value a = rewriter.create<vector::ExtractOp>(loc, lhs, k);
+    Value s = rewriter.create<vector::ExtractOp>(loc, rhs, k);
+    Value b = rewriter.create<vector::BroadcastOp>(loc, resType, s);
+    res = rewriter.create<vector::FMAOp>(loc, a, b, res);
+  }
+  rewriter.replaceOp(op, res);
+}
+
 /// Progressive lowering of ContractionOp.
 /// One:
 ///   %x = vector.contract with at least one free/batch dimension
@@ -1732,11 +1816,14 @@ void ContractionOpToOuterProductOpLowering::rewrite(
 ///   ..
 ///   %x = combine %a %b ..
 /// until a pure contraction is reached (no free/batch dimensions),
-/// which is replaced by a dot-product/reduction pair.
+/// which is replaced by a dot-product.
 ///
-/// TODO(ajcbik): break down into transpose/reshape/cast ops
-///               when they become available to avoid code dup
-/// TODO(ajcbik): investigate lowering order impact on performance
+/// This only kicks in when either VectorTransformsOptions is set
+/// to DOT or when other contraction patterns fail.
+//
+// TODO(ajcbik): break down into transpose/reshape/cast ops
+//               when they become available to avoid code dup
+// TODO(ajcbik): investigate lowering order impact on performance
 LogicalResult
 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
                                        PatternRewriter &rewriter) const {
@@ -1758,6 +1845,9 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
   ContractionOpToOuterProductOpLowering pat2(vectorTransformsOptions, ctx);
   if (succeeded(pat2.match(op)))
     return failure();
+  ContractionOpToAXPYLowering pat3(vectorTransformsOptions, ctx);
+  if (succeeded(pat3.match(op)))
+    return failure();
 
   // Find first batch dimension in LHS/RHS, and lower when found.
   std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
@@ -1943,6 +2033,7 @@ void mlir::vector::populateVectorContractLoweringPatterns(
   patterns.insert<TransposeOpLowering,
                   ContractionOpLowering,
                   ContractionOpToMatmulOpLowering,
-                  ContractionOpToOuterProductOpLowering>(parameters, context);
+                  ContractionOpToOuterProductOpLowering,
+                  ContractionOpToAXPYLowering>(parameters, context);
   // clang-format on
 }

diff  --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
new file mode 100644
index 000000000000..81c70983cded
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
@@ -0,0 +1,163 @@
+// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-axpy=1 | FileCheck %s
+
+#matvec_accesses = [
+  affine_map<(i, j) -> (i, j)>,
+  affine_map<(i, j) -> (j)>,
+  affine_map<(i, j) -> (i)>
+]
+#matvec_trait = {
+  indexing_maps = #matvec_accesses,
+  iterator_types = ["parallel", "reduction"]
+}
+
+#mattransvec_accesses = [
+  affine_map<(i, j) -> (j, i)>,
+  affine_map<(i, j) -> (j)>,
+  affine_map<(i, j) -> (i)>
+]
+#mattransvec_trait = {
+  indexing_maps = #mattransvec_accesses,
+  iterator_types = ["parallel", "reduction"]
+}
+
+#vecmat_accesses = [
+  affine_map<(i, j) -> (j)>,
+  affine_map<(i, j) -> (i, j)>,
+  affine_map<(i, j) -> (i)>
+]
+#vecmat_trait = {
+  indexing_maps = #vecmat_accesses,
+  iterator_types = ["parallel", "reduction"]
+}
+
+#vecmattrans_accesses = [
+  affine_map<(i, j) -> (j)>,
+  affine_map<(i, j) -> (j, i)>,
+  affine_map<(i, j) -> (i)>
+]
+#vecmattrans_trait = {
+  indexing_maps = #vecmattrans_accesses,
+  iterator_types = ["parallel", "reduction"]
+}
+
+// CHECK-LABEL: func @matvec2x2
+// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
+// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
+// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
+// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<2x2xf32>
+// CHECK: %[[T0:.*]] = load %[[A]][] : memref<vector<2x2xf32>>
+// CHECK: %[[T1:.*]] = load %[[B]][] : memref<vector<2xf32>>
+// CHECK: %[[T2:.*]] = load %[[C]][] : memref<vector<2xf32>>
+// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0, 0] : vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[T0]][1, 0] : vector<2x2xf32>
+// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T4]] [0, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[T0]][0, 1] : vector<2x2xf32>
+// CHECK: %[[T8:.*]] = vector.insert %[[T7]], %[[T6]] [1, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[T9:.*]] = vector.extract %[[T0]][1, 1] : vector<2x2xf32>
+// CHECK: %[[T10:.*]] = vector.insert %[[T9]], %[[T8]] [1, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[T11:.*]] = vector.extract %[[T10]][0] : vector<2x2xf32>
+// CHECK: %[[T12:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
+// CHECK: %[[T13:.*]] = splat %[[T12]] : vector<2xf32>
+// CHECK: %[[T14:.*]] = vector.fma %[[T11]], %[[T13]], %[[T2]] : vector<2xf32>
+// CHECK: %[[T15:.*]] = vector.extract %[[T10]][1] : vector<2x2xf32>
+// CHECK: %[[T16:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
+// CHECK: %[[T17:.*]] = splat %[[T16]] : vector<2xf32>
+// CHECK: %[[T18:.*]] = vector.fma %[[T15]], %[[T17]], %[[T14]] : vector<2xf32>
+// CHECK: store %[[T18]], %[[C]][] : memref<vector<2xf32>>
+func @matvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
+                                                %arg2: memref<vector<2xf32>>) {
+  %A = load %arg0[] : memref<vector<2x2xf32>>
+  %x = load %arg1[] : memref<vector<2xf32>>
+  %b = load %arg2[] : memref<vector<2xf32>>
+  %0 = vector.contract #matvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
+  store %0, %arg2[] : memref<vector<2xf32>>
+  return
+}
+
+// CHECK-LABEL: func @mattransvec2x2
+// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
+// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
+// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
+// CHECK: %[[T0:.*]] = load %[[A]][] : memref<vector<2x2xf32>>
+// CHECK: %[[T1:.*]] = load %[[B]][] : memref<vector<2xf32>>
+// CHECK: %[[T2:.*]] = load %[[C]][] : memref<vector<2xf32>>
+// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
+// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.fma %[[T3]], %[[T5]], %[[T2]] : vector<2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
+// CHECK: %[[T9:.*]] = splat %[[T8]] : vector<2xf32>
+// CHECK: %[[T10:.*]] = vector.fma %[[T7]], %[[T9]], %[[T6]] : vector<2xf32>
+// CHECK: store %[[T10]], %[[C]][] : memref<vector<2xf32>>
+func @mattransvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
+                                                     %arg2: memref<vector<2xf32>>) {
+  %A = load %arg0[] : memref<vector<2x2xf32>>
+  %x = load %arg1[] : memref<vector<2xf32>>
+  %b = load %arg2[] : memref<vector<2xf32>>
+  %0 = vector.contract #mattransvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
+  store %0, %arg2[] : memref<vector<2xf32>>
+  return
+}
+
+// CHECK-LABEL: func @vecmat2x2
+// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
+// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
+// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
+// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<2x2xf32>
+// CHECK: %[[T0:.*]] = load %[[A]][] : memref<vector<2x2xf32>>
+// CHECK: %[[T1:.*]] = load %[[B]][] : memref<vector<2xf32>>
+// CHECK: %[[T2:.*]] = load %[[C]][] : memref<vector<2xf32>>
+// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0, 0] : vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[T0]][1, 0] : vector<2x2xf32>
+// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T4]] [0, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[T0]][0, 1] : vector<2x2xf32>
+// CHECK: %[[T8:.*]] = vector.insert %[[T7]], %[[T6]] [1, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[T9:.*]] = vector.extract %[[T0]][1, 1] : vector<2x2xf32>
+// CHECK: %[[T10:.*]] = vector.insert %[[T9]], %[[T8]] [1, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[T11:.*]] = vector.extract %[[T10]][0] : vector<2x2xf32>
+// CHECK: %[[T12:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
+// CHECK: %[[T13:.*]] = splat %[[T12]] : vector<2xf32>
+// CHECK: %[[T14:.*]] = vector.fma %[[T11]], %[[T13]], %[[T2]] : vector<2xf32>
+// CHECK: %[[T15:.*]] = vector.extract %[[T10]][1] : vector<2x2xf32>
+// CHECK: %[[T16:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
+// CHECK: %[[T17:.*]] = splat %[[T16]] : vector<2xf32>
+// CHECK: %[[T18:.*]] = vector.fma %[[T15]], %[[T17]], %[[T14]] : vector<2xf32>
+// CHECK: store %[[T18]], %[[C]][] : memref<vector<2xf32>>
+func @vecmat2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
+                                                %arg2: memref<vector<2xf32>>) {
+  %A = load %arg0[] : memref<vector<2x2xf32>>
+  %x = load %arg1[] : memref<vector<2xf32>>
+  %b = load %arg2[] : memref<vector<2xf32>>
+  %0 = vector.contract #vecmat_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
+  store %0, %arg2[] : memref<vector<2xf32>>
+  return
+}
+
+// CHECK-LABEL: func @vecmattrans2x2
+// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
+// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
+// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
+// CHECK: %[[T0:.*]] = load %[[A]][] : memref<vector<2x2xf32>>
+// CHECK: %[[T1:.*]] = load %[[B]][] : memref<vector<2xf32>>
+// CHECK: %[[T2:.*]] = load %[[C]][] : memref<vector<2xf32>>
+// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
+// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.fma %[[T3]], %[[T5]], %[[T2]] : vector<2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
+// CHECK: %[[T9:.*]] = splat %[[T8]] : vector<2xf32>
+// CHECK: %[[T10:.*]] = vector.fma %[[T7]], %[[T9]], %[[T6]] : vector<2xf32>
+// CHECK: store %[[T10]], %[[C]][] : memref<vector<2xf32>>
+func @vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
+                                                     %arg2: memref<vector<2xf32>>) {
+  %A = load %arg0[] : memref<vector<2x2xf32>>
+  %x = load %arg1[] : memref<vector<2xf32>>
+  %b = load %arg2[] : memref<vector<2xf32>>
+  %0 = vector.contract #vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
+  store %0, %arg2[] : memref<vector<2xf32>>
+  return
+}

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 22585fde4ff7..c6cf45e824d7 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -59,9 +59,14 @@ struct TestVectorContractionConversion
       *this, "vector-outerproduct",
       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
       llvm::cl::init(false)};
+  Option<bool> lowerToAXPY{*this, "vector-axpy",
+                           llvm::cl::desc("Lower vector.contract to AXPY"),
+                           llvm::cl::init(false)};
 
   void runOnFunction() override {
     OwningRewritePatternList patterns;
+
+    // Test on one pattern in isolation.
     if (lowerToOuterProduct) {
       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
       VectorTransformsOptions options{lowering};
@@ -71,9 +76,12 @@ struct TestVectorContractionConversion
       return;
     }
 
-    VectorContractLowering contractLowering = VectorContractLowering::FMA;
+    // Test on all contract lowering patterns.
+    VectorContractLowering contractLowering = VectorContractLowering::Dot;
     if (lowerToFlatMatrix)
       contractLowering = VectorContractLowering::Matmul;
+    else if (lowerToAXPY)
+      contractLowering = VectorContractLowering::AXPY;
     VectorTransposeLowering transposeLowering =
         VectorTransposeLowering::EltWise;
     if (lowerToFlatTranspose)


        


More information about the Mlir-commits mailing list