[Mlir-commits] [mlir] 17b2e73 - [mlir][LinAlg][Transform] Add a transform op for conv2d to im2col
Thomas Raoux
llvmlistbot at llvm.org
Thu Feb 23 14:29:31 PST 2023
Author: Quentin Colombet
Date: 2023-02-23T22:27:16Z
New Revision: 17b2e73cb477d42771fbc68a215dde648f3eaaef
URL: https://github.com/llvm/llvm-project/commit/17b2e73cb477d42771fbc68a215dde648f3eaaef
DIFF: https://github.com/llvm/llvm-project/commit/17b2e73cb477d42771fbc68a215dde648f3eaaef.diff
LOG: [mlir][LinAlg][Transform] Add a transform op for conv2d to im2col
This patch adds patterns to convert `linalg.conv_2d_xxx` operations
into `linalg.generic` (for img2col packing) and `linalg.matmul`.
The meat of the patch comes straight from IREE
(https://github.com/iree-org/iree).
(To the original authors are you okay with that?)
What this patch adds is proper plumbing of the im2col patterns into the
transform dialect.
PS: Feel free to add more reviewers. I wanted to cover the original contributors of im2col in IREE but I'm not sure I got all of them.
Reviewed By: nicolasvasilache, ThomasRaoux
Differential Revision: https://reviews.llvm.org/D144108
Added:
mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 11f3b3c634fdf..c53497839ea9a 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1740,6 +1740,82 @@ def HoistRedundantVectorTransfersOp :
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::func::FuncOp target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// ConvertConv2DToImg2ColOp
+//===----------------------------------------------------------------------===//
+
+def ConvertConv2DToImg2ColOp : Op<Transform_Dialect,
+ "structured.convert_conv2d_to_img2col",
+ [FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ TransformOpInterface,
+ TransformEachOpTrait]> {
+ let description = [{
+ Convert linalg.conv_2d_xxx into linalg.generic (for img2col packing)
+ and linalg.matmul.
+
+ A convolution operation can be written as a matrix-matrix multiplication by
+ unfolding the cross-correlation between input and filter and explicitly copy
+ overlapped sliding window inputs.
+
+ Consider 2D input X with single channel input and output and 2x2 filter W:
+ ```
+ [x(0, 0) , x(0, 1) , ..., x(0, n) ]
+ [x(1, 0) , x(1, 1) , ..., x(1, n) ]
+ [. , . ,. , . ] [w(0, 0), w(0, 1)]
+ [. , . , . , . ] (conv) [w(1, 0), w(1, 1)]
+ [. , . , ., . ]
+ [x(n-1, 0), x(n-1, 1), ..., x(n-1, n-1)]
+ ```
+
+ The packed input data (img2col) is a matrix with |rows| = output spatial
+ size, |columns| = filter spatial size. To compute the output Y(i, j) we need
+ to calculate the dot product between filter window at input X(x, y)) and the
+ filter which will look like the following where r.h.s is the img2col matrix
+ and l.h.s is the flattned filter:
+ ```
+ [x(0,0), x(0,1), x(1,0), x(1,1)]
+ [x(0,1), x(1,1), x(0,2), x(1,2)] (matmul) [w(0,0), w(0,1), w(1,0), w(1,1)]
+ [x(0,1), x(1,1), x(0,2), x(1,2)]
+ [ . , . , . , . ]
+ ```
+
+ In general for 2D case with (N, H, W, C) input and (Kh, Kw, C, D) filter
+ and output (N, Ho, Wo, D) the convolution is the following matrix-matrix
+ multiplication (Ho x Wo, Kh x Kw x C) * (Kh x Kw x C, D) for each input in
+ the N input. For the case where N > 1 its a batched matrxi-matrix
+ multplication.
+
+ Returns two handles:
+ - One on the operation that produces the img2col tensor.
+ - One on the final operation of the sequence that replaces the original
+ convolution.
+
+ #### Return modes:
+
+ Returns a definite failure if target is not isolated from above.
+ Returns a silenceable failure if the pattern application failed.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$img2col_tensor,
+ TransformHandleTypeInterface:$transformed);
+
+ let assemblyFormat =
+ "$target attr-dict `:` functional-type($target, results)";
+
+ let builders = [
+ OpBuilder<(ins "Value":$target)>
+ ];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::linalg::LinalgOp target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index c782cab1e9f94..8e77b54251aa7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -879,6 +879,64 @@ void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Populates patterns to transform linalg.conv_2d_xxx operations into
+/// linalg.generic (for img2col packing) and linalg.matmul.
+/// \see rewriteInIm2Col for more details.
+void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns);
+
+/// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing)
+/// and linalg.matmul.
+///
+/// A convolution operation can be written as a matrix-matrix multiplication by
+/// unfolding the cross-correlation between input and filter and explicitly copy
+/// overlapped sliding window inputs.
+///
+/// Consider 2D input X with single channel input and output and 2x2 filter W:
+/// [x(0, 0) , x(0, 1) , ..., x(0, n) ]
+/// [x(1, 0) , x(1, 1) , ..., x(1, n) ]
+/// [. , . ,. , . ] [w(0, 0), w(0, 1)]
+/// [. , . , . , . ] (conv) [w(1, 0), w(1, 1)]
+/// [. , . , ., . ]
+/// [x(n-1, 0), x(n-1, 1), ..., x(n-1, n-1)]
+///
+/// The packed input data (img2col) is a matrix with |rows| = output spatial
+/// size, |columns| = filter spatial size. To compute the output Y(i, j) we need
+/// to calculate the dot product between filter window at input X(x, y)) and the
+/// filter which will look like the following where r.h.s is the img2col matrix
+/// and l.h.s is the flattned filter:
+///
+/// [x(0,0), x(0,1), x(1,0), x(1,1)]
+/// [x(0,1), x(1,1), x(0,2), x(1,2)] (matmul) [w(0,0), w(0,1), w(1,0), w(1,1)]
+/// [x(0,1), x(1,1), x(0,2), x(1,2)]
+/// [ . , . , . , . ]
+///
+/// In general for 2D case with (N, H, W, C) input and (Kh, Kw, C, D) filter
+/// and output (N, Ho, Wo, D) the convolution is the following matrix-matrix
+/// multiplication (Ho x Wo, Kh x Kw x C) * (Kh x Kw x C, D) for each input in
+/// the N input. For the case where N > 1 its a batched matrxi-matrix
+/// multplication.
+///
+/// On success, return both the operation that produces the img2col tensor and
+/// the final operation of the sequence that replaces the original convolution.
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp);
+
+/// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except there is no
+/// reduction among the input channels so each convolution can be a
+/// matrix-vector product and by transposing both input filter so channels are
+/// outer most the computation is a batched matrix-vector product.
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter,
+ linalg::DepthwiseConv2DNhwcHwcOp convOp);
+
+/// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except because the
+/// channels are to the left of the image shape dimensions, the position of the
+/// contraction dimension in the resulting matmul is reversed. This swaps the
+/// LHS and RHS of the matmul when compared with nhwc (i.e. (D, C x Kh x Kw) *
+/// (C x Kh x Kw, Ho x Wo))
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp);
+
//===----------------------------------------------------------------------===//
// Op-specific patterns.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index aae073c28db8d..2d102383ddbe0 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -34,6 +34,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
@@ -3072,6 +3073,40 @@ transform::HoistRedundantVectorTransfersOp::applyToOne(
results.push_back(target);
return DiagnosedSilenceableFailure::success();
}
+
+//===----------------------------------------------------------------------===//
+// ConvertConv2DToImg2ColOp.
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
+ linalg::LinalgOp target, transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ IRRewriter rewriter(target->getContext());
+ rewriter.setInsertionPoint(target);
+ auto maybeTransformed =
+ TypeSwitch<Operation *, FailureOr<std::pair<Operation *, Operation *>>>(
+ target)
+ .Case([&](linalg::Conv2DNhwcHwcfOp op) {
+ return rewriteInIm2Col(rewriter, op);
+ })
+ .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
+ return rewriteInIm2Col(rewriter, op);
+ })
+ .Case([&](linalg::Conv2DNchwFchwOp op) {
+ return rewriteInIm2Col(rewriter, op);
+ })
+ .Default([&](Operation *op) {
+ return rewriter.notifyMatchFailure(op, "not supported");
+ });
+ if (failed(maybeTransformed))
+ return emitDefaultSilenceableFailure(target);
+ // Handle to the operation producing the img2col tensor.
+ results.push_back(maybeTransformed->first);
+ // Handle to the operation that replaces the original convolution.
+ results.push_back(maybeTransformed->second);
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 7ede7e4f96915..adcc87f42dab2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
Bufferize.cpp
ConstantFold.cpp
ConvertToDestinationStyle.cpp
+ ConvertConv2DToImg2Col.cpp
DataLayoutPropagation.cpp
DecomposeLinalgOps.cpp
Detensorize.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
new file mode 100644
index 0000000000000..14bff411ef8c1
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -0,0 +1,540 @@
+//===- ConvertConv2DToImg2Col.cpp - im2col implementation -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include <utility>
+
+namespace mlir {
+namespace linalg {
+static bool hasAllOneValues(DenseIntElementsAttr attr) {
+ return llvm::all_of(
+ attr, [](APInt element) { return element.getSExtValue() == 1; });
+}
+
+static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) {
+ bool isInt = x.getType().isa<IntegerType>();
+ if (isInt)
+ return builder.create<arith::AddIOp>(loc, x, y);
+ return builder.create<arith::AddFOp>(loc, x, y);
+}
+
+static Value createMul(Location loc, Value x, Value y, OpBuilder &builder) {
+ bool isInt = x.getType().isa<IntegerType>();
+ if (isInt)
+ return builder.create<arith::MulIOp>(loc, x, y);
+ return builder.create<arith::MulFOp>(loc, x, y);
+}
+
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
+ auto inputType = convOp.getInputs()[0].getType().cast<ShapedType>();
+ auto filterType = convOp.getInputs()[1].getType().cast<ShapedType>();
+ auto outputType = convOp.getOutputs()[0].getType().cast<ShapedType>();
+
+ if (!filterType.hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ convOp, "expected a static shape for the filter");
+
+ if (!inputType.hasStaticShape())
+ return rewriter.notifyMatchFailure(convOp,
+ "expected a static shape for the input");
+
+ // TODO: Support dilation.
+ if (!hasAllOneValues(convOp.getDilations()))
+ return rewriter.notifyMatchFailure(convOp,
+ "expected all ones for dilations");
+
+ MLIRContext *context = rewriter.getContext();
+ Value input = convOp.getInputs()[0];
+ Value filter = convOp.getInputs()[1];
+ Value output = convOp.getOutputs()[0];
+
+ ArrayRef<int64_t> filterShape = filterType.getShape();
+ ArrayRef<int64_t> outputShape = outputType.getShape();
+
+ int n = outputShape[0];
+ int oh = outputShape[1];
+ int ow = outputShape[2];
+ int oc = outputShape[3];
+ int fh = filterShape[0];
+ int fw = filterShape[1];
+ int ic = filterShape[2];
+
+ Location loc = convOp.getLoc();
+
+ SmallVector<int64_t> colTensorShape = {n, oh, ow, fh, fw, ic};
+
+ Value colTensor = rewriter.create<tensor::EmptyOp>(
+ loc, colTensorShape, inputType.getElementType());
+
+ AffineExpr nDim, ohDim, owDim, khDim, kwDim, icDim;
+ bindDims(context, nDim, ohDim, owDim, khDim, kwDim, icDim);
+
+ AffineExpr shSym = rewriter.getAffineConstantExpr(
+ convOp.getStrides().getValues<int64_t>()[0]);
+ AffineExpr swSym = rewriter.getAffineConstantExpr(
+ convOp.getStrides().getValues<int64_t>()[1]);
+
+ SmallVector<AffineExpr> inputExprs = {nDim, ohDim * shSym + khDim,
+ owDim * swSym + kwDim, icDim};
+
+ auto nloops = colTensorShape.size();
+
+ auto parallel = utils::IteratorType::parallel;
+ auto reduction = utils::IteratorType::reduction;
+ SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
+
+ SmallVector<AffineMap> img2colIndexingMaps = {
+ AffineMap::get(nloops, 0, inputExprs, context),
+ AffineMap::getMultiDimIdentityMap(nloops, context)};
+
+ auto img2ColTensor = rewriter.create<linalg::GenericOp>(
+ loc, colTensor.getType(),
+ /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
+ img2colIterators,
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
+ });
+
+ SmallVector<ReassociationIndices> img2ColTensorReassocIndices;
+ SmallVector<ReassociationIndices> outputReassocIndices;
+ RankedTensorType reshapedImg2ColTensorType, reshapedOutputType;
+ if (n == 1) {
+ img2ColTensorReassocIndices = {{0, 1, 2}, {3, 4, 5}};
+ outputReassocIndices = {{0, 1, 2}, {3}};
+
+ reshapedImg2ColTensorType = RankedTensorType::get(
+ {oh * ow, fh * fw * ic}, inputType.getElementType());
+ reshapedOutputType =
+ RankedTensorType::get({oh * ow, oc}, outputType.getElementType());
+ } else {
+ img2ColTensorReassocIndices = {{0}, {1, 2}, {3, 4, 5}};
+ outputReassocIndices = {{0}, {1, 2}, {3}};
+
+ reshapedImg2ColTensorType = RankedTensorType::get(
+ {n, oh * ow, fh * fw * ic}, inputType.getElementType());
+ reshapedOutputType =
+ RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
+ }
+
+ SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
+ auto reshapedFilterType =
+ RankedTensorType::get({fh * fw * ic, oc}, inputType.getElementType());
+
+ Value reshapedImg2ColTensor = rewriter.create<tensor::CollapseShapeOp>(
+ loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
+ img2ColTensorReassocIndices);
+
+ Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
+ loc, reshapedFilterType, filter, filterReassocIndices);
+
+ Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
+ loc, reshapedOutputType, output, outputReassocIndices);
+
+ Value result;
+ if (n == 1) {
+ auto matmulOp = rewriter.create<linalg::MatmulOp>(
+ loc, reshapedOutputType,
+ ArrayRef<Value>{reshapedImg2ColTensor, reshapedFilter},
+ ArrayRef<Value>{reshapedOutput});
+ result = matmulOp.getResults().front();
+ } else {
+ // For cases where batch is not 1, we need to keep the batch dimension
+ // separate. Because the filter does not share the same batch dimension,
+ // the batch dimension is only used in indexing the input and output. Thus
+ // we cannot use existing linalg named ops like linalg.batch_matmul.
+ // i.e. (B x) M x K * K x N = (B x) M x N
+ AffineExpr bDim, mDim, nDim, kDim;
+ bindDims(context, bDim, mDim, nDim, kDim);
+ auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
+ auto rhsMap = AffineMap::get(4, 0, {kDim, nDim}, context);
+ auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
+ SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
+ parallel, reduction};
+
+ auto genericOp = rewriter.create<linalg::GenericOp>(
+ loc, reshapedOutputType,
+ /*inputs=*/ValueRange{reshapedImg2ColTensor, reshapedFilter},
+ /*outputs=*/ValueRange{reshapedOutput},
+ ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ Value mul = createMul(loc, args[0], args[1], nestedBuilder);
+ Value add = createAdd(loc, mul, args[2], nestedBuilder);
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
+ });
+ result = genericOp.getResults().front();
+ }
+
+ auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
+ loc, outputType, result, outputReassocIndices);
+
+ rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
+
+ return std::make_pair(img2ColTensor.getOperation(),
+ reshapedResult.getOperation());
+}
+
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter,
+ linalg::DepthwiseConv2DNhwcHwcOp convOp) {
+ auto inputType = convOp.getInputs()[0].getType().cast<RankedTensorType>();
+ auto filterType = convOp.getInputs()[1].getType().cast<RankedTensorType>();
+ auto outputType = convOp.getOutputs()[0].getType().cast<RankedTensorType>();
+
+ if (!filterType.hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ convOp, "expected a static shape for the filter");
+
+ if (!inputType.hasStaticShape())
+ return rewriter.notifyMatchFailure(convOp,
+ "expected a static shape for the input");
+
+ // TODO: Support dilation.
+ if (!hasAllOneValues(convOp.getDilations()))
+ return rewriter.notifyMatchFailure(convOp,
+ "expected all ones for dilations");
+
+ Location loc = convOp.getLoc();
+
+ auto transposeOperand = [&](Value operand, ArrayRef<int64_t> indices) {
+ auto operandTensorType = operand.getType().cast<RankedTensorType>();
+ auto nloops = indices.size();
+ ArrayRef<int64_t> inputShape = operandTensorType.getShape();
+
+ SmallVector<AffineExpr> exprs = llvm::to_vector<4>(
+ llvm::map_range(indices, [&](int64_t index) -> AffineExpr {
+ return rewriter.getAffineDimExpr(index);
+ }));
+
+ SmallVector<int64_t> targetShape = llvm::to_vector<4>(llvm::map_range(
+ indices, [&](int64_t index) -> int64_t { return inputShape[index]; }));
+
+ Value outputTensor = rewriter.create<tensor::EmptyOp>(
+ loc, targetShape, operandTensorType.getElementType());
+
+ SmallVector<utils::IteratorType> loopAttributeTypes(
+ nloops, utils::IteratorType::parallel);
+
+ SmallVector<AffineMap> indexingMaps = {
+ inversePermutation(
+ AffineMap::get(nloops, 0, exprs, rewriter.getContext())),
+ AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};
+
+ auto transposedOp = rewriter.create<linalg::GenericOp>(
+ loc, outputTensor.getType(),
+ /*inputs=*/operand, /*outputs=*/outputTensor, indexingMaps,
+ loopAttributeTypes,
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
+ });
+
+ return transposedOp.getResult(0);
+ };
+
+ Value input = convOp.getInputs()[0];
+ Value filter = convOp.getInputs()[1];
+ Value output = convOp.getOutputs()[0];
+
+ // Transpose input, filter so channels are outermost
+ Value inputT = transposeOperand(input, {0, 3, 1, 2});
+ Value filterT = transposeOperand(filter, {2, 0, 1});
+ ArrayRef<int64_t> filterTShape =
+ filterT.getType().cast<RankedTensorType>().getShape();
+ ArrayRef<int64_t> outputShape = outputType.getShape();
+
+ int n = outputShape[0];
+ int oh = outputShape[1];
+ int ow = outputShape[2];
+ int c = outputShape[3];
+ int fh = filterTShape[1];
+ int fw = filterTShape[2];
+
+ SmallVector<int64_t> colTensorShape = {n, c, oh, ow, fh, fw};
+ Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2});
+
+ AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim;
+ bindDims(rewriter.getContext(), nDim, cDim, ohDim, owDim, khDim, kwDim);
+
+ AffineExpr shSym = rewriter.getAffineConstantExpr(
+ convOp.getStrides().getValues<int64_t>()[0]);
+ AffineExpr swSym = rewriter.getAffineConstantExpr(
+ convOp.getStrides().getValues<int64_t>()[1]);
+
+ SmallVector<AffineExpr> inputExprs = {nDim, cDim, ohDim * shSym + khDim,
+ owDim * swSym + kwDim};
+
+ auto nloops = colTensorShape.size();
+
+ SmallVector<utils::IteratorType> loopAttributeTypes(
+ nloops, utils::IteratorType::parallel);
+
+ SmallVector<AffineMap> indexingMaps = {
+ AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()),
+ AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};
+
+ Value colTensor = rewriter.create<tensor::EmptyOp>(
+ loc, colTensorShape, inputType.getElementType());
+
+ auto img2ColTensor = rewriter.create<linalg::GenericOp>(
+ loc, colTensor.getType(),
+ /*inputs=*/inputT, /*outputs=*/colTensor, indexingMaps,
+ loopAttributeTypes,
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
+ });
+
+ SmallVector<ReassociationIndices> img2ColTensorReassocIndices = {
+ {0, 1}, {2, 3}, {4, 5}};
+ SmallVector<ReassociationIndices> filterReassociationIndice = {{0}, {1, 2}};
+ SmallVector<ReassociationIndices> outputReassociationIndice = {{0, 1},
+ {2, 3}};
+
+ auto reshapedImg2ColTensorType = RankedTensorType::get(
+ {n * c, oh * ow, fh * fw}, inputType.getElementType());
+ auto reshapedFilterTensorType =
+ RankedTensorType::get({c, fh * fw}, filterType.getElementType());
+ auto reshapedOutputTensorType =
+ RankedTensorType::get({n * c, oh * ow}, outputType.getElementType());
+
+ Value reshapedImg2ColTensor = rewriter.create<tensor::CollapseShapeOp>(
+ loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
+ img2ColTensorReassocIndices);
+ Value reshapedFilterTensor = rewriter.create<tensor::CollapseShapeOp>(
+ loc, reshapedFilterTensorType, filterT, filterReassociationIndice);
+ Value reshapedoutputTensor = rewriter.create<tensor::CollapseShapeOp>(
+ loc, reshapedOutputTensorType, transposedOutputTensor,
+ outputReassociationIndice);
+
+ auto batchMatVecResult = rewriter.create<linalg::BatchMatvecOp>(
+ loc, TypeRange{reshapedoutputTensor.getType()},
+ ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},
+ ValueRange{reshapedoutputTensor});
+
+ SmallVector<ReassociationIndices> batchMatVecReassociationIndice = {{0, 1},
+ {2, 3}};
+
+ Value batchMatVecResultReshaped = rewriter.create<tensor::ExpandShapeOp>(
+ loc, transposedOutputTensor.getType(), batchMatVecResult.getResult(0),
+ batchMatVecReassociationIndice);
+
+ Value transposedResult =
+ transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});
+
+ rewriter.replaceOp(convOp, ArrayRef<Value>{transposedResult});
+ return std::make_pair(img2ColTensor.getOperation(),
+ transposedResult.getDefiningOp());
+}
+
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
+ auto inputType = convOp.getInputs()[0].getType().cast<ShapedType>();
+ auto filterType = convOp.getInputs()[1].getType().cast<ShapedType>();
+ auto outputType = convOp.getOutputs()[0].getType().cast<ShapedType>();
+
+ if (!filterType.hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ convOp, "expected a static shape for the filter");
+
+ if (!inputType.hasStaticShape())
+ return rewriter.notifyMatchFailure(convOp,
+ "expected a static shape for the input");
+
+ // TODO: Support dilation.
+ if (!hasAllOneValues(convOp.getDilations()))
+ return rewriter.notifyMatchFailure(convOp,
+ "expected all ones for dilations");
+
+ Value input = convOp.getInputs()[0];
+ Value filter = convOp.getInputs()[1];
+ Value output = convOp.getOutputs()[0];
+
+ auto filterShape = filterType.getShape();
+ auto outputShape = outputType.getShape();
+
+ int n = outputShape[0];
+ int oc = outputShape[1];
+ int oh = outputShape[2];
+ int ow = outputShape[3];
+ int ic = filterShape[1];
+ int fh = filterShape[2];
+ int fw = filterShape[3];
+
+ auto loc = convOp.getLoc();
+
+ SmallVector<int64_t, 4> colTensorShape = {n, ic, fh, fw, oh, ow};
+
+ Value colTensor = rewriter.create<tensor::EmptyOp>(
+ loc, colTensorShape, inputType.getElementType());
+
+ MLIRContext *context = rewriter.getContext();
+
+ AffineExpr nDim, icDim, khDim, kwDim, ohDim, owDim;
+ bindDims(context, nDim, icDim, khDim, kwDim, ohDim, owDim);
+
+ auto shSym = rewriter.getAffineConstantExpr(
+ convOp.getStrides().getValues<int64_t>()[0]);
+ auto swSym = rewriter.getAffineConstantExpr(
+ convOp.getStrides().getValues<int64_t>()[1]);
+
+ SmallVector<AffineExpr, 4> inputExprs = {nDim, icDim, ohDim * shSym + khDim,
+ owDim * swSym + kwDim};
+
+ auto nloops = colTensorShape.size();
+
+ auto parallel = utils::IteratorType::parallel;
+ auto reduction = utils::IteratorType::reduction;
+ SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel);
+
+ SmallVector<AffineMap, 4> img2colIndexingMaps = {
+ AffineMap::get(nloops, 0, inputExprs, context),
+ AffineMap::getMultiDimIdentityMap(nloops, context)};
+
+ auto img2ColTensor = rewriter.create<linalg::GenericOp>(
+ loc, colTensor.getType(),
+ /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
+ img2colIterators,
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
+ });
+
+ SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
+ auto reshapedFilterType =
+ RankedTensorType::get({oc, fh * fw * ic}, inputType.getElementType());
+ Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
+ loc, reshapedFilterType, filter, filterReassocIndices);
+
+ SmallVector<ReassociationIndices> img2ColTensorReassocIndices;
+ SmallVector<ReassociationIndices> outputReassocIndices;
+ RankedTensorType reshapedImg2ColTensorType, reshapedOutputType;
+ if (n == 1) {
+ img2ColTensorReassocIndices = {{0, 1, 2, 3}, {4, 5}};
+ outputReassocIndices = {{0, 1}, {2, 3}};
+
+ reshapedImg2ColTensorType = RankedTensorType::get(
+ {fh * fw * ic, oh * ow}, inputType.getElementType());
+ reshapedOutputType =
+ RankedTensorType::get({oc, oh * ow}, outputType.getElementType());
+ } else {
+ img2ColTensorReassocIndices = {{0}, {1, 2, 3}, {4, 5}};
+ outputReassocIndices = {{0}, {1}, {2, 3}};
+
+ reshapedImg2ColTensorType = RankedTensorType::get(
+ {n, fh * fw * ic, oh * ow}, inputType.getElementType());
+ reshapedOutputType =
+ RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType());
+ }
+
+ Value reshapedImg2ColTensor = rewriter.create<tensor::CollapseShapeOp>(
+ loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
+ img2ColTensorReassocIndices);
+
+ Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
+ loc, reshapedOutputType, output, outputReassocIndices);
+
+ Value result;
+ if (n == 1) {
+ auto matmulOp = rewriter.create<linalg::MatmulOp>(
+ loc, reshapedOutputType,
+ ArrayRef<Value>{reshapedFilter, reshapedImg2ColTensor},
+ ArrayRef<Value>{reshapedOutput});
+ result = matmulOp.getResults().front();
+ } else {
+ // For cases where batch is not 1, we need to keep the batch dimension
+ // separate. Because the filter does not share the same batch dimension,
+ // the batch dimension is only used in indexing the input and output. Thus
+ // we cannot use existing linalg named ops like linalg.batch_matmul.
+ // i.e. M x K * (B x) K x N = (B x) M x N
+ AffineExpr bDim, mDim, nDim, kDim;
+ bindDims(context, bDim, mDim, nDim, kDim);
+ auto lhsMap = AffineMap::get(4, 0, {mDim, kDim}, context);
+ auto rhsMap = AffineMap::get(4, 0, {bDim, kDim, nDim}, context);
+ auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
+ SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
+ parallel, reduction};
+ auto genericOp = rewriter.create<linalg::GenericOp>(
+ loc, reshapedOutputType,
+ /*inputs=*/ValueRange{reshapedFilter, reshapedImg2ColTensor},
+ /*outputs=*/ValueRange{reshapedOutput},
+ ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ Value mul = createMul(loc, args[0], args[1], nestedBuilder);
+ Value add = createAdd(loc, mul, args[2], nestedBuilder);
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
+ });
+ result = genericOp.getResults().front();
+ }
+
+ auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
+ loc, outputType, result, outputReassocIndices);
+
+ rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
+
+ return std::make_pair(img2ColTensor.getOperation(),
+ reshapedResult.getOperation());
+}
+
+namespace {
+
+class ConvertConv2DNhwcHwcf final
+ : public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
+ PatternRewriter &rewriter) const override {
+ if (failed(rewriteInIm2Col(rewriter, convOp)))
+ return failure();
+ return success();
+ }
+};
+
+class ConvertDepthwiseConv2DNhwcHwc final
+ : public OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> {
+public:
+ using OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,
+ PatternRewriter &rewriter) const override {
+ if (failed(rewriteInIm2Col(rewriter, convOp)))
+ return failure();
+ return success();
+ }
+};
+
+class ConvertConv2DNchwFchw final
+ : public OpRewritePattern<linalg::Conv2DNchwFchwOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,
+ PatternRewriter &rewriter) const override {
+ if (failed(rewriteInIm2Col(rewriter, convOp)))
+ return failure();
+ return success();
+ }
+};
+} // end anonymous namespace
+
+void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns) {
+ MLIRContext *context = patterns.getContext();
+ patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
+ ConvertConv2DNchwFchw>(context);
+}
+} // end namespace linalg
+} // end namespace mlir
diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
new file mode 100644
index 0000000000000..e33e51ddababb
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
@@ -0,0 +1,245 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics | FileCheck %s
+
+// Check that the im2col patterns are properly connected with the
+// transform dialect.
+
+// Non static shapes are not supported.
+// Check that we emit an error.
+// TODO: Hook up the rewriter errors in transform dialect.
+func.func @conv_non_static(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ // expected-note at below {{when applied to this op}}
+ %0 = linalg.conv_2d_nhwc_hwcf
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<?x?x?x?xf32>, tensor<3x3x4x16xf32>)
+ outs(%arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ // expected-error at below {{failed to apply}}
+ %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
+}
+
+// -----
+
+// Check that we get the proper handles for the img2col tensor producer
+// and the final instruction.
+
+// CHECK: IR printer: tensor_producer
+// CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d3, d2 + d4, d5)>,
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>]
+// CHECK: ^bb0(%[[IN_DATA:.+]]: f32, %[[OUT_DATA:.+]]: f32)
+// CHECK: linalg.yield %[[IN_DATA]] : f32
+
+// CHECK: IR printer: transformed
+// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0, 1, 2], [3]] : tensor<196x16xf32> into tensor<1x14x14x16xf32>
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d3, d2 + d4, d5)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK: @conv_16433136
+// CHECK: %[[INPUT:.+]]: tensor<1x16x16x4xf32>
+// CHECK: %[[FILTER:.+]]: tensor<3x3x4x16xf32>
+// CHECK: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32>
+// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x14x14x3x3x4xf32>
+// CHECK: %[[COL_TENSOR:.+]] = linalg.generic
+// CHECK-SAME: #[[MAP0]]
+// CHECK-SAME: #[[MAP1]]
+// CHECK: ^bb0(%[[IN_DATA:.+]]: f32, %[[OUT_DATA:.+]]: f32)
+// CHECK: linalg.yield %[[IN_DATA]] : f32
+// CHECK-DAG: %[[RESHAPED_INIT_COL_TENSOR:.+]] = tensor.collapse_shape %[[COL_TENSOR]]
+// CHECK-SAME: [0, 1, 2], [3, 4, 5]
+// CHECK-SAME: tensor<1x14x14x3x3x4xf32> into tensor<196x36xf32>
+// CHECK-DAG: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]]
+// CHECK-SAME: [0, 1, 2], [3]
+// CHECK-SAME: tensor<3x3x4x16xf32> into tensor<36x16xf32>
+// CHECK-DAG: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]]
+// CHECK-SAME: [0, 1, 2], [3]
+// CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_INIT_COL_TENSOR]], %[[RESHAPED_FILTER]] : tensor<196x36xf32>, tensor<36x16xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<196x16xf32>)
+// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]] : tensor<196x16xf32> into tensor<1x14x14x16xf32>
+// CHECK: return %[[RESULT]]
+
+func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+ %0 = linalg.conv_2d_nhwc_hwcf
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
+ outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ return %0 : tensor<1x14x14x16xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
+ transform.print %img2col_tensor_producer {name = "tensor_producer"}: !pdl.operation
+ transform.print %transformed {name = "transformed"}: !pdl.operation
+}
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
+// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1, d2)>
+// CHECK: @depthwise_conv_hwc_114x16x3
+// CHECK-SAME: %[[INPUT:.+]]: tensor<1x114x114x16xf32>
+// CHECK-SAME: %[[FILTER:.+]]: tensor<3x3x16xf32>
+// CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x112x112x16xf32>
+// CHECK: %[[INPUT_T_INIT:.+]] = tensor.empty() : tensor<1x16x114x114xf32>
+// CHECK: %[[INPUT_T:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[INPUT]] : tensor<1x114x114x16xf32>) outs(%[[INPUT_T_INIT]] : tensor<1x16x114x114xf32>) {
+// CHECK-NEXT: ^bb0(%[[ARG3:.+]]: f32, %[[ARG4:.+]]: f32):
+// CHECK-NEXT: linalg.yield %[[ARG3]] : f32
+// CHECK-NEXT: } -> tensor<1x16x114x114xf32>
+// CHECK: %[[FILTER_T_INIT:.+]] = tensor.empty() : tensor<16x3x3xf32>
+// CHECK: %[[FILTER_T:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[FILTER]] : tensor<3x3x16xf32>) outs(%[[FILTER_T_INIT]] : tensor<16x3x3xf32>) {
+// CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32):
+// CHECK: linalg.yield
+// CHECK: } -> tensor<16x3x3xf32>
+// CHECK: %[[INIT_OUTPUT_TENSOR:.+]] = tensor.empty() : tensor<1x16x112x112xf32>
+// CHECK: %[[OUTPUT_T:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[OUTPUT]] : tensor<1x112x112x16xf32>) outs(%[[INIT_OUTPUT_TENSOR]] : tensor<1x16x112x112xf32>) {
+// CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32):
+// CHECK-NEXT: linalg.yield
+// CHECK-NEXT: } -> tensor<1x16x112x112xf32>
+// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x16x112x112x3x3xf32>
+// CHECK: %[[COL_TENSOR:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP4]], #[[MAP5]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[INPUT_T]] : tensor<1x16x114x114xf32>) outs(%[[INIT_COL_TENSOR]] : tensor<1x16x112x112x3x3xf32>) {
+// CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32):
+// CHECK-NEXT: linalg.yield
+// CHECK-NEXT: } -> tensor<1x16x112x112x3x3xf32>
+// CHECK: %[[COL_TENSOR_R:.+]] = tensor.collapse_shape %[[COL_TENSOR]]
+// CHECK-SAME: tensor<1x16x112x112x3x3xf32> into tensor<16x12544x9xf32>
+// CHECK: %[[FILTER_T_R:.+]] = tensor.collapse_shape %[[FILTER_T]]
+// CHECK-SAME: tensor<16x3x3xf32> into tensor<16x9xf32>
+// CHECK: %[[OUTPUT_T_R:.+]] = tensor.collapse_shape %[[OUTPUT_T]]
+// CHECK-SAME: tensor<1x16x112x112xf32> into tensor<16x12544xf32>
+// CHECK: %[[BMV_RESULT:.+]] = linalg.batch_matvec ins(%[[COL_TENSOR_R]], %[[FILTER_T_R]] : tensor<16x12544x9xf32>, tensor<16x9xf32>) outs(%[[OUTPUT_T_R]] : tensor<16x12544xf32>) -> tensor<16x12544xf32>
+// CHECK: %[[RESULT_R:.+]] = tensor.expand_shape %[[BMV_RESULT]]
+// CHECK-SAME: tensor<16x12544xf32> into tensor<1x16x112x112xf32>
+// CHECK: %[[RESULT_INIT:.+]] = tensor.empty() : tensor<1x112x112x16xf32>
+// CHECK: %[[RESULT:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP6]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[RESULT_R]] : tensor<1x16x112x112xf32>) outs(%[[RESULT_INIT]] : tensor<1x112x112x16xf32>) {
+// CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32):
+// CHECK-NEXT: linalg.yield
+// CHECK-NEXT: } -> tensor<1x112x112x16xf32>
+// CHECK: return %[[RESULT]] : tensor<1x112x112x16xf32>
+func.func @depthwise_conv_hwc_114x16x3(%input: tensor<1x114x114x16xf32>, %filter: tensor<3x3x16xf32>, %output: tensor<1x112x112x16xf32>) -> tensor<1x112x112x16xf32> {
+ %0 = linalg.depthwise_conv_2d_nhwc_hwc {
+ dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>
+ } ins(%input, %filter : tensor<1x114x114x16xf32>, tensor<3x3x16xf32>) outs(%output : tensor<1x112x112x16xf32>) -> tensor<1x112x112x16xf32>
+ return %0 : tensor<1x112x112x16xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_2d_nhwc_hwc"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
+}
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d3, d2 + d4, d5)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK: func.func @batch_nhwc_conv
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<8x16x16x4xf32>, %[[FILTER:.+]]: tensor<3x3x4x16xf32>, %[[INIT:.+]]: tensor<8x14x14x16xf32>)
+// CHECK: %[[IT:.+]] = tensor.empty() : tensor<8x14x14x3x3x4xf32>
+// CHECK: %[[IMG2COL:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[INPUT]] : tensor<8x16x16x4xf32>)
+// CHECK-SAME: outs(%[[IT]] : tensor<8x14x14x3x3x4xf32>)
+// CHECK: %[[CS_INPUT:.+]] = tensor.collapse_shape %[[IMG2COL]] {{\[}}[0], [1, 2], [3, 4, 5]] : tensor<8x14x14x3x3x4xf32> into tensor<8x196x36xf32>
+// CHECK: %[[CS_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<3x3x4x16xf32> into tensor<36x16xf32>
+// CHECK: %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2], [3]] : tensor<8x14x14x16xf32> into tensor<8x196x16xf32>
+// CHECK: %[[MATMUL:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME: ins(%[[CS_INPUT]], %[[CS_FILTER]] : tensor<8x196x36xf32>, tensor<36x16xf32>)
+// CHECK-SAME: outs(%[[CS_RESULT]] : tensor<8x196x16xf32>)
+// CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32):
+// CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
+// CHECK: %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
+// CHECK: linalg.yield %[[ADD]] : f32
+// CHECK: } -> tensor<8x196x16xf32>
+// CHECK: %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1, 2], [3]] : tensor<8x196x16xf32> into tensor<8x14x14x16xf32>
+// CHECK: return %[[CS_FINAL]]
+func.func @batch_nhwc_conv(%arg0: tensor<8x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<8x14x14x16xf32>) -> tensor<8x14x14x16xf32> {
+ %0 = linalg.conv_2d_nhwc_hwcf
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<8x16x16x4xf32>, tensor<3x3x4x16xf32>)
+ outs(%arg2: tensor<8x14x14x16xf32>) -> tensor<8x14x14x16xf32>
+ return %0 : tensor<8x14x14x16xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
+}
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4 + d2, d5 + d3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK: func.func @batch_nchw_conv
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<8x4x16x16xf32>, %[[FILTER:.+]]: tensor<16x4x3x3xf32>, %[[INIT:.+]]: tensor<8x16x14x14xf32>)
+// CHECK: %[[IT:.+]] = tensor.empty() : tensor<8x4x3x3x14x14xf32>
+// CHECK: %[[IMG2COL:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[INPUT]] : tensor<8x4x16x16xf32>)
+// CHECK-SAME: outs(%[[IT]] : tensor<8x4x3x3x14x14xf32>)
+// CHECK: %[[CS_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<16x4x3x3xf32> into tensor<16x36xf32>
+// CHECK: %[[CS_INPUT:.+]] = tensor.collapse_shape %[[IMG2COL]] {{\[}}[0], [1, 2, 3], [4, 5]] : tensor<8x4x3x3x14x14xf32> into tensor<8x36x196xf32>
+// CHECK: %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1], [2, 3]] : tensor<8x16x14x14xf32> into tensor<8x16x196xf32>
+// CHECK: %[[MATMUL:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME: ins(%[[CS_FILTER]], %[[CS_INPUT]] : tensor<16x36xf32>, tensor<8x36x196xf32>)
+// CHECK-SAME: outs(%[[CS_RESULT]] : tensor<8x16x196xf32>)
+// CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32):
+// CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
+// CHECK: %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
+// CHECK: linalg.yield %[[ADD]] : f32
+// CHECK: } -> tensor<8x16x196xf32>
+// CHECK: %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1], [2, 3]] : tensor<8x16x196xf32> into tensor<8x16x14x14xf32>
+// CHECK: return %[[CS_FINAL]]
+func.func @batch_nchw_conv(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32>, %arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> {
+ %0 = linalg.conv_2d_nchw_fchw
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32>)
+ outs(%arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32>
+ return %0 : tensor<8x16x14x14xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
+}
More information about the Mlir-commits
mailing list