[Mlir-commits] [mlir] [mlir][linalg] Introduce new `linalg.conv` op (PR #117688)

Felix Schneider llvmlistbot at llvm.org
Tue Nov 26 01:26:49 PST 2024


https://github.com/ubfx created https://github.com/llvm/llvm-project/pull/117688

This patch lays the groundwork for the new `linalg.conv` op which is designed to replace the multitude of `linalg.conv_...` as well as `linalg.depthwise_conv_...` ops.

A test pass is implemented which can convert the old conv ops to the new op. The `linalg-generalize-named-ops` can then be used to convert both the old and the new ops to a `linalg.generic` op for comparison.

>From 4ff4a8a71ef5d118497c269d8f9b60bedd2d4ec8 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Tue, 26 Nov 2024 10:25:47 +0100
Subject: [PATCH] [mlir][linalg] Introduce new `linalg.conv` op

This patch lays the groundwork for the new `linalg.conv` op which is
designed to replace the multitude of `linalg.conv_...` as well as
`linalg.depthwise_conv_...` ops.

A test pass is implemented which can convert the old conv ops to the new
op. The `linalg-generalize-named-ops` can then be used to convert both the
old and the new ops to a `linalg.generic` op for comparison.
---
 .../mlir/Dialect/Linalg/IR/LinalgBase.td      |  39 ++
 .../mlir/Dialect/Linalg/IR/LinalgEnums.td     |  26 +
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h |  27 +
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 116 ++++
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 246 +++++++
 .../Dialect/Linalg/generalize-new-conv.mlir   | 656 ++++++++++++++++++
 mlir/test/Dialect/Linalg/roundtrip.mlir       |  57 ++
 mlir/test/lib/Dialect/Linalg/CMakeLists.txt   |   3 +-
 mlir/test/lib/Dialect/Linalg/TestNewConv.cpp  | 187 +++++
 mlir/tools/mlir-opt/mlir-opt.cpp              |   2 +
 10 files changed, 1358 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Dialect/Linalg/generalize-new-conv.mlir
 create mode 100644 mlir/test/lib/Dialect/Linalg/TestNewConv.cpp

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 73f984dc072d31..b659241b5ed5b7 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -81,4 +81,43 @@ def IteratorTypeEnum : EnumAttr<Linalg_Dialect, IteratorType, "iterator_type"> {
 def IteratorTypeArrayAttr : TypedArrayAttrBase<IteratorTypeEnum,
   "Iterator type should be an enum.">;
 
+
+def ConvolutionDimArray : ArrayRefParameter<"ConvDimEnum"> {
+  let printer = [{
+    $_printer << '{';
+    llvm::interleaveComma($_self, $_printer, [&](ConvDimEnum en) {
+        $_printer.printStrippedAttrOrType(en);
+    });
+    $_printer << '}';
+  }];
+
+  let parser = [{
+    [&]() -> FailureOr<SmallVector<ConvDimEnum>> {
+        using Result = SmallVector<ConvDimEnum>;
+        if ($_parser.parseLBrace())
+            return failure();
+        FailureOr<Result> result = FieldParser<Result>::parse($_parser);
+        if (failed(result))
+            return failure();
+        if ($_parser.parseRBrace())
+            return failure();
+        return result;
+    }()
+  }];
+}
+
+/// Attribute that represents an ordered set of tensor dimensions involved in
+/// convolution.
+def ConvDimsAttr : AttrDef<Linalg_Dialect, "ConvDims", [], "::mlir::Attribute"> {
+  let mnemonic = "conv_dims";
+
+  let parameters = (ins
+    ConvolutionDimArray:$dims
+  );
+
+  let assemblyFormat = "$dims";
+
+  let returnType = "mlir::linalg::ConvDims";
+  let convertFromStorage = "mlir::linalg::ConvDims($_self.getDims())";
+}
 #endif // LINALG_BASE
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index e615876a95d057..ef9e00822fbe3b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -63,4 +63,30 @@ def TypeFn : I32EnumAttr<"TypeFn", "", [
   let cppNamespace = "::mlir::linalg";
 }
 
+
+class ConvDimEnumAttrCase<string sym, int val, string str = sym>
+    : IntEnumAttrCaseBase<I8, sym, str, val>;
+
+def ConvDimEnumAttr :
+    IntEnumAttr<I8, "ConvDimEnum", "summary", [
+      /// Batch is a dimension of input and output, indexed from a parallel loop.
+      ConvDimEnumAttrCase<"BATCH", 0, "N">,
+      /// Input channel is a dimension in all tensors, indexed from a reduction loop.
+      /// Depthwise convolutions perform no reduction across channels and therefore
+      /// do not use this.
+      ConvDimEnumAttrCase<"INPUT_CHANNEL", 1, "C">,
+      /// Output channel is a dimension in filter and output, index from a parallel loop.
+      ConvDimEnumAttrCase<"OUTPUT_CHANNEL", 2, "F">,
+      /// Group is a dimension in all tensors and indexed from a parallel loop.
+      ConvDimEnumAttrCase<"GROUP", 3, "G">,
+      /// Spatial dimensions occur in all tensors. Output is indexed from a parallel
+      /// loop, filter from a reduction loop and input from both.
+      ConvDimEnumAttrCase<"SPATIAL_0", 4, "0">,
+      ConvDimEnumAttrCase<"SPATIAL_1", 5, "1">,
+      ConvDimEnumAttrCase<"SPATIAL_2", 6, "2">,
+    ]> {
+  let underlyingType = "uint8_t";
+  let cppNamespace = "::mlir::linalg";
+}
+
 #endif // LINALG_ENUMS
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 6f1c243cc4396d..752fcd8affaa27 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -117,6 +117,33 @@ FailureOr<ConvolutionDimensions> inferConvolutionDims(LinalgOp linalgOp);
 bool isaConvolutionOpInterface(LinalgOp linalgOp,
                                bool allowEmptyConvolvedDims = false);
 
+enum class ConvDimEnum : uint8_t;
+class ConvDims {
+  ArrayRef<ConvDimEnum> storage;
+
+public:
+  ConvDims() = default;
+  ConvDims(ArrayRef<ConvDimEnum> dims) : storage(dims) {}
+  ConvDims(SmallVectorImpl<ConvDimEnum> &dims) : storage(dims) {}
+
+  bool contains(ConvDimEnum dim) const {
+    return llvm::is_contained(storage, dim);
+  }
+
+  int64_t getPos(ConvDimEnum dim) const {
+    auto it = llvm::find(storage, dim);
+    assert(it != storage.end() && "expected dimension to be present");
+
+    return std::distance(storage.begin(), it);
+  }
+
+  int64_t size() const { return storage.size(); }
+  operator ArrayRef<ConvDimEnum>() const { return storage; }
+
+  auto begin() const { return storage.begin(); }
+  auto end() const { return storage.end(); }
+};
+
 /// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
 bool isaCopyOpInterface(LinalgOp linalgOp);
 
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 37eec6e07963b1..09b2dfd75cf67e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -683,6 +683,122 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
     }];
 }
 
+//===----------------------------------------------------------------------===//
+// Op definition for ConvOp
+//===----------------------------------------------------------------------===//
+
+def ConvOp : LinalgStructuredBase_Op<"conv", [AttrSizedOperandSegments]> {
+
+  let summary = [{
+    Configurable convolution operation with configurable tensor layouts.
+  }];
+  let description = [{
+    Numeric casting is performed on the operands to the inner multiply,
+    promoting them to the same data type as the accumulator/output.
+
+    The subtype of convolution is defined by the tensor layouts of `input`,
+    `filter`, and `output`. For example, a standard batched 2D convolution:
+
+    ```
+      %0 = linalg.conv {
+          input_dims = #linalg<conv_dims {N, C, "1", "0"}>,
+          filter_dims = #linalg<conv_dims {F, C, "1", "0"}>,
+          output_dims = #linalg<conv_dims {N, F, "1", "0"}>
+        }
+        ins(%input, %filter : tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32>)
+        outs(%output : tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32>
+    ```
+
+    This op could be turned into a depthwise convolution as follows:
+    ```
+      %0 = linalg.conv {
+          input_dims = #linalg<conv_dims {N, G, "1", "0"}>,
+          filter_dims = #linalg<conv_dims {G, "1", "0"}>,
+          output_dims = #linalg<conv_dims {N, G, "1", "0"}>
+        }
+        ins(%input, %filter : tensor<8x4x16x16xf32>, tensor<4x3x3xf32>)
+        outs(%output : tensor<8x4x14x14xf32>) -> tensor<8x4x14x14xf32>
+    ```
+
+    For the detailed semantics of the available tensor dimensions, refer to
+    `mlir::linalg::ConvDimsEnum`.
+
+    Strides and dilations can be supplied as optional attributes, where
+    `strides[0]` is the stride for the `SPATIAL_0` dimension, etc.
+  }];
+
+  let arguments = (ins
+    Variadic<AnyType>:$inputs, Variadic<AnyShaped>:$outputs,
+    ConvDimsAttr:$input_dims, ConvDimsAttr:$filter_dims, ConvDimsAttr:$output_dims,
+    OptionalAttr<I64ElementsAttr>:$strides, OptionalAttr<I64ElementsAttr>:$dilations
+  );
+  let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+  let regions = (region AnyRegion:$region);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<
+      (ins "TypeRange":$resTys, "Value":$input, "Value":$filter, "Value":$output, "ConvDims":$input_dims,
+            "ConvDims":$filter_dims, "ConvDims":$output_dims, "ArrayRef<int64_t>":$strides,
+            "ArrayRef<int64_t>":$dilations, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildConvOp($_builder, $_state, resTys, input, filter, output,
+            input_dims, filter_dims, output_dims, strides, dilations,
+            attributes, ConvOp::getRegionBuilder());
+      }]>,
+    OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs, "ConvDimsAttr":$input_dims,
+            "ConvDimsAttr":$filter_dims, "ConvDimsAttr":$output_dims,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildConvOp($_builder, $_state, std::nullopt, inputs, outputs,
+            input_dims, filter_dims, output_dims, nullptr, nullptr,
+            attributes, ConvOp::getRegionBuilder());
+      }]>,
+    OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+            "ValueRange":$outputs, "ConvDimsAttr":$input_dims,
+            "ConvDimsAttr":$filter_dims, "ConvDimsAttr":$output_dims,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildConvOp($_builder, $_state, resultTensorTypes,
+            inputs, outputs, input_dims, filter_dims, output_dims, nullptr, nullptr,
+            attributes, ConvOp::getRegionBuilder());
+      }]>
+  ];
+  let hasCustomAssemblyFormat = 1;
+  let hasFolder = 1;
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = structuredOpsBaseDecls # [{
+    SmallVector<utils::IteratorType> getIteratorTypesArray();
+    ArrayAttr getIndexingMaps();
+
+    /// Implements the block region builder.
+    static void regionBuilder(ImplicitLocOpBuilder &b,
+                              Block &block, ArrayRef<NamedAttribute> attrs);
+
+    /// Returns a list of AffineMap with the typical matmul indexing charactristic.
+    static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
+
+    static std::function<void(ImplicitLocOpBuilder &,
+                              Block &, ArrayRef<NamedAttribute>)>
+    getRegionBuilder() { return regionBuilder; }
+
+    ::mlir::MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
+
+    bool hasDynamicIndexingMaps() { return true; }
+
+    /// Returns the number of spatial dimensions, i.e. 1 for 1D convolution,
+    /// 2 for 2D convolution, etc.
+    int64_t getNumSpatialDims();
+
+    bool isDepthwise();
+    bool isGrouped();
+    bool isBatched();
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 8973e87c063b33..03d9a7f3f09ce3 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -203,6 +203,41 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state,
                            attributes, regionBuilder);
 }
 
+static void buildConvOp(OpBuilder &b, OperationState &state,
+                        std::optional<TypeRange> resultTensorTypes,
+                        ValueRange inputs, ValueRange outputs,
+                        ConvDimsAttr inputDims, ConvDimsAttr filterDims,
+                        ConvDimsAttr outputDims, Attribute strides,
+                        Attribute dilations,
+                        ArrayRef<NamedAttribute> attributes,
+                        RegionBuilderFn regionBuilder) {
+  state.addAttribute("input_dims", inputDims);
+  state.addAttribute("filter_dims", filterDims);
+  state.addAttribute("output_dims", outputDims);
+  if (strides)
+    state.addAttribute("strides", strides);
+
+  if (dilations)
+    state.addAttribute("dilations", dilations);
+  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+                           attributes, regionBuilder);
+}
+
+static void buildConvOp(OpBuilder &b, OperationState &state,
+                        std::optional<TypeRange> resultTensorTypes, Value input,
+                        Value filter, Value output, ConvDims inputDims,
+                        ConvDims filterDims, ConvDims outputDims,
+                        ArrayRef<int64_t> strides, ArrayRef<int64_t> dilations,
+                        ArrayRef<NamedAttribute> attributes,
+                        RegionBuilderFn regionBuilder) {
+  auto iAttr = ConvDimsAttr::get(b.getContext(), inputDims);
+  auto fAttr = ConvDimsAttr::get(b.getContext(), filterDims);
+  auto oAttr = ConvDimsAttr::get(b.getContext(), outputDims);
+  return buildConvOp(b, state, resultTensorTypes, {input, filter}, {output},
+                     iAttr, fAttr, oAttr, b.getI64VectorAttr(strides),
+                     b.getI64VectorAttr(dilations), attributes, regionBuilder);
+}
+
 /// Common parsing used for both named structured ops created by ods-gen and by
 /// manually defined C++ ops. Does not handle regions.
 static ParseResult
@@ -3611,5 +3646,216 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+//===----------------------------------------------------------------------===//
+// ConvOp
+//===----------------------------------------------------------------------===//
+
+bool ConvOp::isDepthwise() {
+  return !getFilterDims().contains(ConvDimEnum::INPUT_CHANNEL);
+}
+
+bool ConvOp::isGrouped() {
+  // If not all tensors contain the GROUP dimension, then it's either not a
+  // grouped convolution, or the number of groups is 1, which we also don't
+  // consider grouped.
+  return getInputDims().contains(ConvDimEnum::GROUP) &&
+         getFilterDims().contains(ConvDimEnum::GROUP) &&
+         getOutputDims().contains(ConvDimEnum::GROUP);
+}
+
+bool ConvOp::isBatched() {
+  // Both input and output tensors must contain the BATCH dimension.
+  return getInputDims().contains(ConvDimEnum::BATCH) &&
+         getOutputDims().contains(ConvDimEnum::BATCH);
+}
+
+int64_t ConvOp::getNumSpatialDims() {
+  if (getInputDims().contains(ConvDimEnum::SPATIAL_2))
+    return 3;
+  if (getInputDims().contains(ConvDimEnum::SPATIAL_1))
+    return 2;
+  return 1;
+}
+
+SmallVector<utils::IteratorType> ConvOp::getIteratorTypesArray() {
+  int numParallelDims = getOutputDims().size();
+
+  int numReductionDims = getNumSpatialDims();
+  if (!isDepthwise())
+    ++numReductionDims; // input channel
+
+  SmallVector<utils::IteratorType> iteratorTypes(numParallelDims,
+                                                 utils::IteratorType::parallel);
+  iteratorTypes.append(numReductionDims, utils::IteratorType::reduction);
+  return iteratorTypes;
+}
+
+ArrayAttr ConvOp::getIndexingMaps() {
+  ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(
+      LinalgDialect::kMemoizedIndexingMapsAttrName);
+  if (cached)
+    return cached;
+
+  Builder b(getContext());
+  SmallVector<AffineExpr> strides, dilations;
+  {
+    SmallVector<int64_t> strideValues, dilationValues;
+
+    if (getStrides())
+      strideValues = SmallVector<int64_t>(getStrides()->getValues<int64_t>());
+    else
+      strideValues = SmallVector<int64_t>(getNumSpatialDims(), 1);
+
+    if (getDilations())
+      dilationValues =
+          SmallVector<int64_t>(getDilations()->getValues<int64_t>());
+    else
+      dilationValues = SmallVector<int64_t>(getNumSpatialDims(), 1);
+
+    for (int j = 0; j < getNumSpatialDims(); ++j) {
+      strides.push_back(b.getAffineConstantExpr(strideValues[j]));
+      dilations.push_back(b.getAffineConstantExpr(dilationValues[j]));
+    }
+  }
+
+  llvm::DenseMap<ConvDimEnum, AffineExpr> parallelDims;
+  llvm::DenseMap<ConvDimEnum, AffineExpr> reductionDims;
+  SmallVector<AffineExpr> oExprs;
+
+  // Via the iterator types, we have defined the parallel loops to come first,
+  // followed by the reduction loops. We choose the order of the parallel loops
+  // to match the order of the output tensor dimensions. This is arbitrary and
+  // is done to follow the convention which most/some of the old linalg
+  // convolution ops follow.
+  int64_t i = 0;
+  for (auto d : getOutputDims()) {
+    auto expr = b.getAffineDimExpr(i++);
+    parallelDims[d] = expr;
+    oExprs.push_back(expr);
+  }
+  // Reduction loops are ordered to match the order of the filter tensor.
+  for (auto d : getFilterDims())
+    if (d == ConvDimEnum::INPUT_CHANNEL || d == ConvDimEnum::SPATIAL_0 ||
+        d == ConvDimEnum::SPATIAL_1 || d == ConvDimEnum::SPATIAL_2)
+      reductionDims[d] = b.getAffineDimExpr(i++);
+
+  SmallVector<AffineExpr> iExprs =
+      llvm::map_to_vector(getInputDims(), [&](ConvDimEnum dim) -> AffineExpr {
+        switch (dim) {
+        case ConvDimEnum::SPATIAL_0:
+          return (parallelDims[dim] * strides[0]) +
+                 (reductionDims[dim] * dilations[0]);
+        case ConvDimEnum::SPATIAL_1:
+          return (parallelDims[dim] * strides[1]) +
+                 (reductionDims[dim] * dilations[1]);
+        case ConvDimEnum::SPATIAL_2:
+          return (parallelDims[dim] * strides[2]) +
+                 (reductionDims[dim] * dilations[2]);
+        case ConvDimEnum::INPUT_CHANNEL:
+          return reductionDims[dim];
+        default:
+          return parallelDims[dim];
+        }
+      });
+  SmallVector<AffineExpr> fExprs =
+      llvm::map_to_vector(getFilterDims(), [&](ConvDimEnum dim) -> AffineExpr {
+        if (reductionDims.contains(dim))
+          return reductionDims[dim];
+        return parallelDims[dim];
+      });
+
+  cached = b.getAffineMapArrayAttr(
+      {AffineMap::get(getNumLoops(), 0, iExprs, getContext()),
+       AffineMap::get(getNumLoops(), 0, fExprs, getContext()),
+       AffineMap::get(getNumLoops(), 0, oExprs, getContext())});
+  getOperation()->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
+  return cached;
+}
+
+void ConvOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                           ArrayRef<NamedAttribute> attrs) {
+  RegionBuilderHelper helper(b, block);
+  SmallVector<Value> yields;
+
+  TypeFn castVal = TypeFn::cast_signed;
+  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+    return attr.getName() == "cast";
+  });
+  if (castIter != attrs.end()) {
+    if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+      castVal = attr.getValue();
+  }
+
+  Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+                                    block.getArgument(0));
+  Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+                                    block.getArgument(1));
+  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+  Value value4 =
+      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+  yields.push_back(value4);
+  helper.yieldOutputs(yields);
+}
+
+ParseResult ConvOp::parse(OpAsmParser &parser, OperationState &result) {
+  return ::parseNamedStructuredOp(parser, result, 3,
+                                  ConvOp::getRegionBuilder());
+}
+void ConvOp::print(OpAsmPrinter &p) {
+  SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes",
+                                           "linalg.memoized_indexing_maps"};
+  ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+                           elidedAttrs);
+}
+
+LogicalResult ConvOp::verify() {
+  // Batch dimension cannot be present in filter tensor.
+  if (getFilterDims().contains(ConvDimEnum::BATCH))
+    return emitOpError("Batch dimension cannot be present in filter tensor.");
+
+  // Output channel cannot be present in input tensor.
+  if (getInputDims().contains(ConvDimEnum::OUTPUT_CHANNEL))
+    return emitOpError("Output channel cannot be present in input tensor.");
+
+  // Higher space dimensions cannot occur without the respective lower ones, so
+  // as to work with the `strides` and `dilations` attributes.
+  bool isSpat2 = getInputDims().contains(ConvDimEnum::SPATIAL_2);
+  bool isSpat1 = getInputDims().contains(ConvDimEnum::SPATIAL_1);
+  bool isSpat0 = getInputDims().contains(ConvDimEnum::SPATIAL_0);
+
+  if ((isSpat2 && (!isSpat1 || !isSpat0)) || (isSpat1 && !isSpat0))
+    return emitOpError("Inconsistent spatial dimensions in `input_dims`.");
+
+  if (!isSpat0)
+    return emitOpError("Requires at least one spatial dimension.");
+
+  // Spatial dimensions have to match between all tensors.
+  if (isSpat2 != getFilterDims().contains(ConvDimEnum::SPATIAL_2) ||
+      isSpat2 != getOutputDims().contains(ConvDimEnum::SPATIAL_2) ||
+      isSpat1 != getFilterDims().contains(ConvDimEnum::SPATIAL_1) ||
+      isSpat1 != getOutputDims().contains(ConvDimEnum::SPATIAL_1) ||
+      isSpat0 != getFilterDims().contains(ConvDimEnum::SPATIAL_0) ||
+      isSpat0 != getOutputDims().contains(ConvDimEnum::SPATIAL_0))
+    return emitOpError("Inconsistent spatial dimensions between tensors.");
+
+  return success();
+}
+
+LogicalResult ConvOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+  return memref::foldMemRefCast(*this);
+}
+
+void ConvOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (hasPureTensorSemantics())
+    return;
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability ConvOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
 } // namespace linalg
 } // namespace mlir
diff --git a/mlir/test/Dialect/Linalg/generalize-new-conv.mlir b/mlir/test/Dialect/Linalg/generalize-new-conv.mlir
new file mode 100644
index 00000000000000..676c69e2d1a305
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/generalize-new-conv.mlir
@@ -0,0 +1,656 @@
+// RUN: mlir-opt %s -split-input-file -test-linalg-new-conv -linalg-generalize-named-ops | FileCheck %s
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+// CHECK: module {
+// CHECK:   func.func @conv_1d_ncw_fcw(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+// CHECK:     %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%arg2 : tensor<?x?x?xf32>) {
+// CHECK:     ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK:       %1 = arith.mulf %in, %in_0 : f32
+// CHECK:       %2 = arith.addf %out, %1 : f32
+// CHECK:       linalg.yield %2 : f32
+// CHECK:     } -> tensor<?x?x?xf32>
+// CHECK:     return %0 : tensor<?x?x?xf32>
+// CHECK:   }
+// CHECK: }
+func.func @conv_1d_ncw_fcw(%input: tensor<?x?x?xf32>, %filter: tensor<?x?x?xf32>, %init: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.conv_1d_ncw_fcw {dilations = dense<1> : tensor<1xi64>,
+                                            strides = dense<1> : tensor<1xi64>}
+     ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+    outs (%init: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+// CHECK: module {
+// CHECK:   func.func @conv_1d_nwc_wcf(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+// CHECK:     %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%arg2 : tensor<?x?x?xf32>) {
+// CHECK:     ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK:       %1 = arith.mulf %in, %in_0 : f32
+// CHECK:       %2 = arith.addf %out, %1 : f32
+// CHECK:       linalg.yield %2 : f32
+// CHECK:     } -> tensor<?x?x?xf32>
+// CHECK:     return %0 : tensor<?x?x?xf32>
+// CHECK:   }
+// CHECK: }
+func.func @conv_1d_nwc_wcf(%input: tensor<?x?x?xf32>, %filter: tensor<?x?x?xf32>, %init: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.conv_1d_nwc_wcf {dilations = dense<1> : tensor<1xi64>,
+                                            strides = dense<1> : tensor<1xi64>}
+     ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+    outs (%init: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 * 2 + d6 * 3, d4 * 2 + d7 * 3)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+// CHECK: module {
+// CHECK:   func.func @conv_2d_ngchw_fgchw_dilated_strided(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tensor<?x?x?x?x?xf32>, %arg2: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+// CHECK:     %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) outs(%arg2 : tensor<?x?x?x?x?xf32>) {
+// CHECK:     ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK:       %1 = arith.mulf %in, %in_0 : f32
+// CHECK:       %2 = arith.addf %out, %1 : f32
+// CHECK:       linalg.yield %2 : f32
+// CHECK:     } -> tensor<?x?x?x?x?xf32>
+// CHECK:     return %0 : tensor<?x?x?x?x?xf32>
+// CHECK:   }
+// CHECK: }
+func.func @conv_2d_ngchw_fgchw_dilated_strided(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.conv_2d_ngchw_fgchw {dilations = dense<3> : tensor<2xi64>,
+                                              strides = dense<2> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+// CHECK: module {
+// CHECK:   func.func @conv_1d_nwc_wcf_memref(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+// CHECK:     linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 : memref<?x?x?xf32>) {
+// CHECK:     ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK:       %0 = arith.mulf %in, %in_0 : f32
+// CHECK:       %1 = arith.addf %out, %0 : f32
+// CHECK:       linalg.yield %1 : f32
+// CHECK:     }
+// CHECK:     return
+// CHECK:   }
+// CHECK: }
+func.func @conv_1d_nwc_wcf_memref(%input: memref<?x?x?xf32>, %filter: memref<?x?x?xf32>, %output: memref<?x?x?xf32>) {
+
+  linalg.conv_1d_nwc_wcf {dilations = dense<1> : tensor<1xi64>,
+                                       strides = dense<1> : tensor<1xi64>}
+     ins (%input, %filter: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs (%output: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1) -> (d0 + d1)>
+// CHECK: #map1 = affine_map<(d0, d1) -> (d1)>
+// CHECK: #map2 = affine_map<(d0, d1) -> (d0)>
+// CHECK: module {
+// CHECK:   func.func @conv1d_8_tensor(%arg0: tensor<11xf32>, %arg1: tensor<4xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> {
+// CHECK:     %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction"]} ins(%arg0, %arg1 : tensor<11xf32>, tensor<4xf32>) outs(%arg2 : tensor<8xf32>) {
+// CHECK:     ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK:       %1 = arith.mulf %in, %in_0 : f32
+// CHECK:       %2 = arith.addf %out, %1 : f32
+// CHECK:       linalg.yield %2 : f32
+// CHECK:     } -> tensor<8xf32>
+// CHECK:     return %0 : tensor<8xf32>
+// CHECK:   }
+// CHECK: }
+func.func @conv1d_8_tensor(%input: tensor<11xf32>, %filter: tensor<4xf32>, %output: tensor<8xf32>) -> tensor<8xf32> {
+  %0 = linalg.conv_1d ins(%input, %filter : tensor<11xf32>, tensor<4xf32>)
+                     outs(%output : tensor<8xf32>) -> tensor<8xf32>
+  return %0 : tensor<8xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+// CHECK: module {
+// CHECK:   func.func @batch_nchw_conv(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32>, %arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> {
+// CHECK:     %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32>) outs(%arg2 : tensor<8x16x14x14xf32>) {
+// CHECK:     ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK:       %1 = arith.mulf %in, %in_0 : f32
+// CHECK:       %2 = arith.addf %out, %1 : f32
+// CHECK:       linalg.yield %2 : f32
+// CHECK:     } -> tensor<8x16x14x14xf32>
+// CHECK:     return %0 : tensor<8x16x14x14xf32>
+// CHECK:   }
+// CHECK: }
+func.func @batch_nchw_conv(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32>, %arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> {
+    %0 = linalg.conv_2d_nchw_fchw
+      {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+       ins(%arg0, %arg1: tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32>)
+      outs(%arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32>
+    return %0 : tensor<8x16x14x14xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+// CHECK: module {
+// CHECK:   func.func @conv_2d_ngchw_fgchw(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tensor<?x?x?x?x?xf32>, %arg2: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+// CHECK:     %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) outs(%arg2 : tensor<?x?x?x?x?xf32>) {
+// CHECK:     ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK:       %1 = arith.mulf %in, %in_0 : f32
+// CHECK:       %2 = arith.addf %out, %1 : f32
+// CHECK:       linalg.yield %2 : f32
+// CHECK:     } -> tensor<?x?x?x?x?xf32>
+// CHECK:     return %0 : tensor<?x?x?x?x?xf32>
+// CHECK:   }
+// CHECK: }
+func.func @conv_2d_ngchw_fgchw(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.conv_2d_ngchw_fgchw {dilations = dense<1> : tensor<2xi64>,
+                                              strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+// CHECK: module {
+// CHECK:   func.func @conv_2d_ngchw_gfchw(%arg0: tensor<1x5x3x32x32xf32>, %arg1: tensor<5x2x3x3x3xf32>, %arg2: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32> {
+// CHECK:     %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<1x5x3x32x32xf32>, tensor<5x2x3x3x3xf32>) outs(%arg2 : tensor<1x5x2x30x30xf32>) {
+// CHECK:     ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK:       %1 = arith.mulf %in, %in_0 : f32
+// CHECK:       %2 = arith.addf %out, %1 : f32
+// CHECK:       linalg.yield %2 : f32
+// CHECK:     } -> tensor<1x5x2x30x30xf32>
+// CHECK:     return %0 : tensor<1x5x2x30x30xf32>
+// CHECK:   }
+// CHECK: }
+func.func @conv_2d_ngchw_gfchw(%input: tensor<1x5x3x32x32xf32>, %filter: tensor<5x2x3x3x3xf32>, %init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32> {
+
+  %0 = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : tensor<2xi64>,
+                                         strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<1x5x3x32x32xf32>, tensor<5x2x3x3x3xf32>)
+    outs (%init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
+  return %0 : tensor<1x5x2x30x30xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+// CHECK: module {
+// CHECK:   func.func @conv_2d_nhwc_fhwc(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+// CHECK:     %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) outs(%arg2 : tensor<?x?x?x?xf32>) {
+// CHECK:     ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK:       %1 = arith.mulf %in, %in_0 : f32
+// CHECK:       %2 = arith.addf %out, %1 : f32
+// CHECK:       linalg.yield %2 : f32
+// CHECK:     } -> tensor<?x?x?x?xf32>
+// CHECK:     return %0 : tensor<?x?x?x?xf32>
+// CHECK:   }
+// CHECK: }
+func.func @conv_2d_nhwc_fhwc(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+                                 strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+// CHECK: module {
+// CHECK:   func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+// CHECK:     %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>) outs(%arg2 : tensor<1x14x14x16xf32>) {
+// CHECK:     ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK:       %1 = arith.mulf %in, %in_0 : f32
+// CHECK:       %2 = arith.addf %out, %1 : f32
+// CHECK:       linalg.yield %2 : f32
+// CHECK:     } -> tensor<1x14x14x16xf32>
+// CHECK:     return %0 : tensor<1x14x14x16xf32>
+// CHECK:   }
+// CHECK: }
+func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+    %0 = linalg.conv_2d_nhwc_hwcf
+      {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+       ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
+      outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+    return %0 : tensor<1x14x14x16xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+// CHECK: module {
+// CHECK:   func.func @conv_2d_nhwgc_gfhwc(%arg0: memref<?x?x?x?x?xf32>, %arg1: memref<?x?x?x?x?xf32>, %arg2: memref<?x?x?x?x?xf32>) {
+// CHECK:     linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>) outs(%arg2 : memref<?x?x?x?x?xf32>) {
+// CHECK:     ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK:       %0 = arith.mulf %in, %in_0 : f32
+// CHECK:       %1 = arith.addf %out, %0 : f32
+// CHECK:       linalg.yield %1 : f32
+// CHECK:     }
+// CHECK:     return
+// CHECK:   }
+// CHECK: }
+func.func @conv_2d_nhwgc_gfhwc(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x?x?x?x?xf32>, %output: memref<?x?x?x?x?xf32>) {
+
+  linalg.conv_2d_nhwgc_gfhwc {dilations = dense<1> : tensor<2xi64>,
+                                         strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
+    outs (%output: memref<?x?x?x?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+// CHECK: module {
+// CHECK:   func.func @conv(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
+// CHECK:     linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>) outs(%arg2 : memref<?x?xf32>) {
+// CHECK:     ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK:       %0 = arith.mulf %in, %in_0 : f32
+// CHECK:       %1 = arith.addf %out, %0 : f32
+// CHECK:       linalg.yield %1 : f32
+// CHECK:     }
+// CHECK:     return
+// CHECK:   }
+// CHECK: }
+func.func @conv(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>, %arg2 : memref<?x?xf32>) {
+  linalg.conv_2d ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>) outs(%arg2 : memref<?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d5, d6, d7, d8)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
+// CHECK: module {
+// CHECK:   func.func @conv_3d_ncdhw_fcdhw(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tensor<?x?x?x?x?xf32>, %arg2: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+// CHECK:     %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) outs(%arg2 : tensor<?x?x?x?x?xf32>) {
+// CHECK:     ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK:       %1 = arith.mulf %in, %in_0 : f32
+// CHECK:       %2 = arith.addf %out, %1 : f32
+// CHECK:       linalg.yield %2 : f32
+// CHECK:     } -> tensor<?x?x?x?x?xf32>
+// CHECK:     return %0 : tensor<?x?x?x?x?xf32>
+// CHECK:   }
+// CHECK: }
+func.func @conv_3d_ncdhw_fcdhw(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+
+  %0 = linalg.conv_3d_ncdhw_fcdhw {dilations = dense<1> : tensor<3xi64>,
+                                                strides = dense<1> : tensor<3xi64>}
+     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
+// CHECK: module {
+// CHECK:   func.func @conv_3d_ndhwc_dhwcf(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tensor<?x?x?x?x?xf32>, %arg2: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+// CHECK:     %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) outs(%arg2 : tensor<?x?x?x?x?xf32>) {
+// CHECK:     ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK:       %1 = arith.mulf %in, %in_0 : f32
+// CHECK:       %2 = arith.addf %out, %1 : f32
+// CHECK:       linalg.yield %2 : f32
+// CHECK:     } -> tensor<?x?x?x?x?xf32>
+// CHECK:     return %0 : tensor<?x?x?x?x?xf32>
+// CHECK:   }
+// CHECK: }
+func.func @conv_3d_ndhwc_dhwcf(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+
+  %0 = linalg.conv_3d_ndhwc_dhwcf {dilations = dense<1> : tensor<3xi64>,
+                                                strides = dense<1> : tensor<3xi64>}
+     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
+// CHECK: module {
+// CHECK:   func.func @conv_3d(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+// CHECK:     linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 : memref<?x?x?xf32>) {
+// CHECK:     ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK:       %0 = arith.mulf %in, %in_0 : f32
+// CHECK:       %1 = arith.addf %out, %0 : f32
+// CHECK:       linalg.yield %1 : f32
+// CHECK:     }
+// CHECK:     return
+// CHECK:   }
+// CHECK: }
+func.func @conv_3d(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+  linalg.conv_3d ins (%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+                outs (%arg2: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2) -> ()>
+// CHECK: #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
+// CHECK: #map3 = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK: #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK: module {
+// CHECK:   func.func @depthwise_conv_1d_ncw_cw(%arg0: tensor<1x8x12xf32>, %arg1: tensor<8x3xf32>) -> tensor<1x8x10xf32> {
+// CHECK:     %cst = arith.constant 0.000000e+00 : f32
+// CHECK:     %0 = tensor.empty() : tensor<1x8x10xf32>
+// CHECK:     %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst : f32) outs(%0 : tensor<1x8x10xf32>) {
+// CHECK:     ^bb0(%in: f32, %out: f32):
+// CHECK:       linalg.yield %in : f32
+// CHECK:     } -> tensor<1x8x10xf32>
+// CHECK:     %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x8x12xf32>, tensor<8x3xf32>) outs(%1 : tensor<1x8x10xf32>) {
+// CHECK:     ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK:       %3 = arith.mulf %in, %in_0 : f32
+// CHECK:       %4 = arith.addf %out, %3 : f32
+// CHECK:       linalg.yield %4 : f32
+// CHECK:     } -> tensor<1x8x10xf32>
+// CHECK:     return %2 : tensor<1x8x10xf32>
+// CHECK:   }
+// CHECK: }
+func.func @depthwise_conv_1d_ncw_cw(%input: tensor<1x8x12xf32>, %filter: tensor<8x3xf32>) -> tensor<1x8x10xf32> {
+  %zero = arith.constant 0.000000e+00 : f32
+  %init = tensor.empty() : tensor<1x8x10xf32>
+  %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<1x8x10xf32>) -> tensor<1x8x10xf32>
+
+  %0 = linalg.depthwise_conv_1d_ncw_cw {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+    ins(%input, %filter : tensor<1x8x12xf32>, tensor<8x3xf32>)
+    outs(%fill : tensor<1x8x10xf32>) -> tensor<1x8x10xf32>
+  return %0 : tensor<1x8x10xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2) -> ()>
+// CHECK: #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
+// CHECK: #map3 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK: #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK: module {
+// CHECK:   func.func @depthwise_conv_1d_nwc_wc(%arg0: tensor<1x12x8xf32>, %arg1: tensor<3x8xf32>) -> tensor<1x10x8xf32> {
+// CHECK:     %cst = arith.constant 0.000000e+00 : f32
+// CHECK:     %0 = tensor.empty() : tensor<1x10x8xf32>
+// CHECK:     %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst : f32) outs(%0 : tensor<1x10x8xf32>) {
+// CHECK:     ^bb0(%in: f32, %out: f32):
+// CHECK:       linalg.yield %in : f32
+// CHECK:     } -> tensor<1x10x8xf32>
+// CHECK:     %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x12x8xf32>, tensor<3x8xf32>) outs(%1 : tensor<1x10x8xf32>) {
+// CHECK:     ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK:       %3 = arith.mulf %in, %in_0 : f32
+// CHECK:       %4 = arith.addf %out, %3 : f32
+// CHECK:       linalg.yield %4 : f32
+// CHECK:     } -> tensor<1x10x8xf32>
+// CHECK:     return %2 : tensor<1x10x8xf32>
+// CHECK:   }
+// CHECK: }
+func.func @depthwise_conv_1d_nwc_wc(%input: tensor<1x12x8xf32>, %filter: tensor<3x8xf32>) -> tensor<1x10x8xf32> {
+  %zero = arith.constant 0.000000e+00 : f32
+  %init = tensor.empty() : tensor<1x10x8xf32>
+  %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<1x10x8xf32>) -> tensor<1x10x8xf32>
+
+  %0 = linalg.depthwise_conv_1d_nwc_wc {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+    ins(%input, %filter : tensor<1x12x8xf32>, tensor<3x8xf32>)
+    outs(%fill : tensor<1x10x8xf32>) -> tensor<1x10x8xf32>
+  return %0 : tensor<1x10x8xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3) -> ()>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)>
+// CHECK: #map3 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)>
+// CHECK: #map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
+// CHECK: module {
+// CHECK:   func.func @depthwise_conv_1d_nwc_wcm(%arg0: tensor<1x12x8xf32>, %arg1: tensor<3x8x8xf32>) -> tensor<1x10x8x8xf32> {
+// CHECK:     %cst = arith.constant 0.000000e+00 : f32
+// CHECK:     %0 = tensor.empty() : tensor<1x10x8x8xf32>
+// CHECK:     %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cst : f32) outs(%0 : tensor<1x10x8x8xf32>) {
+// CHECK:     ^bb0(%in: f32, %out: f32):
+// CHECK:       linalg.yield %in : f32
+// CHECK:     } -> tensor<1x10x8x8xf32>
+// CHECK:     %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x12x8xf32>, tensor<3x8x8xf32>) outs(%1 : tensor<1x10x8x8xf32>) {
+// CHECK:     ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK:       %3 = arith.mulf %in, %in_0 : f32
+// CHECK:       %4 = arith.addf %out, %3 : f32
+// CHECK:       linalg.yield %4 : f32
+// CHECK:     } -> tensor<1x10x8x8xf32>
+// CHECK:     return %2 : tensor<1x10x8x8xf32>
+// CHECK:   }
+// CHECK: }
+func.func @depthwise_conv_1d_nwc_wcm(%input: tensor<1x12x8xf32>, %filter: tensor<3x8x8xf32>) -> tensor<1x10x8x8xf32> {
+  %zero = arith.constant 0.000000e+00 : f32
+  %init = tensor.empty() : tensor<1x10x8x8xf32>
+  %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<1x10x8x8xf32>) -> tensor<1x10x8x8xf32>
+
+  %0 = linalg.depthwise_conv_1d_nwc_wcm {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+    ins(%input, %filter : tensor<1x12x8xf32>, tensor<3x8x8xf32>)
+    outs(%fill : tensor<1x10x8x8xf32>) -> tensor<1x10x8x8xf32>
+  return %0 : tensor<1x10x8x8xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d4, d5)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+// CHECK: module {
+// CHECK:   func.func @depthwise_conv_2d_nchw_chw_tensor(%arg0: tensor<1x96x113x113xf32>, %arg1: tensor<96x3x3xf32>) -> tensor<1x96x56x56xf32> {
+// CHECK:     %0 = tensor.empty() : tensor<1x96x56x56xf32>
+// CHECK:     %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<1x96x113x113xf32>, tensor<96x3x3xf32>) outs(%0 : tensor<1x96x56x56xf32>) {
+// CHECK:     ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK:       %2 = arith.mulf %in, %in_0 : f32
+// CHECK:       %3 = arith.addf %out, %2 : f32
+// CHECK:       linalg.yield %3 : f32
+// CHECK:     } -> tensor<1x96x56x56xf32>
+// CHECK:     return %1 : tensor<1x96x56x56xf32>
+// CHECK:   }
+// CHECK: }
+func.func @depthwise_conv_2d_nchw_chw_tensor(%input: tensor<1x96x113x113xf32>, %filter: tensor<96x3x3xf32>) -> tensor<1x96x56x56xf32> {
+  %init = tensor.empty() : tensor<1x96x56x56xf32>
+
+  %0 = linalg.depthwise_conv_2d_nchw_chw {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
+         ins(%input, %filter: tensor<1x96x113x113xf32>, tensor<96x3x3xf32>)
+         outs(%init: tensor<1x96x56x56xf32>) -> tensor<1x96x56x56xf32>
+  return %0: tensor<1x96x56x56xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3) -> ()>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+// CHECK: #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
+// CHECK: #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+// CHECK: module {
+// CHECK:   func.func @convolution_depthwise(%arg0: tensor<1x10x196x48xf32>, %arg1: tensor<1x4x48xf32>) -> tensor<1x10x191x48xf32> {
+// CHECK:     %cst = arith.constant 0.000000e+00 : f32
+// CHECK:     %0 = tensor.empty() : tensor<1x10x191x48xf32>
+// CHECK:     %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cst : f32) outs(%0 : tensor<1x10x191x48xf32>) {
+// CHECK:     ^bb0(%in: f32, %out: f32):
+// CHECK:       linalg.yield %in : f32
+// CHECK:     } -> tensor<1x10x191x48xf32>
+// CHECK:     %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<1x10x196x48xf32>, tensor<1x4x48xf32>) outs(%1 : tensor<1x10x191x48xf32>) {
+// CHECK:     ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK:       %3 = arith.mulf %in, %in_0 : f32
+// CHECK:       %4 = arith.addf %out, %3 : f32
+// CHECK:       linalg.yield %4 : f32
+// CHECK:     } -> tensor<1x10x191x48xf32>
+// CHECK:     return %2 : tensor<1x10x191x48xf32>
+// CHECK:   }
+// CHECK: }
+func.func @convolution_depthwise(%input: tensor<1x10x196x48xf32>, %filter: tensor<1x4x48xf32>) -> tensor<1x10x191x48xf32> {
+  %cst = arith.constant 0.0 : f32 
+  %empty = tensor.empty() : tensor<1x10x191x48xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x10x191x48xf32>) -> tensor<1x10x191x48xf32>
+
+  %result = linalg.depthwise_conv_2d_nhwc_hwc {
+    dilations = dense<1> : tensor<2xi64>,
+    strides = dense<1> : tensor<2xi64>}
+    ins(%input, %filter : tensor<1x10x196x48xf32>, tensor<1x4x48xf32>)
+    outs(%fill : tensor<1x10x191x48xf32>) -> tensor<1x10x191x48xf32>
+
+  return %result : tensor<1x10x191x48xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
+// CHECK: module {
+// CHECK:   func.func @depthwise_conv_2d_nhwc_hwcm(%arg0: memref<2x4x5x2xf32>, %arg1: memref<2x2x2x3xf32>, %arg2: memref<2x3x4x2x3xf32>) {
+// CHECK:     linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>) outs(%arg2 : memref<2x3x4x2x3xf32>) {
+// CHECK:     ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK:       %0 = arith.mulf %in, %in_0 : f32
+// CHECK:       %1 = arith.addf %out, %0 : f32
+// CHECK:       linalg.yield %1 : f32
+// CHECK:     }
+// CHECK:     return
+// CHECK:   }
+// CHECK: }
+func.func @depthwise_conv_2d_nhwc_hwcm(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x3x4x2x3xf32>) {
+  linalg.depthwise_conv_2d_nhwc_hwcm
+     { dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+     ins(%input, %filter : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
+    outs(%output : memref<2x3x4x2x3xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4) -> ()>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2 * 2 + d5, d3 + d6, d4 * 3 + d7)>
+// CHECK: #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d5, d6, d7)>
+// CHECK: #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+// CHECK: module {
+// CHECK:   func.func @depthwise_conv_3d_ncdhw_cdhw(%arg0: tensor<2x6x6x13x12xf32>, %arg1: tensor<6x2x1x3xf32>) -> tensor<2x6x3x13x4xf32> {
+// CHECK:     %cst = arith.constant 0.000000e+00 : f32
+// CHECK:     %0 = tensor.empty() : tensor<2x6x3x13x4xf32>
+// CHECK:     %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%cst : f32) outs(%0 : tensor<2x6x3x13x4xf32>) {
+// CHECK:     ^bb0(%in: f32, %out: f32):
+// CHECK:       linalg.yield %in : f32
+// CHECK:     } -> tensor<2x6x3x13x4xf32>
+// CHECK:     %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<2x6x6x13x12xf32>, tensor<6x2x1x3xf32>) outs(%1 : tensor<2x6x3x13x4xf32>) {
+// CHECK:     ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK:       %3 = arith.mulf %in, %in_0 : f32
+// CHECK:       %4 = arith.addf %out, %3 : f32
+// CHECK:       linalg.yield %4 : f32
+// CHECK:     } -> tensor<2x6x3x13x4xf32>
+// CHECK:     return %2 : tensor<2x6x3x13x4xf32>
+// CHECK:   }
+// CHECK: }
+func.func @depthwise_conv_3d_ncdhw_cdhw(%input: tensor<2x6x6x13x12xf32>, %filter: tensor<6x2x1x3xf32>) -> tensor<2x6x3x13x4xf32> {
+  %zero = arith.constant 0.000000e+00 : f32
+  %init = tensor.empty() : tensor<2x6x3x13x4xf32>
+  %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<2x6x3x13x4xf32>) -> tensor<2x6x3x13x4xf32>
+
+  %0 = linalg.depthwise_conv_3d_ncdhw_cdhw {dilations = dense<1> : tensor<3xi64>, strides = dense<[2, 1, 3]> : tensor<3xi64>}
+    ins(%input, %filter : tensor<2x6x6x13x12xf32>, tensor<6x2x1x3xf32>)
+    outs(%fill : tensor<2x6x3x13x4xf32>) -> tensor<2x6x3x13x4xf32>
+  return %0 : tensor<2x6x3x13x4xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4) -> ()>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 * 2 + d5, d2 + d6, d3 * 3 + d7, d4)>
+// CHECK: #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7, d4)>
+// CHECK: #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+// CHECK: module {
+// CHECK:   func.func @depthwise_conv_3d_ndhwc_dhwc(%arg0: tensor<2x6x13x12x6xf32>, %arg1: tensor<2x1x3x6xf32>) -> tensor<2x3x13x4x6xf32> {
+// CHECK:     %cst = arith.constant 0.000000e+00 : f32
+// CHECK:     %0 = tensor.empty() : tensor<2x3x13x4x6xf32>
+// CHECK:     %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%cst : f32) outs(%0 : tensor<2x3x13x4x6xf32>) {
+// CHECK:     ^bb0(%in: f32, %out: f32):
+// CHECK:       linalg.yield %in : f32
+// CHECK:     } -> tensor<2x3x13x4x6xf32>
+// CHECK:     %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<2x6x13x12x6xf32>, tensor<2x1x3x6xf32>) outs(%1 : tensor<2x3x13x4x6xf32>) {
+// CHECK:     ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK:       %3 = arith.mulf %in, %in_0 : f32
+// CHECK:       %4 = arith.addf %out, %3 : f32
+// CHECK:       linalg.yield %4 : f32
+// CHECK:     } -> tensor<2x3x13x4x6xf32>
+// CHECK:     return %2 : tensor<2x3x13x4x6xf32>
+// CHECK:   }
+// CHECK: }
+func.func @depthwise_conv_3d_ndhwc_dhwc(%input: tensor<2x6x13x12x6xf32>, %filter: tensor<2x1x3x6xf32>) -> tensor<2x3x13x4x6xf32> {
+  %zero = arith.constant 0.000000e+00 : f32
+  %init = tensor.empty() : tensor<2x3x13x4x6xf32>
+  %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<2x3x13x4x6xf32>) -> tensor<2x3x13x4x6xf32>
+
+  %0 = linalg.depthwise_conv_3d_ndhwc_dhwc {dilations = dense<1> : tensor<3xi64>, strides = dense<[2, 1, 3]> : tensor<3xi64>}
+    ins(%input, %filter : tensor<2x6x13x12x6xf32>, tensor<2x1x3x6xf32>)
+    outs(%fill : tensor<2x3x13x4x6xf32>) -> tensor<2x3x13x4x6xf32>
+  return %0 : tensor<2x3x13x4x6xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 * 2 + d6, d2 + d7, d3 * 3 + d8, d4)>
+// CHECK: #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d6, d7, d8, d4, d5)>
+// CHECK: #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK: module {
+// CHECK:   func.func @depthwise_conv_3d_ndhwc_dhwcm(%arg0: tensor<2x6x13x12x6xf32>, %arg1: tensor<2x1x3x6x6xf32>) -> tensor<2x3x13x4x6x6xf32> {
+// CHECK:     %cst = arith.constant 0.000000e+00 : f32
+// CHECK:     %0 = tensor.empty() : tensor<2x3x13x4x6x6xf32>
+// CHECK:     %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%cst : f32) outs(%0 : tensor<2x3x13x4x6x6xf32>) {
+// CHECK:     ^bb0(%in: f32, %out: f32):
+// CHECK:       linalg.yield %in : f32
+// CHECK:     } -> tensor<2x3x13x4x6x6xf32>
+// CHECK:     %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<2x6x13x12x6xf32>, tensor<2x1x3x6x6xf32>) outs(%1 : tensor<2x3x13x4x6x6xf32>) {
+// CHECK:     ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK:       %3 = arith.mulf %in, %in_0 : f32
+// CHECK:       %4 = arith.addf %out, %3 : f32
+// CHECK:       linalg.yield %4 : f32
+// CHECK:     } -> tensor<2x3x13x4x6x6xf32>
+// CHECK:     return %2 : tensor<2x3x13x4x6x6xf32>
+// CHECK:   }
+// CHECK: }
+func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor<2x6x13x12x6xf32>, %filter: tensor<2x1x3x6x6xf32>) -> tensor<2x3x13x4x6x6xf32> {
+  %zero = arith.constant 0.000000e+00 : f32
+  %init = tensor.empty() : tensor<2x3x13x4x6x6xf32>
+  %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<2x3x13x4x6x6xf32>) -> tensor<2x3x13x4x6x6xf32>
+
+  %0 = linalg.depthwise_conv_3d_ndhwc_dhwcm {dilations = dense<1> : tensor<3xi64>, strides = dense<[2, 1, 3]> : tensor<3xi64>}
+    ins(%input, %filter : tensor<2x6x13x12x6xf32>, tensor<2x1x3x6x6xf32>)
+    outs(%fill : tensor<2x3x13x4x6x6xf32>) -> tensor<2x3x13x4x6x6xf32>
+  return %0 : tensor<2x3x13x4x6x6xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 1b8969bd115595..5a8f1dbf84db05 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -694,3 +694,60 @@ func.func @conv2d_channel_first_q_promote(%img: tensor<100x3x224x224xi8>, %filt:
 // CHECK-LABEL: func @conv2d_channel_first_q_promote(
 // CHECK:   %[[arg0:[a-zA-z0-9]*]]: tensor<100x3x224x224xi8>, %[[arg1:[a-zA-z0-9]*]]: tensor<64x3x5x5xi8>, %[[arg2:[a-zA-z0-9]*]]: i8, %[[arg3:[a-zA-z0-9]*]]: i8)
 // CHECK:         linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]] : tensor<100x3x224x224xi8>, tensor<64x3x5x5xi8>, i8, i8) outs(%{{.*}} : tensor<100x64x220x220xi32>) -> tensor<100x64x220x220xi32>
+
+// -----
+
+func.func @newconv_1d(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.conv {
+    filter_dims = #linalg<conv_dims {F, C, "0"}>,
+    input_dims = #linalg<conv_dims {N, C, "0"}>,
+    output_dims = #linalg<conv_dims {N, F, "0"}>
+    }
+    ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+    outs(%arg2 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL: func @newconv_1d(
+// CHECK:   %[[arg0:[a-zA-z0-9]*]]: tensor<?x?x?xf32>, %[[arg1:[a-zA-z0-9]*]]: tensor<?x?x?xf32>, %[[arg2:[a-zA-z0-9]*]]: tensor<?x?x?xf32>
+// CHECK:   linalg.conv {filter_dims = #linalg<conv_dims {F, C, "0"}>, input_dims = #linalg<conv_dims {N, C, "0"}>, output_dims = #linalg<conv_dims {N, F, "0"}>} ins(%[[arg0]], %[[arg1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[arg2]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+
+// -----
+
+func.func @newconv_depthwise_2d(%input: tensor<8x4x16x16xf32>, %filter: tensor<4x3x3xf32>) -> tensor<8x4x14x14xf32> {
+  %init = tensor.empty() : tensor<8x4x14x14xf32>
+
+  %0 = linalg.conv {
+      input_dims = #linalg<conv_dims {N, G, "1", "0"}>,
+      filter_dims = #linalg<conv_dims {G, "1", "0"}>,
+      output_dims = #linalg<conv_dims {N, G, "1", "0"}>
+    }
+    ins(%input, %filter : tensor<8x4x16x16xf32>, tensor<4x3x3xf32>)
+    outs(%init : tensor<8x4x14x14xf32>) -> tensor<8x4x14x14xf32>
+
+  return %0: tensor<8x4x14x14xf32>
+}
+
+// CHECK-LABEL: func @newconv_depthwise_2d(
+// CHECK:   %[[arg0:[a-zA-z0-9]*]]: tensor<8x4x16x16xf32>, %[[arg1:[a-zA-z0-9]*]]: tensor<4x3x3xf32>
+// CHECK:   %[[arg2:[a-zA-z0-9]*]] = tensor.empty() : tensor<8x4x14x14xf32>
+// CHECK:   linalg.conv {filter_dims = #linalg<conv_dims {G, "1", "0"}>, input_dims = #linalg<conv_dims {N, G, "1", "0"}>, output_dims = #linalg<conv_dims {N, G, "1", "0"}>} ins(%[[arg0]], %[[arg1]] : tensor<8x4x16x16xf32>, tensor<4x3x3xf32>) outs(%[[arg2]] : tensor<8x4x14x14xf32>) -> tensor<8x4x14x14xf32>
+
+// -----
+
+func.func @newconv_2d(%input: tensor<8x4x16x16xf32>, %filter: tensor<16x4x3x3xf32>) -> tensor<8x16x14x14xf32> {
+  %init = tensor.empty() : tensor<8x16x14x14xf32>  
+  %0 = linalg.conv {
+      input_dims = #linalg<conv_dims {N, C, "1", "0"}>,
+      filter_dims = #linalg<conv_dims {F, C, "1", "0"}>,
+      output_dims = #linalg<conv_dims {N, F, "1", "0"}>
+    }
+    ins(%input, %filter : tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32>)
+    outs(%init : tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32>
+  return %0 : tensor<8x16x14x14xf32>
+}
+
+// CHECK-LABEL: func @newconv_2d(
+// CHECK:   %[[arg0:[a-zA-z0-9]*]]: tensor<8x4x16x16xf32>, %[[arg1:[a-zA-z0-9]*]]: tensor<16x4x3x3xf32>
+// CHECK:   %[[arg2:[a-zA-z0-9]*]] = tensor.empty() : tensor<8x16x14x14xf32>
+// CHECK:   linalg.conv {filter_dims = #linalg<conv_dims {F, C, "1", "0"}>, input_dims = #linalg<conv_dims {N, C, "1", "0"}>, output_dims = #linalg<conv_dims {N, F, "1", "0"}>} ins(%[[arg0]], %[[arg1]] : tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32>) outs(%[[arg2]] : tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32>
diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index 283e426b4e5947..37d04821935700 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_library(MLIRLinalgTestPasses
   TestLinalgRankReduceContractionOps.cpp
   TestLinalgTransforms.cpp
   TestPadFusion.cpp
+  TestNewConv.cpp
 
   EXCLUDE_FROM_LIBMLIR
 
@@ -32,4 +33,4 @@ add_mlir_library(MLIRLinalgTestPasses
   MLIRVectorDialect
   MLIRVectorToSCF
   MLIRVectorTransforms
-  )
+)
diff --git a/mlir/test/lib/Dialect/Linalg/TestNewConv.cpp b/mlir/test/lib/Dialect/Linalg/TestNewConv.cpp
new file mode 100644
index 00000000000000..53564738171dbd
--- /dev/null
+++ b/mlir/test/lib/Dialect/Linalg/TestNewConv.cpp
@@ -0,0 +1,187 @@
+//===- TestNewConv.cpp - Test `linalg.conv` -------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a test pass which converts "old" convolution ops, e.g.
+// `linalg.depthwise_conv_2d_nhwc`, `linalg.conv_2d_nhwc`, etc., to the new
+// `linalg.conv` op.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+class OldToNewConv : public OpInterfaceRewritePattern<linalg::LinalgOp> {
+public:
+  using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::LinalgOp op,
+                                PatternRewriter &rewriter) const override {
+    if (llvm::isa<linalg::ConvOp>(op))
+      return failure();
+    auto nameStr = op->getName().stripDialect().str();
+
+    bool isDepthwise = nameStr.substr(0, 14) == "depthwise_conv";
+    if (isDepthwise)
+      nameStr = nameStr.substr(15);
+    else if (nameStr.substr(0, 4) == "conv")
+      nameStr = nameStr.substr(5);
+    else
+      return failure();
+
+    int64_t spatialDims;
+    {
+      auto dimensionality = nameStr.substr(0, 2);
+      if (dimensionality == "1d")
+        spatialDims = 1;
+      else if (dimensionality == "2d")
+        spatialDims = 2;
+      else if (dimensionality == "3d")
+        spatialDims = 3;
+      else
+        return failure();
+    }
+
+    SmallVector<ConvDimEnum, 4> inputDims, filterDims, outputDims;
+    if (nameStr.length() == 2) {
+
+      // These are the ops `conv_1d`, `conv_2d` and `conv_3d` which use only
+      // spatial dimensions.
+      if (spatialDims == 1)
+        filterDims = inputDims = {ConvDimEnum::SPATIAL_0};
+      else if (spatialDims == 2)
+        filterDims =
+            inputDims = {ConvDimEnum::SPATIAL_0, ConvDimEnum::SPATIAL_1};
+      else if (spatialDims == 3)
+        filterDims =
+            inputDims = {ConvDimEnum::SPATIAL_0, ConvDimEnum::SPATIAL_1,
+                         ConvDimEnum::SPATIAL_2};
+      else
+        return failure();
+
+    } else {
+      // This handles all the ops with specialized tensor dimension orders like
+      // `conv_2d_nhwc_fhwc`, `depthwise_conv_2d_nhwc_hwc`, etc.
+      auto specialization = nameStr.substr(3); // get rid of first _
+
+      // Separator between input and filter layout.
+      auto sep = specialization.find('_');
+      if (sep == StringRef::npos)
+        return failure();
+      auto inputDimStr = specialization.substr(0, sep);
+      auto filterDimStr = specialization.substr(sep + 1);
+
+      auto parseDim = [&](char c) -> ConvDimEnum {
+        switch (c) {
+        case 'n':
+          return ConvDimEnum::BATCH;
+        case 'h':
+          return ConvDimEnum::SPATIAL_1;
+        case 'w':
+          return ConvDimEnum::SPATIAL_0;
+        case 'd':
+          return ConvDimEnum::SPATIAL_2;
+        case 'f':
+          return ConvDimEnum::OUTPUT_CHANNEL;
+        case 'g':
+          return ConvDimEnum::GROUP;
+        case 'c':
+          // The old convolution ops use the letter 'c' to denote a
+          // non-reduction dimension in all tensors in the depthwise case. The
+          // new convolution captures this behavior in the group dimension.
+          return isDepthwise ? ConvDimEnum::GROUP : ConvDimEnum::INPUT_CHANNEL;
+        case 'm':
+          // Similarly, the old convolution ops use the letter 'm' to denote a
+          // parallel dimesion in filter and output in the depthwise case. This
+          // behavior is captured by the ordinary output channel dimension.
+          assert(isDepthwise && "Unexpected letter 'm' in non-depthwise conv");
+          return ConvDimEnum::OUTPUT_CHANNEL;
+        default:
+          llvm_unreachable("unknown dimensional character ");
+        }
+      };
+
+      inputDims = llvm::map_to_vector(inputDimStr, parseDim);
+      filterDims = llvm::map_to_vector(filterDimStr, parseDim);
+    }
+
+    // This is the behavior of the old convolution ops:
+    // The output dimension order is the same as the input dimension order, but
+    // output channel stands in for input channel...
+    for (auto d : inputDims)
+      if (d == ConvDimEnum::INPUT_CHANNEL)
+        outputDims.push_back(ConvDimEnum::OUTPUT_CHANNEL);
+      else
+        outputDims.push_back(d);
+    // ... and if the "depthwise channel multiplier" dimension 'm' appears, the
+    // output tensor has an additional dimension appended.
+    if (isDepthwise &&
+        llvm::is_contained(filterDims, ConvDimEnum::OUTPUT_CHANNEL))
+      outputDims.push_back(ConvDimEnum::OUTPUT_CHANNEL);
+
+    SmallVector<int64_t> strides(spatialDims, 1), dilations(spatialDims, 1);
+    // The old convolution ops order the strides and dilations in the order "D,
+    // H, W". We order them as spatial 0, spatial 1, spatial 2, so we have to
+    // reverse the order.
+    if (op->hasAttr("strides"))
+      strides = SmallVector<int64_t>(llvm::reverse(
+          SmallVector<int64_t>(op->getAttrOfType<DenseElementsAttr>("strides")
+                                   .getValues<int64_t>())));
+    if (op->hasAttr("dilations"))
+      dilations = SmallVector<int64_t>(llvm::reverse(
+          SmallVector<int64_t>(op->getAttrOfType<DenseElementsAttr>("dilations")
+                                   .getValues<int64_t>())));
+
+    rewriter.replaceOpWithNewOp<linalg::ConvOp>(
+        op, op->getResultTypes(), op->getOperand(0), op->getOperand(1),
+        op->getOperand(2), inputDims, filterDims, outputDims, strides,
+        dilations);
+
+    return success();
+  }
+};
+
+struct TestNewConvPass : public PassWrapper<TestNewConvPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestNewConvPass)
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<linalg::LinalgDialect, tensor::TensorDialect>();
+  }
+
+  StringRef getArgument() const final { return "test-linalg-new-conv"; }
+  StringRef getDescription() const final { return "Test new linalg.conv Op"; }
+
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    RewritePatternSet patterns(context);
+    ConversionTarget target(getContext());
+
+    target.addLegalOp<linalg::ConvOp>();
+    // Every non-converted old conv op should fail the converison.
+    target.markUnknownOpDynamicallyLegal([](Operation *op) {
+      return op->getName().getStringRef().str().find("conv") ==
+             std::string::npos;
+    });
+
+    patterns.add<OldToNewConv>(context);
+
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestNewConv() { PassRegistration<TestNewConvPass>(); }
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 002c3900056dee..25a8430500b6cc 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -111,6 +111,7 @@ void registerTestLinalgElementwiseFusion();
 void registerTestLinalgGreedyFusion();
 void registerTestLinalgRankReduceContractionOps();
 void registerTestLinalgTransforms();
+void registerTestNewConv();
 void registerTestLivenessAnalysisPass();
 void registerTestLivenessPass();
 void registerTestLoopFusion();
@@ -248,6 +249,7 @@ void registerTestPasses() {
   mlir::test::registerTestLinalgGreedyFusion();
   mlir::test::registerTestLinalgRankReduceContractionOps();
   mlir::test::registerTestLinalgTransforms();
+  mlir::test::registerTestNewConv();
   mlir::test::registerTestLivenessAnalysisPass();
   mlir::test::registerTestLivenessPass();
   mlir::test::registerTestLoopFusion();



More information about the Mlir-commits mailing list