[Mlir-commits] [mlir] 9578a54 - [mlir][Vector] Add vector contraction to outerproduct lowering

Nicolas Vasilache llvmlistbot at llvm.org
Tue May 26 06:33:09 PDT 2020


Author: Nicolas Vasilache
Date: 2020-05-26T09:31:26-04:00
New Revision: 9578a54f5007e8a02cef449dd151da27837b388e

URL: https://github.com/llvm/llvm-project/commit/9578a54f5007e8a02cef449dd151da27837b388e
DIFF: https://github.com/llvm/llvm-project/commit/9578a54f5007e8a02cef449dd151da27837b388e.diff

LOG: [mlir][Vector] Add vector contraction to outerproduct lowering

This revision adds the additional lowering and exposes the patterns at a finer granularity for better programmatic reuse. The unit test makes use of the finer grained pattern for simpler checks.

As the ContractionOpLowering is exposed programmatically, cleanup opportunities appear and static class methods are turned into free functions with static visibility.

Differential Revision: https://reviews.llvm.org/D80375

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/include/mlir/Dialect/Vector/VectorTransforms.h
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-contract-transforms.mlir
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 6394fae21375..423c72da6471 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -25,13 +25,6 @@ class MLIRContext;
 class OwningRewritePatternList;
 namespace vector {
 
-/// Structure to control the behavior of vector transform patterns.
-struct VectorTransformsOptions {
-  /// Let vector.contract lower to vector.matrix_multiply and LLVM matrix
-  /// intrinsics.
-  bool lowerToLLVMMatrixIntrinsics = false;
-};
-
 /// Collect a set of vector-to-vector canonicalization patterns.
 void populateVectorToVectorCanonicalizationPatterns(
     OwningRewritePatternList &patterns, MLIRContext *context);
@@ -51,6 +44,20 @@ void populateVectorToVectorTransformationPatterns(
 void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns,
                                           MLIRContext *context);
 
+/// Enum to control the lowering of `vector.contract` operations.
+enum class VectorContractLowering {
+  /// Progressively lower to finer grained `vector.contract` and `vector.fma`.
+  FMA = 0,
+  /// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics.
+  Matmul = 1,
+  /// Lower to `vector.outerproduct`.
+  OuterProduct = 2,
+};
+/// Structure to control the behavior of vector transform patterns.
+struct VectorTransformsOptions {
+  VectorContractLowering vectorContractLowering = VectorContractLowering::FMA;
+};
+
 /// Collect a set of transformation patterns that are related to contracting
 /// or expanding vector operations:
 ///   ContractionOpLowering,

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 264c8ad034c8..1b978e44dd6a 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -686,6 +686,11 @@ def Vector_OuterProductOp :
     return %3: vector<4x8xf32>
     ```
   }];
+  let builders = [
+    // Build an op without mask, use the type of `acc` as the return type.
+    OpBuilder<
+    "OpBuilder &builder, OperationState &result, Value lhs, Value rhs, "
+    "Value acc">];
   let extraClassDeclaration = [{
     VectorType getOperandVectorTypeLHS() {
       return lhs().getType().cast<VectorType>();

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index 337ac75f7cbb..08aa579d651b 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -9,6 +9,7 @@
 #ifndef DIALECT_VECTOR_VECTORTRANSFORMS_H_
 #define DIALECT_VECTOR_VECTORTRANSFORMS_H_
 
+#include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/PatternMatch.h"
 
 namespace mlir {
@@ -22,13 +23,6 @@ void populateVectorToVectorConversionPatterns(
     ArrayRef<int64_t> coarseVectorShape = {},
     ArrayRef<int64_t> fineVectorShape = {});
 
-////////////////////////////////////////////////////////////////////////////////
-// The following Declarative Rewrite Rule (DRR) helpers are used in rewrite
-// patterns. As such, they must not call into `rewriter.erase/replace` APIs and
-// it is the responsibility of the enclosing PatternRewriter to erase on
-// success.
-////////////////////////////////////////////////////////////////////////////////
-
 namespace vector {
 
 // Entry point for unrolling declarative pattern rewrites.
@@ -69,6 +63,115 @@ unrollSingleResultOpMatchingType(OpBuilder &builder, Operation *op,
                                  ArrayRef<int64_t> targetShape);
 
 } // namespace vector
+
+//===----------------------------------------------------------------------===//
+// Finer-grained patterns exposed for more control over individual lowerings.
+//===----------------------------------------------------------------------===//
+
+/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to:
+/// ```
+///    %flattened_a = vector.shape_cast %a
+///    %flattened_b = vector.shape_cast %b
+///    %flattened_d = vector.matmul %flattened_a, %flattened_b
+///    %d = vector.shape_cast %%flattened_d
+///    %e = add %c, %d
+/// ```
+/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
+//
+/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
+/// the vector.contract op is a row-major matrix multiply.
+class ContractionOpToMatmulOpLowering
+    : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  ContractionOpToMatmulOpLowering(
+      vector::VectorTransformsOptions vectorTransformsOptions,
+      MLIRContext *context)
+      : OpRewritePattern<vector::ContractionOp>(context),
+        vectorTransformsOptions(vectorTransformsOptions) {}
+
+  LogicalResult match(vector::ContractionOp op) const override;
+  void rewrite(vector::ContractionOp op,
+               PatternRewriter &rewriter) const override;
+
+private:
+  /// Options to control the vector patterns.
+  vector::VectorTransformsOptions vectorTransformsOptions;
+};
+
+/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to a reduction_size-unrolled sequence:
+/// ```
+///    %at = vector.transpose %a, [1, 0]
+///    %bRow0 = vector.extract %b[0]
+///    %atRow0 = vector.extract %at[0]
+///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
+///    ...
+///    %bRowK = vector.extract %b[K]
+///    %atRowK = vector.extract %at[K]
+///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
+/// ```
+///
+/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
+/// the vector.contract op is a row-major matrix multiply.
+class ContractionOpToOuterProductOpLowering
+    : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+  ContractionOpToOuterProductOpLowering(
+      vector::VectorTransformsOptions vectorTransformsOptions,
+      MLIRContext *context)
+      : OpRewritePattern<vector::ContractionOp>(context),
+        vectorTransformsOptions(vectorTransformsOptions) {}
+
+  LogicalResult match(vector::ContractionOp op) const override;
+  void rewrite(vector::ContractionOp op,
+               PatternRewriter &rewriter) const override;
+
+private:
+  /// Options to control the vector patterns.
+  vector::VectorTransformsOptions vectorTransformsOptions;
+};
+
+/// Progressive lowering of ContractionOp.
+///
+/// One:
+///   %x = vector.contract with at least one free/batch dimension
+/// is replaced by:
+///   %a = vector.contract with one less free/batch dimension
+///   %b = vector.contract with one less free/batch dimension
+///   ..
+///   %x = combine %a %b ..
+/// until a pure contraction is reached (no free/batch dimensions),
+/// which is replaced by a fma/reduction op.
+///
+/// This only kicks in when either VectorTransformsOptions is set to FMA or when
+/// other contraction patterns fail.
+class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  ContractionOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,
+                        MLIRContext *context)
+      : OpRewritePattern<vector::ContractionOp>(context),
+        vectorTransformsOptions(vectorTransformsOptions) {}
+
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override;
+
+private:
+  /// Options to control the vector patterns.
+  vector::VectorTransformsOptions vectorTransformsOptions;
+  // Lower one parallel dimension.
+  Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
+                      int64_t rhsIndex, PatternRewriter &rewriter) const;
+  // Lower one reduction dimension.
+  Value lowerReduction(vector::ContractionOp op,
+                       PatternRewriter &rewriter) const;
+};
+
 } // namespace mlir
 
 #endif // DIALECT_VECTOR_VECTORTRANSFORMS_H_

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 5439233c96b1..1574edb34494 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -957,6 +957,13 @@ static LogicalResult verify(InsertStridedSliceOp op) {
 // OuterProductOp
 //===----------------------------------------------------------------------===//
 
+/// Build an op without mask, use the type of `acc` as the return type.
+void OuterProductOp::build(OpBuilder &builder, OperationState &result,
+                           Value lhs, Value rhs, Value acc) {
+  result.addOperands({lhs, rhs, acc});
+  result.addTypes(acc.getType());
+}
+
 static void print(OpAsmPrinter &p, OuterProductOp op) {
   p << op.getOperationName() << " " << op.lhs() << ", " << op.rhs();
   if (!op.acc().empty())

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 1c1de155d8b6..44ff03a04f22 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -39,6 +39,120 @@
 using namespace mlir;
 using llvm::dbgs;
 
+// Helper to find an index in an affine map.
+static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
+  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
+    int64_t idx = map.getResult(i).cast<AffineDimExpr>().getPosition();
+    if (idx == index)
+      return i;
+  }
+  return None;
+}
+
+// Helper to construct iterator types with one index removed.
+static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
+                                            int64_t index) {
+  SmallVector<Attribute, 4> results;
+  for (auto it : llvm::enumerate(iteratorTypes)) {
+    int64_t idx = it.index();
+    if (idx == index)
+      continue;
+    results.push_back(it.value());
+  }
+  return results;
+}
+
+// Helper to construct an affine map with one index removed.
+static AffineMap adjustMap(AffineMap map, int64_t index,
+                           PatternRewriter &rewriter) {
+  auto *ctx = rewriter.getContext();
+  SmallVector<AffineExpr, 4> results;
+  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
+    int64_t idx = map.getResult(i).cast<AffineDimExpr>().getPosition();
+    if (idx == index)
+      continue;
+    // Re-insert remaining indices, but renamed when occurring
+    // after the removed index.
+    auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
+    results.push_back(targetExpr);
+  }
+  return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
+}
+
+// Helper to drop dimension from vector type.
+static Type adjustType(VectorType tp, int64_t index) {
+  int64_t rank = tp.getRank();
+  Type eltType = tp.getElementType();
+  if (rank == 1) {
+    assert(index == 0 && "index for scalar result out of bounds");
+    return eltType;
+  }
+  SmallVector<int64_t, 4> adjustedShape;
+  for (int64_t i = 0; i < rank; ++i) {
+    // Omit dimension at the given index.
+    if (i == index)
+      continue;
+    // Otherwise, add dimension back.
+    adjustedShape.push_back(tp.getDimSize(i));
+  }
+  return VectorType::get(adjustedShape, eltType);
+}
+
+// Helper method to possibly drop a dimension in a load.
+// TODO(ajcbik): use a reshaping vector load (and share lowering code)
+static Value reshapeLoad(Location loc, Value val, VectorType type,
+                         int64_t index, int64_t pos,
+                         PatternRewriter &rewriter) {
+  if (index == -1)
+    return val;
+  Type lowType = adjustType(type, 0);
+  // At extraction dimension?
+  if (index == 0) {
+    auto posAttr = rewriter.getI64ArrayAttr(pos);
+    return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
+  }
+  // Unroll leading dimensions.
+  VectorType vType = lowType.cast<VectorType>();
+  VectorType resType = adjustType(type, index).cast<VectorType>();
+  Value result =
+      rewriter.create<ConstantOp>(loc, resType, rewriter.getZeroAttr(resType));
+  for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
+    auto posAttr = rewriter.getI64ArrayAttr(d);
+    Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
+    Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
+    result =
+        rewriter.create<vector::InsertOp>(loc, resType, load, result, posAttr);
+  }
+  return result;
+}
+
+// Helper method to possibly drop a dimension in a store.
+// TODO(ajcbik): use a reshaping vector store (and share lowering code)
+static Value reshapeStore(Location loc, Value val, Value result,
+                          VectorType type, int64_t index, int64_t pos,
+                          PatternRewriter &rewriter) {
+  // Unmodified?
+  if (index == -1)
+    return val;
+  // At insertion dimension?
+  if (index == 0) {
+    auto posAttr = rewriter.getI64ArrayAttr(pos);
+    return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
+  }
+  // Unroll leading dimensions.
+  Type lowType = adjustType(type, 0);
+  VectorType vType = lowType.cast<VectorType>();
+  Type insType = adjustType(vType, 0);
+  for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
+    auto posAttr = rewriter.getI64ArrayAttr(d);
+    Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
+    Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
+    Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
+    result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
+  }
+  return result;
+}
+
 // Clones `op` into a new operations that takes `operands` and returns
 // `resultTypes`.
 static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
@@ -1252,343 +1366,6 @@ class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
   }
 };
 
-/// Progressive lowering of ContractionOp.
-/// One:
-///   %x = vector.contract with at least one free/batch dimension
-/// is replaced by:
-///   %a = vector.contract with one less free/batch dimension
-///   %b = vector.contract with one less free/batch dimension
-///   ..
-///   %x = combine %a %b ..
-/// until a pure contraction is reached (no free/batch dimensions),
-/// which is replaced by a fma/reduction op.
-///
-/// TODO(ajcbik): break down into transpose/reshape/cast ops
-///               when they become available to avoid code dup
-/// TODO(ajcbik): investigate lowering order impact on performance
-class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
-public:
-  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
-
-  ContractionOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,
-                        MLIRContext *context)
-      : OpRewritePattern<vector::ContractionOp>(context),
-        vectorTransformsOptions(vectorTransformsOptions) {}
-
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override {
-    // TODO(ajcbik): implement masks
-    if (llvm::size(op.masks()) != 0)
-      return failure();
-
-    // TODO(ntv, ajcbik): implement benefits, cost models, separate this out in
-    // a new pattern.
-    if (vectorTransformsOptions.lowerToLLVMMatrixIntrinsics &&
-        isRowMajorMatmul(op.indexing_maps())) {
-      VectorType lhsType = op.getLhsType();
-      VectorType rhsType = op.getRhsType();
-      unsigned lhsRows = op.getLhsType().getShape()[0];
-      unsigned lhsColumns = op.getLhsType().getShape()[1];
-      unsigned rhsColumns = op.getRhsType().getShape()[1];
-
-      Type flattenedLHSType =
-          VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
-      Type flattenedRHSType =
-          VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
-      auto lhs = rewriter.create<vector::ShapeCastOp>(
-          op.getLoc(), flattenedLHSType, op.lhs());
-      auto rhs = rewriter.create<vector::ShapeCastOp>(
-          op.getLoc(), flattenedRHSType, op.rhs());
-
-      Value mul = rewriter.create<vector::MatmulOp>(
-          op.getLoc(), lhs, rhs, lhsRows, lhsColumns, rhsColumns);
-      mul = rewriter.create<vector::ShapeCastOp>(op.getLoc(),
-                                                 op.acc().getType(), mul);
-      Type elementType = op.getLhsType().getElementType();
-      assert(elementType.isIntOrFloat());
-      if (elementType.isa<IntegerType>())
-        rewriter.replaceOpWithNewOp<AddIOp>(op, op.acc(), mul);
-      else
-        rewriter.replaceOpWithNewOp<AddFOp>(op, op.acc(), mul);
-      return success();
-    }
-
-    // Find first batch dimension in LHS/RHS, and lower when found.
-    std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
-    if (!batchDimMap.empty()) {
-      int64_t lhsIndex = batchDimMap[0].first;
-      int64_t rhsIndex = batchDimMap[0].second;
-      rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
-      return success();
-    }
-
-    // Collect contracting dimensions.
-    std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
-        op.getContractingDimMap();
-    DenseSet<int64_t> lhsContractingDimSet;
-    DenseSet<int64_t> rhsContractingDimSet;
-    for (auto &dimPair : contractingDimMap) {
-      lhsContractingDimSet.insert(dimPair.first);
-      rhsContractingDimSet.insert(dimPair.second);
-    }
-
-    // Find first free dimension in LHS, and lower when found.
-    VectorType lhsType = op.getLhsType();
-    for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e;
-         ++lhsIndex) {
-      if (lhsContractingDimSet.count(lhsIndex) == 0) {
-        rewriter.replaceOp(
-            op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
-        return success();
-      }
-    }
-
-    // Find first free dimension in RHS, and lower when found.
-    VectorType rhsType = op.getRhsType();
-    for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e;
-         ++rhsIndex) {
-      if (rhsContractingDimSet.count(rhsIndex) == 0) {
-        rewriter.replaceOp(
-            op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
-        return success();
-      }
-    }
-
-    // Lower the first remaining reduction dimension.
-    if (!contractingDimMap.empty()) {
-      rewriter.replaceOp(op, lowerReduction(op, rewriter));
-      return success();
-    }
-
-    return failure();
-  }
-
-private:
-  // Lower one parallel dimension.
-  // TODO(ajcbik): consider reusing existing contract unrolling
-  Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
-                      int64_t rhsIndex, PatternRewriter &rewriter) const {
-    VectorType lhsType = op.getLhsType();
-    VectorType rhsType = op.getRhsType();
-    VectorType resType = op.getResultType().cast<VectorType>();
-    // Find the iterator type index and result index.
-    SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
-    int64_t iterIndex = -1;
-    int64_t dimSize = -1;
-    if (lhsIndex >= 0) {
-      iterIndex =
-          iMap[0].getResult(lhsIndex).cast<AffineDimExpr>().getPosition();
-      assert((rhsIndex < 0 || iterIndex == iMap[1]
-                                               .getResult(rhsIndex)
-                                               .cast<AffineDimExpr>()
-                                               .getPosition()) &&
-             "parallel index should be free in LHS or batch in LHS/RHS");
-      dimSize = lhsType.getDimSize(lhsIndex);
-    } else {
-      assert(rhsIndex >= 0 && "missing parallel index");
-      iterIndex =
-          iMap[1].getResult(rhsIndex).cast<AffineDimExpr>().getPosition();
-      dimSize = rhsType.getDimSize(rhsIndex);
-    }
-    assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
-    Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
-    assert(lookup.hasValue() && "parallel index not listed in reduction");
-    int64_t resIndex = lookup.getValue();
-    // Construct new iterator types and affine map array attribute.
-    SmallVector<AffineMap, 4> lowIndexingMaps;
-    lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter));
-    lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter));
-    lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter));
-    auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
-    auto lowIter =
-        rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
-    // Unroll into a series of lower dimensional vector.contract ops.
-    Location loc = op.getLoc();
-    Value result = rewriter.create<ConstantOp>(loc, resType,
-                                               rewriter.getZeroAttr(resType));
-    for (int64_t d = 0; d < dimSize; ++d) {
-      auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
-      auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
-      auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter);
-      Value lowContract = rewriter.create<vector::ContractionOp>(
-          loc, lhs, rhs, acc, lowAffine, lowIter);
-      result = reshapeStore(loc, lowContract, result, resType, resIndex, d,
-                            rewriter);
-    }
-    return result;
-  }
-
-  // Lower one reduction dimension.
-  Value lowerReduction(vector::ContractionOp op,
-                       PatternRewriter &rewriter) const {
-    auto loc = op.getLoc();
-    VectorType lhsType = op.getLhsType();
-    VectorType rhsType = op.getRhsType();
-    Type resType = op.getResultType();
-    assert(!resType.isa<VectorType>());
-    // Use iterator index 0.
-    int64_t iterIndex = 0;
-    SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
-    Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
-    Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
-    assert(lookupLhs.hasValue() && "missing LHS parallel index");
-    assert(lookupRhs.hasValue() && "missing RHS parallel index");
-    int64_t lhsIndex = lookupLhs.getValue();
-    int64_t rhsIndex = lookupRhs.getValue();
-    int64_t dimSize = lhsType.getDimSize(lhsIndex);
-    assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape");
-    // Base case.
-    if (lhsType.getRank() == 1) {
-      assert(rhsType.getRank() == 1 && "corrupt contraction");
-      Value zero = rewriter.create<ConstantOp>(loc, lhsType,
-                                               rewriter.getZeroAttr(lhsType));
-      Value fma = rewriter.create<vector::FMAOp>(loc, op.lhs(), op.rhs(), zero);
-      StringAttr kind = rewriter.getStringAttr("add");
-      return rewriter.create<vector::ReductionOp>(loc, resType, kind, fma,
-                                                  op.acc());
-    }
-    // Construct new iterator types and affine map array attribute.
-    SmallVector<AffineMap, 4> lowIndexingMaps;
-    lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter));
-    lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter));
-    lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter));
-    auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
-    auto lowIter =
-        rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
-    // Unroll into a series of lower dimensional vector.contract ops.
-    // By feeding the initial accumulator into the first contraction,
-    // and the result of each contraction into the next, eventually
-    // the sum of all reductions is computed.
-    Value result = op.acc();
-    for (int64_t d = 0; d < dimSize; ++d) {
-      auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
-      auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
-      result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
-                                                      lowAffine, lowIter);
-    }
-    return result;
-  }
-
-  // Helper to find an index in an affine map.
-  static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
-    for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
-      int64_t idx = map.getResult(i).cast<AffineDimExpr>().getPosition();
-      if (idx == index)
-        return i;
-    }
-    return None;
-  }
-
-  // Helper to construct iterator types with one index removed.
-  static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
-                                              int64_t index) {
-    SmallVector<Attribute, 4> results;
-    for (auto it : llvm::enumerate(iteratorTypes)) {
-      int64_t idx = it.index();
-      if (idx == index)
-        continue;
-      results.push_back(it.value());
-    }
-    return results;
-  }
-
-  // Helper to construct an affine map with one index removed.
-  static AffineMap adjustMap(AffineMap map, int64_t index,
-                             PatternRewriter &rewriter) {
-    auto *ctx = rewriter.getContext();
-    SmallVector<AffineExpr, 4> results;
-    for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
-      int64_t idx = map.getResult(i).cast<AffineDimExpr>().getPosition();
-      if (idx == index)
-        continue;
-      // Re-insert remaining indices, but renamed when occurring
-      // after the removed index.
-      auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
-      results.push_back(targetExpr);
-    }
-    return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
-  }
-
-  // Helper to drop dimension from vector type.
-  static Type adjustType(VectorType tp, int64_t index) {
-    int64_t rank = tp.getRank();
-    Type eltType = tp.getElementType();
-    if (rank == 1) {
-      assert(index == 0 && "index for scalar result out of bounds");
-      return eltType;
-    }
-    SmallVector<int64_t, 4> adjustedShape;
-    for (int64_t i = 0; i < rank; ++i) {
-      // Omit dimension at the given index.
-      if (i == index)
-        continue;
-      // Otherwise, add dimension back.
-      adjustedShape.push_back(tp.getDimSize(i));
-    }
-    return VectorType::get(adjustedShape, eltType);
-  }
-
-  // Helper method to possibly drop a dimension in a load.
-  // TODO(ajcbik): use a reshaping vector load (and share lowering code)
-  static Value reshapeLoad(Location loc, Value val, VectorType type,
-                           int64_t index, int64_t pos,
-                           PatternRewriter &rewriter) {
-    if (index == -1)
-      return val;
-    Type lowType = adjustType(type, 0);
-    // At extraction dimension?
-    if (index == 0) {
-      auto posAttr = rewriter.getI64ArrayAttr(pos);
-      return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
-    }
-    // Unroll leading dimensions.
-    VectorType vType = lowType.cast<VectorType>();
-    VectorType resType = adjustType(type, index).cast<VectorType>();
-    Value result = rewriter.create<ConstantOp>(loc, resType,
-                                               rewriter.getZeroAttr(resType));
-    for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
-      auto posAttr = rewriter.getI64ArrayAttr(d);
-      Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
-      Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
-      result = rewriter.create<vector::InsertOp>(loc, resType, load, result,
-                                                 posAttr);
-    }
-    return result;
-  }
-
-  // Helper method to possibly drop a dimension in a store.
-  // TODO(ajcbik): use a reshaping vector store (and share lowering code)
-  static Value reshapeStore(Location loc, Value val, Value result,
-                            VectorType type, int64_t index, int64_t pos,
-                            PatternRewriter &rewriter) {
-    // Unmodified?
-    if (index == -1)
-      return val;
-    // At insertion dimension?
-    if (index == 0) {
-      auto posAttr = rewriter.getI64ArrayAttr(pos);
-      return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
-    }
-    // Unroll leading dimensions.
-    Type lowType = adjustType(type, 0);
-    VectorType vType = lowType.cast<VectorType>();
-    Type insType = adjustType(vType, 0);
-    for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
-      auto posAttr = rewriter.getI64ArrayAttr(d);
-      Value ext =
-          rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
-      Value ins =
-          rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
-      Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
-      result =
-          rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
-    }
-    return result;
-  }
-
-  vector::VectorTransformsOptions vectorTransformsOptions;
-};
-
 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
 /// vectors progressively on the way to target llvm.matrix intrinsics.
 /// This iterates over the most major dimension of the 2-D vector and performs
@@ -1656,6 +1433,302 @@ class ShapeCastOp2DUpCastRewritePattern
 
 } // namespace
 
+namespace mlir {
+
+/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to:
+/// ```
+///    %flattened_a = vector.shape_cast %a
+///    %flattened_b = vector.shape_cast %b
+///    %flattened_d = vector.matmul %flattened_a, %flattened_b
+///    %d = vector.shape_cast %%flattened_d
+///    %e = add %c, %d
+/// ```
+/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
+//
+/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
+/// the vector.contract op is a row-major matrix multiply.
+LogicalResult
+ContractionOpToMatmulOpLowering::match(vector::ContractionOp op) const {
+  // TODO(ajcbik): implement masks
+  if (llvm::size(op.masks()) != 0)
+    return failure();
+
+  if (vectorTransformsOptions.vectorContractLowering !=
+          vector::VectorContractLowering::Matmul ||
+      !isRowMajorMatmul(op.indexing_maps()))
+    return failure();
+  return success();
+}
+
+void ContractionOpToMatmulOpLowering::rewrite(vector::ContractionOp op,
+                                              PatternRewriter &rewriter) const {
+  VectorType lhsType = op.getLhsType();
+  VectorType rhsType = op.getRhsType();
+  unsigned lhsRows = op.getLhsType().getShape()[0];
+  unsigned lhsColumns = op.getLhsType().getShape()[1];
+  unsigned rhsColumns = op.getRhsType().getShape()[1];
+
+  Type flattenedLHSType =
+      VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
+  Type flattenedRHSType =
+      VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
+  auto lhs = rewriter.create<vector::ShapeCastOp>(op.getLoc(), flattenedLHSType,
+                                                  op.lhs());
+  auto rhs = rewriter.create<vector::ShapeCastOp>(op.getLoc(), flattenedRHSType,
+                                                  op.rhs());
+
+  Value mul = rewriter.create<vector::MatmulOp>(op.getLoc(), lhs, rhs, lhsRows,
+                                                lhsColumns, rhsColumns);
+  mul = rewriter.create<vector::ShapeCastOp>(op.getLoc(), op.acc().getType(),
+                                             mul);
+  Type elementType = op.getLhsType().getElementType();
+  assert(elementType.isIntOrFloat());
+  if (elementType.isa<IntegerType>())
+    rewriter.replaceOpWithNewOp<AddIOp>(op, op.acc(), mul);
+  else
+    rewriter.replaceOpWithNewOp<AddFOp>(op, op.acc(), mul);
+}
+
+/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to a reduction_size-unrolled sequence:
+/// ```
+///    %at = vector.transpose %a, [1, 0]
+///    %bRow0 = vector.extract %b[0]
+///    %atRow0 = vector.extract %at[0]
+///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
+///    ...
+///    %bRowK = vector.extract %b[K]
+///    %atRowK = vector.extract %at[K]
+///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
+/// ```
+///
+/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
+/// the vector.contract op is a row-major matrix multiply.
+void ContractionOpToOuterProductOpLowering::rewrite(
+    vector::ContractionOp op, PatternRewriter &rewriter) const {
+  VectorType lhsType = op.getLhsType();
+  // TODO(ntv) other modes.
+  // We know we are in row-major.
+  bool transposeLhs = false;
+  unsigned reductionSize =
+      transposeLhs ? lhsType.getShape()[0] : lhsType.getShape()[1];
+
+  // If transposeLhs == false (i.e. lhs(m, reductionSize)), we need to
+  // transpose it to extract the proper vector<m x f32>. Otherwise, just take
+  // the lhs.
+  Value lhs = transposeLhs
+                  ? op.lhs()
+                  : rewriter.create<vector::TransposeOp>(
+                        op.getLoc(), op.lhs(), ArrayRef<int64_t>{1, 0});
+  Value res = op.acc();
+  // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
+  for (unsigned k = 0; k < reductionSize; ++k) {
+    Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, k);
+    Value b = rewriter.create<vector::ExtractOp>(op.getLoc(), op.rhs(), k);
+    res = rewriter.create<vector::OuterProductOp>(op.getLoc(), a, b, res);
+  }
+  rewriter.replaceOp(op, res);
+}
+
+LogicalResult
+ContractionOpToOuterProductOpLowering ::match(vector::ContractionOp op) const {
+  // TODO(ajcbik): implement masks
+  if (llvm::size(op.masks()) != 0)
+    return failure();
+
+  if (vectorTransformsOptions.vectorContractLowering !=
+          vector::VectorContractLowering::OuterProduct ||
+      !isRowMajorMatmul(op.indexing_maps()))
+    return failure();
+  return success();
+}
+
+/// Progressive lowering of ContractionOp.
+/// One:
+///   %x = vector.contract with at least one free/batch dimension
+/// is replaced by:
+///   %a = vector.contract with one less free/batch dimension
+///   %b = vector.contract with one less free/batch dimension
+///   ..
+///   %x = combine %a %b ..
+/// until a pure contraction is reached (no free/batch dimensions),
+/// which is replaced by a fma/reduction op.
+///
+/// TODO(ajcbik): break down into transpose/reshape/cast ops
+///               when they become available to avoid code dup
+/// TODO(ajcbik): investigate lowering order impact on performance
+LogicalResult
+ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
+                                       PatternRewriter &rewriter) const {
+
+  // TODO(ajcbik): implement masks.
+  if (llvm::size(op.masks()) != 0)
+    return failure();
+
+  // TODO(ntv, ajcbik): implement benefits, cost models.
+  MLIRContext *ctx = op.getContext();
+  ContractionOpToMatmulOpLowering pat1(vectorTransformsOptions, ctx);
+  if (succeeded(pat1.match(op)))
+    return failure();
+  ContractionOpToOuterProductOpLowering pat2(vectorTransformsOptions, ctx);
+  if (succeeded(pat2.match(op)))
+    return failure();
+
+  // Find first batch dimension in LHS/RHS, and lower when found.
+  std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
+  if (!batchDimMap.empty()) {
+    int64_t lhsIndex = batchDimMap[0].first;
+    int64_t rhsIndex = batchDimMap[0].second;
+    rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
+    return success();
+  }
+
+  // Collect contracting dimensions.
+  std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
+      op.getContractingDimMap();
+  DenseSet<int64_t> lhsContractingDimSet;
+  DenseSet<int64_t> rhsContractingDimSet;
+  for (auto &dimPair : contractingDimMap) {
+    lhsContractingDimSet.insert(dimPair.first);
+    rhsContractingDimSet.insert(dimPair.second);
+  }
+
+  // Find first free dimension in LHS, and lower when found.
+  VectorType lhsType = op.getLhsType();
+  for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
+    if (lhsContractingDimSet.count(lhsIndex) == 0) {
+      rewriter.replaceOp(
+          op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
+      return success();
+    }
+  }
+
+  // Find first free dimension in RHS, and lower when found.
+  VectorType rhsType = op.getRhsType();
+  for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
+    if (rhsContractingDimSet.count(rhsIndex) == 0) {
+      rewriter.replaceOp(
+          op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
+      return success();
+    }
+  }
+
+  // Lower the first remaining reduction dimension.
+  if (!contractingDimMap.empty()) {
+    rewriter.replaceOp(op, lowerReduction(op, rewriter));
+    return success();
+  }
+
+  return failure();
+}
+
+// Lower one parallel dimension.
+// TODO(ajcbik): consider reusing existing contract unrolling
+Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
+                                           int64_t lhsIndex, int64_t rhsIndex,
+                                           PatternRewriter &rewriter) const {
+  VectorType lhsType = op.getLhsType();
+  VectorType rhsType = op.getRhsType();
+  VectorType resType = op.getResultType().cast<VectorType>();
+  // Find the iterator type index and result index.
+  SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
+  int64_t iterIndex = -1;
+  int64_t dimSize = -1;
+  if (lhsIndex >= 0) {
+    iterIndex = iMap[0].getResult(lhsIndex).cast<AffineDimExpr>().getPosition();
+    assert(
+        (rhsIndex < 0 ||
+         iterIndex ==
+             iMap[1].getResult(rhsIndex).cast<AffineDimExpr>().getPosition()) &&
+        "parallel index should be free in LHS or batch in LHS/RHS");
+    dimSize = lhsType.getDimSize(lhsIndex);
+  } else {
+    assert(rhsIndex >= 0 && "missing parallel index");
+    iterIndex = iMap[1].getResult(rhsIndex).cast<AffineDimExpr>().getPosition();
+    dimSize = rhsType.getDimSize(rhsIndex);
+  }
+  assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
+  Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
+  assert(lookup.hasValue() && "parallel index not listed in reduction");
+  int64_t resIndex = lookup.getValue();
+  // Construct new iterator types and affine map array attribute.
+  SmallVector<AffineMap, 4> lowIndexingMaps;
+  lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter));
+  lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter));
+  lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter));
+  auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
+  auto lowIter =
+      rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
+  // Unroll into a series of lower dimensional vector.contract ops.
+  Location loc = op.getLoc();
+  Value result =
+      rewriter.create<ConstantOp>(loc, resType, rewriter.getZeroAttr(resType));
+  for (int64_t d = 0; d < dimSize; ++d) {
+    auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
+    auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
+    auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter);
+    Value lowContract = rewriter.create<vector::ContractionOp>(
+        loc, lhs, rhs, acc, lowAffine, lowIter);
+    result =
+        reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter);
+  }
+  return result;
+}
+
+// Lower one reduction dimension.
+Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
+                                            PatternRewriter &rewriter) const {
+  auto loc = op.getLoc();
+  VectorType lhsType = op.getLhsType();
+  VectorType rhsType = op.getRhsType();
+  Type resType = op.getResultType();
+  assert(!resType.isa<VectorType>());
+  // Use iterator index 0.
+  int64_t iterIndex = 0;
+  SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
+  Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
+  Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
+  assert(lookupLhs.hasValue() && "missing LHS parallel index");
+  assert(lookupRhs.hasValue() && "missing RHS parallel index");
+  int64_t lhsIndex = lookupLhs.getValue();
+  int64_t rhsIndex = lookupRhs.getValue();
+  int64_t dimSize = lhsType.getDimSize(lhsIndex);
+  assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape");
+  // Base case.
+  if (lhsType.getRank() == 1) {
+    assert(rhsType.getRank() == 1 && "corrupt contraction");
+    Value zero = rewriter.create<ConstantOp>(loc, lhsType,
+                                             rewriter.getZeroAttr(lhsType));
+    Value fma = rewriter.create<vector::FMAOp>(loc, op.lhs(), op.rhs(), zero);
+    StringAttr kind = rewriter.getStringAttr("add");
+    return rewriter.create<vector::ReductionOp>(loc, resType, kind, fma,
+                                                op.acc());
+  }
+  // Construct new iterator types and affine map array attribute.
+  SmallVector<AffineMap, 4> lowIndexingMaps;
+  lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter));
+  lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter));
+  lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter));
+  auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
+  auto lowIter =
+      rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
+  // Unroll into a series of lower dimensional vector.contract ops.
+  // By feeding the initial accumulator into the first contraction,
+  // and the result of each contraction into the next, eventually
+  // the sum of all reductions is computed.
+  Value result = op.acc();
+  for (int64_t d = 0; d < dimSize; ++d) {
+    auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
+    auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
+    result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
+                                                    lowAffine, lowIter);
+  }
+  return result;
+}
+
+} // namespace mlir
+
 // TODO(andydavis) Add pattern to rewrite ExtractSlices(ConstantMaskOp).
 // TODO(andydavis) Add this as DRR pattern.
 void mlir::vector::populateVectorToVectorTransformationPatterns(
@@ -1685,6 +1758,8 @@ void mlir::vector::populateVectorContractLoweringPatterns(
                   ShapeCastOp2DDownCastRewritePattern,
                   ShapeCastOp2DUpCastRewritePattern,
                   TransposeOpLowering>(context);
+  patterns.insert<ContractionOpLowering,
+                  ContractionOpToMatmulOpLowering,
+                  ContractionOpToOuterProductOpLowering>(parameters, context);
   // clang-format on
-  patterns.insert<ContractionOpLowering>(parameters, context);
 }

diff  --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 72270dab1153..7eea3baa8d87 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -1,5 +1,6 @@
-// RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s
-// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX
+// RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s --dump-input-on-failure
+// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX --dump-input-on-failure
+// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT --dump-input-on-failure
 
 #dotp_accesses = [
   affine_map<(i) -> (i)>,
@@ -382,6 +383,35 @@ func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) {
 //      MATRIX:  %[[mm4:.*]] = vector.extract_strided_slice %[[mm1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
 //      MATRIX:  %[[mm5:.*]] = vector.insert %[[mm4]], %[[mm3]] [1] : vector<3xf32> into vector<2x3xf32>
 //      MATRIX:  %[[mm6:.*]] = addf %[[C]], %[[mm5]] : vector<2x3xf32>
+
+// OUTERPRODUCT-LABEL: func @matmul
+// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
+// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
+// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+//      OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// OUTERPRODUCT-SAME:  : vector<2x4xf32> to vector<4x2xf32>
+//
+//      OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<4x2xf32>
+//      OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<4x3xf32>
+//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+// OUTERPRODUCT-SAME:  : vector<2xf32>, vector<3xf32>
+//
+//      OUTERPRODUCT: %[[a1:.*]] = vector.extract %[[At]][1] : vector<4x2xf32>
+//      OUTERPRODUCT: %[[b1:.*]] = vector.extract %[[B]][1] : vector<4x3xf32>
+//      OUTERPRODUCT: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]]
+// OUTERPRODUCT-SAME:  : vector<2xf32>, vector<3xf32>
+//
+//      OUTERPRODUCT: %[[a2:.*]] = vector.extract %[[At]][2] : vector<4x2xf32>
+//      OUTERPRODUCT: %[[b2:.*]] = vector.extract %[[B]][2] : vector<4x3xf32>
+//      OUTERPRODUCT: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]]
+// OUTERPRODUCT-SAME:  : vector<2xf32>, vector<3xf32>
+//
+//      OUTERPRODUCT: %[[a3:.*]] = vector.extract %[[At]][3] : vector<4x2xf32>
+//      OUTERPRODUCT: %[[b3:.*]] = vector.extract %[[B]][3] : vector<4x3xf32>
+//      OUTERPRODUCT: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]]
+// OUTERPRODUCT-SAME:  : vector<2xf32>, vector<3xf32>
+//
+//      OUTERPRODUCT: return %[[c3]] : vector<2x3xf32>
 func @matmul(%arg0: vector<2x4xf32>,
                           %arg1: vector<4x3xf32>,
                           %arg2: vector<2x3xf32>) -> vector<2x3xf32> {

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index c57540bc2ef7..65024dbe3acd 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -51,11 +51,26 @@ struct TestVectorContractionConversion
       *this, "vector-lower-matrix-intrinsics",
       llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
       llvm::cl::init(false)};
+  Option<bool> lowerToOuterProduct{
+      *this, "vector-outerproduct",
+      llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
+      llvm::cl::init(false)};
 
   void runOnFunction() override {
     OwningRewritePatternList patterns;
-    VectorTransformsOptions options{
-        /*lowerToLLVMMatrixIntrinsics=*/lowerToLLVMMatrixIntrinsics};
+    if (lowerToOuterProduct) {
+      VectorContractLowering lowering = VectorContractLowering::OuterProduct;
+      VectorTransformsOptions options{lowering};
+      patterns.insert<ContractionOpToOuterProductOpLowering>(options,
+                                                             &getContext());
+      applyPatternsAndFoldGreedily(getFunction(), patterns);
+      return;
+    }
+
+    VectorContractLowering lowering = VectorContractLowering::FMA;
+    if (lowerToLLVMMatrixIntrinsics)
+      lowering = VectorContractLowering::Matmul;
+    VectorTransformsOptions options{lowering};
     populateVectorContractLoweringPatterns(patterns, &getContext(), options);
     applyPatternsAndFoldGreedily(getFunction(), patterns);
   }


        


More information about the Mlir-commits mailing list