[Mlir-commits] [mlir] 17b2e73 - [mlir][LinAlg][Transform] Add a transform op for conv2d to im2col

Thomas Raoux llvmlistbot at llvm.org
Thu Feb 23 14:29:31 PST 2023


Author: Quentin Colombet
Date: 2023-02-23T22:27:16Z
New Revision: 17b2e73cb477d42771fbc68a215dde648f3eaaef

URL: https://github.com/llvm/llvm-project/commit/17b2e73cb477d42771fbc68a215dde648f3eaaef
DIFF: https://github.com/llvm/llvm-project/commit/17b2e73cb477d42771fbc68a215dde648f3eaaef.diff

LOG: [mlir][LinAlg][Transform] Add a transform op for conv2d to im2col

This patch adds patterns to convert `linalg.conv_2d_xxx` operations
into `linalg.generic` (for img2col packing) and `linalg.matmul`.

The meat of the patch comes straight from IREE
(https://github.com/iree-org/iree).
(To the original authors are you okay with that?)

What this patch adds is proper plumbing of the im2col patterns into the
transform dialect.

PS: Feel free to add more reviewers. I wanted to cover the original contributors of im2col in IREE but I'm not sure I got all of them.

Reviewed By: nicolasvasilache, ThomasRaoux

Differential Revision: https://reviews.llvm.org/D144108

Added: 
    mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
    mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 11f3b3c634fdf..c53497839ea9a 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1740,6 +1740,82 @@ def HoistRedundantVectorTransfersOp :
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
         ::mlir::func::FuncOp target,
+         ::mlir::transform::ApplyToEachResultList &results,
+         ::mlir::transform::TransformState &state);
+   }];
+}
+
+//===----------------------------------------------------------------------===//
+// ConvertConv2DToImg2ColOp
+//===----------------------------------------------------------------------===//
+
+def ConvertConv2DToImg2ColOp : Op<Transform_Dialect,
+    "structured.convert_conv2d_to_img2col",
+    [FunctionalStyleTransformOpTrait,
+     MemoryEffectsOpInterface,
+     TransformOpInterface,
+     TransformEachOpTrait]> {
+  let description = [{
+    Convert linalg.conv_2d_xxx into linalg.generic (for img2col packing)
+    and linalg.matmul.
+
+    A convolution operation can be written as a matrix-matrix multiplication by
+    unfolding the cross-correlation between input and filter and explicitly copy
+    overlapped sliding window inputs.
+
+    Consider 2D input X with single channel input and output and 2x2 filter W:
+    ```
+    [x(0, 0)  , x(0, 1)  , ...,   x(0, n)  ]
+    [x(1, 0)  , x(1, 1)  , ...,   x(1, n)  ]
+    [.        ,  .       ,.   ,      .     ]            [w(0, 0), w(0, 1)]
+    [.        ,  .       , .  ,      .     ]    (conv)  [w(1, 0), w(1, 1)]
+    [.        ,  .       ,   .,      .     ]
+    [x(n-1, 0), x(n-1, 1), ..., x(n-1, n-1)]
+    ```
+
+    The packed input data (img2col) is a matrix with |rows| = output spatial
+    size, |columns| = filter spatial size. To compute the output Y(i, j) we need
+    to calculate the dot product between filter window at input X(x, y)) and the
+    filter which will look like the following where r.h.s is the img2col matrix
+    and l.h.s is the flattned filter:
+    ```
+    [x(0,0), x(0,1), x(1,0), x(1,1)]
+    [x(0,1), x(1,1), x(0,2), x(1,2)] (matmul) [w(0,0), w(0,1), w(1,0), w(1,1)]
+    [x(0,1), x(1,1), x(0,2), x(1,2)]
+    [   .  ,    .  ,    .  ,    .  ]
+    ```
+
+    In general for 2D case with (N, H, W, C) input and (Kh, Kw, C, D) filter
+    and output (N, Ho, Wo, D) the convolution is the following matrix-matrix
+    multiplication (Ho x Wo, Kh x Kw x C) * (Kh x Kw x C, D) for each input in
+    the N input. For the case where N > 1 its a batched matrxi-matrix
+    multplication.
+
+    Returns two handles:
+    - One on the operation that produces the img2col tensor.
+    - One on the final operation of the sequence that replaces the original
+      convolution.
+
+    #### Return modes:
+
+    Returns a definite failure if target is not isolated from above.
+    Returns a silenceable failure if the pattern application failed.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$img2col_tensor,
+                      TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat =
+    "$target attr-dict `:` functional-type($target, results)";
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>
+  ];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::linalg::LinalgOp target,
         ::mlir::transform::ApplyToEachResultList &results,
         ::mlir::transform::TransformState &state);
   }];

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index c782cab1e9f94..8e77b54251aa7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -879,6 +879,64 @@ void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
 void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
                                           PatternBenefit benefit = 1);
 
+/// Populates patterns to transform linalg.conv_2d_xxx operations into
+/// linalg.generic (for img2col packing) and linalg.matmul.
+/// \see rewriteInIm2Col for more details.
+void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns);
+
+/// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing)
+/// and linalg.matmul.
+///
+/// A convolution operation can be written as a matrix-matrix multiplication by
+/// unfolding the cross-correlation between input and filter and explicitly copy
+/// overlapped sliding window inputs.
+///
+/// Consider 2D input X with single channel input and output and 2x2 filter W:
+/// [x(0, 0)  , x(0, 1)  , ...,   x(0, n)  ]
+/// [x(1, 0)  , x(1, 1)  , ...,   x(1, n)  ]
+/// [.        ,  .       ,.   ,      .     ]            [w(0, 0), w(0, 1)]
+/// [.        ,  .       , .  ,      .     ]    (conv)  [w(1, 0), w(1, 1)]
+/// [.        ,  .       ,   .,      .     ]
+/// [x(n-1, 0), x(n-1, 1), ..., x(n-1, n-1)]
+///
+/// The packed input data (img2col) is a matrix with |rows| = output spatial
+/// size, |columns| = filter spatial size. To compute the output Y(i, j) we need
+/// to calculate the dot product between filter window at input X(x, y)) and the
+/// filter which will look like the following where r.h.s is the img2col matrix
+/// and l.h.s is the flattned filter:
+///
+/// [x(0,0), x(0,1), x(1,0), x(1,1)]
+/// [x(0,1), x(1,1), x(0,2), x(1,2)] (matmul) [w(0,0), w(0,1), w(1,0), w(1,1)]
+/// [x(0,1), x(1,1), x(0,2), x(1,2)]
+/// [   .  ,    .  ,    .  ,    .  ]
+///
+/// In general for 2D case with (N, H, W, C) input and (Kh, Kw, C, D) filter
+/// and output (N, Ho, Wo, D) the convolution is the following matrix-matrix
+/// multiplication (Ho x Wo, Kh x Kw x C) * (Kh x Kw x C, D) for each input in
+/// the N input. For the case where N > 1 its a batched matrxi-matrix
+/// multplication.
+///
+/// On success, return both the operation that produces the img2col tensor and
+/// the final operation of the sequence that replaces the original convolution.
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp);
+
+/// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except there is no
+/// reduction among the input channels so each convolution can be a
+/// matrix-vector product and by transposing both input filter so channels are
+/// outer most the computation is a batched matrix-vector product.
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter,
+                linalg::DepthwiseConv2DNhwcHwcOp convOp);
+
+/// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except because the
+/// channels are to the left of the image shape dimensions, the position of the
+/// contraction dimension in the resulting matmul is reversed. This swaps the
+/// LHS and RHS of the matmul when compared with nhwc (i.e. (D, C x Kh x Kw) *
+/// (C x Kh x Kw, Ho x Wo))
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp);
+
 //===----------------------------------------------------------------------===//
 // Op-specific patterns.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index aae073c28db8d..2d102383ddbe0 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -34,6 +34,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
@@ -3072,6 +3073,40 @@ transform::HoistRedundantVectorTransfersOp::applyToOne(
   results.push_back(target);
   return DiagnosedSilenceableFailure::success();
 }
+
+//===----------------------------------------------------------------------===//
+// ConvertConv2DToImg2ColOp.
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
+    linalg::LinalgOp target, transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  IRRewriter rewriter(target->getContext());
+  rewriter.setInsertionPoint(target);
+  auto maybeTransformed =
+      TypeSwitch<Operation *, FailureOr<std::pair<Operation *, Operation *>>>(
+          target)
+          .Case([&](linalg::Conv2DNhwcHwcfOp op) {
+            return rewriteInIm2Col(rewriter, op);
+          })
+          .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
+            return rewriteInIm2Col(rewriter, op);
+          })
+          .Case([&](linalg::Conv2DNchwFchwOp op) {
+            return rewriteInIm2Col(rewriter, op);
+          })
+          .Default([&](Operation *op) {
+            return rewriter.notifyMatchFailure(op, "not supported");
+          });
+  if (failed(maybeTransformed))
+    return emitDefaultSilenceableFailure(target);
+  // Handle to the operation producing the img2col tensor.
+  results.push_back(maybeTransformed->first);
+  // Handle to the operation that replaces the original convolution.
+  results.push_back(maybeTransformed->second);
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 7ede7e4f96915..adcc87f42dab2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   Bufferize.cpp
   ConstantFold.cpp
   ConvertToDestinationStyle.cpp
+  ConvertConv2DToImg2Col.cpp
   DataLayoutPropagation.cpp
   DecomposeLinalgOps.cpp
   Detensorize.cpp

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
new file mode 100644
index 0000000000000..14bff411ef8c1
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -0,0 +1,540 @@
+//===- ConvertConv2DToImg2Col.cpp - im2col implementation -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include <utility>
+
+namespace mlir {
+namespace linalg {
+static bool hasAllOneValues(DenseIntElementsAttr attr) {
+  return llvm::all_of(
+      attr, [](APInt element) { return element.getSExtValue() == 1; });
+}
+
+static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) {
+  bool isInt = x.getType().isa<IntegerType>();
+  if (isInt)
+    return builder.create<arith::AddIOp>(loc, x, y);
+  return builder.create<arith::AddFOp>(loc, x, y);
+}
+
+static Value createMul(Location loc, Value x, Value y, OpBuilder &builder) {
+  bool isInt = x.getType().isa<IntegerType>();
+  if (isInt)
+    return builder.create<arith::MulIOp>(loc, x, y);
+  return builder.create<arith::MulFOp>(loc, x, y);
+}
+
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
+  auto inputType = convOp.getInputs()[0].getType().cast<ShapedType>();
+  auto filterType = convOp.getInputs()[1].getType().cast<ShapedType>();
+  auto outputType = convOp.getOutputs()[0].getType().cast<ShapedType>();
+
+  if (!filterType.hasStaticShape())
+    return rewriter.notifyMatchFailure(
+        convOp, "expected a static shape for the filter");
+
+  if (!inputType.hasStaticShape())
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected a static shape for the input");
+
+  // TODO: Support dilation.
+  if (!hasAllOneValues(convOp.getDilations()))
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected all ones for dilations");
+
+  MLIRContext *context = rewriter.getContext();
+  Value input = convOp.getInputs()[0];
+  Value filter = convOp.getInputs()[1];
+  Value output = convOp.getOutputs()[0];
+
+  ArrayRef<int64_t> filterShape = filterType.getShape();
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+
+  int n = outputShape[0];
+  int oh = outputShape[1];
+  int ow = outputShape[2];
+  int oc = outputShape[3];
+  int fh = filterShape[0];
+  int fw = filterShape[1];
+  int ic = filterShape[2];
+
+  Location loc = convOp.getLoc();
+
+  SmallVector<int64_t> colTensorShape = {n, oh, ow, fh, fw, ic};
+
+  Value colTensor = rewriter.create<tensor::EmptyOp>(
+      loc, colTensorShape, inputType.getElementType());
+
+  AffineExpr nDim, ohDim, owDim, khDim, kwDim, icDim;
+  bindDims(context, nDim, ohDim, owDim, khDim, kwDim, icDim);
+
+  AffineExpr shSym = rewriter.getAffineConstantExpr(
+      convOp.getStrides().getValues<int64_t>()[0]);
+  AffineExpr swSym = rewriter.getAffineConstantExpr(
+      convOp.getStrides().getValues<int64_t>()[1]);
+
+  SmallVector<AffineExpr> inputExprs = {nDim, ohDim * shSym + khDim,
+                                        owDim * swSym + kwDim, icDim};
+
+  auto nloops = colTensorShape.size();
+
+  auto parallel = utils::IteratorType::parallel;
+  auto reduction = utils::IteratorType::reduction;
+  SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
+
+  SmallVector<AffineMap> img2colIndexingMaps = {
+      AffineMap::get(nloops, 0, inputExprs, context),
+      AffineMap::getMultiDimIdentityMap(nloops, context)};
+
+  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
+      loc, colTensor.getType(),
+      /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
+      img2colIterators,
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
+      });
+
+  SmallVector<ReassociationIndices> img2ColTensorReassocIndices;
+  SmallVector<ReassociationIndices> outputReassocIndices;
+  RankedTensorType reshapedImg2ColTensorType, reshapedOutputType;
+  if (n == 1) {
+    img2ColTensorReassocIndices = {{0, 1, 2}, {3, 4, 5}};
+    outputReassocIndices = {{0, 1, 2}, {3}};
+
+    reshapedImg2ColTensorType = RankedTensorType::get(
+        {oh * ow, fh * fw * ic}, inputType.getElementType());
+    reshapedOutputType =
+        RankedTensorType::get({oh * ow, oc}, outputType.getElementType());
+  } else {
+    img2ColTensorReassocIndices = {{0}, {1, 2}, {3, 4, 5}};
+    outputReassocIndices = {{0}, {1, 2}, {3}};
+
+    reshapedImg2ColTensorType = RankedTensorType::get(
+        {n, oh * ow, fh * fw * ic}, inputType.getElementType());
+    reshapedOutputType =
+        RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
+  }
+
+  SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
+  auto reshapedFilterType =
+      RankedTensorType::get({fh * fw * ic, oc}, inputType.getElementType());
+
+  Value reshapedImg2ColTensor = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
+      img2ColTensorReassocIndices);
+
+  Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedFilterType, filter, filterReassocIndices);
+
+  Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedOutputType, output, outputReassocIndices);
+
+  Value result;
+  if (n == 1) {
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, reshapedOutputType,
+        ArrayRef<Value>{reshapedImg2ColTensor, reshapedFilter},
+        ArrayRef<Value>{reshapedOutput});
+    result = matmulOp.getResults().front();
+  } else {
+    // For cases where batch is not 1, we need to keep the batch dimension
+    // separate. Because the filter does not share the same batch dimension,
+    // the batch dimension is only used in indexing the input and output. Thus
+    // we cannot use existing linalg named ops like linalg.batch_matmul.
+    // i.e. (B x) M x K * K x N = (B x) M x N
+    AffineExpr bDim, mDim, nDim, kDim;
+    bindDims(context, bDim, mDim, nDim, kDim);
+    auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
+    auto rhsMap = AffineMap::get(4, 0, {kDim, nDim}, context);
+    auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
+    SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
+                                                         parallel, reduction};
+
+    auto genericOp = rewriter.create<linalg::GenericOp>(
+        loc, reshapedOutputType,
+        /*inputs=*/ValueRange{reshapedImg2ColTensor, reshapedFilter},
+        /*outputs=*/ValueRange{reshapedOutput},
+        ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
+        [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+          Value mul = createMul(loc, args[0], args[1], nestedBuilder);
+          Value add = createAdd(loc, mul, args[2], nestedBuilder);
+          nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
+        });
+    result = genericOp.getResults().front();
+  }
+
+  auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
+      loc, outputType, result, outputReassocIndices);
+
+  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
+
+  return std::make_pair(img2ColTensor.getOperation(),
+                        reshapedResult.getOperation());
+}
+
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter,
+                linalg::DepthwiseConv2DNhwcHwcOp convOp) {
+  auto inputType = convOp.getInputs()[0].getType().cast<RankedTensorType>();
+  auto filterType = convOp.getInputs()[1].getType().cast<RankedTensorType>();
+  auto outputType = convOp.getOutputs()[0].getType().cast<RankedTensorType>();
+
+  if (!filterType.hasStaticShape())
+    return rewriter.notifyMatchFailure(
+        convOp, "expected a static shape for the filter");
+
+  if (!inputType.hasStaticShape())
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected a static shape for the input");
+
+  // TODO: Support dilation.
+  if (!hasAllOneValues(convOp.getDilations()))
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected all ones for dilations");
+
+  Location loc = convOp.getLoc();
+
+  auto transposeOperand = [&](Value operand, ArrayRef<int64_t> indices) {
+    auto operandTensorType = operand.getType().cast<RankedTensorType>();
+    auto nloops = indices.size();
+    ArrayRef<int64_t> inputShape = operandTensorType.getShape();
+
+    SmallVector<AffineExpr> exprs = llvm::to_vector<4>(
+        llvm::map_range(indices, [&](int64_t index) -> AffineExpr {
+          return rewriter.getAffineDimExpr(index);
+        }));
+
+    SmallVector<int64_t> targetShape = llvm::to_vector<4>(llvm::map_range(
+        indices, [&](int64_t index) -> int64_t { return inputShape[index]; }));
+
+    Value outputTensor = rewriter.create<tensor::EmptyOp>(
+        loc, targetShape, operandTensorType.getElementType());
+
+    SmallVector<utils::IteratorType> loopAttributeTypes(
+        nloops, utils::IteratorType::parallel);
+
+    SmallVector<AffineMap> indexingMaps = {
+        inversePermutation(
+            AffineMap::get(nloops, 0, exprs, rewriter.getContext())),
+        AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};
+
+    auto transposedOp = rewriter.create<linalg::GenericOp>(
+        loc, outputTensor.getType(),
+        /*inputs=*/operand, /*outputs=*/outputTensor, indexingMaps,
+        loopAttributeTypes,
+        [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+          nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
+        });
+
+    return transposedOp.getResult(0);
+  };
+
+  Value input = convOp.getInputs()[0];
+  Value filter = convOp.getInputs()[1];
+  Value output = convOp.getOutputs()[0];
+
+  // Transpose input, filter so channels are outermost
+  Value inputT = transposeOperand(input, {0, 3, 1, 2});
+  Value filterT = transposeOperand(filter, {2, 0, 1});
+  ArrayRef<int64_t> filterTShape =
+      filterT.getType().cast<RankedTensorType>().getShape();
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+
+  int n = outputShape[0];
+  int oh = outputShape[1];
+  int ow = outputShape[2];
+  int c = outputShape[3];
+  int fh = filterTShape[1];
+  int fw = filterTShape[2];
+
+  SmallVector<int64_t> colTensorShape = {n, c, oh, ow, fh, fw};
+  Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2});
+
+  AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim;
+  bindDims(rewriter.getContext(), nDim, cDim, ohDim, owDim, khDim, kwDim);
+
+  AffineExpr shSym = rewriter.getAffineConstantExpr(
+      convOp.getStrides().getValues<int64_t>()[0]);
+  AffineExpr swSym = rewriter.getAffineConstantExpr(
+      convOp.getStrides().getValues<int64_t>()[1]);
+
+  SmallVector<AffineExpr> inputExprs = {nDim, cDim, ohDim * shSym + khDim,
+                                        owDim * swSym + kwDim};
+
+  auto nloops = colTensorShape.size();
+
+  SmallVector<utils::IteratorType> loopAttributeTypes(
+      nloops, utils::IteratorType::parallel);
+
+  SmallVector<AffineMap> indexingMaps = {
+      AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()),
+      AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};
+
+  Value colTensor = rewriter.create<tensor::EmptyOp>(
+      loc, colTensorShape, inputType.getElementType());
+
+  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
+      loc, colTensor.getType(),
+      /*inputs=*/inputT, /*outputs=*/colTensor, indexingMaps,
+      loopAttributeTypes,
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
+      });
+
+  SmallVector<ReassociationIndices> img2ColTensorReassocIndices = {
+      {0, 1}, {2, 3}, {4, 5}};
+  SmallVector<ReassociationIndices> filterReassociationIndice = {{0}, {1, 2}};
+  SmallVector<ReassociationIndices> outputReassociationIndice = {{0, 1},
+                                                                 {2, 3}};
+
+  auto reshapedImg2ColTensorType = RankedTensorType::get(
+      {n * c, oh * ow, fh * fw}, inputType.getElementType());
+  auto reshapedFilterTensorType =
+      RankedTensorType::get({c, fh * fw}, filterType.getElementType());
+  auto reshapedOutputTensorType =
+      RankedTensorType::get({n * c, oh * ow}, outputType.getElementType());
+
+  Value reshapedImg2ColTensor = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
+      img2ColTensorReassocIndices);
+  Value reshapedFilterTensor = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedFilterTensorType, filterT, filterReassociationIndice);
+  Value reshapedoutputTensor = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedOutputTensorType, transposedOutputTensor,
+      outputReassociationIndice);
+
+  auto batchMatVecResult = rewriter.create<linalg::BatchMatvecOp>(
+      loc, TypeRange{reshapedoutputTensor.getType()},
+      ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},
+      ValueRange{reshapedoutputTensor});
+
+  SmallVector<ReassociationIndices> batchMatVecReassociationIndice = {{0, 1},
+                                                                      {2, 3}};
+
+  Value batchMatVecResultReshaped = rewriter.create<tensor::ExpandShapeOp>(
+      loc, transposedOutputTensor.getType(), batchMatVecResult.getResult(0),
+      batchMatVecReassociationIndice);
+
+  Value transposedResult =
+      transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});
+
+  rewriter.replaceOp(convOp, ArrayRef<Value>{transposedResult});
+  return std::make_pair(img2ColTensor.getOperation(),
+                        transposedResult.getDefiningOp());
+}
+
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
+  auto inputType = convOp.getInputs()[0].getType().cast<ShapedType>();
+  auto filterType = convOp.getInputs()[1].getType().cast<ShapedType>();
+  auto outputType = convOp.getOutputs()[0].getType().cast<ShapedType>();
+
+  if (!filterType.hasStaticShape())
+    return rewriter.notifyMatchFailure(
+        convOp, "expected a static shape for the filter");
+
+  if (!inputType.hasStaticShape())
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected a static shape for the input");
+
+  // TODO: Support dilation.
+  if (!hasAllOneValues(convOp.getDilations()))
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected all ones for dilations");
+
+  Value input = convOp.getInputs()[0];
+  Value filter = convOp.getInputs()[1];
+  Value output = convOp.getOutputs()[0];
+
+  auto filterShape = filterType.getShape();
+  auto outputShape = outputType.getShape();
+
+  int n = outputShape[0];
+  int oc = outputShape[1];
+  int oh = outputShape[2];
+  int ow = outputShape[3];
+  int ic = filterShape[1];
+  int fh = filterShape[2];
+  int fw = filterShape[3];
+
+  auto loc = convOp.getLoc();
+
+  SmallVector<int64_t, 4> colTensorShape = {n, ic, fh, fw, oh, ow};
+
+  Value colTensor = rewriter.create<tensor::EmptyOp>(
+      loc, colTensorShape, inputType.getElementType());
+
+  MLIRContext *context = rewriter.getContext();
+
+  AffineExpr nDim, icDim, khDim, kwDim, ohDim, owDim;
+  bindDims(context, nDim, icDim, khDim, kwDim, ohDim, owDim);
+
+  auto shSym = rewriter.getAffineConstantExpr(
+      convOp.getStrides().getValues<int64_t>()[0]);
+  auto swSym = rewriter.getAffineConstantExpr(
+      convOp.getStrides().getValues<int64_t>()[1]);
+
+  SmallVector<AffineExpr, 4> inputExprs = {nDim, icDim, ohDim * shSym + khDim,
+                                           owDim * swSym + kwDim};
+
+  auto nloops = colTensorShape.size();
+
+  auto parallel = utils::IteratorType::parallel;
+  auto reduction = utils::IteratorType::reduction;
+  SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel);
+
+  SmallVector<AffineMap, 4> img2colIndexingMaps = {
+      AffineMap::get(nloops, 0, inputExprs, context),
+      AffineMap::getMultiDimIdentityMap(nloops, context)};
+
+  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
+      loc, colTensor.getType(),
+      /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
+      img2colIterators,
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
+      });
+
+  SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
+  auto reshapedFilterType =
+      RankedTensorType::get({oc, fh * fw * ic}, inputType.getElementType());
+  Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedFilterType, filter, filterReassocIndices);
+
+  SmallVector<ReassociationIndices> img2ColTensorReassocIndices;
+  SmallVector<ReassociationIndices> outputReassocIndices;
+  RankedTensorType reshapedImg2ColTensorType, reshapedOutputType;
+  if (n == 1) {
+    img2ColTensorReassocIndices = {{0, 1, 2, 3}, {4, 5}};
+    outputReassocIndices = {{0, 1}, {2, 3}};
+
+    reshapedImg2ColTensorType = RankedTensorType::get(
+        {fh * fw * ic, oh * ow}, inputType.getElementType());
+    reshapedOutputType =
+        RankedTensorType::get({oc, oh * ow}, outputType.getElementType());
+  } else {
+    img2ColTensorReassocIndices = {{0}, {1, 2, 3}, {4, 5}};
+    outputReassocIndices = {{0}, {1}, {2, 3}};
+
+    reshapedImg2ColTensorType = RankedTensorType::get(
+        {n, fh * fw * ic, oh * ow}, inputType.getElementType());
+    reshapedOutputType =
+        RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType());
+  }
+
+  Value reshapedImg2ColTensor = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
+      img2ColTensorReassocIndices);
+
+  Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedOutputType, output, outputReassocIndices);
+
+  Value result;
+  if (n == 1) {
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, reshapedOutputType,
+        ArrayRef<Value>{reshapedFilter, reshapedImg2ColTensor},
+        ArrayRef<Value>{reshapedOutput});
+    result = matmulOp.getResults().front();
+  } else {
+    // For cases where batch is not 1, we need to keep the batch dimension
+    // separate. Because the filter does not share the same batch dimension,
+    // the batch dimension is only used in indexing the input and output. Thus
+    // we cannot use existing linalg named ops like linalg.batch_matmul.
+    // i.e. M x K * (B x) K x N = (B x) M x N
+    AffineExpr bDim, mDim, nDim, kDim;
+    bindDims(context, bDim, mDim, nDim, kDim);
+    auto lhsMap = AffineMap::get(4, 0, {mDim, kDim}, context);
+    auto rhsMap = AffineMap::get(4, 0, {bDim, kDim, nDim}, context);
+    auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
+    SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
+                                                         parallel, reduction};
+    auto genericOp = rewriter.create<linalg::GenericOp>(
+        loc, reshapedOutputType,
+        /*inputs=*/ValueRange{reshapedFilter, reshapedImg2ColTensor},
+        /*outputs=*/ValueRange{reshapedOutput},
+        ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
+        [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+          Value mul = createMul(loc, args[0], args[1], nestedBuilder);
+          Value add = createAdd(loc, mul, args[2], nestedBuilder);
+          nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
+        });
+    result = genericOp.getResults().front();
+  }
+
+  auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
+      loc, outputType, result, outputReassocIndices);
+
+  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
+
+  return std::make_pair(img2ColTensor.getOperation(),
+                        reshapedResult.getOperation());
+}
+
+namespace {
+
+class ConvertConv2DNhwcHwcf final
+    : public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
+                                PatternRewriter &rewriter) const override {
+    if (failed(rewriteInIm2Col(rewriter, convOp)))
+      return failure();
+    return success();
+  }
+};
+
+class ConvertDepthwiseConv2DNhwcHwc final
+    : public OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> {
+public:
+  using OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,
+                                PatternRewriter &rewriter) const override {
+    if (failed(rewriteInIm2Col(rewriter, convOp)))
+      return failure();
+    return success();
+  }
+};
+
+class ConvertConv2DNchwFchw final
+    : public OpRewritePattern<linalg::Conv2DNchwFchwOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,
+                                PatternRewriter &rewriter) const override {
+    if (failed(rewriteInIm2Col(rewriter, convOp)))
+      return failure();
+    return success();
+  }
+};
+} // end anonymous namespace
+
+void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns) {
+  MLIRContext *context = patterns.getContext();
+  patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
+                  ConvertConv2DNchwFchw>(context);
+}
+} // end namespace linalg
+} // end namespace mlir

diff  --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
new file mode 100644
index 0000000000000..e33e51ddababb
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
@@ -0,0 +1,245 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics | FileCheck %s
+
+// Check that the im2col patterns are properly connected with the
+// transform dialect.
+
+// Non static shapes are not supported.
+// Check that we emit an error.
+// TODO: Hook up the rewriter errors in transform dialect.
+func.func @conv_non_static(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+    // expected-note at below {{when applied to this op}}
+    %0 = linalg.conv_2d_nhwc_hwcf
+      {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+       ins(%arg0, %arg1: tensor<?x?x?x?xf32>, tensor<3x3x4x16xf32>)
+      outs(%arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+    return %0 : tensor<?x?x?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+  // expected-error at below {{failed to apply}}
+  %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
+}
+
+// -----
+
+// Check that we get the proper handles for the img2col tensor producer
+// and the final instruction.
+
+// CHECK: IR printer: tensor_producer
+// CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d3, d2 + d4, d5)>,
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>]
+// CHECK: ^bb0(%[[IN_DATA:.+]]: f32, %[[OUT_DATA:.+]]: f32)
+// CHECK: linalg.yield %[[IN_DATA]] : f32
+
+// CHECK: IR printer: transformed
+// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0, 1, 2], [3]] : tensor<196x16xf32> into tensor<1x14x14x16xf32>
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d3, d2 + d4, d5)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+//      CHECK: @conv_16433136
+//      CHECK: %[[INPUT:.+]]: tensor<1x16x16x4xf32>
+//      CHECK: %[[FILTER:.+]]: tensor<3x3x4x16xf32>
+//      CHECK: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32>
+//      CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x14x14x3x3x4xf32>
+//      CHECK: %[[COL_TENSOR:.+]] = linalg.generic
+//           CHECK-SAME: #[[MAP0]]
+//           CHECK-SAME: #[[MAP1]]
+//                CHECK: ^bb0(%[[IN_DATA:.+]]: f32, %[[OUT_DATA:.+]]: f32)
+//                CHECK: linalg.yield %[[IN_DATA]] : f32
+//      CHECK-DAG: %[[RESHAPED_INIT_COL_TENSOR:.+]] = tensor.collapse_shape %[[COL_TENSOR]]
+//           CHECK-SAME: [0, 1, 2], [3, 4, 5]
+//           CHECK-SAME: tensor<1x14x14x3x3x4xf32> into tensor<196x36xf32>
+//      CHECK-DAG: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]]
+//           CHECK-SAME: [0, 1, 2], [3]
+//           CHECK-SAME: tensor<3x3x4x16xf32> into tensor<36x16xf32>
+//      CHECK-DAG: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]]
+//           CHECK-SAME: [0, 1, 2], [3]
+//      CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_INIT_COL_TENSOR]], %[[RESHAPED_FILTER]] : tensor<196x36xf32>, tensor<36x16xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<196x16xf32>)
+//      CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]] : tensor<196x16xf32> into tensor<1x14x14x16xf32>
+//      CHECK: return %[[RESULT]]
+
+func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+    %0 = linalg.conv_2d_nhwc_hwcf
+      {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+       ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
+      outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+    return %0 : tensor<1x14x14x16xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+  %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
+  transform.print %img2col_tensor_producer {name = "tensor_producer"}: !pdl.operation
+  transform.print %transformed {name = "transformed"}: !pdl.operation
+}
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
+// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1, d2)>
+// CHECK: @depthwise_conv_hwc_114x16x3
+// CHECK-SAME: %[[INPUT:.+]]: tensor<1x114x114x16xf32>
+// CHECK-SAME: %[[FILTER:.+]]: tensor<3x3x16xf32>
+// CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x112x112x16xf32>
+//      CHECK: %[[INPUT_T_INIT:.+]] = tensor.empty() : tensor<1x16x114x114xf32>
+//      CHECK: %[[INPUT_T:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[INPUT]] : tensor<1x114x114x16xf32>) outs(%[[INPUT_T_INIT]] : tensor<1x16x114x114xf32>) {
+// CHECK-NEXT: ^bb0(%[[ARG3:.+]]: f32, %[[ARG4:.+]]: f32):
+// CHECK-NEXT:     linalg.yield %[[ARG3]] : f32
+// CHECK-NEXT:  } -> tensor<1x16x114x114xf32>
+//      CHECK: %[[FILTER_T_INIT:.+]] = tensor.empty() : tensor<16x3x3xf32>
+//      CHECK: %[[FILTER_T:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[FILTER]] : tensor<3x3x16xf32>) outs(%[[FILTER_T_INIT]] : tensor<16x3x3xf32>) {
+// CHECK-NEXT:      ^bb0(%{{.*}}: f32, %{{.*}}: f32):
+//      CHECK:      linalg.yield
+//      CHECK:    } -> tensor<16x3x3xf32>
+//      CHECK: %[[INIT_OUTPUT_TENSOR:.+]] = tensor.empty() : tensor<1x16x112x112xf32>
+//      CHECK: %[[OUTPUT_T:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[OUTPUT]] : tensor<1x112x112x16xf32>) outs(%[[INIT_OUTPUT_TENSOR]] : tensor<1x16x112x112xf32>) {
+// CHECK-NEXT:  ^bb0(%{{.*}}: f32, %{{.*}}: f32):
+// CHECK-NEXT:     linalg.yield
+// CHECK-NEXT:  } -> tensor<1x16x112x112xf32>
+//      CHECK:  %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x16x112x112x3x3xf32>
+//      CHECK: %[[COL_TENSOR:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP4]], #[[MAP5]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:   ins(%[[INPUT_T]] : tensor<1x16x114x114xf32>) outs(%[[INIT_COL_TENSOR]] : tensor<1x16x112x112x3x3xf32>) {
+// CHECK-NEXT:      ^bb0(%{{.*}}: f32, %{{.*}}: f32):
+// CHECK-NEXT:         linalg.yield
+// CHECK-NEXT:    } -> tensor<1x16x112x112x3x3xf32>
+//      CHECK: %[[COL_TENSOR_R:.+]] = tensor.collapse_shape %[[COL_TENSOR]]
+// CHECK-SAME:    tensor<1x16x112x112x3x3xf32> into tensor<16x12544x9xf32>
+//      CHECK: %[[FILTER_T_R:.+]] = tensor.collapse_shape %[[FILTER_T]]
+// CHECK-SAME:    tensor<16x3x3xf32> into tensor<16x9xf32>
+//      CHECK: %[[OUTPUT_T_R:.+]] = tensor.collapse_shape %[[OUTPUT_T]]
+// CHECK-SAME:    tensor<1x16x112x112xf32> into tensor<16x12544xf32>
+//      CHECK: %[[BMV_RESULT:.+]] = linalg.batch_matvec ins(%[[COL_TENSOR_R]], %[[FILTER_T_R]] : tensor<16x12544x9xf32>, tensor<16x9xf32>) outs(%[[OUTPUT_T_R]] : tensor<16x12544xf32>) -> tensor<16x12544xf32>
+//      CHECK: %[[RESULT_R:.+]] = tensor.expand_shape %[[BMV_RESULT]]
+// CHECK-SAME:    tensor<16x12544xf32> into tensor<1x16x112x112xf32>
+//      CHECK: %[[RESULT_INIT:.+]] = tensor.empty() : tensor<1x112x112x16xf32>
+//      CHECK: %[[RESULT:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP6]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[RESULT_R]] : tensor<1x16x112x112xf32>) outs(%[[RESULT_INIT]] : tensor<1x112x112x16xf32>) {
+// CHECK-NEXT:      ^bb0(%{{.*}}: f32, %{{.*}}: f32):
+// CHECK-NEXT:      linalg.yield
+// CHECK-NEXT:    } -> tensor<1x112x112x16xf32>
+//      CHECK: return %[[RESULT]] : tensor<1x112x112x16xf32>
+func.func @depthwise_conv_hwc_114x16x3(%input: tensor<1x114x114x16xf32>, %filter: tensor<3x3x16xf32>, %output: tensor<1x112x112x16xf32>) -> tensor<1x112x112x16xf32> {
+    %0 = linalg.depthwise_conv_2d_nhwc_hwc {
+      dilations = dense<1> : tensor<2xi64>,
+      strides = dense<1> : tensor<2xi64>
+    } ins(%input, %filter : tensor<1x114x114x16xf32>, tensor<3x3x16xf32>) outs(%output : tensor<1x112x112x16xf32>) -> tensor<1x112x112x16xf32>
+    return %0 : tensor<1x112x112x16xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.depthwise_conv_2d_nhwc_hwc"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+  %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
+}
+
+// -----
+
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d3, d2 + d4, d5)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+//  CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+//  CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+//  CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+//      CHECK: func.func @batch_nhwc_conv
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<8x16x16x4xf32>, %[[FILTER:.+]]: tensor<3x3x4x16xf32>, %[[INIT:.+]]: tensor<8x14x14x16xf32>)
+//      CHECK:   %[[IT:.+]] = tensor.empty() : tensor<8x14x14x3x3x4xf32>
+//      CHECK:   %[[IMG2COL:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:   ins(%[[INPUT]] : tensor<8x16x16x4xf32>)
+// CHECK-SAME:   outs(%[[IT]] : tensor<8x14x14x3x3x4xf32>)
+//      CHECK:   %[[CS_INPUT:.+]] = tensor.collapse_shape %[[IMG2COL]] {{\[}}[0], [1, 2], [3, 4, 5]] : tensor<8x14x14x3x3x4xf32> into tensor<8x196x36xf32>
+//      CHECK:   %[[CS_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<3x3x4x16xf32> into tensor<36x16xf32>
+//      CHECK:   %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2], [3]] : tensor<8x14x14x16xf32> into tensor<8x196x16xf32>
+//      CHECK:   %[[MATMUL:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]],
+// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:   ins(%[[CS_INPUT]], %[[CS_FILTER]] : tensor<8x196x36xf32>, tensor<36x16xf32>)
+// CHECK-SAME:   outs(%[[CS_RESULT]] : tensor<8x196x16xf32>)
+//      CHECK:   ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32):
+//      CHECK:     %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
+//      CHECK:     %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
+//      CHECK:     linalg.yield %[[ADD]] : f32
+//      CHECK:   } -> tensor<8x196x16xf32>
+//      CHECK:   %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1, 2], [3]] : tensor<8x196x16xf32> into tensor<8x14x14x16xf32>
+//      CHECK:   return %[[CS_FINAL]]
+func.func @batch_nhwc_conv(%arg0: tensor<8x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<8x14x14x16xf32>) -> tensor<8x14x14x16xf32> {
+    %0 = linalg.conv_2d_nhwc_hwcf
+      {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+       ins(%arg0, %arg1: tensor<8x16x16x4xf32>, tensor<3x3x4x16xf32>)
+      outs(%arg2: tensor<8x14x14x16xf32>) -> tensor<8x14x14x16xf32>
+    return %0 : tensor<8x14x14x16xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+  %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
+}
+
+// -----
+
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4 + d2, d5 + d3)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+//  CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+//  CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+//  CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+//      CHECK: func.func @batch_nchw_conv
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<8x4x16x16xf32>, %[[FILTER:.+]]: tensor<16x4x3x3xf32>, %[[INIT:.+]]: tensor<8x16x14x14xf32>)
+//      CHECK:   %[[IT:.+]] = tensor.empty() : tensor<8x4x3x3x14x14xf32>
+//      CHECK:   %[[IMG2COL:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:   ins(%[[INPUT]] : tensor<8x4x16x16xf32>)
+// CHECK-SAME:   outs(%[[IT]] : tensor<8x4x3x3x14x14xf32>)
+//      CHECK:   %[[CS_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<16x4x3x3xf32> into tensor<16x36xf32>
+//      CHECK:   %[[CS_INPUT:.+]] = tensor.collapse_shape %[[IMG2COL]] {{\[}}[0], [1, 2, 3], [4, 5]] : tensor<8x4x3x3x14x14xf32> into tensor<8x36x196xf32>
+//      CHECK:   %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1], [2, 3]] : tensor<8x16x14x14xf32> into tensor<8x16x196xf32>
+//      CHECK:   %[[MATMUL:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]],
+// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:   ins(%[[CS_FILTER]], %[[CS_INPUT]] : tensor<16x36xf32>, tensor<8x36x196xf32>)
+// CHECK-SAME:   outs(%[[CS_RESULT]] : tensor<8x16x196xf32>)
+//      CHECK:   ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32):
+//      CHECK:     %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
+//      CHECK:     %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
+//      CHECK:     linalg.yield %[[ADD]] : f32
+//      CHECK:   } -> tensor<8x16x196xf32>
+//      CHECK:   %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1], [2, 3]] : tensor<8x16x196xf32> into tensor<8x16x14x14xf32>
+//      CHECK:   return %[[CS_FINAL]]
+func.func @batch_nchw_conv(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32>, %arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> {
+    %0 = linalg.conv_2d_nchw_fchw
+      {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+       ins(%arg0, %arg1: tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32>)
+      outs(%arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32>
+    return %0 : tensor<8x16x14x14xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+  %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
+}


        


More information about the Mlir-commits mailing list