[Mlir-commits] [mlir] f245b7a - [mlir][Linalg] Generalize the definition of a Linalg contraction.

Nicolas Vasilache llvmlistbot at llvm.org
Wed Feb 3 23:55:09 PST 2021


Author: Nicolas Vasilache
Date: 2021-02-04T07:50:44Z
New Revision: f245b7ad36ff8bd85cddbe9784f7efe6dee577c0

URL: https://github.com/llvm/llvm-project/commit/f245b7ad36ff8bd85cddbe9784f7efe6dee577c0
DIFF: https://github.com/llvm/llvm-project/commit/f245b7ad36ff8bd85cddbe9784f7efe6dee577c0.diff

LOG: [mlir][Linalg] Generalize the definition of a Linalg contraction.

This revision defines a Linalg contraction in general terms:

  1. Has 2 input and 1 output shapes.
  2. Has at least one reduction dimension.
  3. Has only projected permutation indexing maps.
  4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
    (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
    operations that may change the type (e.g. for mixed-precision).

As a consequence, when vectorization of such an op occurs, the only special
behavior is that the (unique) MulOpType is vectorized into a
`vector.contract`. All other ops are handled in a generic fashion.

 In the future, we may wish to allow more input arguments and elementwise and
 constant operations that do not involve the reduction dimension(s).

A test is added to demonstrate the proper vectorization of matmul_i8_i8_i32.

Differential revision: https://reviews.llvm.org/D95939

Added: 
    

Modified: 
    mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir
    mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir
    mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir
    mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_i8_i8_i32.mlir
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
    mlir/test/Dialect/Linalg/vectorization.mlir
    mlir/test/lib/Transforms/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir
index 7791ed0d5eee..5e577d778210 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir
@@ -77,7 +77,7 @@ func @main() {
   scf.for %arg0 = %c0 to %iters step %c1 {
     // linalg.matmul writes %C in place, need to reset it to zero every time.
     // This is accounts for about 10-15% perf hit on small sizes.
-    // Once linalg on tensors is ready, fusing fill at teh register level will
+    // Once linalg on tensors is ready, fusing fill at the register level will
     // be easy.
     %z = constant 0.0 : !elem_type_c
     linalg.fill(%C, %z) : !row_major_C, !elem_type_c

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir
index e454c7cb8160..de4e51bd8c0e 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir
@@ -75,7 +75,7 @@ func @main() {
   scf.for %arg0 = %c0 to %iters step %c1 {
     // linalg.matmul writes %C in place, need to reset it to zero every time.
     // This is accounts for about 10-15% perf hit on small sizes.
-    // Once linalg on tensors is ready, fusing fill at teh register level will
+    // Once linalg on tensors is ready, fusing fill at the register level will
     // be easy.
     linalg.fill(%cC, %f0) : !column_major_C, !elem_type_c
     call @matmul_column_major(%cA, %cB, %cC) : (!column_major_A, !column_major_B, !column_major_C) -> ()

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir
index 287cb1c24059..95fc57506c43 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir
@@ -84,7 +84,7 @@ func @main() {
   scf.for %arg0 = %c0 to %iters step %c1 {
     // linalg.matmul writes %C in place, need to reset it to zero every time.
     // This is accounts for about 10-15% perf hit on small sizes.
-    // Once linalg on tensors is ready, fusing fill at teh register level will
+    // Once linalg on tensors is ready, fusing fill at the register level will
     // be easy.
     linalg.fill(%C, %f0) : !row_major_C, !elem_type_c
     call @matmul_column_major_as_row_major(%cA, %cB, %cC, %A, %B, %C) :

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_i8_i8_i32.mlir b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_i8_i8_i32.mlir
index 961a83fd3f57..abfb14739e25 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_i8_i8_i32.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_i8_i8_i32.mlir
@@ -1,12 +1,11 @@
 // RUN: export M=24 && export K=64 && export N=192 && export ITERS=10 && \
 // RUN: cat %s | sed 's@${M}@'"$M"'@g'| sed 's@${K}@'"$K"'@g' | sed 's@${N}@'"$N"'@g'| sed 's@${ITERS}@'"$ITERS"'@g'| \
-// TODO: extend vectorization with interfaces so that it works with sexti
-// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.matmul_i8_i8_i32 register-tile-sizes=12,32,16" | \
+// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.matmul_i8_i8_i32 register-tile-sizes=12,32,16 vectorize" | \
 // RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.fill register-tile-sizes=4,32 vectorize" | \
 // RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.copy register-tile-sizes=4,32 vectorize" | \
 // RUN: mlir-opt -canonicalize -convert-vector-to-scf -lower-affine -convert-linalg-to-loops | \
 
-// RUN: mlir-opt -canonicalize -convert-scf-to-std -convert-vector-to-llvm | \
+// RUN: mlir-opt -canonicalize -convert-scf-to-std -convert-vector-to-llvm -mlir-disable-threading | \
 // RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \
 // Activate to dump assembly
 // R_UN:   -dump-object-file -object-filename=/tmp/a.o \
@@ -18,9 +17,9 @@
 !elem_type_a = type i8
 !elem_type_b = type i8
 !elem_type_c = type i32
-!row_major_A = type memref<${M}x${K}x!elem_type_a>
-!row_major_B = type memref<${K}x${N}x!elem_type_b>
-!row_major_C = type memref<${M}x${N}x!elem_type_c>
+!row_major_A = type memref<24x64x!elem_type_a>
+!row_major_B = type memref<64x192x!elem_type_b>
+!row_major_C = type memref<24x192x!elem_type_c>
 
 func @matmul(%a: !row_major_A, %b: !row_major_B, %c: !row_major_C)
 // TODO: activate manually for now.
@@ -33,9 +32,9 @@ func @matmul(%a: !row_major_A, %b: !row_major_B, %c: !row_major_C)
 
 func @print_perf(%iters: index, %total_time: f64) {
   %c2 = constant 2 : index
-  %cM = constant ${M} : index
-  %cN = constant ${N} : index
-  %cK = constant ${K} : index
+  %cM = constant 24 : index
+  %cN = constant 192 : index
+  %cK = constant 64 : index
 
   %mn = muli %cM, %cN : index
   %mnk = muli %mn, %cK : index
@@ -65,7 +64,7 @@ func @main() {
 
   %c0 = constant 0: index
   %c1 = constant 1: index
-  %iters = constant ${ITERS}: index
+  %iters = constant 100: index
 
   /// Run and dump performance for matmul.
   /// Preheating run:
@@ -77,7 +76,7 @@ func @main() {
   scf.for %arg0 = %c0 to %iters step %c1 {
     // linalg.matmul writes %C in place, need to reset it to zero every time.
     // This is accounts for about 10-15% perf hit on small sizes.
-    // Once linalg on tensors is ready, fusing fill at teh register level will
+    // Once linalg on tensors is ready, fusing fill at the register level will
     // be easy.
     linalg.fill(%C, %v0) : !row_major_C, !elem_type_c
     call @matmul(%A, %B, %C) : (!row_major_A, !row_major_B, !row_major_C) -> ()

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index fc1266e3608a..fb9d452cfcef 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -38,6 +38,70 @@ using llvm::dbgs;
 
 #define DEBUG_TYPE "linalg-vectorization"
 
+/// Return true if the use-def chain from `v` to `from` consists of 0 or more
+/// unary single-operand operations.
+// TODO: relax to multi-operands with constants, which are technically unary ops
+// as needed (e.g. add5).
+static bool isChainOfUnaryOpsFrom(Value v, Value from) {
+  while (v != from) {
+    Operation *op = v.getDefiningOp();
+    if (!op || op->getNumOperands() != 1)
+      return false;
+    v = op->getOperand(0);
+  };
+  return true;
+}
+
+/// Return the unique instance of OpType in `block` if it is indeed unique.
+/// Return null if none or more than 1 instances exist.
+template <typename OpType>
+static OpType getSingleOpOfType(Block &block) {
+  OpType res;
+  block.walk([&](OpType op) {
+    if (res) {
+      res = nullptr;
+      return WalkResult::interrupt();
+    }
+    res = op;
+    return WalkResult::advance();
+  });
+  return res;
+}
+
+/// Detect whether res is any permutation of `u5(u1(c) + u2(u3(a) * u4(b)))`
+/// on the field (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent
+/// unary operations that may change the type.
+template <typename AddOpType, typename MulOpType>
+static bool isAddMul(Block &block) {
+  if (block.getNumArguments() != 3)
+    return false;
+  Operation *yieldOp = block.getTerminator();
+  if (yieldOp->getNumOperands() != 1)
+    return false;
+
+  LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: isAddMul: "; block.dump());
+  AddOpType addOp = getSingleOpOfType<AddOpType>(block);
+  MulOpType mulOp = getSingleOpOfType<MulOpType>(block);
+  if (!addOp || !mulOp)
+    return false;
+
+  Value argA = block.getArgument(0), argB = block.getArgument(1);
+  Value a = mulOp->getOperand(0), b = mulOp->getOperand(1);
+  Value mul = mulOp->getResult(0);
+  Value argC = block.getArgument(2);
+  Value c1 = addOp->getOperand(0), c2 = addOp->getOperand(1);
+  Value add = addOp->getResult(0);
+  Value res = yieldOp->getOperand(0);
+  // Result traces back to add.
+  auto un = isChainOfUnaryOpsFrom;
+  bool success = un(res, add);
+  // One of the operands of add traces back to argC, the other to the mul.
+  success |= (un(c1, argC) && un(c2, mul)) || ((un(c1, mul)) && un(c2, argC));
+  // One of the operands of mul traces back to argA, the other to argB.
+  success |= (un(a, argA) && un(b, argB)) || ((un(a, argB)) && un(b, argA));
+  return success;
+}
+
 /// Helper data structure to represent the result of vectorization.
 /// In certain specific cases, like terminators, we do not want to propagate/
 enum VectorizationStatus {
@@ -146,7 +210,7 @@ vectorizeLinalgYield(OpBuilder &builder, Operation *op,
       results.push_back(result);
   }
   return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
-};
+}
 
 /// Generic vectorization for a single operation `op`, given already vectorized
 /// operands carried by `bvm`. Vectorization occurs as follows:
@@ -305,55 +369,34 @@ static LogicalResult vectorizeAsLinalgGeneric(
   return success();
 }
 
-/// Detect whether `r` exactly computes a floating-point or integer
-/// multiply-accumulate.
-static bool hasMultiplyAddBody(Region &r) {
-  if (!llvm::hasSingleElement(r))
-    return false;
-  if (!llvm::hasNItems(r.front().begin(), r.front().end(), 3))
-    return false;
-
-  using mlir::matchers::m_Val;
-  auto a = m_Val(r.getArgument(0));
-  auto b = m_Val(r.getArgument(1));
-  auto c = m_Val(r.getArgument(2));
-  // TODO: Update this detection once we have  matcher support for specifying
-  // that any permutation of operands matches.
-  auto pattern1 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c));
-  auto pattern2 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b)));
-  auto pattern3 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c));
-  auto pattern4 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a)));
-  auto pattern5 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(a, b), c));
-  auto pattern6 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(a, b)));
-  auto pattern7 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(b, a), c));
-  auto pattern8 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(b, a)));
-  return pattern1.match(&r.front().back()) ||
-         pattern2.match(&r.front().back()) ||
-         pattern3.match(&r.front().back()) ||
-         pattern4.match(&r.front().back()) ||
-         pattern5.match(&r.front().back()) ||
-         pattern6.match(&r.front().back()) ||
-         pattern7.match(&r.front().back()) || pattern8.match(&r.front().back());
-}
-
 /// Detect whether the LinalgOp `op` is a contraction.
-// TODO: Should be Tablegen'd from a single source that generates the op itself.
+/// A Linalg contraction is defined in general terms:
+///   1. Has 2 input and 1 output shapes.
+///   2. Has at least one reduction dimension.
+///   3. Has only projected permutation indexing maps.
+///   4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
+///   (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
+///   operations that may change the type (e.g. for mixed-precision).
+/// As a consequence, when vectorization of such an op occurs, the only special
+/// behavior is that the (unique) MulOpType is vectorized into a
+/// `vector.contract`. All other ops are handled in a generic fashion.
+/// In the future, we may wish to allow more input arguments and elementwise and
+/// constant operations that do not involve the reduction dimension(s).
 static LogicalResult isContraction(Operation *op) {
-  // TODO: interface for named ops.
-  if (isa<linalg::BatchMatmulOp, linalg::MatmulOp, linalg::MatmulColumnMajorOp,
-          linalg::MatvecOp, linalg::VecmatOp, linalg::DotOp>(op))
-    return success();
-
-  auto genericOp = dyn_cast<linalg::GenericOp>(op);
-  if (!genericOp)
+  LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: isContraction: "; op->dump());
+  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
+  if (!linalgOp)
     return failure();
 
-  auto mapRange = genericOp.indexing_maps().getAsValueRange<AffineMapAttr>();
+  auto mapRange = linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>();
   return success(
-      genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 &&
+      linalgOp.getNumInputs() == 2 && linalgOp.getNumOutputs() == 1 &&
+      linalgOp.getNumReductionLoops() > 0 &&
       llvm::all_of(mapRange,
                    [](AffineMap m) { return m.isProjectedPermutation(); }) &&
-      hasMultiplyAddBody(genericOp.region()));
+      // TODO: more fields than add/mul.
+      (isAddMul<AddFOp, MulFOp>(linalgOp->getRegion(0).front()) ||
+       isAddMul<AddIOp, MulIOp>(linalgOp->getRegion(0).front())));
 }
 
 /// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
@@ -382,7 +425,7 @@ static bool isElementwise(Operation *op) {
     if (!genericOp.getOutputIndexingMap(i).isIdentity())
       return false;
   }
-  // Currently limit the input indexing map to minor identity as other
+  // Currently bound the input indexing map to minor identity as other
   // permutations might require adding transpose ops to convert the vector read
   // to the right shape.
   for (unsigned i = 0, e = genericOp.getNumInputs(); i < e; i++) {
@@ -479,6 +522,150 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
          "Unexpected vectorization failed despite preconditions");
 }
 
+//----------------------------------------------------------------------------//
+// Misc. conv vectorization patterns.
+//----------------------------------------------------------------------------//
+// TODO: cleanup all this.
+template <class ConvOp, int N>
+LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
+    ConvOp op, PatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  MLIRContext *context = op.getContext();
+  edsc::ScopedContext scope(rewriter, loc);
+
+  ShapedType inShapeType = op.getInputShapedType(0);
+  ShapedType kShapeType = op.getInputShapedType(1);
+
+  ArrayRef<int64_t> inShape = inShapeType.getShape();
+  ArrayRef<int64_t> kShape = kShapeType.getShape();
+
+  if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape())
+    return failure();
+
+  SmallVector<AffineExpr, 4> mapping;
+  SmallVector<int64_t, 4> vectorDims;
+  // Fail to apply when the size of not vectorized dimension is not 1.
+  for (unsigned i = 0; i < N; i++) {
+    if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1))
+      return failure();
+
+    if (mask[i] && inShape[i] != kShape[i])
+      return failure();
+
+    if (mask[i]) {
+      mapping.push_back(getAffineDimExpr(i, context));
+      vectorDims.push_back(inShape[i]);
+    }
+  }
+
+  Value input = op.getInput(0);
+  Value kernel = op.getInput(1);
+  Value output = op.getOutputBuffer(0);
+
+  unsigned rank = inShapeType.getRank();
+  unsigned numDims = mapping.size();
+  Type elemType = inShapeType.getElementType();
+
+  auto map = AffineMap::get(rank, 0, mapping, context);
+  SmallVector<Value, 4> zeros(rank, std_constant_index(0));
+  auto vecType = VectorType::get(vectorDims, elemType);
+
+  auto inputVec = vector_transfer_read(vecType, input, zeros, map);
+  auto kernelVec = vector_transfer_read(vecType, kernel, zeros, map);
+
+  auto acc = std_constant(elemType, rewriter.getZeroAttr(elemType));
+
+  std::array<AffineMap, 3> indexingMaps{
+      AffineMap::getMultiDimIdentityMap(numDims, context),
+      AffineMap::getMultiDimIdentityMap(numDims, context),
+      AffineMap::get(numDims, 0, {}, context)};
+
+  std::vector<StringRef> iteratorTypes(numDims, "reduction");
+
+  auto result = rewriter.create<vector::ContractionOp>(
+      loc, inputVec, kernelVec, acc,
+      rewriter.getAffineMapArrayAttr(indexingMaps),
+      rewriter.getStrArrayAttr(iteratorTypes));
+
+  rewriter.create<StoreOp>(loc, result, output, ValueRange(zeros));
+  rewriter.eraseOp(op);
+  return success();
+}
+
+using ConvOpConst = ConvOpVectorization<ConvWOp, 1>;
+
+/// Inserts tiling, promotion and vectorization pattern for ConvOp
+/// conversion into corresponding pattern lists.
+template <typename ConvOp, unsigned N>
+static void
+populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns,
+                              OwningRewritePatternList &promotionPatterns,
+                              OwningRewritePatternList &vectorizationPatterns,
+                              ArrayRef<int64_t> tileSizes,
+                              MLIRContext *context) {
+  if (tileSizes.size() < N)
+    return;
+
+  constexpr static StringRef kTiledMarker = "TILED";
+  constexpr static StringRef kPromotedMarker = "PROMOTED";
+  tilingPatterns.insert<LinalgTilingPattern<ConvOp>>(
+      context, LinalgTilingOptions().setTileSizes(tileSizes),
+      LinalgTransformationFilter(ArrayRef<Identifier>{},
+                                 Identifier::get(kTiledMarker, context)));
+
+  promotionPatterns.insert<LinalgPromotionPattern<ConvOp>>(
+      context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
+      LinalgTransformationFilter(Identifier::get(kTiledMarker, context),
+                                 Identifier::get(kPromotedMarker, context)));
+
+  SmallVector<bool, 4> mask(N);
+  int offset = tileSizes.size() - N;
+  std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(),
+                 [](int64_t i) -> bool { return i > 1; });
+
+  vectorizationPatterns.insert<ConvOpVectorization<ConvOp, N>>(context, mask);
+}
+
+void mlir::linalg::populateConvVectorizationPatterns(
+    MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
+    ArrayRef<int64_t> tileSizes) {
+  OwningRewritePatternList tiling, promotion, vectorization;
+  populateVectorizationPatterns<ConvWOp, 1>(tiling, promotion, vectorization,
+                                            tileSizes, context);
+
+  populateVectorizationPatterns<ConvNWCOp, 3>(tiling, promotion, vectorization,
+                                              tileSizes, context);
+
+  populateVectorizationPatterns<ConvNCWOp, 3>(tiling, promotion, vectorization,
+                                              tileSizes, context);
+
+  populateVectorizationPatterns<ConvHWOp, 2>(tiling, promotion, vectorization,
+                                             tileSizes, context);
+
+  populateVectorizationPatterns<ConvNHWCOp, 4>(tiling, promotion, vectorization,
+                                               tileSizes, context);
+
+  populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization,
+                                               tileSizes, context);
+
+  populateVectorizationPatterns<ConvDHWOp, 3>(tiling, promotion, vectorization,
+                                              tileSizes, context);
+
+  populateVectorizationPatterns<ConvNDHWCOp, 5>(
+      tiling, promotion, vectorization, tileSizes, context);
+
+  populateVectorizationPatterns<ConvNCDHWOp, 5>(
+      tiling, promotion, vectorization, tileSizes, context);
+
+  patterns.push_back(std::move(tiling));
+  patterns.push_back(std::move(promotion));
+  patterns.push_back(std::move(vectorization));
+}
+
+//----------------------------------------------------------------------------//
+// Forwarding patterns
+//----------------------------------------------------------------------------//
+
 /// Check whether there is any interleaved use of any `values` between `firstOp`
 /// and `secondOp`. Conservatively return `true` if any op or value is in a
 /// 
diff erent block.
@@ -649,139 +836,3 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
 
   return success();
 }
-
-template <class ConvOp, int N>
-LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
-    ConvOp op, PatternRewriter &rewriter) const {
-  Location loc = op.getLoc();
-  MLIRContext *context = op.getContext();
-  edsc::ScopedContext scope(rewriter, loc);
-
-  ShapedType inShapeType = op.getInputShapedType(0);
-  ShapedType kShapeType = op.getInputShapedType(1);
-
-  ArrayRef<int64_t> inShape = inShapeType.getShape();
-  ArrayRef<int64_t> kShape = kShapeType.getShape();
-
-  if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape())
-    return failure();
-
-  SmallVector<AffineExpr, 4> mapping;
-  SmallVector<int64_t, 4> vectorDims;
-  // Fail to apply when the size of not vectorized dimension is not 1.
-  for (unsigned i = 0; i < N; i++) {
-    if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1))
-      return failure();
-
-    if (mask[i] && inShape[i] != kShape[i])
-      return failure();
-
-    if (mask[i]) {
-      mapping.push_back(getAffineDimExpr(i, context));
-      vectorDims.push_back(inShape[i]);
-    }
-  }
-
-  Value input = op.getInput(0);
-  Value kernel = op.getInput(1);
-  Value output = op.getOutputBuffer(0);
-
-  unsigned rank = inShapeType.getRank();
-  unsigned numDims = mapping.size();
-  Type elemType = inShapeType.getElementType();
-
-  auto map = AffineMap::get(rank, 0, mapping, context);
-  SmallVector<Value, 4> zeros(rank, std_constant_index(0));
-  auto vecType = VectorType::get(vectorDims, elemType);
-
-  auto inputVec = vector_transfer_read(vecType, input, zeros, map);
-  auto kernelVec = vector_transfer_read(vecType, kernel, zeros, map);
-
-  auto acc = std_constant(elemType, rewriter.getZeroAttr(elemType));
-
-  std::array<AffineMap, 3> indexingMaps{
-      AffineMap::getMultiDimIdentityMap(numDims, context),
-      AffineMap::getMultiDimIdentityMap(numDims, context),
-      AffineMap::get(numDims, 0, {}, context)};
-
-  std::vector<StringRef> iteratorTypes(numDims, "reduction");
-
-  auto result = rewriter.create<vector::ContractionOp>(
-      loc, inputVec, kernelVec, acc,
-      rewriter.getAffineMapArrayAttr(indexingMaps),
-      rewriter.getStrArrayAttr(iteratorTypes));
-
-  rewriter.create<StoreOp>(loc, result, output, ValueRange(zeros));
-  rewriter.eraseOp(op);
-  return success();
-}
-
-using ConvOpConst = ConvOpVectorization<ConvWOp, 1>;
-
-/// Inserts tiling, promotion and vectorization pattern for ConvOp
-/// conversion into corresponding pattern lists.
-template <typename ConvOp, unsigned N>
-static void
-populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns,
-                              OwningRewritePatternList &promotionPatterns,
-                              OwningRewritePatternList &vectorizationPatterns,
-                              ArrayRef<int64_t> tileSizes,
-                              MLIRContext *context) {
-  if (tileSizes.size() < N)
-    return;
-
-  constexpr static StringRef kTiledMarker = "TILED";
-  constexpr static StringRef kPromotedMarker = "PROMOTED";
-  tilingPatterns.insert<LinalgTilingPattern<ConvOp>>(
-      context, LinalgTilingOptions().setTileSizes(tileSizes),
-      LinalgTransformationFilter(ArrayRef<Identifier>{},
-                                 Identifier::get(kTiledMarker, context)));
-
-  promotionPatterns.insert<LinalgPromotionPattern<ConvOp>>(
-      context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
-      LinalgTransformationFilter(Identifier::get(kTiledMarker, context),
-                                 Identifier::get(kPromotedMarker, context)));
-
-  SmallVector<bool, 4> mask(N);
-  int offset = tileSizes.size() - N;
-  std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(),
-                 [](int64_t i) -> bool { return i > 1; });
-
-  vectorizationPatterns.insert<ConvOpVectorization<ConvOp, N>>(context, mask);
-}
-
-void mlir::linalg::populateConvVectorizationPatterns(
-    MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
-    ArrayRef<int64_t> tileSizes) {
-  OwningRewritePatternList tiling, promotion, vectorization;
-  populateVectorizationPatterns<ConvWOp, 1>(tiling, promotion, vectorization,
-                                            tileSizes, context);
-
-  populateVectorizationPatterns<ConvNWCOp, 3>(tiling, promotion, vectorization,
-                                              tileSizes, context);
-
-  populateVectorizationPatterns<ConvNCWOp, 3>(tiling, promotion, vectorization,
-                                              tileSizes, context);
-
-  populateVectorizationPatterns<ConvHWOp, 2>(tiling, promotion, vectorization,
-                                             tileSizes, context);
-
-  populateVectorizationPatterns<ConvNHWCOp, 4>(tiling, promotion, vectorization,
-                                               tileSizes, context);
-
-  populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization,
-                                               tileSizes, context);
-
-  populateVectorizationPatterns<ConvDHWOp, 3>(tiling, promotion, vectorization,
-                                              tileSizes, context);
-
-  populateVectorizationPatterns<ConvNDHWCOp, 5>(
-      tiling, promotion, vectorization, tileSizes, context);
-
-  populateVectorizationPatterns<ConvNCDHWOp, 5>(
-      tiling, promotion, vectorization, tileSizes, context);
-
-  patterns.push_back(std::move(tiling));
-  patterns.push_back(std::move(promotion));
-  patterns.push_back(std::move(vectorization));
-}

diff  --git a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
index 64d3405f016d..21aba6cdaf7b 100644
--- a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
+++ b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,3" --cse | FileCheck %s
+// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,3" --cse -split-input-file
+// | FileCheck %s
 
 // CHECK-DAG:  #[[$map0:.*]] = affine_map<(d0)[s0] -> (1, -d0 + s0)>
 // CHECK-DAG:  #[[$map1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
@@ -6,16 +7,11 @@
 // CHECK-DAG:  #[[$map3:.*]] = affine_map<(d0, d1)[s0] -> (3, -d0 - d1 + s0)>
 // CHECK-DAG:  #[[$map4:.*]] = affine_map<(d0)[s0] -> (3, -d0 + s0)>
 
-func @conv_1d(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
-  linalg.conv_1d ins(%arg0, %arg1 : memref<?xf32>, memref<?xf32>)
-                outs(%arg2 : memref<?xf32>)
-  return
-}
-
 // CHECK-LABEL: @conv_1d
 //  CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?xf32>
 //  CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?xf32>
 //  CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?xf32
+func @conv_1d(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
 //   CHECK-DAG:   %[[c12:.*]] = constant 12 : index
 //   CHECK-DAG:   %[[c4:.*]] = constant 4 : index
 //   CHECK-DAG:   %[[cst:.*]] = constant 0.000000e+00 : f32
@@ -50,3 +46,8 @@ func @conv_1d(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>)
 //       CHECK:       scf.for %[[arg5:.*]] = %[[c0]] to %[[v9]] step %[[c1]] {
 //       CHECK:         %[[v23:.*]] = load %[[v11]][%[[arg5]]] : memref<?xf32>
 //       CHECK:         store %[[v23]], %[[v10]][%[[arg5]]] : memref<?xf32, #[[$map1]]>
+  linalg.conv_1d ins(%arg0, %arg1 : memref<?xf32>, memref<?xf32>)
+                outs(%arg2 : memref<?xf32>)
+  return
+}
+

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index aa249542a07d..3904353287c5 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1,8 +1,6 @@
-// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-to-vector-patterns | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-to-vector-patterns -split-input-file | FileCheck %s
 
-// CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
-// CHECK-DAG: #[[$kn:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
-// CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// -----
 
 // CHECK-LABEL: contraction_dot
 func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
@@ -13,6 +11,8 @@ func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32
   return
 }
 
+// -----
+
 // CHECK-LABEL: contraction_matvec
 func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) {
   // CHECK: vector.contract
@@ -22,6 +22,8 @@ func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: me
   return
 }
 
+// -----
+
 // CHECK-LABEL: contraction_matmul
 func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
   // CHECK: vector.contract
@@ -31,6 +33,8 @@ func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %
   return
 }
 
+// -----
+
 // CHECK-LABEL: contraction_batch_matmul
 func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) {
   // CHECK: vector.contract
@@ -41,6 +45,8 @@ func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1
   return
 }
 
+// -----
+
 #matmul_trait = {
   args_in = 2,
   args_out = 1,
@@ -51,8 +57,20 @@ func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1
   ],
   iterator_types = ["parallel", "parallel", "reduction"]
 }
+
+// CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$kn:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @vectorization_test
 func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
                          %C: memref<8x32xf32>) {
+  //       CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32>
+  //       CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<16x32xf32>
+  //       CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
+  //       CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]]
+  //  CHECK-SAME:   vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
+  //       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
   linalg.generic #matmul_trait
     ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>)
    outs(%C : memref<8x32xf32>) {
@@ -63,15 +81,33 @@ func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
   }
   return
 }
-// CHECK-LABEL: func @vectorization_test
-//       CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32>
-//       CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<16x32xf32>
-//       CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
-//       CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
-//       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
 
+// -----
+
+#matmul_trait = {
+  args_in = 2,
+  args_out = 1,
+  indexing_maps = [
+    affine_map<(m, n, k) -> (m, k)>,
+    affine_map<(m, n, k) -> (k, n)>,
+    affine_map<(m, n, k) -> (m, n)>
+  ],
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$kn:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @vectorization_test_integer
 func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
                                  %C: memref<8x32xi32>) {
+  //       CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x16xi32>
+  //       CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<16x32xi32>
+  //       CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32>
+  //       CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]],
+  //  CHECK-SAME:   vector<8x16xi32>, vector<16x32xi32> into vector<8x32xi32>
+  //       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32>
   linalg.generic #matmul_trait
     ins(%A, %B : memref<8x16xi32>, memref<16x32xi32>)
    outs(%C : memref<8x32xi32>) {
@@ -82,58 +118,71 @@ func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
   }
   return
 }
-// CHECK-LABEL: func @vectorization_test_integer
-//       CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x16xi32>
-//       CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<16x32xi32>
-//       CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32>
-//       CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xi32>, vector<16x32xi32> into vector<8x32xi32>
-//       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32>
 
+// -----
+
+// CHECK-LABEL: func @vectorization_test_2
 func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
                          %C: memref<8x32xf32>) {
+  //       CHECK: vector.contract {{.*}} :
+  //                vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
   linalg.matmul
     ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>)
    outs(%C: memref<8x32xf32>)
   return
 }
-// CHECK-LABEL: func @vectorization_test_2
-//       CHECK: vector.contract {{.*}} :
-//                vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
 
+// -----
+
+// CHECK-LABEL: func @test_vectorize_fill
 func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
+  //       CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>
+  //       CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
   linalg.fill(%A, %arg0) :  memref<8x16xf32>, f32
   return
 }
-// CHECK-LABEL: func @test_vectorize_fill
-//       CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>
-//       CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
 
+// -----
+
+// CHECK-LABEL: func @test_vectorize_fill
 func @test_vectorize_fill_scalar(%A : memref<f32>, %arg0 : f32) {
+  //  CHECK-SAME: (%[[M:.*]]: memref<f32>, %[[V:.*]]: f32)
+  //       CHECK:   store %[[V]], %[[M]][] : memref<f32>
   linalg.fill(%A, %arg0) :  memref<f32>, f32
   return
 }
-// CHECK-LABEL: func @test_vectorize_fill
-//  CHECK-SAME: (%[[M:.*]]: memref<f32>, %[[V:.*]]: f32)
-//       CHECK:   store %[[V]], %[[M]][] : memref<f32>
 
+// -----
+
+// CHECK-LABEL: func @test_vectorize_copy
 func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
+  //       CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32>
+  //       CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
   linalg.copy(%A, %B) :  memref<8x16xf32>, memref<8x16xf32>
   return
 }
-// CHECK-LABEL: func @test_vectorize_copy
-//       CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32>
-//       CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
 
+// -----
+
+// CHECK-LABEL: func @test_vectorize_copy_scalar
 func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
+  //       CHECK: %[[V:.*]] = load {{.*}} : memref<f32>
+  //       CHECK: store %[[V]], {{.*}} : memref<f32>
   linalg.copy(%A, %B) :  memref<f32>, memref<f32>
   return
 }
-// CHECK-LABEL: func @test_vectorize_copy_scalar
-//       CHECK: %[[V:.*]] = load {{.*}} : memref<f32>
-//       CHECK: store %[[V]], {{.*}} : memref<f32>
 
-func @generic_vectorize(%arg0: memref<4x256xf32>, %arg1: memref<4x256xf32>,
-  %arg2: memref<256xf32>, %i: f32) {
+// -----
+
+// CHECK-LABEL: func @generic_vectorize
+  //  CHECK-SAME: (%[[ARG0:.*]]: memref<4x256xf32>, %[[ARG1:.*]]: memref<4x256xf32>,
+  //  CHECK-SAME:  %[[ARG2:.*]]: memref<256xf32>, %[[ARG3:.*]]: f32)
+func @generic_vectorize(%arg0: memref<4x256xf32>,
+                        %arg1: memref<4x256xf32>,
+                        %arg2: memref<256xf32>, %i: f32) {
+  //   CHECK-DAG:   %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32>
+  //   CHECK-DAG:   %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32>
+  //   CHECK-DAG:   %[[C0:.*]] = constant 0 : index
   %c1_f32 = constant 1.0 : f32
   linalg.generic {
     args_in = 0 : i64,
@@ -159,57 +208,56 @@ func @generic_vectorize(%arg0: memref<4x256xf32>, %arg1: memref<4x256xf32>,
     memref<4x256xf32>, memref<4x256xf32>, memref<4x256xf32>, memref<4x256xf32>,
     memref<4x256xf32>, memref<4x256xf32>) {
   ^bb0(%arg3 : f32, %arg4 : f32, %arg5: f32, %arg6: f32, %arg7: f32, %arg8: f32,
+  //       CHECK:   %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
+  //       CHECK:   %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : memref<256xf32>, vector<256xf32>
+  //       CHECK:   %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
+  //       CHECK:   %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
     %arg9 : f32, %arg10 : f32, %arg11 : f32, %arg12 : f32, %arg13 : f32,
     %arg14 : f32):
+  //       CHECK:   %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
+  //       CHECK:   %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32>
     %6 = addf %arg4, %arg6 : f32
+  //       CHECK:   %[[CMP:.*]] = cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32>
     %7 = cmpf ogt, %arg3, %arg6 : f32
+  //       CHECK:   %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32>
     %8 = constant 2.0 : f32
+  //       CHECK:   %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32>
     %9 = divf %arg5, %i : f32
+  //       CHECK:   %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32>
     %10 = exp2 %arg5 : f32
+  //       CHECK:   %[[MUL:.*]] = mulf %[[V3]], %[[CST0]] : vector<4x256xf32>
     %11 = mulf %arg5, %8 : f32
+  //       CHECK:   %[[RSQRT:.*]] = rsqrt %[[V3]] : vector<4x256xf32>
     %12 = rsqrt %arg5 : f32
+  //       CHECK:   %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32>
     %13 = select %7, %arg5, %arg6 : f32
+  //       CHECK:   %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
+  //       CHECK:   %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32>
     %14 = subf %arg5, %arg4 : f32
+  //       CHECK:   %[[TAN:.*]] = tanh %[[V3]] : vector<4x256xf32>
     %15 = tanh %arg5 : f32
+  //       CHECK:   vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+  //       CHECK:   vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+  //       CHECK:   vector.transfer_write %[[CST1]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+  //       CHECK:   vector.transfer_write %[[DIV]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+  //       CHECK:   vector.transfer_write %[[EXP]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+  //       CHECK:   vector.transfer_write %[[MUL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+  //       CHECK:   vector.transfer_write %[[RSQRT]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+  //       CHECK:   vector.transfer_write %[[SEL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+  //       CHECK:   vector.transfer_write %[[SUB]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+  //       CHECK:   vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
     linalg.yield %6, %8, %c1_f32, %9, %10, %11, %12, %13, %14, %15 : f32, f32,
       f32, f32, f32, f32, f32, f32, f32, f32
   }
   return
 }
 
-// CHECK-LABEL: func @generic_vectorize
-//  CHECK-SAME: (%[[ARG0:.*]]: memref<4x256xf32>, %[[ARG1:.*]]: memref<4x256xf32>,
-//  CHECK-SAME:  %[[ARG2:.*]]: memref<256xf32>, %[[ARG3:.*]]: f32)
-//   CHECK-DAG:   %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32>
-//   CHECK-DAG:   %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32>
-//   CHECK-DAG:   %[[C0:.*]] = constant 0 : index
-//       CHECK:   %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
-//       CHECK:   %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : memref<256xf32>, vector<256xf32>
-//       CHECK:   %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
-//       CHECK:   %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
-//       CHECK:   %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
-//       CHECK:   %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32>
-//       CHECK:   %[[CMP:.*]] = cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32>
-//       CHECK:   %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32>
-//       CHECK:   %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32>
-//       CHECK:   %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32>
-//       CHECK:   %[[MUL:.*]] = mulf %[[V3]], %[[CST0]] : vector<4x256xf32>
-//       CHECK:   %[[RSQRT:.*]] = rsqrt %[[V3]] : vector<4x256xf32>
-//       CHECK:   %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32>
-//       CHECK:   %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
-//       CHECK:   %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32>
-//       CHECK:   %[[TAN:.*]] = tanh %[[V3]] : vector<4x256xf32>
-//       CHECK:   vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
-//       CHECK:   vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
-//       CHECK:   vector.transfer_write %[[CST1]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
-//       CHECK:   vector.transfer_write %[[DIV]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
-//       CHECK:   vector.transfer_write %[[EXP]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
-//       CHECK:   vector.transfer_write %[[MUL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
-//       CHECK:   vector.transfer_write %[[RSQRT]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
-//       CHECK:   vector.transfer_write %[[SEL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
-//       CHECK:   vector.transfer_write %[[SUB]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
-//       CHECK:   vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
 
+// -----
+
+// CHECK-LABEL: func @generic_vectorize_tensor
+//  CHECK-SAME: (%[[ARG0:.*]]: tensor<4x256xf32>, %[[ARG1:.*]]: tensor<4x256xf32>,
+//  CHECK-SAME:  %[[ARG2:.*]]: tensor<256xf32>, %[[ARG3:.*]]: f32)
 func @generic_vectorize_tensor(%arg0: tensor<4x256xf32>,
   %arg1: tensor<4x256xf32>, %arg2: tensor<256xf32>,
   %i: f32) -> (tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
@@ -240,82 +288,105 @@ func @generic_vectorize_tensor(%arg0: tensor<4x256xf32>,
   ^bb0(%arg3 : f32, %arg4 : f32, %arg5: f32, %arg6: f32, %arg7: f32, %arg8: f32,
     %arg9 : f32, %arg10 : f32, %arg11 : f32, %arg12 : f32, %arg13 : f32,
     %arg14 : f32):
+  //   CHECK-DAG:   %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32>
+  //   CHECK-DAG:   %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32>
+  //   CHECK-DAG:   %[[C0:.*]] = constant 0 : index
+  //       CHECK:   %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
+  //       CHECK:   %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : tensor<256xf32>, vector<256xf32>
+  //       CHECK:   %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
+  //       CHECK:   %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
+  //       CHECK:   %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
+  //       CHECK:   %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32>
     %6 = addf %arg4, %arg6 : f32
+  //       CHECK:   %[[CMP:.*]] = cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32>
     %7 = cmpf ogt, %arg3, %arg6 : f32
+  //       CHECK:   %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32>
     %8 = constant 2.0 : f32
+  //       CHECK:   %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32>
     %9 = divf %arg5, %i : f32
+  //       CHECK:   %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32>
     %10 = exp2 %arg5 : f32
+  //       CHECK:   %[[MUL:.*]] = mulf %[[V3]], %[[CST0]] : vector<4x256xf32>
     %11 = mulf %arg5, %8 : f32
+  //       CHECK:   %[[RSQRT:.*]] = rsqrt %[[V3]] : vector<4x256xf32>
     %12 = rsqrt %arg5 : f32
+  //       CHECK:   %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32>
     %13 = select %7, %arg5, %arg6 : f32
+  //       CHECK:   %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
+  //       CHECK:   %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32>
     %14 = subf %arg5, %arg4 : f32
+  //       CHECK:   %[[TAN:.*]] = tanh %[[V3]] : vector<4x256xf32>
     %15 = tanh %arg5 : f32
+  //       CHECK:   %[[R0:.*]] = vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+  //       CHECK:   %[[R1:.*]] = vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+  //       CHECK:   %[[R2:.*]] = vector.transfer_write %[[CST1]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+  //       CHECK:   %[[R3:.*]] = vector.transfer_write %[[DIV]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+  //       CHECK:   %[[R4:.*]] = vector.transfer_write %[[EXP]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+  //       CHECK:   %[[R5:.*]] = vector.transfer_write %[[MUL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+  //       CHECK:   %[[R6:.*]] = vector.transfer_write %[[RSQRT]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+  //       CHECK:   %[[R7:.*]] = vector.transfer_write %[[SEL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+  //       CHECK:   %[[R8:.*]] = vector.transfer_write %[[SUB]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+  //       CHECK:   %[[R9:.*]] = vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
     linalg.yield %6, %8, %c1_f32, %9, %10, %11, %12, %13, %14, %15 : f32, f32,
       f32, f32, f32, f32, f32, f32, f32, f32
   } -> tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
     tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
     tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>
+  //       CHECK:   return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]], %[[R6]], %[[R7]], %[[R8]], %[[R9]] : tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>
   return %r#0, %r#1, %r#2, %r#3, %r#4, %r#5, %r#6, %r#7, %r#8, %r#9:
     tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
     tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
     tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>
 }
 
-// CHECK-LABEL: func @generic_vectorize_tensor
-//  CHECK-SAME: (%[[ARG0:.*]]: tensor<4x256xf32>, %[[ARG1:.*]]: tensor<4x256xf32>,
-//  CHECK-SAME:  %[[ARG2:.*]]: tensor<256xf32>, %[[ARG3:.*]]: f32)
-//   CHECK-DAG:   %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32>
-//   CHECK-DAG:   %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32>
-//   CHECK-DAG:   %[[C0:.*]] = constant 0 : index
-//       CHECK:   %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
-//       CHECK:   %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : tensor<256xf32>, vector<256xf32>
-//       CHECK:   %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
-//       CHECK:   %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
-//       CHECK:   %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
-//       CHECK:   %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32>
-//       CHECK:   %[[CMP:.*]] = cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32>
-//       CHECK:   %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32>
-//       CHECK:   %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32>
-//       CHECK:   %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32>
-//       CHECK:   %[[MUL:.*]] = mulf %[[V3]], %[[CST0]] : vector<4x256xf32>
-//       CHECK:   %[[RSQRT:.*]] = rsqrt %[[V3]] : vector<4x256xf32>
-//       CHECK:   %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32>
-//       CHECK:   %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
-//       CHECK:   %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32>
-//       CHECK:   %[[TAN:.*]] = tanh %[[V3]] : vector<4x256xf32>
-//       CHECK:   %[[R0:.*]] = vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
-//       CHECK:   %[[R1:.*]] = vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
-//       CHECK:   %[[R2:.*]] = vector.transfer_write %[[CST1]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
-//       CHECK:   %[[R3:.*]] = vector.transfer_write %[[DIV]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
-//       CHECK:   %[[R4:.*]] = vector.transfer_write %[[EXP]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
-//       CHECK:   %[[R5:.*]] = vector.transfer_write %[[MUL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
-//       CHECK:   %[[R6:.*]] = vector.transfer_write %[[RSQRT]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
-//       CHECK:   %[[R7:.*]] = vector.transfer_write %[[SEL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
-//       CHECK:   %[[R8:.*]] = vector.transfer_write %[[SUB]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
-//       CHECK:   %[[R9:.*]] = vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
-//       CHECK:   return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]], %[[R6]], %[[R7]], %[[R8]], %[[R9]] : tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>
+// -----
 
+// CHECK-LABEL: func @matmul_tensors
+//  CHECK-SAME: (%[[ARG0:.*]]: tensor<8x4xf32>, %[[ARG1:.*]]: tensor<4x12xf32>,
+//  CHECK-SAME:  %[[ARG2:.*]]: tensor<8x12xf32>) -> tensor<8x12xf32>
 func @matmul_tensors(
   %arg0: tensor<8x4xf32>, %arg1: tensor<4x12xf32>, %arg2: tensor<8x12xf32>)
     -> tensor<8x12xf32> {
+  //   CHECK-DAG:   %[[C0:.*]] = constant 0 : index
+  //   CHECK-DAG:   %[[VEC_C0:.*]] = constant dense<0.000000e+00> : vector<8x12xf32>
+  //   CHECK-DAG:   %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32>
+  //   CHECK-DAG:   %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<4x12xf32>
+  //   CHECK-DAG:   %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32>
+  //
+  // linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp.
+  // a later canonicalization fuses the add into vector.contract.
+  //       CHECK:   %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[VEC_C0]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32>
+  //       CHECK:   %[[C2:.*]] = addf %[[V2]], %[[C]] : vector<8x12xf32>
+  //       CHECK:   %[[W:.*]] = vector.transfer_write %[[C2]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<8x12xf32>, tensor<8x12xf32>
   %0 = linalg.matmul  ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>)
                      outs(%arg2: tensor<8x12xf32>)
     -> tensor<8x12xf32>
+  //       CHECK:   return %[[W]] : tensor<8x12xf32>
   return %0 : tensor<8x12xf32>
 }
 
-// CHECK-LABEL: func @matmul_tensors
-//  CHECK-SAME: (%[[ARG0:.*]]: tensor<8x4xf32>, %[[ARG1:.*]]: tensor<4x12xf32>,
-//  CHECK-SAME:  %[[ARG2:.*]]: tensor<8x12xf32>) -> tensor<8x12xf32>
-//   CHECK-DAG:   %[[C0:.*]] = constant 0 : index
-//   CHECK-DAG:   %[[VEC_C0:.*]] = constant dense<0.000000e+00> : vector<8x12xf32>
-//   CHECK-DAG:   %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32>
-//   CHECK-DAG:   %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<4x12xf32>
-//   CHECK-DAG:   %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32>
-//
-// linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp.
-// a later canonicalization fuses the add into vector.contract.
-//       CHECK:   %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[VEC_C0]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32>
-//       CHECK:   %[[C2:.*]] = addf %[[V2]], %[[C]] : vector<8x12xf32>
-//       CHECK:   %[[W:.*]] = vector.transfer_write %[[C2]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<8x12xf32>, tensor<8x12xf32>
-//       CHECK:   return %[[W]] : tensor<8x12xf32>
+// -----
+
+// CHECK-LABEL: func @matmul_i8_i8_i32
+//  CHECK-SAME:  %[[ARG0:[a-z0-9]+]]: memref<4x6xi8>
+//  CHECK-SAME:  %[[ARG1:[a-z0-9]+]]: memref<6x12xi8>
+//  CHECK-SAME:  %[[ARG2:[a-z0-9]+]]: memref<4x12xi32>
+func @matmul_i8_i8_i32(%a: memref<4x6xi8>, %b: memref<6x12xi8>, %c: memref<4x12xi32>) {
+  //   CHECK-DAG:   %[[C0:.*]] = constant 0 : index
+  //   CHECK-DAG:   %[[VEC_C0:.*]] = constant dense<0> : vector<4x12xi8>
+  //   CHECK-DAG:   %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x6xi8>, vector<4x6xi8>
+  //   CHECK-DAG:   %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<6x12xi8>, vector<6x12xi8>
+  //   CHECK-DAG:   %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : memref<4x12xi32>, vector<4x12xi32>
+  //
+  // linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp.
+  // a later canonicalization fuses the add into vector.contract.
+  //       CHECK:   %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[VEC_C0]]
+  //  CHECK-SAME:     vector<4x6xi8>, vector<6x12xi8> into vector<4x12xi8>
+  //       CHECK:   %[[C32:.*]] = sexti %[[C]] : vector<4x12xi8> to vector<4x12xi32>
+  //       CHECK:   %[[RES:.*]] = addi %[[V2]], %[[C32]] : vector<4x12xi32>
+  //       CHECK:   vector.transfer_write %[[RES]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]}
+  //  CHECK-SAME:     vector<4x12xi32>, memref<4x12xi32>
+  linalg.matmul_i8_i8_i32 ins(%a, %b : memref<4x6xi8>, memref<6x12xi8>)
+    outs(%c: memref<4x12xi32>)
+  return
+}

diff  --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index 126bbc3639af..27ca9942a74c 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -493,9 +493,11 @@ static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
 
 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
   OwningRewritePatternList patterns;
+  // TODO: remove all this in favor of a single LinalgOp.
   patterns.insert<
       LinalgVectorizationPattern<BatchMatmulOp>,
       LinalgVectorizationPattern<MatmulOp>,
+      LinalgVectorizationPattern<MatmulI8I8I32Op>,
       LinalgVectorizationPattern<MatvecOp>,
       LinalgVectorizationPattern<VecmatOp>, LinalgVectorizationPattern<DotOp>,
       LinalgVectorizationPattern<FillOp>, LinalgVectorizationPattern<CopyOp>,


        


More information about the Mlir-commits mailing list