[Mlir-commits] [mlir] 17b2e73 - [mlir][LinAlg][Transform] Add a transform op for conv2d to im2col

Thomas Raoux llvmlistbot at llvm.org
Thu Feb 23 14:29:31 PST 2023

Author: Quentin Colombet
Date: 2023-02-23T22:27:16Z
New Revision: 17b2e73cb477d42771fbc68a215dde648f3eaaef

URL: https://github.com/llvm/llvm-project/commit/17b2e73cb477d42771fbc68a215dde648f3eaaef
DIFF: https://github.com/llvm/llvm-project/commit/17b2e73cb477d42771fbc68a215dde648f3eaaef.diff

LOG: [mlir][LinAlg][Transform] Add a transform op for conv2d to im2col

This patch adds patterns to convert `linalg.conv_2d_xxx` operations
into `linalg.generic` (for img2col packing) and `linalg.matmul`.

The meat of the patch comes straight from IREE
(To the original authors are you okay with that?)

What this patch adds is proper plumbing of the im2col patterns into the
transform dialect.

PS: Feel free to add more reviewers. I wanted to cover the original contributors of im2col in IREE but I'm not sure I got all of them.

Reviewed By: nicolasvasilache, ThomasRaoux

Differential Revision: https://reviews.llvm.org/D144108




diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 11f3b3c634fdf..c53497839ea9a 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1740,6 +1740,82 @@ def HoistRedundantVectorTransfersOp :
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
         ::mlir::func::FuncOp target,
+         ::mlir::transform::ApplyToEachResultList &results,
+         ::mlir::transform::TransformState &state);
+   }];
+// ConvertConv2DToImg2ColOp
+def ConvertConv2DToImg2ColOp : Op<Transform_Dialect,
+    "structured.convert_conv2d_to_img2col",
+    [FunctionalStyleTransformOpTrait,
+     MemoryEffectsOpInterface,
+     TransformOpInterface,
+     TransformEachOpTrait]> {
+  let description = [{
+    Convert linalg.conv_2d_xxx into linalg.generic (for img2col packing)
+    and linalg.matmul.
+    A convolution operation can be written as a matrix-matrix multiplication by
+    unfolding the cross-correlation between input and filter and explicitly copy
+    overlapped sliding window inputs.
+    Consider 2D input X with single channel input and output and 2x2 filter W:
+    ```
+    [x(0, 0)  , x(0, 1)  , ...,   x(0, n)  ]
+    [x(1, 0)  , x(1, 1)  , ...,   x(1, n)  ]
+    [.        ,  .       ,.   ,      .     ]            [w(0, 0), w(0, 1)]
+    [.        ,  .       , .  ,      .     ]    (conv)  [w(1, 0), w(1, 1)]
+    [.        ,  .       ,   .,      .     ]
+    [x(n-1, 0), x(n-1, 1), ..., x(n-1, n-1)]
+    ```
+    The packed input data (img2col) is a matrix with |rows| = output spatial
+    size, |columns| = filter spatial size. To compute the output Y(i, j) we need
+    to calculate the dot product between filter window at input X(x, y)) and the
+    filter which will look like the following where r.h.s is the img2col matrix
+    and l.h.s is the flattned filter:
+    ```
+    [x(0,0), x(0,1), x(1,0), x(1,1)]
+    [x(0,1), x(1,1), x(0,2), x(1,2)] (matmul) [w(0,0), w(0,1), w(1,0), w(1,1)]
+    [x(0,1), x(1,1), x(0,2), x(1,2)]
+    [   .  ,    .  ,    .  ,    .  ]
+    ```
+    In general for 2D case with (N, H, W, C) input and (Kh, Kw, C, D) filter
+    and output (N, Ho, Wo, D) the convolution is the following matrix-matrix
+    multiplication (Ho x Wo, Kh x Kw x C) * (Kh x Kw x C, D) for each input in
+    the N input. For the case where N > 1 its a batched matrxi-matrix
+    multplication.
+    Returns two handles:
+    - One on the operation that produces the img2col tensor.
+    - One on the final operation of the sequence that replaces the original
+      convolution.
+    #### Return modes:
+    Returns a definite failure if target is not isolated from above.
+    Returns a silenceable failure if the pattern application failed.
+  }];
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$img2col_tensor,
+                      TransformHandleTypeInterface:$transformed);
+  let assemblyFormat =
+    "$target attr-dict `:` functional-type($target, results)";
+  let builders = [
+    OpBuilder<(ins "Value":$target)>
+  ];
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::linalg::LinalgOp target,
         ::mlir::transform::ApplyToEachResultList &results,
         ::mlir::transform::TransformState &state);

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index c782cab1e9f94..8e77b54251aa7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -879,6 +879,64 @@ void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
 void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
                                           PatternBenefit benefit = 1);
+/// Populates patterns to transform linalg.conv_2d_xxx operations into
+/// linalg.generic (for img2col packing) and linalg.matmul.
+/// \see rewriteInIm2Col for more details.
+void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns);
+/// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing)
+/// and linalg.matmul.
+/// A convolution operation can be written as a matrix-matrix multiplication by
+/// unfolding the cross-correlation between input and filter and explicitly copy
+/// overlapped sliding window inputs.
+/// Consider 2D input X with single channel input and output and 2x2 filter W:
+/// [x(0, 0)  , x(0, 1)  , ...,   x(0, n)  ]
+/// [x(1, 0)  , x(1, 1)  , ...,   x(1, n)  ]
+/// [.        ,  .       ,.   ,      .     ]            [w(0, 0), w(0, 1)]
+/// [.        ,  .       , .  ,      .     ]    (conv)  [w(1, 0), w(1, 1)]
+/// [.        ,  .       ,   .,      .     ]
+/// [x(n-1, 0), x(n-1, 1), ..., x(n-1, n-1)]
+/// The packed input data (img2col) is a matrix with |rows| = output spatial
+/// size, |columns| = filter spatial size. To compute the output Y(i, j) we need
+/// to calculate the dot product between filter window at input X(x, y)) and the
+/// filter which will look like the following where r.h.s is the img2col matrix
+/// and l.h.s is the flattned filter:
+/// [x(0,0), x(0,1), x(1,0), x(1,1)]
+/// [x(0,1), x(1,1), x(0,2), x(1,2)] (matmul) [w(0,0), w(0,1), w(1,0), w(1,1)]
+/// [x(0,1), x(1,1), x(0,2), x(1,2)]
+/// [   .  ,    .  ,    .  ,    .  ]
+/// In general for 2D case with (N, H, W, C) input and (Kh, Kw, C, D) filter
+/// and output (N, Ho, Wo, D) the convolution is the following matrix-matrix
+/// multiplication (Ho x Wo, Kh x Kw x C) * (Kh x Kw x C, D) for each input in
+/// the N input. For the case where N > 1 its a batched matrxi-matrix
+/// multplication.
+/// On success, return both the operation that produces the img2col tensor and
+/// the final operation of the sequence that replaces the original convolution.
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp);
+/// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except there is no
+/// reduction among the input channels so each convolution can be a
+/// matrix-vector product and by transposing both input filter so channels are
+/// outer most the computation is a batched matrix-vector product.
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter,
+                linalg::DepthwiseConv2DNhwcHwcOp convOp);
+/// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except because the
+/// channels are to the left of the image shape dimensions, the position of the
+/// contraction dimension in the resulting matmul is reversed. This swaps the
+/// LHS and RHS of the matmul when compared with nhwc (i.e. (D, C x Kh x Kw) *
+/// (C x Kh x Kw, Ho x Wo))
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp);
 // Op-specific patterns.

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index aae073c28db8d..2d102383ddbe0 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -34,6 +34,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
@@ -3072,6 +3073,40 @@ transform::HoistRedundantVectorTransfersOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
+// ConvertConv2DToImg2ColOp.
+DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
+    linalg::LinalgOp target, transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  IRRewriter rewriter(target->getContext());
+  rewriter.setInsertionPoint(target);
+  auto maybeTransformed =
+      TypeSwitch<Operation *, FailureOr<std::pair<Operation *, Operation *>>>(
+          target)
+          .Case([&](linalg::Conv2DNhwcHwcfOp op) {
+            return rewriteInIm2Col(rewriter, op);
+          })
+          .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
+            return rewriteInIm2Col(rewriter, op);
+          })
+          .Case([&](linalg::Conv2DNchwFchwOp op) {
+            return rewriteInIm2Col(rewriter, op);
+          })
+          .Default([&](Operation *op) {
+            return rewriter.notifyMatchFailure(op, "not supported");
+          });
+  if (failed(maybeTransformed))
+    return emitDefaultSilenceableFailure(target);
+  // Handle to the operation producing the img2col tensor.
+  results.push_back(maybeTransformed->first);
+  // Handle to the operation that replaces the original convolution.
+  results.push_back(maybeTransformed->second);
+  return DiagnosedSilenceableFailure::success();
 // Transform op registration

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 7ede7e4f96915..adcc87f42dab2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
+  ConvertConv2DToImg2Col.cpp

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
new file mode 100644
index 0000000000000..14bff411ef8c1
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -0,0 +1,540 @@
+//===- ConvertConv2DToImg2Col.cpp - im2col implementation -----------------===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include <utility>
+namespace mlir {
+namespace linalg {
+static bool hasAllOneValues(DenseIntElementsAttr attr) {
+  return llvm::all_of(
+      attr, [](APInt element) { return element.getSExtValue() == 1; });
+static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) {
+  bool isInt = x.getType().isa<IntegerType>();
+  if (isInt)
+    return builder.create<arith::AddIOp>(loc, x, y);
+  return builder.create<arith::AddFOp>(loc, x, y);
+static Value createMul(Location loc, Value x, Value y, OpBuilder &builder) {
+  bool isInt = x.getType().isa<IntegerType>();
+  if (isInt)
+    return builder.create<arith::MulIOp>(loc, x, y);
+  return builder.create<arith::MulFOp>(loc, x, y);
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
+  auto inputType = convOp.getInputs()[0].getType().cast<ShapedType>();
+  auto filterType = convOp.getInputs()[1].getType().cast<ShapedType>();
+  auto outputType = convOp.getOutputs()[0].getType().cast<ShapedType>();
+  if (!filterType.hasStaticShape())
+    return rewriter.notifyMatchFailure(
+        convOp, "expected a static shape for the filter");
+  if (!inputType.hasStaticShape())
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected a static shape for the input");
+  // TODO: Support dilation.
+  if (!hasAllOneValues(convOp.getDilations()))
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected all ones for dilations");
+  MLIRContext *context = rewriter.getContext();
+  Value input = convOp.getInputs()[0];
+  Value filter = convOp.getInputs()[1];
+  Value output = convOp.getOutputs()[0];
+  ArrayRef<int64_t> filterShape = filterType.getShape();
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+  int n = outputShape[0];
+  int oh = outputShape[1];
+  int ow = outputShape[2];
+  int oc = outputShape[3];
+  int fh = filterShape[0];
+  int fw = filterShape[1];
+  int ic = filterShape[2];
+  Location loc = convOp.getLoc();
+  SmallVector<int64_t> colTensorShape = {n, oh, ow, fh, fw, ic};
+  Value colTensor = rewriter.create<tensor::EmptyOp>(
+      loc, colTensorShape, inputType.getElementType());
+  AffineExpr nDim, ohDim, owDim, khDim, kwDim, icDim;
+  bindDims(context, nDim, ohDim, owDim, khDim, kwDim, icDim);
+  AffineExpr shSym = rewriter.getAffineConstantExpr(
+      convOp.getStrides().getValues<int64_t>()[0]);
+  AffineExpr swSym = rewriter.getAffineConstantExpr(
+      convOp.getStrides().getValues<int64_t>()[1]);
+  SmallVector<AffineExpr> inputExprs = {nDim, ohDim * shSym + khDim,
+                                        owDim * swSym + kwDim, icDim};
+  auto nloops = colTensorShape.size();
+  auto parallel = utils::IteratorType::parallel;
+  auto reduction = utils::IteratorType::reduction;
+  SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
+  SmallVector<AffineMap> img2colIndexingMaps = {
+      AffineMap::get(nloops, 0, inputExprs, context),
+      AffineMap::getMultiDimIdentityMap(nloops, context)};
+  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
+      loc, colTensor.getType(),
+      /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
+      img2colIterators,
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
+      });
+  SmallVector<ReassociationIndices> img2ColTensorReassocIndices;
+  SmallVector<ReassociationIndices> outputReassocIndices;
+  RankedTensorType reshapedImg2ColTensorType, reshapedOutputType;
+  if (n == 1) {
+    img2ColTensorReassocIndices = {{0, 1, 2}, {3, 4, 5}};
+    outputReassocIndices = {{0, 1, 2}, {3}};
+    reshapedImg2ColTensorType = RankedTensorType::get(
+        {oh * ow, fh * fw * ic}, inputType.getElementType());
+    reshapedOutputType =
+        RankedTensorType::get({oh * ow, oc}, outputType.getElementType());
+  } else {
+    img2ColTensorReassocIndices = {{0}, {1, 2}, {3, 4, 5}};
+    outputReassocIndices = {{0}, {1, 2}, {3}};
+    reshapedImg2ColTensorType = RankedTensorType::get(
+        {n, oh * ow, fh * fw * ic}, inputType.getElementType());
+    reshapedOutputType =
+        RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
+  }
+  SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
+  auto reshapedFilterType =
+      RankedTensorType::get({fh * fw * ic, oc}, inputType.getElementType());
+  Value reshapedImg2ColTensor = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
+      img2ColTensorReassocIndices);
+  Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedFilterType, filter, filterReassocIndices);
+  Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedOutputType, output, outputReassocIndices);
+  Value result;
+  if (n == 1) {
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, reshapedOutputType,
+        ArrayRef<Value>{reshapedImg2ColTensor, reshapedFilter},
+        ArrayRef<Value>{reshapedOutput});
+    result = matmulOp.getResults().front();
+  } else {
+    // For cases where batch is not 1, we need to keep the batch dimension
+    // separate. Because the filter does not share the same batch dimension,
+    // the batch dimension is only used in indexing the input and output. Thus
+    // we cannot use existing linalg named ops like linalg.batch_matmul.
+    // i.e. (B x) M x K * K x N = (B x) M x N
+    AffineExpr bDim, mDim, nDim, kDim;
+    bindDims(context, bDim, mDim, nDim, kDim);
+    auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
+    auto rhsMap = AffineMap::get(4, 0, {kDim, nDim}, context);
+    auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
+    SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
+                                                         parallel, reduction};
+    auto genericOp = rewriter.create<linalg::GenericOp>(
+        loc, reshapedOutputType,
+        /*inputs=*/ValueRange{reshapedImg2ColTensor, reshapedFilter},
+        /*outputs=*/ValueRange{reshapedOutput},
+        ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
+        [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+          Value mul = createMul(loc, args[0], args[1], nestedBuilder);
+          Value add = createAdd(loc, mul, args[2], nestedBuilder);
+          nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
+        });
+    result = genericOp.getResults().front();
+  }
+  auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
+      loc, outputType, result, outputReassocIndices);
+  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
+  return std::make_pair(img2ColTensor.getOperation(),
+                        reshapedResult.getOperation());
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter,
+                linalg::DepthwiseConv2DNhwcHwcOp convOp) {
+  auto inputType = convOp.getInputs()[0].getType().cast<RankedTensorType>();
+  auto filterType = convOp.getInputs()[1].getType().cast<RankedTensorType>();
+  auto outputType = convOp.getOutputs()[0].getType().cast<RankedTensorType>();
+  if (!filterType.hasStaticShape())
+    return rewriter.notifyMatchFailure(
+        convOp, "expected a static shape for the filter");
+  if (!inputType.hasStaticShape())
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected a static shape for the input");
+  // TODO: Support dilation.
+  if (!hasAllOneValues(convOp.getDilations()))
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected all ones for dilations");
+  Location loc = convOp.getLoc();
+  auto transposeOperand = [&](Value operand, ArrayRef<int64_t> indices) {
+    auto operandTensorType = operand.getType().cast<RankedTensorType>();
+    auto nloops = indices.size();
+    ArrayRef<int64_t> inputShape = operandTensorType.getShape();
+    SmallVector<AffineExpr> exprs = llvm::to_vector<4>(
+        llvm::map_range(indices, [&](int64_t index) -> AffineExpr {
+          return rewriter.getAffineDimExpr(index);
+        }));
+    SmallVector<int64_t> targetShape = llvm::to_vector<4>(llvm::map_range(
+        indices, [&](int64_t index) -> int64_t { return inputShape[index]; }));
+    Value outputTensor = rewriter.create<tensor::EmptyOp>(
+        loc, targetShape, operandTensorType.getElementType());
+    SmallVector<utils::IteratorType> loopAttributeTypes(
+        nloops, utils::IteratorType::parallel);
+    SmallVector<AffineMap> indexingMaps = {
+        inversePermutation(
+            AffineMap::get(nloops, 0, exprs, rewriter.getContext())),
+        AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};
+    auto transposedOp = rewriter.create<linalg::GenericOp>(
+        loc, outputTensor.getType(),
+        /*inputs=*/operand, /*outputs=*/outputTensor, indexingMaps,
+        loopAttributeTypes,
+        [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+          nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
+        });
+    return transposedOp.getResult(0);
+  };
+  Value input = convOp.getInputs()[0];
+  Value filter = convOp.getInputs()[1];
+  Value output = convOp.getOutputs()[0];
+  // Transpose input, filter so channels are outermost
+  Value inputT = transposeOperand(input, {0, 3, 1, 2});
+  Value filterT = transposeOperand(filter, {2, 0, 1});
+  ArrayRef<int64_t> filterTShape =
+      filterT.getType().cast<RankedTensorType>().getShape();
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+  int n = outputShape[0];
+  int oh = outputShape[1];
+  int ow = outputShape[2];
+  int c = outputShape[3];
+  int fh = filterTShape[1];
+  int fw = filterTShape[2];
+  SmallVector<int64_t> colTensorShape = {n, c, oh, ow, fh, fw};
+  Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2});
+  AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim;
+  bindDims(rewriter.getContext(), nDim, cDim, ohDim, owDim, khDim, kwDim);
+  AffineExpr shSym = rewriter.getAffineConstantExpr(
+      convOp.getStrides().getValues<int64_t>()[0]);
+  AffineExpr swSym = rewriter.getAffineConstantExpr(
+      convOp.getStrides().getValues<int64_t>()[1]);
+  SmallVector<AffineExpr> inputExprs = {nDim, cDim, ohDim * shSym + khDim,
+                                        owDim * swSym + kwDim};
+  auto nloops = colTensorShape.size();
+  SmallVector<utils::IteratorType> loopAttributeTypes(
+      nloops, utils::IteratorType::parallel);
+  SmallVector<AffineMap> indexingMaps = {
+      AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()),
+      AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};
+  Value colTensor = rewriter.create<tensor::EmptyOp>(
+      loc, colTensorShape, inputType.getElementType());
+  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
+      loc, colTensor.getType(),
+      /*inputs=*/inputT, /*outputs=*/colTensor, indexingMaps,
+      loopAttributeTypes,
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
+      });
+  SmallVector<ReassociationIndices> img2ColTensorReassocIndices = {
+      {0, 1}, {2, 3}, {4, 5}};
+  SmallVector<ReassociationIndices> filterReassociationIndice = {{0}, {1, 2}};
+  SmallVector<ReassociationIndices> outputReassociationIndice = {{0, 1},
+                                                                 {2, 3}};
+  auto reshapedImg2ColTensorType = RankedTensorType::get(
+      {n * c, oh * ow, fh * fw}, inputType.getElementType());
+  auto reshapedFilterTensorType =
+      RankedTensorType::get({c, fh * fw}, filterType.getElementType());
+  auto reshapedOutputTensorType =
+      RankedTensorType::get({n * c, oh * ow}, outputType.getElementType());
+  Value reshapedImg2ColTensor = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
+      img2ColTensorReassocIndices);
+  Value reshapedFilterTensor = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedFilterTensorType, filterT, filterReassociationIndice);
+  Value reshapedoutputTensor = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedOutputTensorType, transposedOutputTensor,
+      outputReassociationIndice);
+  auto batchMatVecResult = rewriter.create<linalg::BatchMatvecOp>(
+      loc, TypeRange{reshapedoutputTensor.getType()},
+      ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},
+      ValueRange{reshapedoutputTensor});
+  SmallVector<ReassociationIndices> batchMatVecReassociationIndice = {{0, 1},
+                                                                      {2, 3}};
+  Value batchMatVecResultReshaped = rewriter.create<tensor::ExpandShapeOp>(
+      loc, transposedOutputTensor.getType(), batchMatVecResult.getResult(0),
+      batchMatVecReassociationIndice);
+  Value transposedResult =
+      transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});
+  rewriter.replaceOp(convOp, ArrayRef<Value>{transposedResult});
+  return std::make_pair(img2ColTensor.getOperation(),
+                        transposedResult.getDefiningOp());
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
+  auto inputType = convOp.getInputs()[0].getType().cast<ShapedType>();
+  auto filterType = convOp.getInputs()[1].getType().cast<ShapedType>();
+  auto outputType = convOp.getOutputs()[0].getType().cast<ShapedType>();
+  if (!filterType.hasStaticShape())
+    return rewriter.notifyMatchFailure(
+        convOp, "expected a static shape for the filter");
+  if (!inputType.hasStaticShape())
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected a static shape for the input");
+  // TODO: Support dilation.
+  if (!hasAllOneValues(convOp.getDilations()))
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected all ones for dilations");
+  Value input = convOp.getInputs()[0];
+  Value filter = convOp.getInputs()[1];
+  Value output = convOp.getOutputs()[0];
+  auto filterShape = filterType.getShape();
+  auto outputShape = outputType.getShape();
+  int n = outputShape[0];
+  int oc = outputShape[1];
+  int oh = outputShape[2];
+  int ow = outputShape[3];
+  int ic = filterShape[1];
+  int fh = filterShape[2];
+  int fw = filterShape[3];
+  auto loc = convOp.getLoc();
+  SmallVector<int64_t, 4> colTensorShape = {n, ic, fh, fw, oh, ow};
+  Value colTensor = rewriter.create<tensor::EmptyOp>(
+      loc, colTensorShape, inputType.getElementType());
+  MLIRContext *context = rewriter.getContext();
+  AffineExpr nDim, icDim, khDim, kwDim, ohDim, owDim;
+  bindDims(context, nDim, icDim, khDim, kwDim, ohDim, owDim);
+  auto shSym = rewriter.getAffineConstantExpr(
+      convOp.getStrides().getValues<int64_t>()[0]);
+  auto swSym = rewriter.getAffineConstantExpr(
+      convOp.getStrides().getValues<int64_t>()[1]);
+  SmallVector<AffineExpr, 4> inputExprs = {nDim, icDim, ohDim * shSym + khDim,
+                                           owDim * swSym + kwDim};
+  auto nloops = colTensorShape.size();
+  auto parallel = utils::IteratorType::parallel;
+  auto reduction = utils::IteratorType::reduction;
+  SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel);
+  SmallVector<AffineMap, 4> img2colIndexingMaps = {
+      AffineMap::get(nloops, 0, inputExprs, context),
+      AffineMap::getMultiDimIdentityMap(nloops, context)};
+  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
+      loc, colTensor.getType(),
+      /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
+      img2colIterators,
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
+      });
+  SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
+  auto reshapedFilterType =
+      RankedTensorType::get({oc, fh * fw * ic}, inputType.getElementType());
+  Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedFilterType, filter, filterReassocIndices);
+  SmallVector<ReassociationIndices> img2ColTensorReassocIndices;
+  SmallVector<ReassociationIndices> outputReassocIndices;
+  RankedTensorType reshapedImg2ColTensorType, reshapedOutputType;
+  if (n == 1) {
+    img2ColTensorReassocIndices = {{0, 1, 2, 3}, {4, 5}};
+    outputReassocIndices = {{0, 1}, {2, 3}};
+    reshapedImg2ColTensorType = RankedTensorType::get(
+        {fh * fw * ic, oh * ow}, inputType.getElementType());
+    reshapedOutputType =
+        RankedTensorType::get({oc, oh * ow}, outputType.getElementType());
+  } else {
+    img2ColTensorReassocIndices = {{0}, {1, 2, 3}, {4, 5}};
+    outputReassocIndices = {{0}, {1}, {2, 3}};
+    reshapedImg2ColTensorType = RankedTensorType::get(
+        {n, fh * fw * ic, oh * ow}, inputType.getElementType());
+    reshapedOutputType =
+        RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType());
+  }
+  Value reshapedImg2ColTensor = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
+      img2ColTensorReassocIndices);
+  Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedOutputType, output, outputReassocIndices);
+  Value result;
+  if (n == 1) {
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, reshapedOutputType,
+        ArrayRef<Value>{reshapedFilter, reshapedImg2ColTensor},
+        ArrayRef<Value>{reshapedOutput});
+    result = matmulOp.getResults().front();
+  } else {
+    // For cases where batch is not 1, we need to keep the batch dimension
+    // separate. Because the filter does not share the same batch dimension,
+    // the batch dimension is only used in indexing the input and output. Thus
+    // we cannot use existing linalg named ops like linalg.batch_matmul.
+    // i.e. M x K * (B x) K x N = (B x) M x N
+    AffineExpr bDim, mDim, nDim, kDim;
+    bindDims(context, bDim, mDim, nDim, kDim);
+    auto lhsMap = AffineMap::get(4, 0, {mDim, kDim}, context);
+    auto rhsMap = AffineMap::get(4, 0, {bDim, kDim, nDim}, context);
+    auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
+    SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
+                                                         parallel, reduction};
+    auto genericOp = rewriter.create<linalg::GenericOp>(
+        loc, reshapedOutputType,
+        /*inputs=*/ValueRange{reshapedFilter, reshapedImg2ColTensor},
+        /*outputs=*/ValueRange{reshapedOutput},
+        ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
+        [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+          Value mul = createMul(loc, args[0], args[1], nestedBuilder);
+          Value add = createAdd(loc, mul, args[2], nestedBuilder);
+          nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
+        });
+    result = genericOp.getResults().front();
+  }
+  auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
+      loc, outputType, result, outputReassocIndices);
+  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
+  return std::make_pair(img2ColTensor.getOperation(),
+                        reshapedResult.getOperation());
+namespace {
+class ConvertConv2DNhwcHwcf final
+    : public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
+                                PatternRewriter &rewriter) const override {
+    if (failed(rewriteInIm2Col(rewriter, convOp)))
+      return failure();
+    return success();
+  }
+class ConvertDepthwiseConv2DNhwcHwc final
+    : public OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> {
+  using OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,
+                                PatternRewriter &rewriter) const override {
+    if (failed(rewriteInIm2Col(rewriter, convOp)))
+      return failure();
+    return success();
+  }
+class ConvertConv2DNchwFchw final
+    : public OpRewritePattern<linalg::Conv2DNchwFchwOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,
+                                PatternRewriter &rewriter) const override {
+    if (failed(rewriteInIm2Col(rewriter, convOp)))
+      return failure();
+    return success();
+  }
+} // end anonymous namespace
+void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns) {
+  MLIRContext *context = patterns.getContext();
+  patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
+                  ConvertConv2DNchwFchw>(context);
+} // end namespace linalg
+} // end namespace mlir

diff  --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
new file mode 100644
index 0000000000000..e33e51ddababb
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
@@ -0,0 +1,245 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics | FileCheck %s
+// Check that the im2col patterns are properly connected with the
+// transform dialect.
+// Non static shapes are not supported.
+// Check that we emit an error.
+// TODO: Hook up the rewriter errors in transform dialect.
+func.func @conv_non_static(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+    // expected-note at below {{when applied to this op}}
+    %0 = linalg.conv_2d_nhwc_hwcf
+      {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+       ins(%arg0, %arg1: tensor<?x?x?x?xf32>, tensor<3x3x4x16xf32>)
+      outs(%arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+    return %0 : tensor<?x?x?x?xf32>
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+  // expected-error at below {{failed to apply}}
+  %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
+// -----
+// Check that we get the proper handles for the img2col tensor producer
+// and the final instruction.
+// CHECK: IR printer: tensor_producer
+// CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d3, d2 + d4, d5)>,
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>]
+// CHECK: ^bb0(%[[IN_DATA:.+]]: f32, %[[OUT_DATA:.+]]: f32)
+// CHECK: linalg.yield %[[IN_DATA]] : f32
+// CHECK: IR printer: transformed
+// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0, 1, 2], [3]] : tensor<196x16xf32> into tensor<1x14x14x16xf32>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d3, d2 + d4, d5)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+//      CHECK: @conv_16433136
+//      CHECK: %[[INPUT:.+]]: tensor<1x16x16x4xf32>
+//      CHECK: %[[FILTER:.+]]: tensor<3x3x4x16xf32>
+//      CHECK: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32>
+//      CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x14x14x3x3x4xf32>
+//      CHECK: %[[COL_TENSOR:.+]] = linalg.generic
+//           CHECK-SAME: #[[MAP0]]
+//           CHECK-SAME: #[[MAP1]]
+//                CHECK: ^bb0(%[[IN_DATA:.+]]: f32, %[[OUT_DATA:.+]]: f32)
+//                CHECK: linalg.yield %[[IN_DATA]] : f32
+//      CHECK-DAG: %[[RESHAPED_INIT_COL_TENSOR:.+]] = tensor.collapse_shape %[[COL_TENSOR]]
+//           CHECK-SAME: [0, 1, 2], [3, 4, 5]
+//           CHECK-SAME: tensor<1x14x14x3x3x4xf32> into tensor<196x36xf32>
+//      CHECK-DAG: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]]
+//           CHECK-SAME: [0, 1, 2], [3]
+//           CHECK-SAME: tensor<3x3x4x16xf32> into tensor<36x16xf32>
+//      CHECK-DAG: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]]
+//           CHECK-SAME: [0, 1, 2], [3]
+//      CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_INIT_COL_TENSOR]], %[[RESHAPED_FILTER]] : tensor<196x36xf32>, tensor<36x16xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<196x16xf32>)
+//      CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]] : tensor<196x16xf32> into tensor<1x14x14x16xf32>
+//      CHECK: return %[[RESULT]]
+func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+    %0 = linalg.conv_2d_nhwc_hwcf
+      {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+       ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
+      outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+    return %0 : tensor<1x14x14x16xf32>
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+  %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
+  transform.print %img2col_tensor_producer {name = "tensor_producer"}: !pdl.operation
+  transform.print %transformed {name = "transformed"}: !pdl.operation
+// -----
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
+// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1, d2)>
+// CHECK: @depthwise_conv_hwc_114x16x3
+// CHECK-SAME: %[[INPUT:.+]]: tensor<1x114x114x16xf32>
+// CHECK-SAME: %[[FILTER:.+]]: tensor<3x3x16xf32>
+// CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x112x112x16xf32>
+//      CHECK: %[[INPUT_T_INIT:.+]] = tensor.empty() : tensor<1x16x114x114xf32>
+//      CHECK: %[[INPUT_T:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[INPUT]] : tensor<1x114x114x16xf32>) outs(%[[INPUT_T_INIT]] : tensor<1x16x114x114xf32>) {
+// CHECK-NEXT: ^bb0(%[[ARG3:.+]]: f32, %[[ARG4:.+]]: f32):
+// CHECK-NEXT:     linalg.yield %[[ARG3]] : f32
+// CHECK-NEXT:  } -> tensor<1x16x114x114xf32>
+//      CHECK: %[[FILTER_T_INIT:.+]] = tensor.empty() : tensor<16x3x3xf32>
+//      CHECK: %[[FILTER_T:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[FILTER]] : tensor<3x3x16xf32>) outs(%[[FILTER_T_INIT]] : tensor<16x3x3xf32>) {
+// CHECK-NEXT:      ^bb0(%{{.*}}: f32, %{{.*}}: f32):
+//      CHECK:      linalg.yield
+//      CHECK:    } -> tensor<16x3x3xf32>
+//      CHECK: %[[INIT_OUTPUT_TENSOR:.+]] = tensor.empty() : tensor<1x16x112x112xf32>
+//      CHECK: %[[OUTPUT_T:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[OUTPUT]] : tensor<1x112x112x16xf32>) outs(%[[INIT_OUTPUT_TENSOR]] : tensor<1x16x112x112xf32>) {
+// CHECK-NEXT:  ^bb0(%{{.*}}: f32, %{{.*}}: f32):
+// CHECK-NEXT:     linalg.yield
+// CHECK-NEXT:  } -> tensor<1x16x112x112xf32>
+//      CHECK:  %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x16x112x112x3x3xf32>
+//      CHECK: %[[COL_TENSOR:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP4]], #[[MAP5]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:   ins(%[[INPUT_T]] : tensor<1x16x114x114xf32>) outs(%[[INIT_COL_TENSOR]] : tensor<1x16x112x112x3x3xf32>) {
+// CHECK-NEXT:      ^bb0(%{{.*}}: f32, %{{.*}}: f32):
+// CHECK-NEXT:         linalg.yield
+// CHECK-NEXT:    } -> tensor<1x16x112x112x3x3xf32>
+//      CHECK: %[[COL_TENSOR_R:.+]] = tensor.collapse_shape %[[COL_TENSOR]]
+// CHECK-SAME:    tensor<1x16x112x112x3x3xf32> into tensor<16x12544x9xf32>
+//      CHECK: %[[FILTER_T_R:.+]] = tensor.collapse_shape %[[FILTER_T]]
+// CHECK-SAME:    tensor<16x3x3xf32> into tensor<16x9xf32>
+//      CHECK: %[[OUTPUT_T_R:.+]] = tensor.collapse_shape %[[OUTPUT_T]]
+// CHECK-SAME:    tensor<1x16x112x112xf32> into tensor<16x12544xf32>
+//      CHECK: %[[BMV_RESULT:.+]] = linalg.batch_matvec ins(%[[COL_TENSOR_R]], %[[FILTER_T_R]] : tensor<16x12544x9xf32>, tensor<16x9xf32>) outs(%[[OUTPUT_T_R]] : tensor<16x12544xf32>) -> tensor<16x12544xf32>
+//      CHECK: %[[RESULT_R:.+]] = tensor.expand_shape %[[BMV_RESULT]]
+// CHECK-SAME:    tensor<16x12544xf32> into tensor<1x16x112x112xf32>
+//      CHECK: %[[RESULT_INIT:.+]] = tensor.empty() : tensor<1x112x112x16xf32>
+//      CHECK: %[[RESULT:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP6]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[RESULT_R]] : tensor<1x16x112x112xf32>) outs(%[[RESULT_INIT]] : tensor<1x112x112x16xf32>) {
+// CHECK-NEXT:      ^bb0(%{{.*}}: f32, %{{.*}}: f32):
+// CHECK-NEXT:      linalg.yield
+// CHECK-NEXT:    } -> tensor<1x112x112x16xf32>
+//      CHECK: return %[[RESULT]] : tensor<1x112x112x16xf32>
+func.func @depthwise_conv_hwc_114x16x3(%input: tensor<1x114x114x16xf32>, %filter: tensor<3x3x16xf32>, %output: tensor<1x112x112x16xf32>) -> tensor<1x112x112x16xf32> {
+    %0 = linalg.depthwise_conv_2d_nhwc_hwc {
+      dilations = dense<1> : tensor<2xi64>,
+      strides = dense<1> : tensor<2xi64>
+    } ins(%input, %filter : tensor<1x114x114x16xf32>, tensor<3x3x16xf32>) outs(%output : tensor<1x112x112x16xf32>) -> tensor<1x112x112x16xf32>
+    return %0 : tensor<1x112x112x16xf32>
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.depthwise_conv_2d_nhwc_hwc"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+  %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
+// -----
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d3, d2 + d4, d5)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+//  CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+//  CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+//  CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+//      CHECK: func.func @batch_nhwc_conv
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<8x16x16x4xf32>, %[[FILTER:.+]]: tensor<3x3x4x16xf32>, %[[INIT:.+]]: tensor<8x14x14x16xf32>)
+//      CHECK:   %[[IT:.+]] = tensor.empty() : tensor<8x14x14x3x3x4xf32>
+//      CHECK:   %[[IMG2COL:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:   ins(%[[INPUT]] : tensor<8x16x16x4xf32>)
+// CHECK-SAME:   outs(%[[IT]] : tensor<8x14x14x3x3x4xf32>)
+//      CHECK:   %[[CS_INPUT:.+]] = tensor.collapse_shape %[[IMG2COL]] {{\[}}[0], [1, 2], [3, 4, 5]] : tensor<8x14x14x3x3x4xf32> into tensor<8x196x36xf32>
+//      CHECK:   %[[CS_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<3x3x4x16xf32> into tensor<36x16xf32>
+//      CHECK:   %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2], [3]] : tensor<8x14x14x16xf32> into tensor<8x196x16xf32>
+//      CHECK:   %[[MATMUL:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]],
+// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:   ins(%[[CS_INPUT]], %[[CS_FILTER]] : tensor<8x196x36xf32>, tensor<36x16xf32>)
+// CHECK-SAME:   outs(%[[CS_RESULT]] : tensor<8x196x16xf32>)
+//      CHECK:   ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32):
+//      CHECK:     %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
+//      CHECK:     %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
+//      CHECK:     linalg.yield %[[ADD]] : f32
+//      CHECK:   } -> tensor<8x196x16xf32>
+//      CHECK:   %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1, 2], [3]] : tensor<8x196x16xf32> into tensor<8x14x14x16xf32>
+//      CHECK:   return %[[CS_FINAL]]
+func.func @batch_nhwc_conv(%arg0: tensor<8x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<8x14x14x16xf32>) -> tensor<8x14x14x16xf32> {
+    %0 = linalg.conv_2d_nhwc_hwcf
+      {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+       ins(%arg0, %arg1: tensor<8x16x16x4xf32>, tensor<3x3x4x16xf32>)
+      outs(%arg2: tensor<8x14x14x16xf32>) -> tensor<8x14x14x16xf32>
+    return %0 : tensor<8x14x14x16xf32>
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+  %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
+// -----
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4 + d2, d5 + d3)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+//  CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+//  CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+//  CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+//      CHECK: func.func @batch_nchw_conv
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<8x4x16x16xf32>, %[[FILTER:.+]]: tensor<16x4x3x3xf32>, %[[INIT:.+]]: tensor<8x16x14x14xf32>)
+//      CHECK:   %[[IT:.+]] = tensor.empty() : tensor<8x4x3x3x14x14xf32>
+//      CHECK:   %[[IMG2COL:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:   ins(%[[INPUT]] : tensor<8x4x16x16xf32>)
+// CHECK-SAME:   outs(%[[IT]] : tensor<8x4x3x3x14x14xf32>)
+//      CHECK:   %[[CS_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<16x4x3x3xf32> into tensor<16x36xf32>
+//      CHECK:   %[[CS_INPUT:.+]] = tensor.collapse_shape %[[IMG2COL]] {{\[}}[0], [1, 2, 3], [4, 5]] : tensor<8x4x3x3x14x14xf32> into tensor<8x36x196xf32>
+//      CHECK:   %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1], [2, 3]] : tensor<8x16x14x14xf32> into tensor<8x16x196xf32>
+//      CHECK:   %[[MATMUL:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]],
+// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:   ins(%[[CS_FILTER]], %[[CS_INPUT]] : tensor<16x36xf32>, tensor<8x36x196xf32>)
+// CHECK-SAME:   outs(%[[CS_RESULT]] : tensor<8x16x196xf32>)
+//      CHECK:   ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32):
+//      CHECK:     %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
+//      CHECK:     %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
+//      CHECK:     linalg.yield %[[ADD]] : f32
+//      CHECK:   } -> tensor<8x16x196xf32>
+//      CHECK:   %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1], [2, 3]] : tensor<8x16x196xf32> into tensor<8x16x14x14xf32>
+//      CHECK:   return %[[CS_FINAL]]
+func.func @batch_nchw_conv(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32>, %arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> {
+    %0 = linalg.conv_2d_nchw_fchw
+      {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+       ins(%arg0, %arg1: tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32>)
+      outs(%arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32>
+    return %0 : tensor<8x16x14x14xf32>
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+  %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation)


More information about the Mlir-commits mailing list