[Mlir-commits] [mlir] [mlir][linalg] Implement `LinalgGroupedConvolutionOpInterface` to unify grouped convs (PR #94796)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 1 13:36:41 PDT 2024
https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/94796
>From b361e4d021dbde4d3c2c1fca6dd36030a7aa4376 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 7 Jun 2024 14:48:18 -0500
Subject: [PATCH 1/3] Implement `LinalgGroupedConvolutionOpInterface` to unify
grouped convs
---
.../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 37 +++++
.../Dialect/Linalg/IR/LinalgInterfaces.td | 95 ++++++++++++
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 141 ++++++++++++++++++
.../mlir/Dialect/Utils/StructuredOpsUtils.td | 12 ++
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 129 ++++++++++++++++
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 104 +++++++++++++
mlir/test/Dialect/Linalg/bufferize.mlir | 21 +++
mlir/test/Dialect/Linalg/loops.mlir | 8 +-
mlir/test/Dialect/Linalg/named-ops.mlir | 11 ++
mlir/test/Dialect/Linalg/tile-conv.mlir | 62 +++++++-
10 files changed, 614 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 08afdf373f014..ed5b4ff2de4dc 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -28,6 +28,8 @@ namespace mlir {
namespace linalg {
class IteratorTypeAttr;
class LinalgOp;
+class ConvolutionOpInterface;
+class GroupedConvolutionOpInterface;
class GenericOp;
namespace detail {
@@ -133,6 +135,38 @@ std::optional<Value> isaFillOpInterface(GenericOp genericOp);
namespace detail {
+// Common implementations for ConvolutionOpInterface
+namespace convolution_impl {
+// Returns strides as a vector.
+SmallVector<int64_t, 2> getStrides(ConvolutionOpInterface op);
+// Returns dilations as a vector.
+SmallVector<int64_t, 2> getDilations(ConvolutionOpInterface op);
+// Region builder for basic convolution
+void regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs);
+// Region builder for basic quantized convolution
+void quantizedRegionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs);
+void getEffects(
+ Operation *op,
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects);
+ParseResult parse(OpAsmParser &parser, OperationState &result,
+ bool isQuantized = false);
+void print(LinalgOp op, OpAsmPrinter &p);
+} // namespace convolution_impl
+
+// Common implementations for GroupedConvolutionOpInterface
+namespace grouped_convolution_impl {
+int64_t getSpatialRank(GroupedConvolutionOpInterface op);
+ArrayAttr createCommonIndexingMaps(
+ MLIRContext *ctx, int64_t numSpatial,
+ const SmallVector<SmallVector<utils::GroupedConvDim>> &layouts,
+ const SmallVectorImpl<int64_t> &strides,
+ const SmallVectorImpl<int64_t> &dilations);
+ArrayAttr getIteratorTypes(GroupedConvolutionOpInterface op);
+} // namespace grouped_convolution_impl
+
/// Returns true if the block contains a contraction of the following form:
///
/// %0 = <elemwise>(permutation-of(cu(block-argument-0),
@@ -189,6 +223,9 @@ LogicalResult verifyContractionInterface(Operation *op);
/// Verify that `op` conforms to the ConvolutionOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op);
+/// Verify that `op` conforms to the GroupedConvolutionOpInterface.
+LogicalResult verifyGroupedConvolutionInterface(Operation *op);
+
/// Verify that `op` conforms to the FillOpInterface.
LogicalResult verifyFillInterface(Operation *op);
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index fbf3f19cde0e9..5ae481a222e3c 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -175,6 +175,101 @@ def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> {
return $_op.getOperation()->getOperand(1);
}]
>,
+ InterfaceMethod<
+ /*desc=*/"Return the spatial rank.",
+ /*retTy=*/"int64_t",
+ /*methodName=*/"getSpatialRank",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ // Most convolution's inputs have batch, channel and spatial dims
+ return cast<ShapedType>(image().getType()).getRank() - 2;
+ }]
+ >
+ ];
+}
+
+def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInterface", [
+ LinalgConvolutionOpInterface]> {
+ let description = [{
+ A grouped convolution is defined in general terms:
+ 1. It is a convolution as defined by `ConvolutionOpInterface`.
+ 2. Operands have a the following distinct dimensions (excluding batch in input/output): group, channel, spatial
+ 3. `input_rank == kernel_rank == output_rank` (including batch in input/output)
+ 4. Reductions are along the input channel and spatial dimensions while group, output channel
+ and output spatial dimensions are parallel.
+ }];
+ let cppNamespace = "::mlir::linalg";
+ let verify = [{ return detail::verifyGroupedConvolutionInterface($_op); }];
+ let methods = [
+ InterfaceMethod<[{
+ Returns the groups position for the input.
+ }],
+ "SmallVector<SmallVector<::mlir::utils::GroupedConvDim>>", "getLayoutsEnums", (ins)
+ >,
+ InterfaceMethod<[{
+ Returns the groups position for the input.
+ }],
+ "int64_t", "getInputGroupsPosition", (ins)
+ >,
+ InterfaceMethod<[{
+ Returns the channel position for the input.
+ }],
+ "int64_t", "getInputChannelPosition", (ins)
+ >,
+ InterfaceMethod<[{
+ Returns the channel position for the output.
+ }],
+ "int64_t", "getOutputChannelPosition", (ins)
+ >,
+ InterfaceMethod<[{
+ Get number of groups.
+ }],
+ "int64_t", "getNumGroups", (ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getInputGroupsPosition() - 1];
+ }]>,
+ InterfaceMethod<[{
+ Get number of input channels.
+ }],
+ "int64_t", "getNumInputChannels", (ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getInputChannelPosition()];
+ }]>,
+ InterfaceMethod<[{
+ Get number of output channels.
+ }],
+ "int64_t", "getNumOutputChannels", (ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return cast<ShapedType>($_op.getDpsInits()[0].getType()).getShape()[$_op.getOutputChannelPosition()];
+ }]>,
+ InterfaceMethod<[{
+ Returns indexing maps for any spatial dimension.
+ }],
+ "::mlir::ArrayAttr", "getIteratorTypes", (ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return detail::grouped_convolution_impl::getIteratorTypes($_op);
+ }]>,
+ InterfaceMethod<[{
+ Returns strides.
+ }],
+ "::llvm::SmallVector<int64_t, 2>", "getStridesVector", (ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return detail::convolution_impl::getStrides($_op);
+ }]>,
+ InterfaceMethod<[{
+ Returns dilations.
+ }],
+ "::llvm::SmallVector<int64_t, 2>", "getDilationsVector", (ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return detail::convolution_impl::getDilations($_op);
+ }]>
];
}
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index ac61117c3d6e3..7db7c54a4ea09 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -384,6 +384,147 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// GroupedConvNDOp ops.
+//===----------------------------------------------------------------------===//
+
+def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
+ [AttrSizedOperandSegments, LinalgGroupedConvolutionOpInterface]> {
+
+ let summary = [{
+ Performs N-D grouped convolution with switchable channel position; either first or last.
+ }];
+ let description = [{
+ Allows any number of spatial dimensions but treats all of them as contiguous. Throughout, `S`,
+ will represent all spatial dimensions. Operand layouts are determined by the `layouts`
+ `StrArrayAttr` attritbute. Each element of the array is a string representing the layout of the
+ corresponding operand and should be be mappable to a `GroupedConvDim` enum, i.e. one of
+ n: (batch dim)
+ g: (group dim)
+ f: (feature or output channel dim)
+ s: (all spatial dims)
+ c: (input channel dim).
+
+ The domain will always be in the order `(N, G, F, S, C, KS)`.
+
+ }];
+
+ let arguments = (ins
+ Variadic<TensorOrMemref>:$inputs,
+ Variadic<TensorOrMemref>:$inits,
+ DefaultValuedAttr<StrArrayAttr, "{\"ngcs\", \"gfcs\", \"ngfs\"}">:$layouts,
+ OptionalAttr<I64ElementsAttr>:$strides,
+ OptionalAttr<I64ElementsAttr>:$dilations
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+ let regions = (region AnyRegion:$region);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<
+ (ins "Value":$input, "Value":$filter, "Value":$init,
+ CArg<"ArrayRef<int64_t>", "{}">:$strides, CArg<"ArrayRef<int64_t>", "{}">:$dilations,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ int64_t numSpatialDims = cast<ShapedType>(input.getType()).getRank() - 3;
+ if (strides.empty())
+ strides = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
+ if (dilations.empty())
+ dilations = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
+ $_state.addAttribute(getStridesAttrName($_state.name),
+ ::mlir::DenseElementsAttr::get(
+ ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), strides));
+ $_state.addAttribute(getDilationsAttrName($_state.name),
+ ::mlir::DenseElementsAttr::get(
+ ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), dilations));
+ buildStructuredOp($_builder, $_state, std::nullopt, {input, filter}, init,
+ attributes, GroupedConvNDOp::getRegionBuilder());
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "Value":$input, "Value":$filter,
+ "Value":$init,
+ CArg<"ArrayRef<int64_t>", "{}">:$strides, CArg<"ArrayRef<int64_t>", "{}">:$dilations,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ int64_t numSpatialDims = cast<ShapedType>(input.getType()).getRank() - 3;
+ if (strides.empty())
+ strides = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
+ if (dilations.empty())
+ dilations = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
+ $_state.addAttribute(getStridesAttrName($_state.name),
+ ::mlir::DenseElementsAttr::get(
+ ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), strides));
+ $_state.addAttribute(getDilationsAttrName($_state.name),
+ ::mlir::DenseElementsAttr::get(
+ ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), dilations));
+ buildStructuredOp($_builder, $_state, resultTensorTypes,
+ {input, filter}, init, attributes, GroupedConvNDOp::getRegionBuilder());
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "Value":$input, "Value":$filter,
+ "Value":$init, "Attribute":$strides, "Attribute":$dilations,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute(getStridesAttrName($_state.name), strides);
+ $_state.addAttribute(getDilationsAttrName($_state.name), dilations);
+ buildStructuredOp($_builder, $_state, resultTensorTypes, {input, filter}, init,
+ attributes, GroupedConvNDOp::getRegionBuilder());
+ }]>
+ ];
+
+ // TODO: Figure out how to move this to the interface
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+ void print(::mlir::OpAsmPrinter &printer) {
+ return detail::convolution_impl::print(*this, printer);
+ }
+ static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser,
+ ::mlir::OperationState &result) {
+ return detail::convolution_impl::parse(parser, result);
+ }
+ static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
+ mlir::ArrayRef<mlir::NamedAttribute>)>
+ getRegionBuilder() {
+ return detail::convolution_impl::regionBuilder;
+ }
+ // Implement functions necessary for DestinationStyleOpInterface.
+ MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
+
+ // Implement functions necessary for LinalgOp.
+ ArrayAttr getIndexingMaps();
+
+ // Implement functions necessary for GroupedConvolutionOpInterface
+ int64_t getSpatialRank() {
+ return detail::grouped_convolution_impl::getSpatialRank(*this);
+ }
+
+ SmallVector<SmallVector<::mlir::utils::GroupedConvDim>> getLayoutsEnums() {
+ SmallVector<SmallVector<::mlir::utils::GroupedConvDim>> layouts;
+ for (auto attr : (*this).getLayoutsAttr().getValue()) {
+ std::string layoutStr = cast<StringAttr>(attr).getValue().str();
+ SmallVector<::mlir::utils::GroupedConvDim> layout(layoutStr.size());
+ for (size_t i = 0; i < layoutStr.size(); i++) {
+ auto maybeDimEnum = ::mlir::utils::symbolizeGroupedConvDim(layoutStr.substr(i, 1).c_str());
+ assert(maybeDimEnum);
+ layout[i] = maybeDimEnum.value();
+ }
+ layouts.push_back(layout);
+ }
+ return layouts;
+ }
+
+ int64_t getOutputChannelPosition() {
+ return 2;
+ }
+
+ int64_t getInputChannelPosition() {
+ return 2;
+ }
+
+ int64_t getInputGroupsPosition() {
+ return 1;
+ }
+ }];
+}
//===----------------------------------------------------------------------===//
// Transpose op.
diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td
index 4200343ce3e13..c7c5d617f6492 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td
@@ -20,4 +20,16 @@ def IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
let cppNamespace = "::mlir::utils";
}
+def GroupedConvDim : I32EnumAttr<"GroupedConvDim", "Convolution dim",
+ [
+ I32EnumAttrCase<"n", 0>, // batch
+ I32EnumAttrCase<"g", 1>, // group
+ I32EnumAttrCase<"f", 2>, // feature (output channel)
+ I32EnumAttrCase<"s", 3>, // spatial
+ I32EnumAttrCase<"c", 4> // channel (input channel)
+ ]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::utils";
+}
+
#endif // STRUCTURED_OPS_UTILS
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index f35ab3b856b4e..c2db6670e4167 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -766,6 +766,135 @@ enum class MatchConvolutionResult {
};
} // namespace mlir::linalg::detail
+SmallVector<int64_t, 2>
+mlir::linalg::detail::convolution_impl::getStrides(ConvolutionOpInterface op) {
+ auto maybeStridesAttr = op->getAttrOfType<DenseIntElementsAttr>("strides");
+ if (!maybeStridesAttr) {
+ OpBuilder builder(op.getContext());
+ return SmallVector<int64_t, 2>(op.getSpatialRank(), 1);
+ }
+ return llvm::to_vector(maybeStridesAttr.getValues<int64_t>());
+}
+
+SmallVector<int64_t, 2> mlir::linalg::detail::convolution_impl::getDilations(
+ ConvolutionOpInterface op) {
+ auto maybeDilationsAttr =
+ op->getAttrOfType<DenseIntElementsAttr>("dilations");
+ if (!maybeDilationsAttr) {
+ OpBuilder builder(op.getContext());
+ return SmallVector<int64_t, 2>(op.getSpatialRank(), 1);
+ }
+ return llvm::to_vector(maybeDilationsAttr.getValues<int64_t>());
+}
+
+int64_t mlir::linalg::detail::grouped_convolution_impl::getSpatialRank(
+ GroupedConvolutionOpInterface op) {
+ return cast<ShapedType>(op.image().getType()).getRank() - 3;
+}
+
+ArrayAttr mlir::linalg::detail::grouped_convolution_impl::getIteratorTypes(
+ GroupedConvolutionOpInterface op) {
+ int64_t numSpatialDims = op.getSpatialRank();
+ SmallVector<Attribute> iteratorTypes(
+ 3 + numSpatialDims, IteratorTypeAttr::get(op.getContext(), par));
+ SmallVector<Attribute> reductions(
+ numSpatialDims + 1, IteratorTypeAttr::get(op.getContext(), red));
+ iteratorTypes.insert(iteratorTypes.end(), reductions.begin(),
+ reductions.end());
+
+ return Builder(op.getContext()).getArrayAttr(iteratorTypes);
+}
+
+ArrayAttr
+mlir::linalg::detail::grouped_convolution_impl::createCommonIndexingMaps(
+ MLIRContext *ctx, int64_t numSpatial,
+ const SmallVector<SmallVector<utils::GroupedConvDim>> &layouts,
+ const SmallVectorImpl<int64_t> &strides,
+ const SmallVectorImpl<int64_t> &dilations) {
+ assert(layouts.size() == 3 && "expected 3 layouts: image, filter, init");
+
+ // Domain: (n, g, f, os, c, ks)
+ AffineExpr n = getAffineDimExpr(0, ctx);
+ AffineExpr g = getAffineDimExpr(1, ctx);
+ AffineExpr f = getAffineDimExpr(2, ctx);
+ SmallVector<AffineExpr> s(
+ llvm::map_range(llvm::seq<int64_t>(3, numSpatial + 3),
+ [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
+ AffineExpr c = getAffineDimExpr(numSpatial + 3, ctx);
+ SmallVector<AffineExpr> ks(llvm::map_range(
+ llvm::seq<int64_t>(numSpatial + 4, 2 * (numSpatial + 1) + 2),
+ [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
+
+ SmallVector<AffineExpr> inSpatials;
+ inSpatials.reserve(numSpatial);
+ for (const auto &[sp, ksp, st, di] : llvm::zip(s, ks, strides, dilations)) {
+ inSpatials.push_back(sp * st + ksp * di);
+ }
+
+ auto getExprs = [&](const SmallVector<utils::GroupedConvDim> &layout,
+ const SmallVector<AffineExpr> &spatials) {
+ SmallVector<AffineExpr> exprs(layout.size());
+ int64_t spatialDim;
+ for (const auto &[i, dim] : llvm::enumerate(layout)) {
+ switch (dim) {
+ case utils::GroupedConvDim::n:
+ exprs[i] = n;
+ break;
+ case utils::GroupedConvDim::g:
+ exprs[i] = g;
+ break;
+ case utils::GroupedConvDim::f:
+ exprs[i] = f;
+ break;
+ case utils::GroupedConvDim::s:
+ exprs[i] = spatials[0];
+ spatialDim = i;
+ break;
+ case utils::GroupedConvDim::c:
+ exprs[i] = c;
+ break;
+ default:
+ assert(false);
+ }
+ }
+ if (spatials.size() > 1)
+ exprs.insert(exprs.begin() + spatialDim + 1, spatials.begin() + 1,
+ spatials.end());
+ return exprs;
+ };
+ SmallVector<AffineExpr> inExprs = getExprs(layouts[0], inSpatials);
+ SmallVector<AffineExpr> kExprs = getExprs(layouts[1], ks);
+ SmallVector<AffineExpr> outExprs = getExprs(layouts[2], s);
+ SmallVector<AffineMap> maps(
+ {AffineMap::get(4 + 2 * numSpatial, 0, getExprs(layouts[0], inSpatials),
+ ctx),
+ AffineMap::get(4 + 2 * numSpatial, 0, getExprs(layouts[1], ks), ctx),
+ AffineMap::get(4 + 2 * numSpatial, 0, getExprs(layouts[2], s), ctx)});
+
+ return Builder(ctx).getAffineMapArrayAttr(maps);
+}
+
+LogicalResult
+mlir::linalg::detail::verifyGroupedConvolutionInterface(Operation *op) {
+ if (failed(verifyConvolutionInterface(op)))
+ return failure();
+ if (GroupedConvolutionOpInterface conv =
+ dyn_cast<GroupedConvolutionOpInterface>(op)) {
+ const auto imageType = conv.image().getType().dyn_cast<ShapedType>();
+ const auto imageRank = imageType.getRank();
+ const auto kernelRank =
+ conv.filter().getType().cast<ShapedType>().getRank();
+ const auto initType =
+ cast<LinalgOp>(op).getDpsInits()[0].getType().dyn_cast<ShapedType>();
+ const auto initRank = initType.getRank();
+ if (imageRank != kernelRank || imageRank != initRank)
+ return op->emitError(
+ "Rank relationship must be `in_rank == out_rank == kernel_rank`");
+ return success();
+ }
+ return failure();
+}
+
mlir::linalg::detail::MatchConvolutionResult
mlir::linalg::detail::isConvolutionInterfaceImpl(
Operation *op, ConvolutionDimensions *dimensions) {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b79afebfa8158..21cc22f034aa6 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1735,6 +1735,110 @@ LogicalResult ReduceOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// ConvolutionOpInterface
+//===----------------------------------------------------------------------===//
+
+// There must be a way to avoid defining the following 3 functions
+ParseResult mlir::linalg::detail::convolution_impl::parse(
+ OpAsmParser &parser, OperationState &result, bool isQuantized) {
+ if (isQuantized)
+ return parseNamedStructuredOp(
+ parser, result, 5,
+ mlir::linalg::detail::convolution_impl::quantizedRegionBuilder);
+ return parseNamedStructuredOp(
+ parser, result, 3, mlir::linalg::detail::convolution_impl::regionBuilder);
+}
+
+void mlir::linalg::detail::convolution_impl::print(LinalgOp op,
+ OpAsmPrinter &p) {
+ printNamedStructuredOp(p, op.getOperation(), op.getDpsInputs(),
+ op.getDpsInits());
+}
+
+// Build {mul, add} region for convolution
+void mlir::linalg::detail::convolution_impl::regionBuilder(
+ ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs) {
+ assert(block.getNumArguments() == 3 &&
+ "ConvolutionInterface regionBuilder expects 3 (>=0) args");
+ RegionBuilderHelper helper(b, block);
+ SmallVector<Value> yields;
+
+ Value value1 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+ block.getArgument(0));
+ Value value2 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+ block.getArgument(1));
+ Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+ Value value4 =
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+ yields.push_back(value4);
+ helper.yieldOutputs(yields);
+}
+
+void mlir::linalg::detail::convolution_impl::quantizedRegionBuilder(
+ ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs) {
+ assert(block.getNumArguments() == 5 &&
+ "ConvolutionInterface regionBuilder expects 5 args");
+ RegionBuilderHelper helper(b, block);
+ Value value1 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(),
+ block.getArgument(0));
+ Value value2 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(),
+ block.getArgument(2));
+ Value value3 = helper.buildBinaryFn(BinaryFn::sub, value1, value2);
+ Value value4 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(),
+ block.getArgument(1));
+ Value value5 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(),
+ block.getArgument(3));
+ Value value6 = helper.buildBinaryFn(BinaryFn::sub, value4, value5);
+ Value value7 = helper.buildBinaryFn(BinaryFn::mul, value3, value6);
+ Value value8 =
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(4), value7);
+ helper.yieldOutputs({value8});
+}
+
+void mlir::linalg::detail::convolution_impl::getEffects(
+ Operation *op,
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ if (!isa<ConvolutionOpInterface>(op))
+ return;
+ if (LinalgOp linalgOp = dyn_cast<LinalgOp>(op)) {
+ if (linalgOp.hasPureTensorSemantics())
+ return;
+ getGenericEffectsImpl(effects, linalgOp);
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// GroupedConvNDOp
+//===----------------------------------------------------------------------===//
+
+void GroupedConvNDOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ return detail::convolution_impl::getEffects(*this, effects);
+}
+
+ArrayAttr GroupedConvNDOp::getIndexingMaps() {
+ ArrayAttr cached = (*this)->getAttrOfType<ArrayAttr>(
+ LinalgDialect::kMemoizedIndexingMapsAttrName);
+ if (cached)
+ return cached;
+
+ cached = detail::grouped_convolution_impl::createCommonIndexingMaps(
+ getContext(), getSpatialRank(), getLayoutsEnums(), getStridesVector(),
+ getDilationsVector());
+
+ (*this)->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
+ return cached;
+}
+
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index e8ab1184b1fd2..56fc39f5fc073 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -189,3 +189,24 @@ func.func @bufferize_dot(%in: tensor<4xf32>, %out: tensor<f32>) -> tensor<f32> {
// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] : memref<f32>
// CHECK: return %[[OUT_TENSOR]]
}
+
+// -----
+
+// CHECK-LABEL: func @gen_grouped_3D_channel_first_tensor(
+// CHECK-SAME: %[[ARG0_TENSOR:.*]]: tensor<64x2x16x26x26x26xf32>,
+// CHECK-SAME: %[[ARG1_TENSOR:.*]]: tensor<2x20x16x3x3x3xf32>,
+// CHECK-SAME: %[[ARG2_TENSOR:.*]]: tensor<64x2x20x8x8x8xf32>) -> tensor<64x2x20x8x8x8xf32> {
+// CHECK-DAG: %[[ARG0_MEMREF:.*]] = bufferization.to_memref %[[ARG0_TENSOR]] : memref<64x2x16x26x26x26xf32>
+// CHECK-DAG: %[[ARG1_MEMREF:.*]] = bufferization.to_memref %[[ARG1_TENSOR]] : memref<2x20x16x3x3x3xf32>
+// CHECK-DAG: %[[ARG2_MEMREF:.*]] = bufferization.to_memref %[[ARG2_TENSOR]] : memref<64x2x20x8x8x8xf32>
+// CHECK-DAG: %[[INIT_BUFFER:.*]] = memref.alloc() {{.*}} : memref<64x2x20x8x8x8xf32>
+// CHECK: memref.copy %[[ARG2_MEMREF]], %[[INIT_BUFFER]] : memref<64x2x20x8x8x8xf32> to memref<64x2x20x8x8x8xf32>
+// CHECK: linalg.grouped_conv_nd
+// CHECK-SAME: dilations = dense<2> : tensor<3xi64>
+// CHECK-SAME: strides = dense<3> : tensor<3xi64>}
+// CHECK-SAME: ins(%[[ARG0_MEMREF]], %[[ARG1_MEMREF]] : memref<64x2x16x26x26x26xf32>, memref<2x20x16x3x3x3xf32>)
+// CHECK-SAME: outs(%[[INIT_BUFFER]] : memref<64x2x20x8x8x8xf32>)
+func.func @gen_grouped_3D_channel_first_tensor(%arg0: tensor<64x2x16x26x26x26xf32>, %arg1: tensor<2x20x16x3x3x3xf32>, %arg2: tensor<64x2x20x8x8x8xf32>) -> tensor<64x2x20x8x8x8xf32> {
+ %0 = linalg.grouped_conv_nd {strides = dense<3> : tensor<3xi64>, dilations = dense<2> : tensor<3xi64>} ins(%arg0, %arg1: tensor<64x2x16x26x26x26xf32>, tensor<2x20x16x3x3x3xf32>) outs(%arg2: tensor<64x2x20x8x8x8xf32>) -> tensor<64x2x20x8x8x8xf32>
+ return %0 : tensor<64x2x20x8x8x8xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index b818170a8e797..df89029e27d86 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -1,12 +1,10 @@
-// RUN: mlir-opt %s -convert-linalg-to-loops | FileCheck %s
-// RUN: mlir-opt %s -convert-linalg-to-parallel-loops | FileCheck --check-prefix=CHECKPARALLEL %s
+// RUN: mlir-opt %s -convert-linalg-to-loops | FileCheck %s --check-prefixes=COMMON,CHECK
+// RUN: mlir-opt %s -convert-linalg-to-parallel-loops | FileCheck --check-prefixes=COMMON,CHECKPARALLEL %s
// Test that we can lower all the way to LLVM without crashing, don't check results here.
// RUN: mlir-opt %s -convert-linalg-to-loops -test-lower-to-llvm -o=/dev/null 2>&1
-// CHECK: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
-
-// CHECKPARALLEL: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// COMMON: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
func.func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
%c0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 02ecbed232c8b..a231569672209 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1,5 +1,16 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+// -----
+
+// CHECK-LABEL: func @gen_grouped_1D_channel_first_memref
+func.func @gen_grouped_1D_channel_first_memref(%arg0: memref<64x8x16x10xf32>, %arg1: memref<8x32x16x3xf32>, %arg2: memref<64x8x32x8xf32>) {
+ // CHECK: grouped_conv_nd
+ linalg.grouped_conv_nd ins(%arg0, %arg1: memref<64x8x16x10xf32>, memref<8x32x16x3xf32>) outs(%arg2: memref<64x8x32x8xf32>)
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @depthwise_conv_1d_nwc_wcm
func.func @depthwise_conv_1d_nwc_wcm(%input: tensor<1x12x8xf32>, %filter: tensor<3x8x8xf32>) -> tensor<1x10x8x8xf32> {
%zero = arith.constant 0.000000e+00 : f32
diff --git a/mlir/test/Dialect/Linalg/tile-conv.mlir b/mlir/test/Dialect/Linalg/tile-conv.mlir
index f674996e42f33..6d19665931662 100644
--- a/mlir/test/Dialect/Linalg/tile-conv.mlir
+++ b/mlir/test/Dialect/Linalg/tile-conv.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -transform-interpreter -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -transform-interpreter -canonicalize -split-input-file | FileCheck %s
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 3)>
@@ -41,3 +41,63 @@ module attributes {transform.with_named_sequence} {
// CHECK: linalg.conv_2d
// CHECK-SAME: ins(%[[SVIN]], %[[SVKER]]
// CHECK-SAME: outs(%[[SVOUT]]
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 3)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)>
+// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
+// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 6)>
+// CHECK-DAG: #[[MAP5:.*]] = affine_map<(d0)[s0] -> (d0 + s0 - 1)>
+
+func.func @grouped_conv_2D(%arg0 : memref<?x?x?x?x?xf32>, %arg1 : memref<?x?x?x?x?xf32>, %arg2 : memref<?x?x?x?x?xf32>) {
+ linalg.grouped_conv_nd {layouts = ["ngcs", "gfcs", "ngfs"]} ins(%arg0, %arg1 : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>) outs(%arg2 : memref<?x?x?x?x?xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.grouped_conv_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1, %loop:5 = transform.structured.tile_using_for %0 tile_sizes [2, 3, 4, 5, 6] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: func @grouped_conv_2D
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?x?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?x?x?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?x?x?x?xf32>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
+// CHECK-DAG: %[[BATCH:.*]] = memref.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[GROUPS:.*]] = memref.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[IN_CHANNELS:.*]] = memref.dim %[[ARG0]], %[[C2]]
+// CHECK-DAG: %[[OUT_CHANNELS:.*]] = memref.dim %[[ARG1]], %[[C1]]
+// CHECK-DAG: %[[KW:.*]] = memref.dim %[[ARG1]], %[[C3]]
+// CHECK-DAG: %[[KH:.*]] = memref.dim %[[ARG1]], %[[C4]]
+// CHECK-DAG: %[[W:.*]] = memref.dim %[[ARG2]], %[[C3]]
+// CHECK-DAG: %[[H:.*]] = memref.dim %[[ARG2]], %[[C4]]
+// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[BATCH]] step %[[C2]]
+// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[GROUPS]] step %[[C3]]
+// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[OUT_CHANNELS]] step %[[C4]]
+// CHECK: scf.for %[[L:.*]] = %[[C0]] to %[[W]] step %[[C5]]
+// CHECK: scf.for %[[M:.*]] = %[[C0]] to %[[H]] step %[[C6]]
+// CHECK: %[[T4:.*]] = affine.min #[[MAP0]](%[[I]])[%[[BATCH]]]
+// CHECK: %[[T5:.*]] = affine.min #[[MAP1]](%[[J]])[%[[GROUPS]]]
+// CHECK-DAG: %[[T6:.*]] = affine.min #[[MAP2]](%[[K]])[%[[OUT_CHANNELS]]]
+// CHECK-DAG: %[[T7:.*]] = affine.min #[[MAP3]](%[[L]])[%[[W]]]
+// CHECK-DAG: %[[T8:.*]] = affine.min #[[MAP4]](%[[M]])[%[[H]]]
+// CHECK-DAG: %[[T9:.*]] = affine.apply #[[MAP5]](%[[T7]])[%[[KW]]]
+// CHECK-DAG: %[[T10:.*]] = affine.apply #[[MAP5]](%[[T8]])[%[[KH]]]
+// CHECK-DAG: %[[SVIN:.*]] = memref.subview %[[ARG0]][%[[I]], %[[J]], 0, %[[L]], %[[M]]] [%[[T4]], %[[T5]], %[[IN_CHANNELS]], %[[T9]], %[[T10]]]
+// CHECK-DAG: %[[SVKER:.*]] = memref.subview %[[ARG1]][%[[J]], %[[K]], 0, 0, 0] [%[[T5]], %[[T6]], %[[IN_CHANNELS]], %[[KW]], %[[KH]]]
+// CHECK-DAG: %[[SVOUT:.*]] = memref.subview %[[ARG2]][%[[I]], %[[J]], %[[K]], %[[L]], %[[M]]] [%[[T4]], %[[T5]], %[[T6]], %[[T7]], %[[T8]]]
+// CHECK: linalg.grouped_conv_nd {layouts = ["ngcs", "gfcs", "ngfs"]}
+// CHECK-SAME: ins(%[[SVIN]], %[[SVKER]]
+// CHECK-SAME: outs(%[[SVOUT]]
>From f0678a4fdf0da49c5742c65cf63614f3e41b7e35 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Mon, 1 Jul 2024 13:25:05 -0500
Subject: [PATCH 2/3] small refactor
---
.../Dialect/Linalg/IR/LinalgInterfaces.td | 81 ++++++++-----------
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 34 ++++++--
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 6 +-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 14 ----
4 files changed, 64 insertions(+), 71 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 5ae481a222e3c..e8b61c2cebf3a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -203,9 +203,10 @@ def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInter
let verify = [{ return detail::verifyGroupedConvolutionInterface($_op); }];
let methods = [
InterfaceMethod<[{
- Returns the groups position for the input.
+ Returns the layouts of each operand (image, kernel, init). Each layout is represented
+ by a vector of `GroupedConvDim`s.
}],
- "SmallVector<SmallVector<::mlir::utils::GroupedConvDim>>", "getLayoutsEnums", (ins)
+ "SmallVector<SmallVector<::mlir::utils::GroupedConvDim>>", "getOperandConvDims", (ins)
>,
InterfaceMethod<[{
Returns the groups position for the input.
@@ -222,55 +223,39 @@ def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInter
}],
"int64_t", "getOutputChannelPosition", (ins)
>,
- InterfaceMethod<[{
- Get number of groups.
- }],
- "int64_t", "getNumGroups", (ins),
- /*methodBody=*/[{}],
- /*defaultImplementation=*/[{
- return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getInputGroupsPosition() - 1];
- }]>,
- InterfaceMethod<[{
- Get number of input channels.
- }],
- "int64_t", "getNumInputChannels", (ins),
- /*methodBody=*/[{}],
- /*defaultImplementation=*/[{
- return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getInputChannelPosition()];
- }]>,
- InterfaceMethod<[{
- Get number of output channels.
- }],
- "int64_t", "getNumOutputChannels", (ins),
- /*methodBody=*/[{}],
- /*defaultImplementation=*/[{
- return cast<ShapedType>($_op.getDpsInits()[0].getType()).getShape()[$_op.getOutputChannelPosition()];
- }]>,
- InterfaceMethod<[{
- Returns indexing maps for any spatial dimension.
- }],
- "::mlir::ArrayAttr", "getIteratorTypes", (ins),
- /*methodBody=*/[{}],
- /*defaultImplementation=*/[{
+ ];
+
+ let extraSharedClassDeclaration = [{
+ // Get number of groups.
+ int64_t getNumGroups() {
+ return cast<ShapedType>(
+ cast<::mlir::linalg::ConvolutionOpInterface>(
+ $_op.getOperation()).image().getType())
+ .getShape()[$_op.getInputGroupsPosition()];
+ }
+ // Get number of input channels.
+ int64_t getNumInputChannels() {
+ return cast<ShapedType>(
+ cast<::mlir::linalg::ConvolutionOpInterface>(
+ $_op.getOperation()).image().getType()).getShape()[$_op.getInputChannelPosition()];
+ }
+ // Get number of output channels.
+ int64_t getNumOutputChannels() {
+ return cast<ShapedType>($_op->getOperand(2).getType()).getShape()[$_op.getOutputChannelPosition()];
+ }
+ // Returns iterator tyes.
+ ::mlir::ArrayAttr getIteratorTypes() {
return detail::grouped_convolution_impl::getIteratorTypes($_op);
- }]>,
- InterfaceMethod<[{
- Returns strides.
- }],
- "::llvm::SmallVector<int64_t, 2>", "getStridesVector", (ins),
- /*methodBody=*/[{}],
- /*defaultImplementation=*/[{
+ }
+ // Returns strides.
+ ::llvm::SmallVector<int64_t, 2> getStridesVector() {
return detail::convolution_impl::getStrides($_op);
- }]>,
- InterfaceMethod<[{
- Returns dilations.
- }],
- "::llvm::SmallVector<int64_t, 2>", "getDilationsVector", (ins),
- /*methodBody=*/[{}],
- /*defaultImplementation=*/[{
+ }
+ // Returns dilations.
+ ::llvm::SmallVector<int64_t, 2> getDilationsVector() {
return detail::convolution_impl::getDilations($_op);
- }]>
- ];
+ }
+ }];
}
def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 7db7c54a4ea09..ba67fd7e46a65 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -392,7 +392,7 @@ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
[AttrSizedOperandSegments, LinalgGroupedConvolutionOpInterface]> {
let summary = [{
- Performs N-D grouped convolution with switchable channel position; either first or last.
+ Performs N-D grouped convolution with parametrizable operand layouts.
}];
let description = [{
Allows any number of spatial dimensions but treats all of them as contiguous. Throughout, `S`,
@@ -490,14 +490,27 @@ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
// Implement functions necessary for LinalgOp.
- ArrayAttr getIndexingMaps();
+ ::mlir::ArrayAttr getIndexingMaps() {
+ ::mlir::ArrayAttr cached = (*this)->getAttrOfType<::mlir::ArrayAttr>(
+ LinalgDialect::kMemoizedIndexingMapsAttrName);
+ if (cached)
+ return cached;
+
+ cached = detail::grouped_convolution_impl::createCommonIndexingMaps(
+ getContext(), getSpatialRank(), getOperandConvDims(), getStridesVector(),
+ getDilationsVector());
+
+ (*this)->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
+ return cached;
+ }
+
// Implement functions necessary for GroupedConvolutionOpInterface
int64_t getSpatialRank() {
return detail::grouped_convolution_impl::getSpatialRank(*this);
}
- SmallVector<SmallVector<::mlir::utils::GroupedConvDim>> getLayoutsEnums() {
+ SmallVector<SmallVector<::mlir::utils::GroupedConvDim>> getOperandConvDims() {
SmallVector<SmallVector<::mlir::utils::GroupedConvDim>> layouts;
for (auto attr : (*this).getLayoutsAttr().getValue()) {
std::string layoutStr = cast<StringAttr>(attr).getValue().str();
@@ -513,15 +526,24 @@ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
}
int64_t getOutputChannelPosition() {
- return 2;
+ std::string layoutStr = cast<StringAttr>((*this).getLayoutsAttr().getValue()[2]).getValue().str();
+ size_t pos = layoutStr.find("f");
+ assert(pos != ::std::string::npos);
+ return pos;
}
int64_t getInputChannelPosition() {
- return 2;
+ std::string layoutStr = cast<StringAttr>((*this).getLayoutsAttr().getValue()[0]).getValue().str();
+ size_t pos = layoutStr.find("c");
+ assert(pos != ::std::string::npos);
+ return pos;
}
int64_t getInputGroupsPosition() {
- return 1;
+ std::string layoutStr = cast<StringAttr>((*this).getLayoutsAttr().getValue()[0]).getValue().str();
+ size_t pos = layoutStr.find("g");
+ assert(pos != ::std::string::npos);
+ return pos;
}
}];
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index c2db6670e4167..e0020a5497148 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -880,12 +880,12 @@ mlir::linalg::detail::verifyGroupedConvolutionInterface(Operation *op) {
return failure();
if (GroupedConvolutionOpInterface conv =
dyn_cast<GroupedConvolutionOpInterface>(op)) {
- const auto imageType = conv.image().getType().dyn_cast<ShapedType>();
+ const auto imageType = cast<ShapedType>(conv.image().getType());
const auto imageRank = imageType.getRank();
const auto kernelRank =
- conv.filter().getType().cast<ShapedType>().getRank();
+ cast<ShapedType>(conv.filter().getType()).getRank();
const auto initType =
- cast<LinalgOp>(op).getDpsInits()[0].getType().dyn_cast<ShapedType>();
+ cast<ShapedType>(cast<LinalgOp>(op).getDpsInits()[0].getType());
const auto initRank = initType.getRank();
if (imageRank != kernelRank || imageRank != initRank)
return op->emitError(
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 21cc22f034aa6..080692dfd6de0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1825,20 +1825,6 @@ void GroupedConvNDOp::getEffects(
return detail::convolution_impl::getEffects(*this, effects);
}
-ArrayAttr GroupedConvNDOp::getIndexingMaps() {
- ArrayAttr cached = (*this)->getAttrOfType<ArrayAttr>(
- LinalgDialect::kMemoizedIndexingMapsAttrName);
- if (cached)
- return cached;
-
- cached = detail::grouped_convolution_impl::createCommonIndexingMaps(
- getContext(), getSpatialRank(), getLayoutsEnums(), getStridesVector(),
- getDilationsVector());
-
- (*this)->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
- return cached;
-}
-
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
>From 30ff1a15cad37169fe4d55c5dbc45e32c7c735c3 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Mon, 1 Jul 2024 15:36:19 -0500
Subject: [PATCH 3/3] add a couple tests to convert to generic
---
.../Dialect/Linalg/generalize-named-ops.mlir | 60 +++++++++++++++++++
1 file changed, 60 insertions(+)
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 31fac9b4b4165..ff91e25db6502 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -864,3 +864,63 @@ func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vec
return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
}
+
+// -----
+
+// CHECK: #[[A_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3 + d5)>
+// CHECK: #[[B_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
+// CHECK: #[[C_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+// CHECK: func @gen_grouped_1D_ngcs_gfcs_ngfs_memref
+func.func @gen_grouped_1D_ngcs_gfcs_ngfs_memref(%arg0: memref<64x8x16x10xf32>, %arg1: memref<8x32x16x3xf32>, %arg2: memref<64x8x32x8xf32>) {
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[A_MAP]], #[[B_MAP]], #[[C_MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]
+// CHECK-SAME: ins(%arg0, %arg1 : {{.*}}) outs(%arg2 : {{.*}})
+// CHECK-NEXT: ^bb0(%[[IN_0:.*]]: f32, %[[IN_1:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN_0]], %[[IN_1]] : f32
+// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[OUT]], %[[MUL]] : f32
+// CHECK-NEXT: linalg.yield %[[ADD]] : f32
+// CHECK-NEXT: }
+ linalg.grouped_conv_nd {layouts = ["ngcs", "gfcs", "ngfs"]} ins(%arg0, %arg1: memref<64x8x16x10xf32>, memref<8x32x16x3xf32>) outs(%arg2: memref<64x8x32x8xf32>)
+ return
+}
+
+// -----
+
+// CHECK: #[[A_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 * 3 + d6 * 2, d4 * 3 + d7 * 2)>
+// CHECK: #[[B_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
+// CHECK: #[[C_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+// CHECK: func @gen_grouped_2D_ngcs_gfcs_ngfs_memref
+func.func @gen_grouped_2D_ngcs_gfcs_ngfs_memref(%arg0: memref<64x2x16x26x26xf32>, %arg1: memref<2x20x16x3x3xf32>, %arg2: memref<64x2x20x8x8xf32>) {
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[A_MAP]], #[[B_MAP]], #[[C_MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]
+// CHECK-SAME: ins(%arg0, %arg1 : {{.*}}) outs(%arg2 : {{.*}})
+// CHECK-NEXT: ^bb0(%[[IN_0:.*]]: f32, %[[IN_1:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN_0]], %[[IN_1]] : f32
+// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[OUT]], %[[MUL]] : f32
+// CHECK-NEXT: linalg.yield %[[ADD]] : f32
+// CHECK-NEXT: }
+ linalg.grouped_conv_nd {strides = dense<3> : memref<2xi64>, dilations = dense<2> : memref<2xi64>, layouts = ["ngcs", "gfcs", "ngfs"]} ins(%arg0, %arg1: memref<64x2x16x26x26xf32>, memref<2x20x16x3x3xf32>) outs(%arg2: memref<64x2x20x8x8xf32>)
+ return
+}
+
+// -----
+
+// CHECK: #[[A_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d3 * 3 + d6 * 2, d4 * 3 + d7 * 2, d5)>
+// CHECK: #[[B_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d6, d7, d2, d5)>
+// CHECK: #[[C_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d3, d4, d2)>
+// CHECK: func @gen_grouped_2D_ngsc_gsfc_ngsf_memref
+func.func @gen_grouped_2D_ngsc_gsfc_ngsf_memref(%arg0: memref<64x2x26x26x16xf32>, %arg1: memref<2x3x3x20x16xf32>, %arg2: memref<64x2x8x8x20xf32>) {
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[A_MAP]], #[[B_MAP]], #[[C_MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]
+// CHECK-SAME: ins(%arg0, %arg1 : {{.*}}) outs(%arg2 : {{.*}})
+// CHECK-NEXT: ^bb0(%[[IN_0:.*]]: f32, %[[IN_1:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN_0]], %[[IN_1]] : f32
+// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[OUT]], %[[MUL]] : f32
+// CHECK-NEXT: linalg.yield %[[ADD]] : f32
+// CHECK-NEXT: }
+ linalg.grouped_conv_nd {strides = dense<3> : memref<2xi64>, dilations = dense<2> : memref<2xi64>, layouts = ["ngsc", "gsfc", "ngsf"]} ins(%arg0, %arg1: memref<64x2x26x26x16xf32>, memref<2x3x3x20x16xf32>) outs(%arg2: memref<64x2x8x8x20xf32>)
+ return
+}
More information about the Mlir-commits
mailing list