[Mlir-commits] [mlir] [mlir][linalg] Introduce new `linalg.conv` op (PR #117688)
Felix Schneider
llvmlistbot at llvm.org
Tue Nov 26 01:26:49 PST 2024
https://github.com/ubfx created https://github.com/llvm/llvm-project/pull/117688
This patch lays the groundwork for the new `linalg.conv` op which is designed to replace the multitude of `linalg.conv_...` as well as `linalg.depthwise_conv_...` ops.
A test pass is implemented which can convert the old conv ops to the new op. The `linalg-generalize-named-ops` can then be used to convert both the old and the new ops to a `linalg.generic` op for comparison.
>From 4ff4a8a71ef5d118497c269d8f9b60bedd2d4ec8 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Tue, 26 Nov 2024 10:25:47 +0100
Subject: [PATCH] [mlir][linalg] Introduce new `linalg.conv` op
This patch lays the groundwork for the new `linalg.conv` op which is
designed to replace the multitude of `linalg.conv_...` as well as
`linalg.depthwise_conv_...` ops.
A test pass is implemented which can convert the old conv ops to the new
op. The `linalg-generalize-named-ops` can then be used to convert both the
old and the new ops to a `linalg.generic` op for comparison.
---
.../mlir/Dialect/Linalg/IR/LinalgBase.td | 39 ++
.../mlir/Dialect/Linalg/IR/LinalgEnums.td | 26 +
.../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 27 +
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 116 ++++
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 246 +++++++
.../Dialect/Linalg/generalize-new-conv.mlir | 656 ++++++++++++++++++
mlir/test/Dialect/Linalg/roundtrip.mlir | 57 ++
mlir/test/lib/Dialect/Linalg/CMakeLists.txt | 3 +-
mlir/test/lib/Dialect/Linalg/TestNewConv.cpp | 187 +++++
mlir/tools/mlir-opt/mlir-opt.cpp | 2 +
10 files changed, 1358 insertions(+), 1 deletion(-)
create mode 100644 mlir/test/Dialect/Linalg/generalize-new-conv.mlir
create mode 100644 mlir/test/lib/Dialect/Linalg/TestNewConv.cpp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 73f984dc072d31..b659241b5ed5b7 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -81,4 +81,43 @@ def IteratorTypeEnum : EnumAttr<Linalg_Dialect, IteratorType, "iterator_type"> {
def IteratorTypeArrayAttr : TypedArrayAttrBase<IteratorTypeEnum,
"Iterator type should be an enum.">;
+
+def ConvolutionDimArray : ArrayRefParameter<"ConvDimEnum"> {
+ let printer = [{
+ $_printer << '{';
+ llvm::interleaveComma($_self, $_printer, [&](ConvDimEnum en) {
+ $_printer.printStrippedAttrOrType(en);
+ });
+ $_printer << '}';
+ }];
+
+ let parser = [{
+ [&]() -> FailureOr<SmallVector<ConvDimEnum>> {
+ using Result = SmallVector<ConvDimEnum>;
+ if ($_parser.parseLBrace())
+ return failure();
+ FailureOr<Result> result = FieldParser<Result>::parse($_parser);
+ if (failed(result))
+ return failure();
+ if ($_parser.parseRBrace())
+ return failure();
+ return result;
+ }()
+ }];
+}
+
+/// Attribute that represents an ordered set of tensor dimensions involved in
+/// convolution.
+def ConvDimsAttr : AttrDef<Linalg_Dialect, "ConvDims", [], "::mlir::Attribute"> {
+ let mnemonic = "conv_dims";
+
+ let parameters = (ins
+ ConvolutionDimArray:$dims
+ );
+
+ let assemblyFormat = "$dims";
+
+ let returnType = "mlir::linalg::ConvDims";
+ let convertFromStorage = "mlir::linalg::ConvDims($_self.getDims())";
+}
#endif // LINALG_BASE
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index e615876a95d057..ef9e00822fbe3b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -63,4 +63,30 @@ def TypeFn : I32EnumAttr<"TypeFn", "", [
let cppNamespace = "::mlir::linalg";
}
+
+class ConvDimEnumAttrCase<string sym, int val, string str = sym>
+ : IntEnumAttrCaseBase<I8, sym, str, val>;
+
+def ConvDimEnumAttr :
+ IntEnumAttr<I8, "ConvDimEnum", "summary", [
+ /// Batch is a dimension of input and output, indexed from a parallel loop.
+ ConvDimEnumAttrCase<"BATCH", 0, "N">,
+ /// Input channel is a dimension in all tensors, indexed from a reduction loop.
+ /// Depthwise convolutions perform no reduction across channels and therefore
+ /// do not use this.
+ ConvDimEnumAttrCase<"INPUT_CHANNEL", 1, "C">,
+ /// Output channel is a dimension in filter and output, index from a parallel loop.
+ ConvDimEnumAttrCase<"OUTPUT_CHANNEL", 2, "F">,
+ /// Group is a dimension in all tensors and indexed from a parallel loop.
+ ConvDimEnumAttrCase<"GROUP", 3, "G">,
+ /// Spatial dimensions occur in all tensors. Output is indexed from a parallel
+ /// loop, filter from a reduction loop and input from both.
+ ConvDimEnumAttrCase<"SPATIAL_0", 4, "0">,
+ ConvDimEnumAttrCase<"SPATIAL_1", 5, "1">,
+ ConvDimEnumAttrCase<"SPATIAL_2", 6, "2">,
+ ]> {
+ let underlyingType = "uint8_t";
+ let cppNamespace = "::mlir::linalg";
+}
+
#endif // LINALG_ENUMS
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 6f1c243cc4396d..752fcd8affaa27 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -117,6 +117,33 @@ FailureOr<ConvolutionDimensions> inferConvolutionDims(LinalgOp linalgOp);
bool isaConvolutionOpInterface(LinalgOp linalgOp,
bool allowEmptyConvolvedDims = false);
+enum class ConvDimEnum : uint8_t;
+class ConvDims {
+ ArrayRef<ConvDimEnum> storage;
+
+public:
+ ConvDims() = default;
+ ConvDims(ArrayRef<ConvDimEnum> dims) : storage(dims) {}
+ ConvDims(SmallVectorImpl<ConvDimEnum> &dims) : storage(dims) {}
+
+ bool contains(ConvDimEnum dim) const {
+ return llvm::is_contained(storage, dim);
+ }
+
+ int64_t getPos(ConvDimEnum dim) const {
+ auto it = llvm::find(storage, dim);
+ assert(it != storage.end() && "expected dimension to be present");
+
+ return std::distance(storage.begin(), it);
+ }
+
+ int64_t size() const { return storage.size(); }
+ operator ArrayRef<ConvDimEnum>() const { return storage; }
+
+ auto begin() const { return storage.begin(); }
+ auto end() const { return storage.end(); }
+};
+
/// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
bool isaCopyOpInterface(LinalgOp linalgOp);
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 37eec6e07963b1..09b2dfd75cf67e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -683,6 +683,122 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
}];
}
+//===----------------------------------------------------------------------===//
+// Op definition for ConvOp
+//===----------------------------------------------------------------------===//
+
+def ConvOp : LinalgStructuredBase_Op<"conv", [AttrSizedOperandSegments]> {
+
+ let summary = [{
+ Configurable convolution operation with configurable tensor layouts.
+ }];
+ let description = [{
+ Numeric casting is performed on the operands to the inner multiply,
+ promoting them to the same data type as the accumulator/output.
+
+ The subtype of convolution is defined by the tensor layouts of `input`,
+ `filter`, and `output`. For example, a standard batched 2D convolution:
+
+ ```
+ %0 = linalg.conv {
+ input_dims = #linalg<conv_dims {N, C, "1", "0"}>,
+ filter_dims = #linalg<conv_dims {F, C, "1", "0"}>,
+ output_dims = #linalg<conv_dims {N, F, "1", "0"}>
+ }
+ ins(%input, %filter : tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32>)
+ outs(%output : tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32>
+ ```
+
+ This op could be turned into a depthwise convolution as follows:
+ ```
+ %0 = linalg.conv {
+ input_dims = #linalg<conv_dims {N, G, "1", "0"}>,
+ filter_dims = #linalg<conv_dims {G, "1", "0"}>,
+ output_dims = #linalg<conv_dims {N, G, "1", "0"}>
+ }
+ ins(%input, %filter : tensor<8x4x16x16xf32>, tensor<4x3x3xf32>)
+ outs(%output : tensor<8x4x14x14xf32>) -> tensor<8x4x14x14xf32>
+ ```
+
+ For the detailed semantics of the available tensor dimensions, refer to
+ `mlir::linalg::ConvDimsEnum`.
+
+ Strides and dilations can be supplied as optional attributes, where
+ `strides[0]` is the stride for the `SPATIAL_0` dimension, etc.
+ }];
+
+ let arguments = (ins
+ Variadic<AnyType>:$inputs, Variadic<AnyShaped>:$outputs,
+ ConvDimsAttr:$input_dims, ConvDimsAttr:$filter_dims, ConvDimsAttr:$output_dims,
+ OptionalAttr<I64ElementsAttr>:$strides, OptionalAttr<I64ElementsAttr>:$dilations
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+ let regions = (region AnyRegion:$region);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<
+ (ins "TypeRange":$resTys, "Value":$input, "Value":$filter, "Value":$output, "ConvDims":$input_dims,
+ "ConvDims":$filter_dims, "ConvDims":$output_dims, "ArrayRef<int64_t>":$strides,
+ "ArrayRef<int64_t>":$dilations, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildConvOp($_builder, $_state, resTys, input, filter, output,
+ input_dims, filter_dims, output_dims, strides, dilations,
+ attributes, ConvOp::getRegionBuilder());
+ }]>,
+ OpBuilder<
+ (ins "ValueRange":$inputs, "ValueRange":$outputs, "ConvDimsAttr":$input_dims,
+ "ConvDimsAttr":$filter_dims, "ConvDimsAttr":$output_dims,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildConvOp($_builder, $_state, std::nullopt, inputs, outputs,
+ input_dims, filter_dims, output_dims, nullptr, nullptr,
+ attributes, ConvOp::getRegionBuilder());
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs, "ConvDimsAttr":$input_dims,
+ "ConvDimsAttr":$filter_dims, "ConvDimsAttr":$output_dims,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildConvOp($_builder, $_state, resultTensorTypes,
+ inputs, outputs, input_dims, filter_dims, output_dims, nullptr, nullptr,
+ attributes, ConvOp::getRegionBuilder());
+ }]>
+ ];
+ let hasCustomAssemblyFormat = 1;
+ let hasFolder = 1;
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
+ ArrayAttr getIndexingMaps();
+
+ /// Implements the block region builder.
+ static void regionBuilder(ImplicitLocOpBuilder &b,
+ Block &block, ArrayRef<NamedAttribute> attrs);
+
+ /// Returns a list of AffineMap with the typical matmul indexing charactristic.
+ static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
+
+ static std::function<void(ImplicitLocOpBuilder &,
+ Block &, ArrayRef<NamedAttribute>)>
+ getRegionBuilder() { return regionBuilder; }
+
+ ::mlir::MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
+
+ bool hasDynamicIndexingMaps() { return true; }
+
+ /// Returns the number of spatial dimensions, i.e. 1 for 1D convolution,
+ /// 2 for 2D convolution, etc.
+ int64_t getNumSpatialDims();
+
+ bool isDepthwise();
+ bool isGrouped();
+ bool isBatched();
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Named Linalg ops, implemented as a declarative configurations of generic ops.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 8973e87c063b33..03d9a7f3f09ce3 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -203,6 +203,41 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state,
attributes, regionBuilder);
}
+static void buildConvOp(OpBuilder &b, OperationState &state,
+ std::optional<TypeRange> resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ConvDimsAttr inputDims, ConvDimsAttr filterDims,
+ ConvDimsAttr outputDims, Attribute strides,
+ Attribute dilations,
+ ArrayRef<NamedAttribute> attributes,
+ RegionBuilderFn regionBuilder) {
+ state.addAttribute("input_dims", inputDims);
+ state.addAttribute("filter_dims", filterDims);
+ state.addAttribute("output_dims", outputDims);
+ if (strides)
+ state.addAttribute("strides", strides);
+
+ if (dilations)
+ state.addAttribute("dilations", dilations);
+ return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+ attributes, regionBuilder);
+}
+
+static void buildConvOp(OpBuilder &b, OperationState &state,
+ std::optional<TypeRange> resultTensorTypes, Value input,
+ Value filter, Value output, ConvDims inputDims,
+ ConvDims filterDims, ConvDims outputDims,
+ ArrayRef<int64_t> strides, ArrayRef<int64_t> dilations,
+ ArrayRef<NamedAttribute> attributes,
+ RegionBuilderFn regionBuilder) {
+ auto iAttr = ConvDimsAttr::get(b.getContext(), inputDims);
+ auto fAttr = ConvDimsAttr::get(b.getContext(), filterDims);
+ auto oAttr = ConvDimsAttr::get(b.getContext(), outputDims);
+ return buildConvOp(b, state, resultTensorTypes, {input, filter}, {output},
+ iAttr, fAttr, oAttr, b.getI64VectorAttr(strides),
+ b.getI64VectorAttr(dilations), attributes, regionBuilder);
+}
+
/// Common parsing used for both named structured ops created by ods-gen and by
/// manually defined C++ ops. Does not handle regions.
static ParseResult
@@ -3611,5 +3646,216 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+//===----------------------------------------------------------------------===//
+// ConvOp
+//===----------------------------------------------------------------------===//
+
+bool ConvOp::isDepthwise() {
+ return !getFilterDims().contains(ConvDimEnum::INPUT_CHANNEL);
+}
+
+bool ConvOp::isGrouped() {
+ // If not all tensors contain the GROUP dimension, then it's either not a
+ // grouped convolution, or the number of groups is 1, which we also don't
+ // consider grouped.
+ return getInputDims().contains(ConvDimEnum::GROUP) &&
+ getFilterDims().contains(ConvDimEnum::GROUP) &&
+ getOutputDims().contains(ConvDimEnum::GROUP);
+}
+
+bool ConvOp::isBatched() {
+ // Both input and output tensors must contain the BATCH dimension.
+ return getInputDims().contains(ConvDimEnum::BATCH) &&
+ getOutputDims().contains(ConvDimEnum::BATCH);
+}
+
+int64_t ConvOp::getNumSpatialDims() {
+ if (getInputDims().contains(ConvDimEnum::SPATIAL_2))
+ return 3;
+ if (getInputDims().contains(ConvDimEnum::SPATIAL_1))
+ return 2;
+ return 1;
+}
+
+SmallVector<utils::IteratorType> ConvOp::getIteratorTypesArray() {
+ int numParallelDims = getOutputDims().size();
+
+ int numReductionDims = getNumSpatialDims();
+ if (!isDepthwise())
+ ++numReductionDims; // input channel
+
+ SmallVector<utils::IteratorType> iteratorTypes(numParallelDims,
+ utils::IteratorType::parallel);
+ iteratorTypes.append(numReductionDims, utils::IteratorType::reduction);
+ return iteratorTypes;
+}
+
+ArrayAttr ConvOp::getIndexingMaps() {
+ ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(
+ LinalgDialect::kMemoizedIndexingMapsAttrName);
+ if (cached)
+ return cached;
+
+ Builder b(getContext());
+ SmallVector<AffineExpr> strides, dilations;
+ {
+ SmallVector<int64_t> strideValues, dilationValues;
+
+ if (getStrides())
+ strideValues = SmallVector<int64_t>(getStrides()->getValues<int64_t>());
+ else
+ strideValues = SmallVector<int64_t>(getNumSpatialDims(), 1);
+
+ if (getDilations())
+ dilationValues =
+ SmallVector<int64_t>(getDilations()->getValues<int64_t>());
+ else
+ dilationValues = SmallVector<int64_t>(getNumSpatialDims(), 1);
+
+ for (int j = 0; j < getNumSpatialDims(); ++j) {
+ strides.push_back(b.getAffineConstantExpr(strideValues[j]));
+ dilations.push_back(b.getAffineConstantExpr(dilationValues[j]));
+ }
+ }
+
+ llvm::DenseMap<ConvDimEnum, AffineExpr> parallelDims;
+ llvm::DenseMap<ConvDimEnum, AffineExpr> reductionDims;
+ SmallVector<AffineExpr> oExprs;
+
+ // Via the iterator types, we have defined the parallel loops to come first,
+ // followed by the reduction loops. We choose the order of the parallel loops
+ // to match the order of the output tensor dimensions. This is arbitrary and
+ // is done to follow the convention which most/some of the old linalg
+ // convolution ops follow.
+ int64_t i = 0;
+ for (auto d : getOutputDims()) {
+ auto expr = b.getAffineDimExpr(i++);
+ parallelDims[d] = expr;
+ oExprs.push_back(expr);
+ }
+ // Reduction loops are ordered to match the order of the filter tensor.
+ for (auto d : getFilterDims())
+ if (d == ConvDimEnum::INPUT_CHANNEL || d == ConvDimEnum::SPATIAL_0 ||
+ d == ConvDimEnum::SPATIAL_1 || d == ConvDimEnum::SPATIAL_2)
+ reductionDims[d] = b.getAffineDimExpr(i++);
+
+ SmallVector<AffineExpr> iExprs =
+ llvm::map_to_vector(getInputDims(), [&](ConvDimEnum dim) -> AffineExpr {
+ switch (dim) {
+ case ConvDimEnum::SPATIAL_0:
+ return (parallelDims[dim] * strides[0]) +
+ (reductionDims[dim] * dilations[0]);
+ case ConvDimEnum::SPATIAL_1:
+ return (parallelDims[dim] * strides[1]) +
+ (reductionDims[dim] * dilations[1]);
+ case ConvDimEnum::SPATIAL_2:
+ return (parallelDims[dim] * strides[2]) +
+ (reductionDims[dim] * dilations[2]);
+ case ConvDimEnum::INPUT_CHANNEL:
+ return reductionDims[dim];
+ default:
+ return parallelDims[dim];
+ }
+ });
+ SmallVector<AffineExpr> fExprs =
+ llvm::map_to_vector(getFilterDims(), [&](ConvDimEnum dim) -> AffineExpr {
+ if (reductionDims.contains(dim))
+ return reductionDims[dim];
+ return parallelDims[dim];
+ });
+
+ cached = b.getAffineMapArrayAttr(
+ {AffineMap::get(getNumLoops(), 0, iExprs, getContext()),
+ AffineMap::get(getNumLoops(), 0, fExprs, getContext()),
+ AffineMap::get(getNumLoops(), 0, oExprs, getContext())});
+ getOperation()->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
+ return cached;
+}
+
+void ConvOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ RegionBuilderHelper helper(b, block);
+ SmallVector<Value> yields;
+
+ TypeFn castVal = TypeFn::cast_signed;
+ auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+ return attr.getName() == "cast";
+ });
+ if (castIter != attrs.end()) {
+ if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+ castVal = attr.getValue();
+ }
+
+ Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+ block.getArgument(0));
+ Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+ block.getArgument(1));
+ Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+ Value value4 =
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+ yields.push_back(value4);
+ helper.yieldOutputs(yields);
+}
+
+ParseResult ConvOp::parse(OpAsmParser &parser, OperationState &result) {
+ return ::parseNamedStructuredOp(parser, result, 3,
+ ConvOp::getRegionBuilder());
+}
+void ConvOp::print(OpAsmPrinter &p) {
+ SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes",
+ "linalg.memoized_indexing_maps"};
+ ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+ elidedAttrs);
+}
+
+LogicalResult ConvOp::verify() {
+ // Batch dimension cannot be present in filter tensor.
+ if (getFilterDims().contains(ConvDimEnum::BATCH))
+ return emitOpError("Batch dimension cannot be present in filter tensor.");
+
+ // Output channel cannot be present in input tensor.
+ if (getInputDims().contains(ConvDimEnum::OUTPUT_CHANNEL))
+ return emitOpError("Output channel cannot be present in input tensor.");
+
+ // Higher space dimensions cannot occur without the respective lower ones, so
+ // as to work with the `strides` and `dilations` attributes.
+ bool isSpat2 = getInputDims().contains(ConvDimEnum::SPATIAL_2);
+ bool isSpat1 = getInputDims().contains(ConvDimEnum::SPATIAL_1);
+ bool isSpat0 = getInputDims().contains(ConvDimEnum::SPATIAL_0);
+
+ if ((isSpat2 && (!isSpat1 || !isSpat0)) || (isSpat1 && !isSpat0))
+ return emitOpError("Inconsistent spatial dimensions in `input_dims`.");
+
+ if (!isSpat0)
+ return emitOpError("Requires at least one spatial dimension.");
+
+ // Spatial dimensions have to match between all tensors.
+ if (isSpat2 != getFilterDims().contains(ConvDimEnum::SPATIAL_2) ||
+ isSpat2 != getOutputDims().contains(ConvDimEnum::SPATIAL_2) ||
+ isSpat1 != getFilterDims().contains(ConvDimEnum::SPATIAL_1) ||
+ isSpat1 != getOutputDims().contains(ConvDimEnum::SPATIAL_1) ||
+ isSpat0 != getFilterDims().contains(ConvDimEnum::SPATIAL_0) ||
+ isSpat0 != getOutputDims().contains(ConvDimEnum::SPATIAL_0))
+ return emitOpError("Inconsistent spatial dimensions between tensors.");
+
+ return success();
+}
+
+LogicalResult ConvOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+ return memref::foldMemRefCast(*this);
+}
+
+void ConvOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ if (hasPureTensorSemantics())
+ return;
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability ConvOp::getSpeculatability() {
+ return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/test/Dialect/Linalg/generalize-new-conv.mlir b/mlir/test/Dialect/Linalg/generalize-new-conv.mlir
new file mode 100644
index 00000000000000..676c69e2d1a305
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/generalize-new-conv.mlir
@@ -0,0 +1,656 @@
+// RUN: mlir-opt %s -split-input-file -test-linalg-new-conv -linalg-generalize-named-ops | FileCheck %s
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+// CHECK: module {
+// CHECK: func.func @conv_1d_ncw_fcw(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+// CHECK: %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%arg2 : tensor<?x?x?xf32>) {
+// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK: %1 = arith.mulf %in, %in_0 : f32
+// CHECK: %2 = arith.addf %out, %1 : f32
+// CHECK: linalg.yield %2 : f32
+// CHECK: } -> tensor<?x?x?xf32>
+// CHECK: return %0 : tensor<?x?x?xf32>
+// CHECK: }
+// CHECK: }
+func.func @conv_1d_ncw_fcw(%input: tensor<?x?x?xf32>, %filter: tensor<?x?x?xf32>, %init: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = linalg.conv_1d_ncw_fcw {dilations = dense<1> : tensor<1xi64>,
+ strides = dense<1> : tensor<1xi64>}
+ ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ outs (%init: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+// CHECK: module {
+// CHECK: func.func @conv_1d_nwc_wcf(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+// CHECK: %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%arg2 : tensor<?x?x?xf32>) {
+// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK: %1 = arith.mulf %in, %in_0 : f32
+// CHECK: %2 = arith.addf %out, %1 : f32
+// CHECK: linalg.yield %2 : f32
+// CHECK: } -> tensor<?x?x?xf32>
+// CHECK: return %0 : tensor<?x?x?xf32>
+// CHECK: }
+// CHECK: }
+func.func @conv_1d_nwc_wcf(%input: tensor<?x?x?xf32>, %filter: tensor<?x?x?xf32>, %init: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = linalg.conv_1d_nwc_wcf {dilations = dense<1> : tensor<1xi64>,
+ strides = dense<1> : tensor<1xi64>}
+ ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ outs (%init: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 * 2 + d6 * 3, d4 * 2 + d7 * 3)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+// CHECK: module {
+// CHECK: func.func @conv_2d_ngchw_fgchw_dilated_strided(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tensor<?x?x?x?x?xf32>, %arg2: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+// CHECK: %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) outs(%arg2 : tensor<?x?x?x?x?xf32>) {
+// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK: %1 = arith.mulf %in, %in_0 : f32
+// CHECK: %2 = arith.addf %out, %1 : f32
+// CHECK: linalg.yield %2 : f32
+// CHECK: } -> tensor<?x?x?x?x?xf32>
+// CHECK: return %0 : tensor<?x?x?x?x?xf32>
+// CHECK: }
+// CHECK: }
+func.func @conv_2d_ngchw_fgchw_dilated_strided(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+ %0 = linalg.conv_2d_ngchw_fgchw {dilations = dense<3> : tensor<2xi64>,
+ strides = dense<2> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+ outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+// CHECK: module {
+// CHECK: func.func @conv_1d_nwc_wcf_memref(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+// CHECK: linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 : memref<?x?x?xf32>) {
+// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK: %0 = arith.mulf %in, %in_0 : f32
+// CHECK: %1 = arith.addf %out, %0 : f32
+// CHECK: linalg.yield %1 : f32
+// CHECK: }
+// CHECK: return
+// CHECK: }
+// CHECK: }
+func.func @conv_1d_nwc_wcf_memref(%input: memref<?x?x?xf32>, %filter: memref<?x?x?xf32>, %output: memref<?x?x?xf32>) {
+
+ linalg.conv_1d_nwc_wcf {dilations = dense<1> : tensor<1xi64>,
+ strides = dense<1> : tensor<1xi64>}
+ ins (%input, %filter: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs (%output: memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1) -> (d0 + d1)>
+// CHECK: #map1 = affine_map<(d0, d1) -> (d1)>
+// CHECK: #map2 = affine_map<(d0, d1) -> (d0)>
+// CHECK: module {
+// CHECK: func.func @conv1d_8_tensor(%arg0: tensor<11xf32>, %arg1: tensor<4xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> {
+// CHECK: %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction"]} ins(%arg0, %arg1 : tensor<11xf32>, tensor<4xf32>) outs(%arg2 : tensor<8xf32>) {
+// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK: %1 = arith.mulf %in, %in_0 : f32
+// CHECK: %2 = arith.addf %out, %1 : f32
+// CHECK: linalg.yield %2 : f32
+// CHECK: } -> tensor<8xf32>
+// CHECK: return %0 : tensor<8xf32>
+// CHECK: }
+// CHECK: }
+func.func @conv1d_8_tensor(%input: tensor<11xf32>, %filter: tensor<4xf32>, %output: tensor<8xf32>) -> tensor<8xf32> {
+ %0 = linalg.conv_1d ins(%input, %filter : tensor<11xf32>, tensor<4xf32>)
+ outs(%output : tensor<8xf32>) -> tensor<8xf32>
+ return %0 : tensor<8xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+// CHECK: module {
+// CHECK: func.func @batch_nchw_conv(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32>, %arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> {
+// CHECK: %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32>) outs(%arg2 : tensor<8x16x14x14xf32>) {
+// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK: %1 = arith.mulf %in, %in_0 : f32
+// CHECK: %2 = arith.addf %out, %1 : f32
+// CHECK: linalg.yield %2 : f32
+// CHECK: } -> tensor<8x16x14x14xf32>
+// CHECK: return %0 : tensor<8x16x14x14xf32>
+// CHECK: }
+// CHECK: }
+func.func @batch_nchw_conv(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32>, %arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> {
+ %0 = linalg.conv_2d_nchw_fchw
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32>)
+ outs(%arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32>
+ return %0 : tensor<8x16x14x14xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+// CHECK: module {
+// CHECK: func.func @conv_2d_ngchw_fgchw(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tensor<?x?x?x?x?xf32>, %arg2: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+// CHECK: %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) outs(%arg2 : tensor<?x?x?x?x?xf32>) {
+// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK: %1 = arith.mulf %in, %in_0 : f32
+// CHECK: %2 = arith.addf %out, %1 : f32
+// CHECK: linalg.yield %2 : f32
+// CHECK: } -> tensor<?x?x?x?x?xf32>
+// CHECK: return %0 : tensor<?x?x?x?x?xf32>
+// CHECK: }
+// CHECK: }
+func.func @conv_2d_ngchw_fgchw(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+ %0 = linalg.conv_2d_ngchw_fgchw {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+ outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+// CHECK: module {
+// CHECK: func.func @conv_2d_ngchw_gfchw(%arg0: tensor<1x5x3x32x32xf32>, %arg1: tensor<5x2x3x3x3xf32>, %arg2: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32> {
+// CHECK: %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<1x5x3x32x32xf32>, tensor<5x2x3x3x3xf32>) outs(%arg2 : tensor<1x5x2x30x30xf32>) {
+// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK: %1 = arith.mulf %in, %in_0 : f32
+// CHECK: %2 = arith.addf %out, %1 : f32
+// CHECK: linalg.yield %2 : f32
+// CHECK: } -> tensor<1x5x2x30x30xf32>
+// CHECK: return %0 : tensor<1x5x2x30x30xf32>
+// CHECK: }
+// CHECK: }
+func.func @conv_2d_ngchw_gfchw(%input: tensor<1x5x3x32x32xf32>, %filter: tensor<5x2x3x3x3xf32>, %init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32> {
+
+ %0 = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<1x5x3x32x32xf32>, tensor<5x2x3x3x3xf32>)
+ outs (%init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
+ return %0 : tensor<1x5x2x30x30xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+// CHECK: module {
+// CHECK: func.func @conv_2d_nhwc_fhwc(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+// CHECK: %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) outs(%arg2 : tensor<?x?x?x?xf32>) {
+// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK: %1 = arith.mulf %in, %in_0 : f32
+// CHECK: %2 = arith.addf %out, %1 : f32
+// CHECK: linalg.yield %2 : f32
+// CHECK: } -> tensor<?x?x?x?xf32>
+// CHECK: return %0 : tensor<?x?x?x?xf32>
+// CHECK: }
+// CHECK: }
+func.func @conv_2d_nhwc_fhwc(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+ outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+// CHECK: module {
+// CHECK: func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+// CHECK: %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>) outs(%arg2 : tensor<1x14x14x16xf32>) {
+// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK: %1 = arith.mulf %in, %in_0 : f32
+// CHECK: %2 = arith.addf %out, %1 : f32
+// CHECK: linalg.yield %2 : f32
+// CHECK: } -> tensor<1x14x14x16xf32>
+// CHECK: return %0 : tensor<1x14x14x16xf32>
+// CHECK: }
+// CHECK: }
+func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+ %0 = linalg.conv_2d_nhwc_hwcf
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
+ outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ return %0 : tensor<1x14x14x16xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+// CHECK: module {
+// CHECK: func.func @conv_2d_nhwgc_gfhwc(%arg0: memref<?x?x?x?x?xf32>, %arg1: memref<?x?x?x?x?xf32>, %arg2: memref<?x?x?x?x?xf32>) {
+// CHECK: linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>) outs(%arg2 : memref<?x?x?x?x?xf32>) {
+// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK: %0 = arith.mulf %in, %in_0 : f32
+// CHECK: %1 = arith.addf %out, %0 : f32
+// CHECK: linalg.yield %1 : f32
+// CHECK: }
+// CHECK: return
+// CHECK: }
+// CHECK: }
+func.func @conv_2d_nhwgc_gfhwc(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x?x?x?x?xf32>, %output: memref<?x?x?x?x?xf32>) {
+
+ linalg.conv_2d_nhwgc_gfhwc {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
+ outs (%output: memref<?x?x?x?x?xf32>)
+ return
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+// CHECK: module {
+// CHECK: func.func @conv(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
+// CHECK: linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>) outs(%arg2 : memref<?x?xf32>) {
+// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK: %0 = arith.mulf %in, %in_0 : f32
+// CHECK: %1 = arith.addf %out, %0 : f32
+// CHECK: linalg.yield %1 : f32
+// CHECK: }
+// CHECK: return
+// CHECK: }
+// CHECK: }
+func.func @conv(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>, %arg2 : memref<?x?xf32>) {
+ linalg.conv_2d ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>) outs(%arg2 : memref<?x?xf32>)
+ return
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d5, d6, d7, d8)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
+// CHECK: module {
+// CHECK: func.func @conv_3d_ncdhw_fcdhw(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tensor<?x?x?x?x?xf32>, %arg2: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+// CHECK: %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) outs(%arg2 : tensor<?x?x?x?x?xf32>) {
+// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK: %1 = arith.mulf %in, %in_0 : f32
+// CHECK: %2 = arith.addf %out, %1 : f32
+// CHECK: linalg.yield %2 : f32
+// CHECK: } -> tensor<?x?x?x?x?xf32>
+// CHECK: return %0 : tensor<?x?x?x?x?xf32>
+// CHECK: }
+// CHECK: }
+func.func @conv_3d_ncdhw_fcdhw(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+
+ %0 = linalg.conv_3d_ncdhw_fcdhw {dilations = dense<1> : tensor<3xi64>,
+ strides = dense<1> : tensor<3xi64>}
+ ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+ outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
+// CHECK: module {
+// CHECK: func.func @conv_3d_ndhwc_dhwcf(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tensor<?x?x?x?x?xf32>, %arg2: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+// CHECK: %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) outs(%arg2 : tensor<?x?x?x?x?xf32>) {
+// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK: %1 = arith.mulf %in, %in_0 : f32
+// CHECK: %2 = arith.addf %out, %1 : f32
+// CHECK: linalg.yield %2 : f32
+// CHECK: } -> tensor<?x?x?x?x?xf32>
+// CHECK: return %0 : tensor<?x?x?x?x?xf32>
+// CHECK: }
+// CHECK: }
+func.func @conv_3d_ndhwc_dhwcf(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+
+ %0 = linalg.conv_3d_ndhwc_dhwcf {dilations = dense<1> : tensor<3xi64>,
+ strides = dense<1> : tensor<3xi64>}
+ ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+ outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
+// CHECK: module {
+// CHECK: func.func @conv_3d(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+// CHECK: linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 : memref<?x?x?xf32>) {
+// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK: %0 = arith.mulf %in, %in_0 : f32
+// CHECK: %1 = arith.addf %out, %0 : f32
+// CHECK: linalg.yield %1 : f32
+// CHECK: }
+// CHECK: return
+// CHECK: }
+// CHECK: }
+func.func @conv_3d(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+ linalg.conv_3d ins (%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs (%arg2: memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2) -> ()>
+// CHECK: #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
+// CHECK: #map3 = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK: #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK: module {
+// CHECK: func.func @depthwise_conv_1d_ncw_cw(%arg0: tensor<1x8x12xf32>, %arg1: tensor<8x3xf32>) -> tensor<1x8x10xf32> {
+// CHECK: %cst = arith.constant 0.000000e+00 : f32
+// CHECK: %0 = tensor.empty() : tensor<1x8x10xf32>
+// CHECK: %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst : f32) outs(%0 : tensor<1x8x10xf32>) {
+// CHECK: ^bb0(%in: f32, %out: f32):
+// CHECK: linalg.yield %in : f32
+// CHECK: } -> tensor<1x8x10xf32>
+// CHECK: %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x8x12xf32>, tensor<8x3xf32>) outs(%1 : tensor<1x8x10xf32>) {
+// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK: %3 = arith.mulf %in, %in_0 : f32
+// CHECK: %4 = arith.addf %out, %3 : f32
+// CHECK: linalg.yield %4 : f32
+// CHECK: } -> tensor<1x8x10xf32>
+// CHECK: return %2 : tensor<1x8x10xf32>
+// CHECK: }
+// CHECK: }
+func.func @depthwise_conv_1d_ncw_cw(%input: tensor<1x8x12xf32>, %filter: tensor<8x3xf32>) -> tensor<1x8x10xf32> {
+ %zero = arith.constant 0.000000e+00 : f32
+ %init = tensor.empty() : tensor<1x8x10xf32>
+ %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<1x8x10xf32>) -> tensor<1x8x10xf32>
+
+ %0 = linalg.depthwise_conv_1d_ncw_cw {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+ ins(%input, %filter : tensor<1x8x12xf32>, tensor<8x3xf32>)
+ outs(%fill : tensor<1x8x10xf32>) -> tensor<1x8x10xf32>
+ return %0 : tensor<1x8x10xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2) -> ()>
+// CHECK: #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
+// CHECK: #map3 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK: #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK: module {
+// CHECK: func.func @depthwise_conv_1d_nwc_wc(%arg0: tensor<1x12x8xf32>, %arg1: tensor<3x8xf32>) -> tensor<1x10x8xf32> {
+// CHECK: %cst = arith.constant 0.000000e+00 : f32
+// CHECK: %0 = tensor.empty() : tensor<1x10x8xf32>
+// CHECK: %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst : f32) outs(%0 : tensor<1x10x8xf32>) {
+// CHECK: ^bb0(%in: f32, %out: f32):
+// CHECK: linalg.yield %in : f32
+// CHECK: } -> tensor<1x10x8xf32>
+// CHECK: %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x12x8xf32>, tensor<3x8xf32>) outs(%1 : tensor<1x10x8xf32>) {
+// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK: %3 = arith.mulf %in, %in_0 : f32
+// CHECK: %4 = arith.addf %out, %3 : f32
+// CHECK: linalg.yield %4 : f32
+// CHECK: } -> tensor<1x10x8xf32>
+// CHECK: return %2 : tensor<1x10x8xf32>
+// CHECK: }
+// CHECK: }
+func.func @depthwise_conv_1d_nwc_wc(%input: tensor<1x12x8xf32>, %filter: tensor<3x8xf32>) -> tensor<1x10x8xf32> {
+ %zero = arith.constant 0.000000e+00 : f32
+ %init = tensor.empty() : tensor<1x10x8xf32>
+ %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<1x10x8xf32>) -> tensor<1x10x8xf32>
+
+ %0 = linalg.depthwise_conv_1d_nwc_wc {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+ ins(%input, %filter : tensor<1x12x8xf32>, tensor<3x8xf32>)
+ outs(%fill : tensor<1x10x8xf32>) -> tensor<1x10x8xf32>
+ return %0 : tensor<1x10x8xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3) -> ()>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)>
+// CHECK: #map3 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)>
+// CHECK: #map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
+// CHECK: module {
+// CHECK: func.func @depthwise_conv_1d_nwc_wcm(%arg0: tensor<1x12x8xf32>, %arg1: tensor<3x8x8xf32>) -> tensor<1x10x8x8xf32> {
+// CHECK: %cst = arith.constant 0.000000e+00 : f32
+// CHECK: %0 = tensor.empty() : tensor<1x10x8x8xf32>
+// CHECK: %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cst : f32) outs(%0 : tensor<1x10x8x8xf32>) {
+// CHECK: ^bb0(%in: f32, %out: f32):
+// CHECK: linalg.yield %in : f32
+// CHECK: } -> tensor<1x10x8x8xf32>
+// CHECK: %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x12x8xf32>, tensor<3x8x8xf32>) outs(%1 : tensor<1x10x8x8xf32>) {
+// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK: %3 = arith.mulf %in, %in_0 : f32
+// CHECK: %4 = arith.addf %out, %3 : f32
+// CHECK: linalg.yield %4 : f32
+// CHECK: } -> tensor<1x10x8x8xf32>
+// CHECK: return %2 : tensor<1x10x8x8xf32>
+// CHECK: }
+// CHECK: }
+func.func @depthwise_conv_1d_nwc_wcm(%input: tensor<1x12x8xf32>, %filter: tensor<3x8x8xf32>) -> tensor<1x10x8x8xf32> {
+ %zero = arith.constant 0.000000e+00 : f32
+ %init = tensor.empty() : tensor<1x10x8x8xf32>
+ %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<1x10x8x8xf32>) -> tensor<1x10x8x8xf32>
+
+ %0 = linalg.depthwise_conv_1d_nwc_wcm {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+ ins(%input, %filter : tensor<1x12x8xf32>, tensor<3x8x8xf32>)
+ outs(%fill : tensor<1x10x8x8xf32>) -> tensor<1x10x8x8xf32>
+ return %0 : tensor<1x10x8x8xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d4, d5)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+// CHECK: module {
+// CHECK: func.func @depthwise_conv_2d_nchw_chw_tensor(%arg0: tensor<1x96x113x113xf32>, %arg1: tensor<96x3x3xf32>) -> tensor<1x96x56x56xf32> {
+// CHECK: %0 = tensor.empty() : tensor<1x96x56x56xf32>
+// CHECK: %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<1x96x113x113xf32>, tensor<96x3x3xf32>) outs(%0 : tensor<1x96x56x56xf32>) {
+// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK: %2 = arith.mulf %in, %in_0 : f32
+// CHECK: %3 = arith.addf %out, %2 : f32
+// CHECK: linalg.yield %3 : f32
+// CHECK: } -> tensor<1x96x56x56xf32>
+// CHECK: return %1 : tensor<1x96x56x56xf32>
+// CHECK: }
+// CHECK: }
+func.func @depthwise_conv_2d_nchw_chw_tensor(%input: tensor<1x96x113x113xf32>, %filter: tensor<96x3x3xf32>) -> tensor<1x96x56x56xf32> {
+ %init = tensor.empty() : tensor<1x96x56x56xf32>
+
+ %0 = linalg.depthwise_conv_2d_nchw_chw {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
+ ins(%input, %filter: tensor<1x96x113x113xf32>, tensor<96x3x3xf32>)
+ outs(%init: tensor<1x96x56x56xf32>) -> tensor<1x96x56x56xf32>
+ return %0: tensor<1x96x56x56xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3) -> ()>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+// CHECK: #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
+// CHECK: #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+// CHECK: module {
+// CHECK: func.func @convolution_depthwise(%arg0: tensor<1x10x196x48xf32>, %arg1: tensor<1x4x48xf32>) -> tensor<1x10x191x48xf32> {
+// CHECK: %cst = arith.constant 0.000000e+00 : f32
+// CHECK: %0 = tensor.empty() : tensor<1x10x191x48xf32>
+// CHECK: %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cst : f32) outs(%0 : tensor<1x10x191x48xf32>) {
+// CHECK: ^bb0(%in: f32, %out: f32):
+// CHECK: linalg.yield %in : f32
+// CHECK: } -> tensor<1x10x191x48xf32>
+// CHECK: %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<1x10x196x48xf32>, tensor<1x4x48xf32>) outs(%1 : tensor<1x10x191x48xf32>) {
+// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK: %3 = arith.mulf %in, %in_0 : f32
+// CHECK: %4 = arith.addf %out, %3 : f32
+// CHECK: linalg.yield %4 : f32
+// CHECK: } -> tensor<1x10x191x48xf32>
+// CHECK: return %2 : tensor<1x10x191x48xf32>
+// CHECK: }
+// CHECK: }
+func.func @convolution_depthwise(%input: tensor<1x10x196x48xf32>, %filter: tensor<1x4x48xf32>) -> tensor<1x10x191x48xf32> {
+ %cst = arith.constant 0.0 : f32
+ %empty = tensor.empty() : tensor<1x10x191x48xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x10x191x48xf32>) -> tensor<1x10x191x48xf32>
+
+ %result = linalg.depthwise_conv_2d_nhwc_hwc {
+ dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins(%input, %filter : tensor<1x10x196x48xf32>, tensor<1x4x48xf32>)
+ outs(%fill : tensor<1x10x191x48xf32>) -> tensor<1x10x191x48xf32>
+
+ return %result : tensor<1x10x191x48xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
+// CHECK: module {
+// CHECK: func.func @depthwise_conv_2d_nhwc_hwcm(%arg0: memref<2x4x5x2xf32>, %arg1: memref<2x2x2x3xf32>, %arg2: memref<2x3x4x2x3xf32>) {
+// CHECK: linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>) outs(%arg2 : memref<2x3x4x2x3xf32>) {
+// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK: %0 = arith.mulf %in, %in_0 : f32
+// CHECK: %1 = arith.addf %out, %0 : f32
+// CHECK: linalg.yield %1 : f32
+// CHECK: }
+// CHECK: return
+// CHECK: }
+// CHECK: }
+func.func @depthwise_conv_2d_nhwc_hwcm(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x3x4x2x3xf32>) {
+ linalg.depthwise_conv_2d_nhwc_hwcm
+ { dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%input, %filter : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
+ outs(%output : memref<2x3x4x2x3xf32>)
+ return
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4) -> ()>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2 * 2 + d5, d3 + d6, d4 * 3 + d7)>
+// CHECK: #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d5, d6, d7)>
+// CHECK: #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+// CHECK: module {
+// CHECK: func.func @depthwise_conv_3d_ncdhw_cdhw(%arg0: tensor<2x6x6x13x12xf32>, %arg1: tensor<6x2x1x3xf32>) -> tensor<2x6x3x13x4xf32> {
+// CHECK: %cst = arith.constant 0.000000e+00 : f32
+// CHECK: %0 = tensor.empty() : tensor<2x6x3x13x4xf32>
+// CHECK: %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%cst : f32) outs(%0 : tensor<2x6x3x13x4xf32>) {
+// CHECK: ^bb0(%in: f32, %out: f32):
+// CHECK: linalg.yield %in : f32
+// CHECK: } -> tensor<2x6x3x13x4xf32>
+// CHECK: %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<2x6x6x13x12xf32>, tensor<6x2x1x3xf32>) outs(%1 : tensor<2x6x3x13x4xf32>) {
+// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK: %3 = arith.mulf %in, %in_0 : f32
+// CHECK: %4 = arith.addf %out, %3 : f32
+// CHECK: linalg.yield %4 : f32
+// CHECK: } -> tensor<2x6x3x13x4xf32>
+// CHECK: return %2 : tensor<2x6x3x13x4xf32>
+// CHECK: }
+// CHECK: }
+func.func @depthwise_conv_3d_ncdhw_cdhw(%input: tensor<2x6x6x13x12xf32>, %filter: tensor<6x2x1x3xf32>) -> tensor<2x6x3x13x4xf32> {
+ %zero = arith.constant 0.000000e+00 : f32
+ %init = tensor.empty() : tensor<2x6x3x13x4xf32>
+ %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<2x6x3x13x4xf32>) -> tensor<2x6x3x13x4xf32>
+
+ %0 = linalg.depthwise_conv_3d_ncdhw_cdhw {dilations = dense<1> : tensor<3xi64>, strides = dense<[2, 1, 3]> : tensor<3xi64>}
+ ins(%input, %filter : tensor<2x6x6x13x12xf32>, tensor<6x2x1x3xf32>)
+ outs(%fill : tensor<2x6x3x13x4xf32>) -> tensor<2x6x3x13x4xf32>
+ return %0 : tensor<2x6x3x13x4xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4) -> ()>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 * 2 + d5, d2 + d6, d3 * 3 + d7, d4)>
+// CHECK: #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7, d4)>
+// CHECK: #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+// CHECK: module {
+// CHECK: func.func @depthwise_conv_3d_ndhwc_dhwc(%arg0: tensor<2x6x13x12x6xf32>, %arg1: tensor<2x1x3x6xf32>) -> tensor<2x3x13x4x6xf32> {
+// CHECK: %cst = arith.constant 0.000000e+00 : f32
+// CHECK: %0 = tensor.empty() : tensor<2x3x13x4x6xf32>
+// CHECK: %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%cst : f32) outs(%0 : tensor<2x3x13x4x6xf32>) {
+// CHECK: ^bb0(%in: f32, %out: f32):
+// CHECK: linalg.yield %in : f32
+// CHECK: } -> tensor<2x3x13x4x6xf32>
+// CHECK: %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<2x6x13x12x6xf32>, tensor<2x1x3x6xf32>) outs(%1 : tensor<2x3x13x4x6xf32>) {
+// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK: %3 = arith.mulf %in, %in_0 : f32
+// CHECK: %4 = arith.addf %out, %3 : f32
+// CHECK: linalg.yield %4 : f32
+// CHECK: } -> tensor<2x3x13x4x6xf32>
+// CHECK: return %2 : tensor<2x3x13x4x6xf32>
+// CHECK: }
+// CHECK: }
+func.func @depthwise_conv_3d_ndhwc_dhwc(%input: tensor<2x6x13x12x6xf32>, %filter: tensor<2x1x3x6xf32>) -> tensor<2x3x13x4x6xf32> {
+ %zero = arith.constant 0.000000e+00 : f32
+ %init = tensor.empty() : tensor<2x3x13x4x6xf32>
+ %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<2x3x13x4x6xf32>) -> tensor<2x3x13x4x6xf32>
+
+ %0 = linalg.depthwise_conv_3d_ndhwc_dhwc {dilations = dense<1> : tensor<3xi64>, strides = dense<[2, 1, 3]> : tensor<3xi64>}
+ ins(%input, %filter : tensor<2x6x13x12x6xf32>, tensor<2x1x3x6xf32>)
+ outs(%fill : tensor<2x3x13x4x6xf32>) -> tensor<2x3x13x4x6xf32>
+ return %0 : tensor<2x3x13x4x6xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 * 2 + d6, d2 + d7, d3 * 3 + d8, d4)>
+// CHECK: #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d6, d7, d8, d4, d5)>
+// CHECK: #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK: module {
+// CHECK: func.func @depthwise_conv_3d_ndhwc_dhwcm(%arg0: tensor<2x6x13x12x6xf32>, %arg1: tensor<2x1x3x6x6xf32>) -> tensor<2x3x13x4x6x6xf32> {
+// CHECK: %cst = arith.constant 0.000000e+00 : f32
+// CHECK: %0 = tensor.empty() : tensor<2x3x13x4x6x6xf32>
+// CHECK: %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%cst : f32) outs(%0 : tensor<2x3x13x4x6x6xf32>) {
+// CHECK: ^bb0(%in: f32, %out: f32):
+// CHECK: linalg.yield %in : f32
+// CHECK: } -> tensor<2x3x13x4x6x6xf32>
+// CHECK: %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<2x6x13x12x6xf32>, tensor<2x1x3x6x6xf32>) outs(%1 : tensor<2x3x13x4x6x6xf32>) {
+// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32):
+// CHECK: %3 = arith.mulf %in, %in_0 : f32
+// CHECK: %4 = arith.addf %out, %3 : f32
+// CHECK: linalg.yield %4 : f32
+// CHECK: } -> tensor<2x3x13x4x6x6xf32>
+// CHECK: return %2 : tensor<2x3x13x4x6x6xf32>
+// CHECK: }
+// CHECK: }
+func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor<2x6x13x12x6xf32>, %filter: tensor<2x1x3x6x6xf32>) -> tensor<2x3x13x4x6x6xf32> {
+ %zero = arith.constant 0.000000e+00 : f32
+ %init = tensor.empty() : tensor<2x3x13x4x6x6xf32>
+ %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<2x3x13x4x6x6xf32>) -> tensor<2x3x13x4x6x6xf32>
+
+ %0 = linalg.depthwise_conv_3d_ndhwc_dhwcm {dilations = dense<1> : tensor<3xi64>, strides = dense<[2, 1, 3]> : tensor<3xi64>}
+ ins(%input, %filter : tensor<2x6x13x12x6xf32>, tensor<2x1x3x6x6xf32>)
+ outs(%fill : tensor<2x3x13x4x6x6xf32>) -> tensor<2x3x13x4x6x6xf32>
+ return %0 : tensor<2x3x13x4x6x6xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 1b8969bd115595..5a8f1dbf84db05 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -694,3 +694,60 @@ func.func @conv2d_channel_first_q_promote(%img: tensor<100x3x224x224xi8>, %filt:
// CHECK-LABEL: func @conv2d_channel_first_q_promote(
// CHECK: %[[arg0:[a-zA-z0-9]*]]: tensor<100x3x224x224xi8>, %[[arg1:[a-zA-z0-9]*]]: tensor<64x3x5x5xi8>, %[[arg2:[a-zA-z0-9]*]]: i8, %[[arg3:[a-zA-z0-9]*]]: i8)
// CHECK: linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]] : tensor<100x3x224x224xi8>, tensor<64x3x5x5xi8>, i8, i8) outs(%{{.*}} : tensor<100x64x220x220xi32>) -> tensor<100x64x220x220xi32>
+
+// -----
+
+func.func @newconv_1d(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = linalg.conv {
+ filter_dims = #linalg<conv_dims {F, C, "0"}>,
+ input_dims = #linalg<conv_dims {N, C, "0"}>,
+ output_dims = #linalg<conv_dims {N, F, "0"}>
+ }
+ ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ outs(%arg2 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL: func @newconv_1d(
+// CHECK: %[[arg0:[a-zA-z0-9]*]]: tensor<?x?x?xf32>, %[[arg1:[a-zA-z0-9]*]]: tensor<?x?x?xf32>, %[[arg2:[a-zA-z0-9]*]]: tensor<?x?x?xf32>
+// CHECK: linalg.conv {filter_dims = #linalg<conv_dims {F, C, "0"}>, input_dims = #linalg<conv_dims {N, C, "0"}>, output_dims = #linalg<conv_dims {N, F, "0"}>} ins(%[[arg0]], %[[arg1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[arg2]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+
+// -----
+
+func.func @newconv_depthwise_2d(%input: tensor<8x4x16x16xf32>, %filter: tensor<4x3x3xf32>) -> tensor<8x4x14x14xf32> {
+ %init = tensor.empty() : tensor<8x4x14x14xf32>
+
+ %0 = linalg.conv {
+ input_dims = #linalg<conv_dims {N, G, "1", "0"}>,
+ filter_dims = #linalg<conv_dims {G, "1", "0"}>,
+ output_dims = #linalg<conv_dims {N, G, "1", "0"}>
+ }
+ ins(%input, %filter : tensor<8x4x16x16xf32>, tensor<4x3x3xf32>)
+ outs(%init : tensor<8x4x14x14xf32>) -> tensor<8x4x14x14xf32>
+
+ return %0: tensor<8x4x14x14xf32>
+}
+
+// CHECK-LABEL: func @newconv_depthwise_2d(
+// CHECK: %[[arg0:[a-zA-z0-9]*]]: tensor<8x4x16x16xf32>, %[[arg1:[a-zA-z0-9]*]]: tensor<4x3x3xf32>
+// CHECK: %[[arg2:[a-zA-z0-9]*]] = tensor.empty() : tensor<8x4x14x14xf32>
+// CHECK: linalg.conv {filter_dims = #linalg<conv_dims {G, "1", "0"}>, input_dims = #linalg<conv_dims {N, G, "1", "0"}>, output_dims = #linalg<conv_dims {N, G, "1", "0"}>} ins(%[[arg0]], %[[arg1]] : tensor<8x4x16x16xf32>, tensor<4x3x3xf32>) outs(%[[arg2]] : tensor<8x4x14x14xf32>) -> tensor<8x4x14x14xf32>
+
+// -----
+
+func.func @newconv_2d(%input: tensor<8x4x16x16xf32>, %filter: tensor<16x4x3x3xf32>) -> tensor<8x16x14x14xf32> {
+ %init = tensor.empty() : tensor<8x16x14x14xf32>
+ %0 = linalg.conv {
+ input_dims = #linalg<conv_dims {N, C, "1", "0"}>,
+ filter_dims = #linalg<conv_dims {F, C, "1", "0"}>,
+ output_dims = #linalg<conv_dims {N, F, "1", "0"}>
+ }
+ ins(%input, %filter : tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32>)
+ outs(%init : tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32>
+ return %0 : tensor<8x16x14x14xf32>
+}
+
+// CHECK-LABEL: func @newconv_2d(
+// CHECK: %[[arg0:[a-zA-z0-9]*]]: tensor<8x4x16x16xf32>, %[[arg1:[a-zA-z0-9]*]]: tensor<16x4x3x3xf32>
+// CHECK: %[[arg2:[a-zA-z0-9]*]] = tensor.empty() : tensor<8x16x14x14xf32>
+// CHECK: linalg.conv {filter_dims = #linalg<conv_dims {F, C, "1", "0"}>, input_dims = #linalg<conv_dims {N, C, "1", "0"}>, output_dims = #linalg<conv_dims {N, F, "1", "0"}>} ins(%[[arg0]], %[[arg1]] : tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32>) outs(%[[arg2]] : tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32>
diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index 283e426b4e5947..37d04821935700 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_library(MLIRLinalgTestPasses
TestLinalgRankReduceContractionOps.cpp
TestLinalgTransforms.cpp
TestPadFusion.cpp
+ TestNewConv.cpp
EXCLUDE_FROM_LIBMLIR
@@ -32,4 +33,4 @@ add_mlir_library(MLIRLinalgTestPasses
MLIRVectorDialect
MLIRVectorToSCF
MLIRVectorTransforms
- )
+)
diff --git a/mlir/test/lib/Dialect/Linalg/TestNewConv.cpp b/mlir/test/lib/Dialect/Linalg/TestNewConv.cpp
new file mode 100644
index 00000000000000..53564738171dbd
--- /dev/null
+++ b/mlir/test/lib/Dialect/Linalg/TestNewConv.cpp
@@ -0,0 +1,187 @@
+//===- TestNewConv.cpp - Test `linalg.conv` -------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a test pass which converts "old" convolution ops, e.g.
+// `linalg.depthwise_conv_2d_nhwc`, `linalg.conv_2d_nhwc`, etc., to the new
+// `linalg.conv` op.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+class OldToNewConv : public OpInterfaceRewritePattern<linalg::LinalgOp> {
+public:
+ using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::LinalgOp op,
+ PatternRewriter &rewriter) const override {
+ if (llvm::isa<linalg::ConvOp>(op))
+ return failure();
+ auto nameStr = op->getName().stripDialect().str();
+
+ bool isDepthwise = nameStr.substr(0, 14) == "depthwise_conv";
+ if (isDepthwise)
+ nameStr = nameStr.substr(15);
+ else if (nameStr.substr(0, 4) == "conv")
+ nameStr = nameStr.substr(5);
+ else
+ return failure();
+
+ int64_t spatialDims;
+ {
+ auto dimensionality = nameStr.substr(0, 2);
+ if (dimensionality == "1d")
+ spatialDims = 1;
+ else if (dimensionality == "2d")
+ spatialDims = 2;
+ else if (dimensionality == "3d")
+ spatialDims = 3;
+ else
+ return failure();
+ }
+
+ SmallVector<ConvDimEnum, 4> inputDims, filterDims, outputDims;
+ if (nameStr.length() == 2) {
+
+ // These are the ops `conv_1d`, `conv_2d` and `conv_3d` which use only
+ // spatial dimensions.
+ if (spatialDims == 1)
+ filterDims = inputDims = {ConvDimEnum::SPATIAL_0};
+ else if (spatialDims == 2)
+ filterDims =
+ inputDims = {ConvDimEnum::SPATIAL_0, ConvDimEnum::SPATIAL_1};
+ else if (spatialDims == 3)
+ filterDims =
+ inputDims = {ConvDimEnum::SPATIAL_0, ConvDimEnum::SPATIAL_1,
+ ConvDimEnum::SPATIAL_2};
+ else
+ return failure();
+
+ } else {
+ // This handles all the ops with specialized tensor dimension orders like
+ // `conv_2d_nhwc_fhwc`, `depthwise_conv_2d_nhwc_hwc`, etc.
+ auto specialization = nameStr.substr(3); // get rid of first _
+
+ // Separator between input and filter layout.
+ auto sep = specialization.find('_');
+ if (sep == StringRef::npos)
+ return failure();
+ auto inputDimStr = specialization.substr(0, sep);
+ auto filterDimStr = specialization.substr(sep + 1);
+
+ auto parseDim = [&](char c) -> ConvDimEnum {
+ switch (c) {
+ case 'n':
+ return ConvDimEnum::BATCH;
+ case 'h':
+ return ConvDimEnum::SPATIAL_1;
+ case 'w':
+ return ConvDimEnum::SPATIAL_0;
+ case 'd':
+ return ConvDimEnum::SPATIAL_2;
+ case 'f':
+ return ConvDimEnum::OUTPUT_CHANNEL;
+ case 'g':
+ return ConvDimEnum::GROUP;
+ case 'c':
+ // The old convolution ops use the letter 'c' to denote a
+ // non-reduction dimension in all tensors in the depthwise case. The
+ // new convolution captures this behavior in the group dimension.
+ return isDepthwise ? ConvDimEnum::GROUP : ConvDimEnum::INPUT_CHANNEL;
+ case 'm':
+ // Similarly, the old convolution ops use the letter 'm' to denote a
+ // parallel dimesion in filter and output in the depthwise case. This
+ // behavior is captured by the ordinary output channel dimension.
+ assert(isDepthwise && "Unexpected letter 'm' in non-depthwise conv");
+ return ConvDimEnum::OUTPUT_CHANNEL;
+ default:
+ llvm_unreachable("unknown dimensional character ");
+ }
+ };
+
+ inputDims = llvm::map_to_vector(inputDimStr, parseDim);
+ filterDims = llvm::map_to_vector(filterDimStr, parseDim);
+ }
+
+ // This is the behavior of the old convolution ops:
+ // The output dimension order is the same as the input dimension order, but
+ // output channel stands in for input channel...
+ for (auto d : inputDims)
+ if (d == ConvDimEnum::INPUT_CHANNEL)
+ outputDims.push_back(ConvDimEnum::OUTPUT_CHANNEL);
+ else
+ outputDims.push_back(d);
+ // ... and if the "depthwise channel multiplier" dimension 'm' appears, the
+ // output tensor has an additional dimension appended.
+ if (isDepthwise &&
+ llvm::is_contained(filterDims, ConvDimEnum::OUTPUT_CHANNEL))
+ outputDims.push_back(ConvDimEnum::OUTPUT_CHANNEL);
+
+ SmallVector<int64_t> strides(spatialDims, 1), dilations(spatialDims, 1);
+ // The old convolution ops order the strides and dilations in the order "D,
+ // H, W". We order them as spatial 0, spatial 1, spatial 2, so we have to
+ // reverse the order.
+ if (op->hasAttr("strides"))
+ strides = SmallVector<int64_t>(llvm::reverse(
+ SmallVector<int64_t>(op->getAttrOfType<DenseElementsAttr>("strides")
+ .getValues<int64_t>())));
+ if (op->hasAttr("dilations"))
+ dilations = SmallVector<int64_t>(llvm::reverse(
+ SmallVector<int64_t>(op->getAttrOfType<DenseElementsAttr>("dilations")
+ .getValues<int64_t>())));
+
+ rewriter.replaceOpWithNewOp<linalg::ConvOp>(
+ op, op->getResultTypes(), op->getOperand(0), op->getOperand(1),
+ op->getOperand(2), inputDims, filterDims, outputDims, strides,
+ dilations);
+
+ return success();
+ }
+};
+
+struct TestNewConvPass : public PassWrapper<TestNewConvPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestNewConvPass)
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<linalg::LinalgDialect, tensor::TensorDialect>();
+ }
+
+ StringRef getArgument() const final { return "test-linalg-new-conv"; }
+ StringRef getDescription() const final { return "Test new linalg.conv Op"; }
+
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
+ ConversionTarget target(getContext());
+
+ target.addLegalOp<linalg::ConvOp>();
+ // Every non-converted old conv op should fail the converison.
+ target.markUnknownOpDynamicallyLegal([](Operation *op) {
+ return op->getName().getStringRef().str().find("conv") ==
+ std::string::npos;
+ });
+
+ patterns.add<OldToNewConv>(context);
+
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestNewConv() { PassRegistration<TestNewConvPass>(); }
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 002c3900056dee..25a8430500b6cc 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -111,6 +111,7 @@ void registerTestLinalgElementwiseFusion();
void registerTestLinalgGreedyFusion();
void registerTestLinalgRankReduceContractionOps();
void registerTestLinalgTransforms();
+void registerTestNewConv();
void registerTestLivenessAnalysisPass();
void registerTestLivenessPass();
void registerTestLoopFusion();
@@ -248,6 +249,7 @@ void registerTestPasses() {
mlir::test::registerTestLinalgGreedyFusion();
mlir::test::registerTestLinalgRankReduceContractionOps();
mlir::test::registerTestLinalgTransforms();
+ mlir::test::registerTestNewConv();
mlir::test::registerTestLivenessAnalysisPass();
mlir::test::registerTestLivenessPass();
mlir::test::registerTestLoopFusion();
More information about the Mlir-commits
mailing list