Adds the Img2Col transformation for the fhwc channel ordering in a
Conv2D. Because of how the channel ordering affects the matrix
dimensions in the flattened filter this results in a slightly different
implementation of the actual "matrix multiplication". Instead of doing a
regular row-column dot-product this arrangement requires a row-row dot
product, otherwise the filter matrix would first need to be transposed.

Adds a lit test to the transform dialect to check the semantics of the
optimization are correct.

 .../Dialect/Linalg/Transforms/Transforms.h    |   8 +
 .../TransformOps/LinalgTransformOps.cpp       |   3 +
 .../Transforms/ConvertConv2DToImg2Col.cpp     | 147 ++++++++++++++++++
 .../Linalg/convert-conv2d-to-img2col.mlir     |  70 +++++++++
 4 files changed, 228 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 07a192f7b8606d3..3597209d7f90c25 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1175,6 +1175,14 @@ FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
 FailureOr<std::pair<Operation *, Operation *>>
 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp);
+/// Same as the above but for Fhwc channel orderings in the filter. In this case
+/// the matrix multiplication is actually a row-wise dot-product rather than a
+/// row-column dot-product. This is to avoid transposing the filter matrix which
+/// would be required for a regular matrix multiplication to produce the correct
+/// output dimensions.
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp);
 /// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except there is no
 /// reduction among the input channels so each convolution can be a
 /// matrix-vector product and by transposing both input filter so channels are
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9ce780d3d249cfb..8508507871d0c6c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3118,6 +3118,9 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
           .Case([&](linalg::Conv2DNhwcHwcfOp op) {
             return rewriteInIm2Col(rewriter, op);
+          .Case([&](linalg::Conv2DNhwcFhwcOp op) {
+            return rewriteInIm2Col(rewriter, op);
+          })
           .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
             return rewriteInIm2Col(rewriter, op);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index 275e78aaa73dde6..848a7b5e7fc52c9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -494,6 +494,140 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
+  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
+  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
+  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
+  if (!filterType.hasStaticShape())
+    return rewriter.notifyMatchFailure(
+        convOp, "expected a static shape for the filter");
+  if (!inputType.hasStaticShape())
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected a static shape for the input");
+  // TODO: Support dilation.
+  if (!hasAllOneValues(convOp.getDilations()))
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected all ones for dilations");
+  MLIRContext *context = rewriter.getContext();
+  Value input = convOp.getInputs()[0];
+  Value filter = convOp.getInputs()[1];
+  Value output = convOp.getOutputs()[0];
+  ArrayRef<int64_t> filterShape = filterType.getShape();
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+  int64_t n = outputShape[0];
+  int64_t oh = outputShape[1];
+  int64_t ow = outputShape[2];
+  int64_t oc = outputShape[3];
+  int64_t fh = filterShape[1];
+  int64_t fw = filterShape[2];
+  int64_t ic = filterShape[3];
+  Location loc = convOp.getLoc();
+  // Reshape output and filter to the LHS and result of a (B)MNK matmul.
+  SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
+  auto reshapedFilterType =
+      RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType());
+  Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedFilterType, filter, filterReassocIndices);
+  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
+  RankedTensorType reshapedOutputType =
+      RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
+  Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedOutputType, output, outputReassocIndices);
+  SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
+  Value colTensor = rewriter.create<tensor::EmptyOp>(
+      loc, colTensorShape, inputType.getElementType());
+  // Convert the input to a (BMK) column tensor.
+  auto nloops = colTensorShape.size();
+  auto parallel = utils::IteratorType::parallel;
+  auto reduction = utils::IteratorType::reduction;
+  SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
+  SmallVector<AffineMap> img2colIndexingMaps = {
+      AffineMap::getMultiDimIdentityMap(nloops, context)};
+  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
+      loc, colTensor.getType(),
+      /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
+      img2colIterators,
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        // Get the iterators named based on the matmul (batch, m, k).
+        Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
+        Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
+        Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
+        // Recover the original iteration indices from the problem/input sizes.
+        SmallVector<Value> mIndices = unrollIndex(
+            nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
+        auto ohIndex = mIndices[0];
+        auto owIndex = mIndices[1];
+        SmallVector<Value> kIndices = unrollIndex(
+            nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
+        auto fhIndex = kIndices[0];
+        auto fwIndex = kIndices[1];
+        auto icIndex = kIndices[2];
+        // Extract the input element corresponding to the expanded indices.
+        Value hIndex =
+            getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
+                              convOp.getStrides().getValues<int64_t>()[0]);
+        Value wIndex =
+            getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
+                              convOp.getStrides().getValues<int64_t>()[1]);
+        // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
+        SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
+        Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
+            loc, input, extractionIndices);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
+      });
+  // Because we didn't transpose the filters we don't actually have a batched
+  // matrix multiply. Instead, we have an operation consisting of "rowise" dot
+  // products.
+  AffineExpr bDim, mDim, nDim, kDim;
+  bindDims(context, bDim, mDim, nDim, kDim);
+  auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
+  auto rhsMap = AffineMap::get(4, 0, {nDim, kDim}, context);
+  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
+  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
+                                                       parallel, reduction};
+  auto genericOp = rewriter.create<linalg::GenericOp>(
+      loc, reshapedOutputType,
+      /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
+      /*outputs=*/ValueRange{reshapedOutput},
+      ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        Value mul =
+            createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
+        Value add = createAdd(loc, mul, args[2], nestedBuilder);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
+      });
+  Value result = genericOp.getResults().front();
+  auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
+      loc, outputType, result, outputReassocIndices);
+  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
+  return std::make_pair(img2ColTensor.getOperation(),
+                        reshapedResult.getOperation());
 namespace {
 class ConvertConv2DNhwcHwcf final
@@ -534,6 +668,19 @@ class ConvertConv2DNchwFchw final
     return success();
+class ConvertConv2DNhwcFhwc final
+    : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
+                                PatternRewriter &rewriter) const override {
+    if (failed(rewriteInIm2Col(rewriter, convOp)))
+      return failure();
+    return success();
+  }
 } // end anonymous namespace
 void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns) {
diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
index 657cf83f25460fd..b2470ed7b748042 100644
--- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
+++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
@@ -279,6 +279,76 @@ transform.sequence failures(propagate) {
 // -----
+// CHECK: IR printer: tensor_producer
+// CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic
+// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
+// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
+// Collapsed indices.
+// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index
+// CHECK: %[[MINDEX:.+]] = linalg.index 1 : index
+// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index
+// Compute input channel/convolved indices.
+// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<(d0) -> (d0 mod 4)>(%[[KINDEX]])
+// CHECK: %[[CONVH:.+]] = affine.apply affine_map<(d0, d1) -> (d0 floordiv 14 + d1 floordiv 12)>(%[[MINDEX]], %[[KINDEX]])
+// CHECK: %[[CONVW:.+]] = affine.apply affine_map<(d0, d1) -> (d0 mod 14 + (d1 mod 12) floordiv 4)>(%[[MINDEX]], %[[KINDEX]])
+// Extract from the input tensor.
+// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
+// CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32>
+// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
+// CHECK: IR printer: transformed
+// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+//      CHECK: @conv_2d_nhwc_fhwc
+//      CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32>
+//      CHECK-SAME: %[[FILTER:.+]]: tensor<16x3x3x4xf32>
+//      CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32>
+//  CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<16x3x3x4xf32> into tensor<16x36xf32>
+//  CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
+//      CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
+//      CHECK: %[[COL_TENSOR:.+]] = linalg.generic
+//           CHECK-SAME: #[[MAP0]]
+//                CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
+//                CHECK: linalg.yield %{{.+}} : f32
+//      CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
+//           CHECK-SAME: #[[MAP1]]
+//           CHECK-SAME: #[[MAP2]]
+//           CHECK-SAME: #[[MAP3]]
+//           CHECK-SAME: ins(%[[COL_TENSOR]], %[[COLLAPSED_FILTER]] : tensor<1x196x36xf32>, tensor<16x36xf32>)
+//           CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xf32>)
+//                CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)
+//                CHECK:     %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
+//                CHECK:     %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
+//                CHECK:     linalg.yield %[[ADD]] : f32
+//                CHECK: } -> tensor<1x196x16xf32>
+//      CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
+//      CHECK: return %[[RESULT]]
+func.func @conv_2d_nhwc_fhwc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+    %0 = linalg.conv_2d_nhwc_fhwc
+      {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+       ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>)
+      outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+    return %0 : tensor<1x14x14x16xf32>
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+  transform.print %img2col_tensor_producer {name = "tensor_producer"}: !transform.any_op
+  transform.print %transformed {name = "transformed"}: !transform.any_op
+// -----
 // Check for signed extend when the input type is smaller than the accumulator type.
 // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>

